1use super::CostlyLoop;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{Severity, SolLint},
5};
6use solar::sema::{
7 Hir,
8 hir::{Block, Expr, ExprKind, Function, ItemId, Res, Stmt, StmtKind},
9};
10
11declare_forge_lint!(COSTLY_LOOP, Severity::Gas, "costly-loop", "storage write inside a loop");
12
13impl<'hir> LateLintPass<'hir> for CostlyLoop {
14 fn check_function(
15 &mut self,
16 ctx: &LintContext,
17 hir: &'hir Hir<'hir>,
18 func: &'hir Function<'hir>,
19 ) {
20 if let Some(body) = func.body {
21 check_block(ctx, hir, body, 0);
22 }
23 }
24}
25
26fn check_block<'hir>(ctx: &LintContext, hir: &'hir Hir<'hir>, block: Block<'hir>, loop_depth: u32) {
27 for stmt in block.stmts {
28 check_stmt(ctx, hir, stmt, loop_depth);
29 }
30}
31
32fn check_stmt<'hir>(
33 ctx: &LintContext,
34 hir: &'hir Hir<'hir>,
35 stmt: &'hir Stmt<'hir>,
36 loop_depth: u32,
37) {
38 match &stmt.kind {
39 StmtKind::Loop(block, _) => check_block(ctx, hir, *block, loop_depth + 1),
40 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
41 check_block(ctx, hir, *block, loop_depth);
42 }
43 StmtKind::If(_, then_stmt, else_stmt) => {
44 check_stmt(ctx, hir, then_stmt, loop_depth);
45 if let Some(else_stmt) = else_stmt {
46 check_stmt(ctx, hir, else_stmt, loop_depth);
47 }
48 }
49 StmtKind::Try(stmt_try) => {
50 for clause in stmt_try.clauses {
51 check_block(ctx, hir, clause.block, loop_depth);
52 }
53 }
54 StmtKind::Expr(expr) if loop_depth > 0 => {
55 check_expr_for_writes(ctx, hir, expr);
56 }
57 StmtKind::DeclSingle(var_id) if loop_depth > 0 => {
58 if let Some(init) = hir.variable(*var_id).initializer {
59 check_expr_for_writes(ctx, hir, init);
60 }
61 }
62 StmtKind::DeclMulti(_, expr) if loop_depth > 0 => {
63 check_expr_for_writes(ctx, hir, expr);
64 }
65 StmtKind::Return(Some(expr)) if loop_depth > 0 => {
66 check_expr_for_writes(ctx, hir, expr);
67 }
68 StmtKind::Emit(expr) | StmtKind::Revert(expr) if loop_depth > 0 => {
69 check_expr_for_writes(ctx, hir, expr);
70 }
71 _ => {}
72 }
73}
74
75fn check_expr_for_writes<'hir>(ctx: &LintContext, hir: &'hir Hir<'hir>, expr: &'hir Expr<'hir>) {
76 match &expr.kind {
77 ExprKind::Assign(lhs, _, rhs) => {
78 if lvalue_is_state_var(hir, lhs) {
79 ctx.emit(&COSTLY_LOOP, expr.span);
80 }
81 check_expr_for_writes(ctx, hir, lhs);
82 check_expr_for_writes(ctx, hir, rhs);
83 }
84 ExprKind::Unary(op, inner) => {
85 if op.kind.has_side_effects() && lvalue_is_state_var(hir, inner) {
86 ctx.emit(&COSTLY_LOOP, expr.span);
87 }
88 check_expr_for_writes(ctx, hir, inner);
89 }
90 ExprKind::Delete(inner) => {
91 if lvalue_is_state_var(hir, inner) {
92 ctx.emit(&COSTLY_LOOP, expr.span);
93 }
94 check_expr_for_writes(ctx, hir, inner);
95 }
96 ExprKind::Binary(lhs, _, rhs) => {
97 check_expr_for_writes(ctx, hir, lhs);
98 check_expr_for_writes(ctx, hir, rhs);
99 }
100 ExprKind::Ternary(cond, then_expr, else_expr) => {
101 check_expr_for_writes(ctx, hir, cond);
102 check_expr_for_writes(ctx, hir, then_expr);
103 check_expr_for_writes(ctx, hir, else_expr);
104 }
105 ExprKind::Call(callee, args, named_args) => {
106 check_expr_for_writes(ctx, hir, callee);
107 for arg in args.exprs() {
108 check_expr_for_writes(ctx, hir, arg);
109 }
110 if let Some(named_args) = named_args {
111 for arg in *named_args {
112 check_expr_for_writes(ctx, hir, &arg.value);
113 }
114 }
115 }
116 ExprKind::Index(base, index) => {
117 check_expr_for_writes(ctx, hir, base);
118 if let Some(index) = index {
119 check_expr_for_writes(ctx, hir, index);
120 }
121 }
122 ExprKind::Slice(base, start, end) => {
123 check_expr_for_writes(ctx, hir, base);
124 if let Some(start) = start {
125 check_expr_for_writes(ctx, hir, start);
126 }
127 if let Some(end) = end {
128 check_expr_for_writes(ctx, hir, end);
129 }
130 }
131 ExprKind::Member(base, _) | ExprKind::Payable(base) => {
132 check_expr_for_writes(ctx, hir, base);
133 }
134 ExprKind::Tuple(exprs) => {
135 for e in exprs.iter().flatten() {
136 check_expr_for_writes(ctx, hir, e);
137 }
138 }
139 ExprKind::Array(exprs) => {
140 for e in *exprs {
141 check_expr_for_writes(ctx, hir, e);
142 }
143 }
144 ExprKind::Ident(_)
145 | ExprKind::Lit(_)
146 | ExprKind::New(_)
147 | ExprKind::TypeCall(_)
148 | ExprKind::Type(_)
149 | ExprKind::Err(_) => {}
150 }
151}
152
153fn lvalue_is_state_var(hir: &Hir<'_>, expr: &Expr<'_>) -> bool {
157 match &expr.peel_parens().kind {
158 ExprKind::Ident([Res::Item(ItemId::Variable(id)), ..]) => {
159 hir.variable(*id).is_state_variable()
160 }
161 ExprKind::Index(base, _)
162 | ExprKind::Slice(base, _, _)
163 | ExprKind::Member(base, _)
164 | ExprKind::Payable(base) => lvalue_is_state_var(hir, base),
165 ExprKind::Tuple(exprs) => exprs.iter().flatten().any(|e| lvalue_is_state_var(hir, e)),
166 _ => false,
167 }
168}