Skip to main content

forge_lint/sol/gas/
cache_array_length.rs

1use super::CacheArrayLength;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    ast::ElementaryType,
8    interface::{kw, sym},
9    sema::{
10        Gcx,
11        hir::{
12            self, BinOpKind, ExprKind, ItemId, LoopSource, Res, StateMutability, StmtKind,
13            VariableId,
14        },
15        ty::TyKind,
16    },
17};
18
19declare_forge_lint!(
20    CACHE_ARRAY_LENGTH,
21    Severity::Gas,
22    "cache-array-length",
23    "array length read in loop condition should be cached outside the loop"
24);
25
26#[derive(Clone, Copy)]
27struct LengthRead<'hir> {
28    expr: &'hir hir::Expr<'hir>,
29    base: &'hir hir::Expr<'hir>,
30}
31
32#[derive(Default)]
33struct LoopFacts {
34    written_vars: Vec<VariableId>,
35    mutates_array_length: bool,
36    has_state_mutating_call: bool,
37}
38
39impl LoopFacts {
40    const fn should_skip(&self) -> bool {
41        self.mutates_array_length || self.has_state_mutating_call
42    }
43
44    fn push_written_var(&mut self, var_id: VariableId) {
45        if !self.written_vars.contains(&var_id) {
46            self.written_vars.push(var_id);
47        }
48    }
49}
50
51impl<'hir> LateLintPass<'hir> for CacheArrayLength {
52    fn check_stmt(
53        &mut self,
54        ctx: &LintContext,
55        gcx: Gcx<'hir>,
56        hir: &'hir hir::Hir<'hir>,
57        stmt: &'hir hir::Stmt<'hir>,
58    ) {
59        let StmtKind::Loop(block, LoopSource::For) = &stmt.kind else { return };
60        let Some((condition, body)) = for_loop_parts(*block) else { return };
61
62        let mut reads = Vec::new();
63        collect_condition_length_reads(gcx, condition, &mut reads);
64        if reads.is_empty() {
65            return;
66        }
67
68        let mut facts = LoopFacts::default();
69        collect_stmt_facts(gcx, hir, body, &mut facts);
70        if facts.should_skip() {
71            return;
72        }
73
74        for read in reads {
75            if expr_is_loop_invariant(gcx, hir, read.base, &facts.written_vars) {
76                ctx.emit(&CACHE_ARRAY_LENGTH, read.expr.span);
77            }
78        }
79    }
80}
81
82fn for_loop_parts<'hir>(
83    block: hir::Block<'hir>,
84) -> Option<(&'hir hir::Expr<'hir>, &'hir hir::Stmt<'hir>)> {
85    let first = block.stmts.first()?;
86    match &first.kind {
87        StmtKind::If(condition, _, Some(else_stmt)) => {
88            matches!(&else_stmt.kind, StmtKind::Break).then_some((*condition, first))
89        }
90        _ => None,
91    }
92}
93
94fn collect_condition_length_reads<'hir>(
95    gcx: Gcx<'hir>,
96    expr: &'hir hir::Expr<'hir>,
97    reads: &mut Vec<LengthRead<'hir>>,
98) {
99    match &expr.peel_parens().kind {
100        ExprKind::Binary(lhs, op, rhs) if is_comparison(op.kind) => {
101            if matches!(lhs.peel_parens().kind, ExprKind::Ident(_)) {
102                collect_state_array_length_read(gcx, rhs, reads);
103            }
104            if matches!(rhs.peel_parens().kind, ExprKind::Ident(_)) {
105                collect_state_array_length_read(gcx, lhs, reads);
106            }
107        }
108        _ => {}
109    }
110}
111
112fn collect_state_array_length_read<'hir>(
113    gcx: Gcx<'hir>,
114    expr: &'hir hir::Expr<'hir>,
115    reads: &mut Vec<LengthRead<'hir>>,
116) {
117    let expr = expr.peel_parens();
118    if let ExprKind::Member(base, member) = &expr.kind
119        && member.name == sym::length
120        && is_state_array(gcx, base)
121    {
122        reads.push(LengthRead { expr, base });
123    }
124}
125
126fn collect_stmt_facts<'hir>(
127    gcx: Gcx<'hir>,
128    hir: &'hir hir::Hir<'hir>,
129    stmt: &'hir hir::Stmt<'hir>,
130    facts: &mut LoopFacts,
131) {
132    match &stmt.kind {
133        StmtKind::DeclSingle(var_id) => {
134            if let Some(expr) = hir.variable(*var_id).initializer {
135                collect_expr_facts(gcx, hir, expr, facts);
136            }
137        }
138        StmtKind::DeclMulti(_, expr)
139        | StmtKind::Emit(expr)
140        | StmtKind::Revert(expr)
141        | StmtKind::Expr(expr) => collect_expr_facts(gcx, hir, expr, facts),
142        StmtKind::Return(expr) => {
143            if let Some(expr) = expr {
144                collect_expr_facts(gcx, hir, expr, facts);
145            }
146        }
147        StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
148            for stmt in block.stmts {
149                collect_stmt_facts(gcx, hir, stmt, facts);
150            }
151        }
152        StmtKind::If(condition, then_stmt, else_stmt) => {
153            collect_expr_facts(gcx, hir, condition, facts);
154            collect_stmt_facts(gcx, hir, then_stmt, facts);
155            if let Some(else_stmt) = else_stmt {
156                collect_stmt_facts(gcx, hir, else_stmt, facts);
157            }
158        }
159        StmtKind::Try(stmt_try) => {
160            collect_expr_facts(gcx, hir, &stmt_try.expr, facts);
161            for clause in stmt_try.clauses {
162                for stmt in clause.block.stmts {
163                    collect_stmt_facts(gcx, hir, stmt, facts);
164                }
165            }
166        }
167        StmtKind::Break
168        | StmtKind::Continue
169        | StmtKind::Placeholder
170        | StmtKind::AssemblyBlock(_)
171        | StmtKind::Switch(_)
172        | StmtKind::Err(_) => {}
173    }
174}
175
176fn collect_expr_facts<'hir>(
177    gcx: Gcx<'hir>,
178    hir: &'hir hir::Hir<'hir>,
179    expr: &'hir hir::Expr<'hir>,
180    facts: &mut LoopFacts,
181) {
182    let expr = expr.peel_parens();
183    if array_length_mutated(gcx, expr) {
184        facts.mutates_array_length = true;
185    }
186
187    match &expr.kind {
188        ExprKind::Array(exprs) => {
189            for expr in *exprs {
190                collect_expr_facts(gcx, hir, expr, facts);
191            }
192        }
193        ExprKind::Assign(lhs, _, rhs) => {
194            collect_written_vars(lhs, facts);
195            collect_expr_facts(gcx, hir, lhs, facts);
196            collect_expr_facts(gcx, hir, rhs, facts);
197        }
198        ExprKind::Binary(lhs, _, rhs) => {
199            collect_expr_facts(gcx, hir, lhs, facts);
200            collect_expr_facts(gcx, hir, rhs, facts);
201        }
202        ExprKind::Call(callee, args, named_args) => {
203            if call_may_mutate_state(gcx, hir, callee) {
204                facts.has_state_mutating_call = true;
205            }
206            collect_expr_facts(gcx, hir, callee, facts);
207            for arg in args.exprs() {
208                collect_expr_facts(gcx, hir, arg, facts);
209            }
210            if let Some(named_args) = named_args {
211                for arg in named_args.args {
212                    collect_expr_facts(gcx, hir, &arg.value, facts);
213                }
214            }
215        }
216        ExprKind::Delete(inner) => {
217            collect_written_vars(inner, facts);
218            collect_expr_facts(gcx, hir, inner, facts);
219        }
220        ExprKind::Payable(inner) => collect_expr_facts(gcx, hir, inner, facts),
221        ExprKind::Unary(op, inner) => {
222            if op.kind.has_side_effects() {
223                collect_written_vars(inner, facts);
224            }
225            collect_expr_facts(gcx, hir, inner, facts);
226        }
227        ExprKind::Index(base, index) => {
228            collect_expr_facts(gcx, hir, base, facts);
229            if let Some(index) = index {
230                collect_expr_facts(gcx, hir, index, facts);
231            }
232        }
233        ExprKind::Slice(base, start, end) => {
234            collect_expr_facts(gcx, hir, base, facts);
235            if let Some(start) = start {
236                collect_expr_facts(gcx, hir, start, facts);
237            }
238            if let Some(end) = end {
239                collect_expr_facts(gcx, hir, end, facts);
240            }
241        }
242        ExprKind::Member(base, _) => collect_expr_facts(gcx, hir, base, facts),
243        ExprKind::Ternary(condition, then_expr, else_expr) => {
244            collect_expr_facts(gcx, hir, condition, facts);
245            collect_expr_facts(gcx, hir, then_expr, facts);
246            collect_expr_facts(gcx, hir, else_expr, facts);
247        }
248        ExprKind::Tuple(exprs) => {
249            for expr in exprs.iter().flatten() {
250                collect_expr_facts(gcx, hir, expr, facts);
251            }
252        }
253        ExprKind::Ident(_)
254        | ExprKind::Lit(_)
255        | ExprKind::New(_)
256        | ExprKind::TypeCall(_)
257        | ExprKind::Type(_)
258        | ExprKind::YulMember(..)
259        | ExprKind::Err(_) => {}
260    }
261}
262
263fn collect_written_vars(expr: &hir::Expr<'_>, facts: &mut LoopFacts) {
264    match &expr.peel_parens().kind {
265        ExprKind::Ident(resolutions) => {
266            if let Some(var_id) = variable_resolution(resolutions) {
267                facts.push_written_var(var_id);
268            }
269        }
270        ExprKind::Index(base, _) => {
271            collect_written_vars(base, facts);
272        }
273        ExprKind::Slice(base, _, _) => {
274            collect_written_vars(base, facts);
275        }
276        ExprKind::Member(base, _) | ExprKind::Payable(base) => collect_written_vars(base, facts),
277        ExprKind::Tuple(exprs) => {
278            for expr in exprs.iter().flatten() {
279                collect_written_vars(expr, facts);
280            }
281        }
282        _ => {}
283    }
284}
285
286fn array_length_mutated<'hir>(gcx: Gcx<'hir>, expr: &'hir hir::Expr<'hir>) -> bool {
287    match &expr.kind {
288        ExprKind::Assign(lhs, _, _) | ExprKind::Delete(lhs) => is_array_like(gcx, lhs),
289        ExprKind::Call(callee, _, _) => {
290            let ExprKind::Member(base, member) = &callee.peel_parens().kind else { return false };
291            matches!(member.name, sym::push | kw::Pop) && is_array_like(gcx, base)
292        }
293        _ => false,
294    }
295}
296
297fn call_may_mutate_state<'hir>(
298    gcx: Gcx<'hir>,
299    hir: &'hir hir::Hir<'hir>,
300    callee: &'hir hir::Expr<'hir>,
301) -> bool {
302    match &callee.peel_parens().kind {
303        ExprKind::Type(_) => false,
304        ExprKind::Ident(resolutions) => resolutions
305            .iter()
306            .find_map(|res| {
307                if let Res::Item(ItemId::Function(function_id)) = res {
308                    Some(hir.function(*function_id).mutates_state())
309                } else {
310                    None
311                }
312            })
313            .unwrap_or(true),
314        ExprKind::Member(base, member)
315            if matches!(member.name, sym::push | kw::Pop) && is_array_like(gcx, base) =>
316        {
317            false
318        }
319        _ => match gcx.type_of_expr(callee.peel_parens().id).map(|ty| ty.peel_refs().kind) {
320            Some(TyKind::Fn(function)) => function.state_mutability >= StateMutability::Payable,
321            _ => true,
322        },
323    }
324}
325
326fn expr_is_loop_invariant<'hir>(
327    gcx: Gcx<'hir>,
328    hir: &'hir hir::Hir<'hir>,
329    expr: &'hir hir::Expr<'hir>,
330    written_vars: &[VariableId],
331) -> bool {
332    match &expr.peel_parens().kind {
333        ExprKind::Ident(resolutions) => {
334            variable_resolution(resolutions).is_none_or(|var_id| !written_vars.contains(&var_id))
335        }
336        ExprKind::Lit(_) | ExprKind::Type(_) | ExprKind::TypeCall(_) => true,
337        ExprKind::Array(exprs) => {
338            exprs.iter().all(|expr| expr_is_loop_invariant(gcx, hir, expr, written_vars))
339        }
340        ExprKind::Binary(lhs, _, rhs) => {
341            expr_is_loop_invariant(gcx, hir, lhs, written_vars)
342                && expr_is_loop_invariant(gcx, hir, rhs, written_vars)
343        }
344        ExprKind::Call(callee, args, named_args) => {
345            call_is_safe_to_cache(gcx, hir, callee)
346                && expr_is_loop_invariant(gcx, hir, callee, written_vars)
347                && args.exprs().all(|arg| expr_is_loop_invariant(gcx, hir, arg, written_vars))
348                && named_args.is_none_or(|named_args| {
349                    named_args
350                        .args
351                        .iter()
352                        .all(|arg| expr_is_loop_invariant(gcx, hir, &arg.value, written_vars))
353                })
354        }
355        ExprKind::Index(base, index) => {
356            expr_is_loop_invariant(gcx, hir, base, written_vars)
357                && index.is_none_or(|index| expr_is_loop_invariant(gcx, hir, index, written_vars))
358        }
359        ExprKind::Slice(base, start, end) => {
360            expr_is_loop_invariant(gcx, hir, base, written_vars)
361                && start.is_none_or(|start| expr_is_loop_invariant(gcx, hir, start, written_vars))
362                && end.is_none_or(|end| expr_is_loop_invariant(gcx, hir, end, written_vars))
363        }
364        ExprKind::Member(base, _) | ExprKind::Payable(base) => {
365            expr_is_loop_invariant(gcx, hir, base, written_vars)
366        }
367        ExprKind::Ternary(condition, then_expr, else_expr) => {
368            expr_is_loop_invariant(gcx, hir, condition, written_vars)
369                && expr_is_loop_invariant(gcx, hir, then_expr, written_vars)
370                && expr_is_loop_invariant(gcx, hir, else_expr, written_vars)
371        }
372        ExprKind::Tuple(exprs) => {
373            exprs.iter().flatten().all(|expr| expr_is_loop_invariant(gcx, hir, expr, written_vars))
374        }
375        ExprKind::Unary(op, inner) => {
376            !op.kind.has_side_effects() && expr_is_loop_invariant(gcx, hir, inner, written_vars)
377        }
378        ExprKind::Assign(_, _, _)
379        | ExprKind::Delete(_)
380        | ExprKind::New(_)
381        | ExprKind::YulMember(..)
382        | ExprKind::Err(_) => false,
383    }
384}
385
386fn call_is_safe_to_cache<'hir>(
387    gcx: Gcx<'hir>,
388    hir: &'hir hir::Hir<'hir>,
389    callee: &'hir hir::Expr<'hir>,
390) -> bool {
391    match &callee.peel_parens().kind {
392        ExprKind::Type(_) => true,
393        ExprKind::Ident(resolutions) => resolutions
394            .iter()
395            .find_map(|res| {
396                if let Res::Item(ItemId::Function(function_id)) = res {
397                    Some(hir.function(*function_id).state_mutability <= StateMutability::View)
398                } else {
399                    None
400                }
401            })
402            .unwrap_or(false),
403        _ => match gcx.type_of_expr(callee.peel_parens().id).map(|ty| ty.peel_refs().kind) {
404            Some(TyKind::Fn(function)) => function.state_mutability <= StateMutability::View,
405            _ => false,
406        },
407    }
408}
409
410const fn is_comparison(op: BinOpKind) -> bool {
411    matches!(
412        op,
413        BinOpKind::Lt
414            | BinOpKind::Le
415            | BinOpKind::Gt
416            | BinOpKind::Ge
417            | BinOpKind::Eq
418            | BinOpKind::Ne
419    )
420}
421
422fn is_array_like<'hir>(gcx: Gcx<'hir>, expr: &'hir hir::Expr<'hir>) -> bool {
423    let Some(ty) = gcx.type_of_expr(expr.peel_parens().id) else { return false };
424    matches!(ty.peel_refs().kind, TyKind::DynArray(_) | TyKind::Elementary(ElementaryType::Bytes))
425}
426
427fn is_state_array<'hir>(gcx: Gcx<'hir>, expr: &'hir hir::Expr<'hir>) -> bool {
428    let ExprKind::Ident(resolutions) = &expr.peel_parens().kind else { return false };
429    let Some(var_id) = variable_resolution(resolutions) else { return false };
430    gcx.hir.variable(var_id).is_state_variable()
431        && matches!(
432            gcx.type_of_expr(expr.peel_parens().id).map(|ty| ty.peel_refs().kind),
433            Some(TyKind::DynArray(_))
434        )
435}
436
437fn variable_resolution(resolutions: &[Res]) -> Option<VariableId> {
438    resolutions.iter().find_map(|res| {
439        if let Res::Item(ItemId::Variable(var_id)) = res { Some(*var_id) } else { None }
440    })
441}