forge_lint/sol/codesize/
unwrapped_modifier_logic.rs

1use super::UnwrappedModifierLogic;
2use crate::{
3    linter::{LateLintPass, LintContext, Snippet},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    ast::{self as ast, Span},
8    sema::hir::{self, Res},
9};
10
11declare_forge_lint!(
12    UNWRAPPED_MODIFIER_LOGIC,
13    Severity::CodeSize,
14    "unwrapped-modifier-logic",
15    "wrap modifier logic to reduce code size"
16);
17
18impl<'hir> LateLintPass<'hir> for UnwrappedModifierLogic {
19    fn check_function(
20        &mut self,
21        ctx: &LintContext,
22        hir: &'hir hir::Hir<'hir>,
23        func: &'hir hir::Function<'hir>,
24    ) {
25        // Only check modifiers with a body and a name
26        let (body, name) = match (func.kind, &func.body, func.name) {
27            (ast::FunctionKind::Modifier, Some(body), Some(name)) => (body, name),
28            _ => return,
29        };
30
31        // Split statements into before and after the placeholder `_`.
32        let stmts = body.stmts[..].as_ref();
33        let (before, after) = stmts
34            .iter()
35            .position(|s| matches!(s.kind, hir::StmtKind::Placeholder))
36            .map_or((stmts, &[][..]), |idx| (&stmts[..idx], &stmts[idx + 1..]));
37
38        // Generate a fix snippet if the modifier logic should be wrapped.
39        if let Some(snippet) = self.get_snippet(ctx, hir, func, before, after) {
40            ctx.emit_with_fix(&UNWRAPPED_MODIFIER_LOGIC, name.span, snippet);
41        }
42    }
43}
44
45impl UnwrappedModifierLogic {
46    /// Returns `true` if an expr is not a built-in ('require' or 'assert') call or a lib function.
47    fn is_valid_expr(&self, hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
48        if let hir::ExprKind::Call(func_expr, _, _) = &expr.kind {
49            if let hir::ExprKind::Ident(resolutions) = &func_expr.kind {
50                return !resolutions.iter().any(|r| matches!(r, Res::Builtin(_)));
51            }
52
53            if let hir::ExprKind::Member(base, _) = &func_expr.kind
54                && let hir::ExprKind::Ident(resolutions) = &base.kind
55            {
56                return resolutions.iter().any(|r| {
57                    matches!(r, Res::Item(hir::ItemId::Contract(id)) if hir.contract(*id).kind == ast::ContractKind::Library)
58                });
59            }
60        }
61
62        false
63    }
64
65    /// Checks if a block of statements is complex and should be wrapped in a helper function.
66    ///
67    /// This always is 'false' the modifier contains assembly. We assume that if devs know how to
68    /// use assembly, they will also know how to reduce the codesize of their contracts and they
69    /// have a good reason to use it on their modifiers.
70    ///
71    /// This is 'true' if the block contains:
72    /// 1. Any statement that is not a placeholder or a valid expression.
73    /// 2. More than one simple call expression.
74    fn stmts_require_wrapping(&self, hir: &hir::Hir<'_>, stmts: &[hir::Stmt<'_>]) -> bool {
75        let (mut res, mut has_valid_stmt) = (false, false);
76        for stmt in stmts {
77            match &stmt.kind {
78                hir::StmtKind::Placeholder => continue,
79                hir::StmtKind::Expr(expr) => {
80                    if !self.is_valid_expr(hir, expr) || has_valid_stmt {
81                        res = true;
82                    }
83                    has_valid_stmt = true;
84                }
85                // HIR doesn't support assembly yet:
86                // <https://github.com/paradigmxyz/solar/blob/d25bf38a5accd11409318e023f701313d98b9e1e/crates/sema/src/hir/mod.rs#L977-L982>
87                hir::StmtKind::Err(_) => return false,
88                _ => res = true,
89            }
90        }
91
92        res
93    }
94
95    fn get_snippet<'a>(
96        &self,
97        ctx: &LintContext,
98        hir: &hir::Hir<'_>,
99        func: &hir::Function<'_>,
100        before: &'a [hir::Stmt<'a>],
101        after: &'a [hir::Stmt<'a>],
102    ) -> Option<Snippet> {
103        let wrap_before = !before.is_empty() && self.stmts_require_wrapping(hir, before);
104        let wrap_after = !after.is_empty() && self.stmts_require_wrapping(hir, after);
105
106        if !(wrap_before || wrap_after) {
107            return None;
108        }
109
110        let binding = func.name.unwrap();
111        let modifier_name = binding.name.as_str();
112        let mut param_list = vec![];
113        let mut param_decls = vec![];
114
115        for var_id in func.parameters {
116            let var = hir.variable(*var_id);
117            let ty = ctx
118                .span_to_snippet(var.ty.span)
119                .unwrap_or_else(|| "/* unknown type */".to_string());
120
121            // solidity functions should always have named parameters
122            if let Some(ident) = var.name {
123                param_list.push(ident.to_string());
124                param_decls.push(format!("{ty} {}", ident.to_string()));
125            }
126        }
127
128        let param_list = param_list.join(", ");
129        let param_decls = param_decls.join(", ");
130
131        let body_indent = " ".repeat(ctx.get_span_indentation(
132            before.first().or(after.first()).map(|stmt| stmt.span).unwrap_or(func.span),
133        ));
134        let body = match (wrap_before, wrap_after) {
135            (true, true) => format!(
136                "{body_indent}_{modifier_name}Before({param_list});\n{body_indent}_;\n{body_indent}_{modifier_name}After({param_list});"
137            ),
138            (true, false) => {
139                format!("{body_indent}_{modifier_name}({param_list});\n{body_indent}_;")
140            }
141            (false, true) => {
142                format!("{body_indent}_;\n{body_indent}_{modifier_name}({param_list});")
143            }
144            _ => unreachable!(),
145        };
146
147        let mod_indent = " ".repeat(ctx.get_span_indentation(func.span));
148        let mut replacement = format!(
149            "{mod_indent}modifier {modifier_name}({param_decls}) {{\n{body}\n{mod_indent}}}"
150        );
151
152        let build_func = |stmts: &[hir::Stmt<'_>], suffix: &str| {
153            let body_stmts = stmts
154                .iter()
155                .filter_map(|s| ctx.span_to_snippet(s.span))
156                .map(|code| format!("\n{body_indent}{code}"))
157                .collect::<String>();
158            format!(
159                "\n\n{mod_indent}function _{modifier_name}{suffix}({param_decls}) internal {{{body_stmts}\n{mod_indent}}}"
160            )
161        };
162
163        if wrap_before {
164            replacement.push_str(&build_func(before, if wrap_after { "Before" } else { "" }));
165        }
166        if wrap_after {
167            replacement.push_str(&build_func(after, if wrap_before { "After" } else { "" }));
168        }
169
170        Some(Snippet::Diff {
171            desc: Some("wrap modifier logic to reduce code size"),
172            span: Some(Span::new(func.span.lo(), func.body_span.hi())),
173            add: replacement,
174            trim_code: true,
175        })
176    }
177}