forge_lint/sol/codesize/
unwrapped_modifier_logic.rs
1use super::UnwrappedModifierLogic;
2use crate::{
3 linter::{LateLintPass, LintContext, Snippet},
4 sol::{Severity, SolLint},
5};
6use solar_ast::{self as ast, Span};
7use solar_sema::hir::{self, Res};
8
9declare_forge_lint!(
10 UNWRAPPED_MODIFIER_LOGIC,
11 Severity::CodeSize,
12 "unwrapped-modifier-logic",
13 "wrap modifier logic to reduce code size"
14);
15
16impl<'hir> LateLintPass<'hir> for UnwrappedModifierLogic {
17 fn check_function(
18 &mut self,
19 ctx: &LintContext<'_>,
20 hir: &'hir hir::Hir<'hir>,
21 func: &'hir hir::Function<'hir>,
22 ) {
23 let (body, name) = match (func.kind, &func.body, func.name) {
25 (ast::FunctionKind::Modifier, Some(body), Some(name)) => (body, name),
26 _ => return,
27 };
28
29 let stmts = body.stmts[..].as_ref();
31 let (before, after) = stmts
32 .iter()
33 .position(|s| matches!(s.kind, hir::StmtKind::Placeholder))
34 .map_or((stmts, &[][..]), |idx| (&stmts[..idx], &stmts[idx + 1..]));
35
36 if let Some(snippet) = self.get_snippet(ctx, hir, func, before, after) {
38 ctx.emit_with_fix(&UNWRAPPED_MODIFIER_LOGIC, name.span, snippet);
39 }
40 }
41}
42
43impl UnwrappedModifierLogic {
44 fn is_valid_expr(&self, hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
46 if let hir::ExprKind::Call(func_expr, _, _) = &expr.kind {
47 if let hir::ExprKind::Ident(resolutions) = &func_expr.kind {
48 return !resolutions.iter().any(|r| matches!(r, Res::Builtin(_)));
49 }
50
51 if let hir::ExprKind::Member(base, _) = &func_expr.kind
52 && let hir::ExprKind::Ident(resolutions) = &base.kind
53 {
54 return resolutions.iter().any(|r| {
55 matches!(r, Res::Item(hir::ItemId::Contract(id)) if hir.contract(*id).kind == ast::ContractKind::Library)
56 });
57 }
58 }
59
60 false
61 }
62
63 fn stmts_require_wrapping(&self, hir: &hir::Hir<'_>, stmts: &[hir::Stmt<'_>]) -> bool {
73 let (mut res, mut has_valid_stmt) = (false, false);
74 for stmt in stmts {
75 match &stmt.kind {
76 hir::StmtKind::Placeholder => continue,
77 hir::StmtKind::Expr(expr) => {
78 if !self.is_valid_expr(hir, expr) || has_valid_stmt {
79 res = true;
80 }
81 has_valid_stmt = true;
82 }
83 hir::StmtKind::Err(_) => return false,
86 _ => res = true,
87 }
88 }
89
90 res
91 }
92
93 fn get_snippet<'a>(
94 &self,
95 ctx: &LintContext<'_>,
96 hir: &hir::Hir<'_>,
97 func: &hir::Function<'_>,
98 before: &'a [hir::Stmt<'a>],
99 after: &'a [hir::Stmt<'a>],
100 ) -> Option<Snippet> {
101 let wrap_before = !before.is_empty() && self.stmts_require_wrapping(hir, before);
102 let wrap_after = !after.is_empty() && self.stmts_require_wrapping(hir, after);
103
104 if !(wrap_before || wrap_after) {
105 return None;
106 }
107
108 let binding = func.name.unwrap();
109 let modifier_name = binding.name.as_str();
110 let mut param_list = vec![];
111 let mut param_decls = vec![];
112
113 for var_id in func.parameters {
114 let var = hir.variable(*var_id);
115 let ty = ctx
116 .span_to_snippet(var.ty.span)
117 .unwrap_or_else(|| "/* unknown type */".to_string());
118
119 if let Some(ident) = var.name {
121 param_list.push(ident.to_string());
122 param_decls.push(format!("{ty} {}", ident.to_string()));
123 }
124 }
125
126 let param_list = param_list.join(", ");
127 let param_decls = param_decls.join(", ");
128
129 let body_indent = " ".repeat(ctx.get_span_indentation(
130 before.first().or(after.first()).map(|stmt| stmt.span).unwrap_or(func.span),
131 ));
132 let body = match (wrap_before, wrap_after) {
133 (true, true) => format!(
134 "{body_indent}_{modifier_name}Before({param_list});\n{body_indent}_;\n{body_indent}_{modifier_name}After({param_list});"
135 ),
136 (true, false) => {
137 format!("{body_indent}_{modifier_name}({param_list});\n{body_indent}_;")
138 }
139 (false, true) => {
140 format!("{body_indent}_;\n{body_indent}_{modifier_name}({param_list});")
141 }
142 _ => unreachable!(),
143 };
144
145 let mod_indent = " ".repeat(ctx.get_span_indentation(func.span));
146 let mut replacement = format!(
147 "{mod_indent}modifier {modifier_name}({param_decls}) {{\n{body}\n{mod_indent}}}"
148 );
149
150 let build_func = |stmts: &[hir::Stmt<'_>], suffix: &str| {
151 let body_stmts = stmts
152 .iter()
153 .filter_map(|s| ctx.span_to_snippet(s.span))
154 .map(|code| format!("\n{body_indent}{code}"))
155 .collect::<String>();
156 format!(
157 "\n\n{mod_indent}function _{modifier_name}{suffix}({param_decls}) internal {{{body_stmts}\n{mod_indent}}}"
158 )
159 };
160
161 if wrap_before {
162 replacement.push_str(&build_func(before, if wrap_after { "Before" } else { "" }));
163 }
164 if wrap_after {
165 replacement.push_str(&build_func(after, if wrap_before { "After" } else { "" }));
166 }
167
168 Some(Snippet::Diff {
169 desc: Some("wrap modifier logic to reduce code size"),
170 span: Some(Span::new(func.span.lo(), func.body_span.hi())),
171 add: replacement,
172 trim_code: true,
173 })
174 }
175}