forge_lint/sol/codesize/
unwrapped_modifier_logic.rs

1use super::UnwrappedModifierLogic;
2use crate::{
3    linter::{LateLintPass, LintContext, Snippet},
4    sol::{Severity, SolLint},
5};
6use solar_ast::{self as ast, Span};
7use solar_sema::hir::{self, Res};
8
9declare_forge_lint!(
10    UNWRAPPED_MODIFIER_LOGIC,
11    Severity::CodeSize,
12    "unwrapped-modifier-logic",
13    "wrap modifier logic to reduce code size"
14);
15
16impl<'hir> LateLintPass<'hir> for UnwrappedModifierLogic {
17    fn check_function(
18        &mut self,
19        ctx: &LintContext<'_>,
20        hir: &'hir hir::Hir<'hir>,
21        func: &'hir hir::Function<'hir>,
22    ) {
23        // Only check modifiers with a body and a name
24        let (body, name) = match (func.kind, &func.body, func.name) {
25            (ast::FunctionKind::Modifier, Some(body), Some(name)) => (body, name),
26            _ => return,
27        };
28
29        // Split statements into before and after the placeholder `_`.
30        let stmts = body.stmts[..].as_ref();
31        let (before, after) = stmts
32            .iter()
33            .position(|s| matches!(s.kind, hir::StmtKind::Placeholder))
34            .map_or((stmts, &[][..]), |idx| (&stmts[..idx], &stmts[idx + 1..]));
35
36        // Generate a fix snippet if the modifier logic should be wrapped.
37        if let Some(snippet) = self.get_snippet(ctx, hir, func, before, after) {
38            ctx.emit_with_fix(&UNWRAPPED_MODIFIER_LOGIC, name.span, snippet);
39        }
40    }
41}
42
43impl UnwrappedModifierLogic {
44    /// Returns `true` if an expr is not a built-in ('require' or 'assert') call or a lib function.
45    fn is_valid_expr(&self, hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
46        if let hir::ExprKind::Call(func_expr, _, _) = &expr.kind {
47            if let hir::ExprKind::Ident(resolutions) = &func_expr.kind {
48                return !resolutions.iter().any(|r| matches!(r, Res::Builtin(_)));
49            }
50
51            if let hir::ExprKind::Member(base, _) = &func_expr.kind
52                && let hir::ExprKind::Ident(resolutions) = &base.kind
53            {
54                return resolutions.iter().any(|r| {
55                    matches!(r, Res::Item(hir::ItemId::Contract(id)) if hir.contract(*id).kind == ast::ContractKind::Library)
56                });
57            }
58        }
59
60        false
61    }
62
63    /// Checks if a block of statements is complex and should be wrapped in a helper function.
64    ///
65    /// This always is 'false' the modifier contains assembly. We assume that if devs know how to
66    /// use assembly, they will also know how to reduce the codesize of their contracts and they
67    /// have a good reason to use it on their modifiers.
68    ///
69    /// This is 'true' if the block contains:
70    /// 1. Any statement that is not a placeholder or a valid expression.
71    /// 2. More than one simple call expression.
72    fn stmts_require_wrapping(&self, hir: &hir::Hir<'_>, stmts: &[hir::Stmt<'_>]) -> bool {
73        let (mut res, mut has_valid_stmt) = (false, false);
74        for stmt in stmts {
75            match &stmt.kind {
76                hir::StmtKind::Placeholder => continue,
77                hir::StmtKind::Expr(expr) => {
78                    if !self.is_valid_expr(hir, expr) || has_valid_stmt {
79                        res = true;
80                    }
81                    has_valid_stmt = true;
82                }
83                // HIR doesn't support assembly yet:
84                // <https://github.com/paradigmxyz/solar/blob/d25bf38a5accd11409318e023f701313d98b9e1e/crates/sema/src/hir/mod.rs#L977-L982>
85                hir::StmtKind::Err(_) => return false,
86                _ => res = true,
87            }
88        }
89
90        res
91    }
92
93    fn get_snippet<'a>(
94        &self,
95        ctx: &LintContext<'_>,
96        hir: &hir::Hir<'_>,
97        func: &hir::Function<'_>,
98        before: &'a [hir::Stmt<'a>],
99        after: &'a [hir::Stmt<'a>],
100    ) -> Option<Snippet> {
101        let wrap_before = !before.is_empty() && self.stmts_require_wrapping(hir, before);
102        let wrap_after = !after.is_empty() && self.stmts_require_wrapping(hir, after);
103
104        if !(wrap_before || wrap_after) {
105            return None;
106        }
107
108        let binding = func.name.unwrap();
109        let modifier_name = binding.name.as_str();
110        let mut param_list = vec![];
111        let mut param_decls = vec![];
112
113        for var_id in func.parameters {
114            let var = hir.variable(*var_id);
115            let ty = ctx
116                .span_to_snippet(var.ty.span)
117                .unwrap_or_else(|| "/* unknown type */".to_string());
118
119            // solidity functions should always have named parameters
120            if let Some(ident) = var.name {
121                param_list.push(ident.to_string());
122                param_decls.push(format!("{ty} {}", ident.to_string()));
123            }
124        }
125
126        let param_list = param_list.join(", ");
127        let param_decls = param_decls.join(", ");
128
129        let body_indent = " ".repeat(ctx.get_span_indentation(
130            before.first().or(after.first()).map(|stmt| stmt.span).unwrap_or(func.span),
131        ));
132        let body = match (wrap_before, wrap_after) {
133            (true, true) => format!(
134                "{body_indent}_{modifier_name}Before({param_list});\n{body_indent}_;\n{body_indent}_{modifier_name}After({param_list});"
135            ),
136            (true, false) => {
137                format!("{body_indent}_{modifier_name}({param_list});\n{body_indent}_;")
138            }
139            (false, true) => {
140                format!("{body_indent}_;\n{body_indent}_{modifier_name}({param_list});")
141            }
142            _ => unreachable!(),
143        };
144
145        let mod_indent = " ".repeat(ctx.get_span_indentation(func.span));
146        let mut replacement = format!(
147            "{mod_indent}modifier {modifier_name}({param_decls}) {{\n{body}\n{mod_indent}}}"
148        );
149
150        let build_func = |stmts: &[hir::Stmt<'_>], suffix: &str| {
151            let body_stmts = stmts
152                .iter()
153                .filter_map(|s| ctx.span_to_snippet(s.span))
154                .map(|code| format!("\n{body_indent}{code}"))
155                .collect::<String>();
156            format!(
157                "\n\n{mod_indent}function _{modifier_name}{suffix}({param_decls}) internal {{{body_stmts}\n{mod_indent}}}"
158            )
159        };
160
161        if wrap_before {
162            replacement.push_str(&build_func(before, if wrap_after { "Before" } else { "" }));
163        }
164        if wrap_after {
165            replacement.push_str(&build_func(after, if wrap_before { "After" } else { "" }));
166        }
167
168        Some(Snippet::Diff {
169            desc: Some("wrap modifier logic to reduce code size"),
170            span: Some(Span::new(func.span.lo(), func.body_span.hi())),
171            add: replacement,
172            trim_code: true,
173        })
174    }
175}