forge_lint/sol/codesize/
unwrapped_modifier_logic.rs1use super::UnwrappedModifierLogic;
2use crate::{
3 linter::{LateLintPass, LintContext, Snippet},
4 sol::{Severity, SolLint},
5};
6use solar::{
7 ast::{self as ast, Span},
8 sema::hir::{self, Res},
9};
10
11declare_forge_lint!(
12 UNWRAPPED_MODIFIER_LOGIC,
13 Severity::CodeSize,
14 "unwrapped-modifier-logic",
15 "wrap modifier logic to reduce code size"
16);
17
18impl<'hir> LateLintPass<'hir> for UnwrappedModifierLogic {
19 fn check_function(
20 &mut self,
21 ctx: &LintContext,
22 hir: &'hir hir::Hir<'hir>,
23 func: &'hir hir::Function<'hir>,
24 ) {
25 let (body, name) = match (func.kind, &func.body, func.name) {
27 (ast::FunctionKind::Modifier, Some(body), Some(name)) => (body, name),
28 _ => return,
29 };
30
31 let stmts = body.stmts[..].as_ref();
33 let (before, after) = stmts
34 .iter()
35 .position(|s| matches!(s.kind, hir::StmtKind::Placeholder))
36 .map_or((stmts, &[][..]), |idx| (&stmts[..idx], &stmts[idx + 1..]));
37
38 if let Some(snippet) = self.get_snippet(ctx, hir, func, before, after) {
40 ctx.emit_with_fix(&UNWRAPPED_MODIFIER_LOGIC, name.span, snippet);
41 }
42 }
43}
44
45impl UnwrappedModifierLogic {
46 fn is_valid_expr(&self, hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
48 if let hir::ExprKind::Call(func_expr, _, _) = &expr.kind {
49 if let hir::ExprKind::Ident(resolutions) = &func_expr.kind {
50 return !resolutions.iter().any(|r| matches!(r, Res::Builtin(_)));
51 }
52
53 if let hir::ExprKind::Member(base, _) = &func_expr.kind
54 && let hir::ExprKind::Ident(resolutions) = &base.kind
55 {
56 return resolutions.iter().any(|r| {
57 matches!(r, Res::Item(hir::ItemId::Contract(id)) if hir.contract(*id).kind == ast::ContractKind::Library)
58 });
59 }
60 }
61
62 false
63 }
64
65 fn stmts_require_wrapping(&self, hir: &hir::Hir<'_>, stmts: &[hir::Stmt<'_>]) -> bool {
75 let (mut res, mut has_valid_stmt) = (false, false);
76 for stmt in stmts {
77 match &stmt.kind {
78 hir::StmtKind::Placeholder => continue,
79 hir::StmtKind::Expr(expr) => {
80 if !self.is_valid_expr(hir, expr) || has_valid_stmt {
81 res = true;
82 }
83 has_valid_stmt = true;
84 }
85 hir::StmtKind::Err(_) => return false,
88 _ => res = true,
89 }
90 }
91
92 res
93 }
94
95 fn get_snippet<'a>(
96 &self,
97 ctx: &LintContext,
98 hir: &hir::Hir<'_>,
99 func: &hir::Function<'_>,
100 before: &'a [hir::Stmt<'a>],
101 after: &'a [hir::Stmt<'a>],
102 ) -> Option<Snippet> {
103 let wrap_before = !before.is_empty() && self.stmts_require_wrapping(hir, before);
104 let wrap_after = !after.is_empty() && self.stmts_require_wrapping(hir, after);
105
106 if !(wrap_before || wrap_after) {
107 return None;
108 }
109
110 let binding = func.name.unwrap();
111 let modifier_name = binding.name.as_str();
112 let mut param_list = vec![];
113 let mut param_decls = vec![];
114
115 for var_id in func.parameters {
116 let var = hir.variable(*var_id);
117 let ty = ctx
118 .span_to_snippet(var.ty.span)
119 .unwrap_or_else(|| "/* unknown type */".to_string());
120
121 if let Some(ident) = var.name {
123 param_list.push(ident.to_string());
124 param_decls.push(format!("{ty} {}", ident.to_string()));
125 }
126 }
127
128 let param_list = param_list.join(", ");
129 let param_decls = param_decls.join(", ");
130
131 let body_indent = " ".repeat(ctx.get_span_indentation(
132 before.first().or(after.first()).map(|stmt| stmt.span).unwrap_or(func.span),
133 ));
134 let body = match (wrap_before, wrap_after) {
135 (true, true) => format!(
136 "{body_indent}_{modifier_name}Before({param_list});\n{body_indent}_;\n{body_indent}_{modifier_name}After({param_list});"
137 ),
138 (true, false) => {
139 format!("{body_indent}_{modifier_name}({param_list});\n{body_indent}_;")
140 }
141 (false, true) => {
142 format!("{body_indent}_;\n{body_indent}_{modifier_name}({param_list});")
143 }
144 _ => unreachable!(),
145 };
146
147 let mod_indent = " ".repeat(ctx.get_span_indentation(func.span));
148 let mut replacement = format!(
149 "{mod_indent}modifier {modifier_name}({param_decls}) {{\n{body}\n{mod_indent}}}"
150 );
151
152 let build_func = |stmts: &[hir::Stmt<'_>], suffix: &str| {
153 let body_stmts = stmts
154 .iter()
155 .filter_map(|s| ctx.span_to_snippet(s.span))
156 .map(|code| format!("\n{body_indent}{code}"))
157 .collect::<String>();
158 format!(
159 "\n\n{mod_indent}function _{modifier_name}{suffix}({param_decls}) internal {{{body_stmts}\n{mod_indent}}}"
160 )
161 };
162
163 if wrap_before {
164 replacement.push_str(&build_func(before, if wrap_after { "Before" } else { "" }));
165 }
166 if wrap_after {
167 replacement.push_str(&build_func(after, if wrap_before { "After" } else { "" }));
168 }
169
170 Some(Snippet::Diff {
171 desc: Some("wrap modifier logic to reduce code size"),
172 span: Some(Span::new(func.span.lo(), func.body_span.hi())),
173 add: replacement,
174 trim_code: true,
175 })
176 }
177}