Skip to main content

forge_lint/sol/gas/
cache_array_length.rs

1use super::CacheArrayLength;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    interface::{Symbol, kw, sym},
8    sema::hir::{
9        self, BinOpKind, ElementaryType, ExprKind, ItemId, LoopSource, Res, StateMutability,
10        StmtKind, TypeKind, UnOpKind, VariableId,
11    },
12};
13
14declare_forge_lint!(
15    CACHE_ARRAY_LENGTH,
16    Severity::Gas,
17    "cache-array-length",
18    "array length read in loop condition should be cached outside the loop"
19);
20
21#[derive(Clone, Copy)]
22struct LengthRead<'hir> {
23    expr: &'hir hir::Expr<'hir>,
24    base: &'hir hir::Expr<'hir>,
25}
26
27#[derive(Default)]
28struct LoopFacts {
29    written_vars: Vec<VariableId>,
30    mutates_array_length: bool,
31    has_state_mutating_call: bool,
32}
33
34impl LoopFacts {
35    const fn should_skip(&self) -> bool {
36        self.mutates_array_length || self.has_state_mutating_call
37    }
38
39    fn push_written_var(&mut self, var_id: VariableId) {
40        if !self.written_vars.contains(&var_id) {
41            self.written_vars.push(var_id);
42        }
43    }
44}
45
46impl<'hir> LateLintPass<'hir> for CacheArrayLength {
47    fn check_stmt(
48        &mut self,
49        ctx: &LintContext,
50        hir: &'hir hir::Hir<'hir>,
51        stmt: &'hir hir::Stmt<'hir>,
52    ) {
53        let StmtKind::Loop(block, LoopSource::For) = &stmt.kind else { return };
54        let Some((condition, body)) = for_loop_parts(*block) else { return };
55
56        let mut reads = Vec::new();
57        collect_condition_length_reads(hir, condition, &mut reads);
58        if reads.is_empty() {
59            return;
60        }
61
62        let mut facts = LoopFacts::default();
63        collect_stmt_facts(hir, body, &mut facts);
64        if facts.should_skip() {
65            return;
66        }
67
68        for read in reads {
69            if expr_is_loop_invariant(hir, read.base, &facts.written_vars) {
70                ctx.emit(&CACHE_ARRAY_LENGTH, read.expr.span);
71            }
72        }
73    }
74}
75
76fn for_loop_parts<'hir>(
77    block: hir::Block<'hir>,
78) -> Option<(&'hir hir::Expr<'hir>, &'hir hir::Stmt<'hir>)> {
79    let first = block.stmts.first()?;
80    match &first.kind {
81        StmtKind::If(condition, _, Some(else_stmt)) => {
82            matches!(&else_stmt.kind, StmtKind::Break).then_some((*condition, first))
83        }
84        _ => None,
85    }
86}
87
88fn collect_condition_length_reads<'hir>(
89    hir: &'hir hir::Hir<'hir>,
90    expr: &'hir hir::Expr<'hir>,
91    reads: &mut Vec<LengthRead<'hir>>,
92) {
93    match &expr.peel_parens().kind {
94        ExprKind::Binary(lhs, op, rhs) if is_comparison(op.kind) => {
95            collect_length_reads(hir, lhs, reads);
96            collect_length_reads(hir, rhs, reads);
97        }
98        ExprKind::Binary(lhs, op, rhs) if matches!(op.kind, BinOpKind::And | BinOpKind::Or) => {
99            collect_condition_length_reads(hir, lhs, reads);
100            collect_condition_length_reads(hir, rhs, reads);
101        }
102        ExprKind::Unary(op, inner) if op.kind == UnOpKind::Not => {
103            collect_condition_length_reads(hir, inner, reads);
104        }
105        _ => {}
106    }
107}
108
109fn collect_length_reads<'hir>(
110    hir: &'hir hir::Hir<'hir>,
111    expr: &'hir hir::Expr<'hir>,
112    reads: &mut Vec<LengthRead<'hir>>,
113) {
114    let expr = expr.peel_parens();
115    if let ExprKind::Member(base, member) = &expr.kind
116        && member.name == sym::length
117        && is_array_like(hir, base)
118    {
119        reads.push(LengthRead { expr, base });
120        return;
121    }
122
123    match &expr.kind {
124        ExprKind::Array(exprs) => {
125            for expr in *exprs {
126                collect_length_reads(hir, expr, reads);
127            }
128        }
129        ExprKind::Assign(lhs, _, rhs) | ExprKind::Binary(lhs, _, rhs) => {
130            collect_length_reads(hir, lhs, reads);
131            collect_length_reads(hir, rhs, reads);
132        }
133        ExprKind::Call(callee, args, named_args) => {
134            collect_length_reads(hir, callee, reads);
135            for arg in args.exprs() {
136                collect_length_reads(hir, arg, reads);
137            }
138            if let Some(named_args) = named_args {
139                for arg in *named_args {
140                    collect_length_reads(hir, &arg.value, reads);
141                }
142            }
143        }
144        ExprKind::Delete(inner) | ExprKind::Payable(inner) | ExprKind::Unary(_, inner) => {
145            collect_length_reads(hir, inner, reads);
146        }
147        ExprKind::Index(base, index) => {
148            collect_length_reads(hir, base, reads);
149            if let Some(index) = index {
150                collect_length_reads(hir, index, reads);
151            }
152        }
153        ExprKind::Slice(base, start, end) => {
154            collect_length_reads(hir, base, reads);
155            if let Some(start) = start {
156                collect_length_reads(hir, start, reads);
157            }
158            if let Some(end) = end {
159                collect_length_reads(hir, end, reads);
160            }
161        }
162        ExprKind::Member(base, _) => collect_length_reads(hir, base, reads),
163        ExprKind::Ternary(condition, then_expr, else_expr) => {
164            collect_length_reads(hir, condition, reads);
165            collect_length_reads(hir, then_expr, reads);
166            collect_length_reads(hir, else_expr, reads);
167        }
168        ExprKind::Tuple(exprs) => {
169            for expr in exprs.iter().flatten() {
170                collect_length_reads(hir, expr, reads);
171            }
172        }
173        ExprKind::Ident(_)
174        | ExprKind::Lit(_)
175        | ExprKind::New(_)
176        | ExprKind::TypeCall(_)
177        | ExprKind::Type(_)
178        | ExprKind::Err(_) => {}
179    }
180}
181
182fn collect_stmt_facts<'hir>(
183    hir: &'hir hir::Hir<'hir>,
184    stmt: &'hir hir::Stmt<'hir>,
185    facts: &mut LoopFacts,
186) {
187    match &stmt.kind {
188        StmtKind::DeclSingle(var_id) => {
189            if let Some(expr) = hir.variable(*var_id).initializer {
190                collect_expr_facts(hir, expr, facts);
191            }
192        }
193        StmtKind::DeclMulti(_, expr)
194        | StmtKind::Emit(expr)
195        | StmtKind::Revert(expr)
196        | StmtKind::Expr(expr) => collect_expr_facts(hir, expr, facts),
197        StmtKind::Return(expr) => {
198            if let Some(expr) = expr {
199                collect_expr_facts(hir, expr, facts);
200            }
201        }
202        StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
203            for stmt in block.stmts {
204                collect_stmt_facts(hir, stmt, facts);
205            }
206        }
207        StmtKind::If(condition, then_stmt, else_stmt) => {
208            collect_expr_facts(hir, condition, facts);
209            collect_stmt_facts(hir, then_stmt, facts);
210            if let Some(else_stmt) = else_stmt {
211                collect_stmt_facts(hir, else_stmt, facts);
212            }
213        }
214        StmtKind::Try(stmt_try) => {
215            collect_expr_facts(hir, &stmt_try.expr, facts);
216            for clause in stmt_try.clauses {
217                for stmt in clause.block.stmts {
218                    collect_stmt_facts(hir, stmt, facts);
219                }
220            }
221        }
222        StmtKind::Break | StmtKind::Continue | StmtKind::Placeholder | StmtKind::Err(_) => {}
223    }
224}
225
226fn collect_expr_facts<'hir>(
227    hir: &'hir hir::Hir<'hir>,
228    expr: &'hir hir::Expr<'hir>,
229    facts: &mut LoopFacts,
230) {
231    let expr = expr.peel_parens();
232    if array_length_mutated(hir, expr) {
233        facts.mutates_array_length = true;
234    }
235
236    match &expr.kind {
237        ExprKind::Array(exprs) => {
238            for expr in *exprs {
239                collect_expr_facts(hir, expr, facts);
240            }
241        }
242        ExprKind::Assign(lhs, _, rhs) => {
243            collect_written_vars(lhs, facts);
244            collect_expr_facts(hir, lhs, facts);
245            collect_expr_facts(hir, rhs, facts);
246        }
247        ExprKind::Binary(lhs, _, rhs) => {
248            collect_expr_facts(hir, lhs, facts);
249            collect_expr_facts(hir, rhs, facts);
250        }
251        ExprKind::Call(callee, args, named_args) => {
252            if call_may_mutate_state(hir, callee) {
253                facts.has_state_mutating_call = true;
254            }
255            collect_expr_facts(hir, callee, facts);
256            for arg in args.exprs() {
257                collect_expr_facts(hir, arg, facts);
258            }
259            if let Some(named_args) = named_args {
260                for arg in *named_args {
261                    collect_expr_facts(hir, &arg.value, facts);
262                }
263            }
264        }
265        ExprKind::Delete(inner) => {
266            collect_written_vars(inner, facts);
267            collect_expr_facts(hir, inner, facts);
268        }
269        ExprKind::Payable(inner) => collect_expr_facts(hir, inner, facts),
270        ExprKind::Unary(op, inner) => {
271            if op.kind.has_side_effects() {
272                collect_written_vars(inner, facts);
273            }
274            collect_expr_facts(hir, inner, facts);
275        }
276        ExprKind::Index(base, index) => {
277            collect_expr_facts(hir, base, facts);
278            if let Some(index) = index {
279                collect_expr_facts(hir, index, facts);
280            }
281        }
282        ExprKind::Slice(base, start, end) => {
283            collect_expr_facts(hir, base, facts);
284            if let Some(start) = start {
285                collect_expr_facts(hir, start, facts);
286            }
287            if let Some(end) = end {
288                collect_expr_facts(hir, end, facts);
289            }
290        }
291        ExprKind::Member(base, _) => collect_expr_facts(hir, base, facts),
292        ExprKind::Ternary(condition, then_expr, else_expr) => {
293            collect_expr_facts(hir, condition, facts);
294            collect_expr_facts(hir, then_expr, facts);
295            collect_expr_facts(hir, else_expr, facts);
296        }
297        ExprKind::Tuple(exprs) => {
298            for expr in exprs.iter().flatten() {
299                collect_expr_facts(hir, expr, facts);
300            }
301        }
302        ExprKind::Ident(_)
303        | ExprKind::Lit(_)
304        | ExprKind::New(_)
305        | ExprKind::TypeCall(_)
306        | ExprKind::Type(_)
307        | ExprKind::Err(_) => {}
308    }
309}
310
311fn collect_written_vars(expr: &hir::Expr<'_>, facts: &mut LoopFacts) {
312    match &expr.peel_parens().kind {
313        ExprKind::Ident(resolutions) => {
314            if let Some(var_id) = variable_resolution(resolutions) {
315                facts.push_written_var(var_id);
316            }
317        }
318        ExprKind::Index(base, _) => {
319            collect_written_vars(base, facts);
320        }
321        ExprKind::Slice(base, _, _) => {
322            collect_written_vars(base, facts);
323        }
324        ExprKind::Member(base, _) | ExprKind::Payable(base) => collect_written_vars(base, facts),
325        ExprKind::Tuple(exprs) => {
326            for expr in exprs.iter().flatten() {
327                collect_written_vars(expr, facts);
328            }
329        }
330        _ => {}
331    }
332}
333
334fn array_length_mutated<'hir>(hir: &'hir hir::Hir<'hir>, expr: &'hir hir::Expr<'hir>) -> bool {
335    match &expr.kind {
336        ExprKind::Assign(lhs, _, _) | ExprKind::Delete(lhs) => is_array_like(hir, lhs),
337        ExprKind::Call(callee, _, _) => {
338            let ExprKind::Member(base, member) = &callee.peel_parens().kind else { return false };
339            matches!(member.name, sym::push | kw::Pop) && is_array_like(hir, base)
340        }
341        _ => false,
342    }
343}
344
345fn call_may_mutate_state(hir: &hir::Hir<'_>, callee: &hir::Expr<'_>) -> bool {
346    match &callee.peel_parens().kind {
347        ExprKind::Type(_) => false,
348        ExprKind::Ident(resolutions) => resolutions
349            .iter()
350            .find_map(|res| {
351                if let Res::Item(ItemId::Function(function_id)) = res {
352                    Some(hir.function(*function_id).mutates_state())
353                } else {
354                    None
355                }
356            })
357            .unwrap_or(true),
358        ExprKind::Member(base, member)
359            if matches!(member.name, sym::push | kw::Pop) && is_array_like(hir, base) =>
360        {
361            false
362        }
363        _ => match &expr_type(hir, callee).map(|ty| &ty.kind) {
364            Some(TypeKind::Function(function)) => {
365                function.state_mutability >= StateMutability::Payable
366            }
367            _ => true,
368        },
369    }
370}
371
372fn expr_is_loop_invariant(
373    hir: &hir::Hir<'_>,
374    expr: &hir::Expr<'_>,
375    written_vars: &[VariableId],
376) -> bool {
377    match &expr.peel_parens().kind {
378        ExprKind::Ident(resolutions) => {
379            variable_resolution(resolutions).is_none_or(|var_id| !written_vars.contains(&var_id))
380        }
381        ExprKind::Lit(_) | ExprKind::Type(_) | ExprKind::TypeCall(_) => true,
382        ExprKind::Array(exprs) => {
383            exprs.iter().all(|expr| expr_is_loop_invariant(hir, expr, written_vars))
384        }
385        ExprKind::Binary(lhs, _, rhs) => {
386            expr_is_loop_invariant(hir, lhs, written_vars)
387                && expr_is_loop_invariant(hir, rhs, written_vars)
388        }
389        ExprKind::Call(callee, args, named_args) => {
390            call_is_safe_to_cache(hir, callee)
391                && expr_is_loop_invariant(hir, callee, written_vars)
392                && args.exprs().all(|arg| expr_is_loop_invariant(hir, arg, written_vars))
393                && named_args.is_none_or(|named_args| {
394                    named_args
395                        .iter()
396                        .all(|arg| expr_is_loop_invariant(hir, &arg.value, written_vars))
397                })
398        }
399        ExprKind::Index(base, index) => {
400            expr_is_loop_invariant(hir, base, written_vars)
401                && index.is_none_or(|index| expr_is_loop_invariant(hir, index, written_vars))
402        }
403        ExprKind::Slice(base, start, end) => {
404            expr_is_loop_invariant(hir, base, written_vars)
405                && start.is_none_or(|start| expr_is_loop_invariant(hir, start, written_vars))
406                && end.is_none_or(|end| expr_is_loop_invariant(hir, end, written_vars))
407        }
408        ExprKind::Member(base, _) | ExprKind::Payable(base) => {
409            expr_is_loop_invariant(hir, base, written_vars)
410        }
411        ExprKind::Ternary(condition, then_expr, else_expr) => {
412            expr_is_loop_invariant(hir, condition, written_vars)
413                && expr_is_loop_invariant(hir, then_expr, written_vars)
414                && expr_is_loop_invariant(hir, else_expr, written_vars)
415        }
416        ExprKind::Tuple(exprs) => {
417            exprs.iter().flatten().all(|expr| expr_is_loop_invariant(hir, expr, written_vars))
418        }
419        ExprKind::Unary(op, inner) => {
420            !op.kind.has_side_effects() && expr_is_loop_invariant(hir, inner, written_vars)
421        }
422        ExprKind::Assign(_, _, _) | ExprKind::Delete(_) | ExprKind::New(_) | ExprKind::Err(_) => {
423            false
424        }
425    }
426}
427
428fn call_is_safe_to_cache(hir: &hir::Hir<'_>, callee: &hir::Expr<'_>) -> bool {
429    match &callee.peel_parens().kind {
430        ExprKind::Type(_) => true,
431        ExprKind::Ident(resolutions) => resolutions
432            .iter()
433            .find_map(|res| {
434                if let Res::Item(ItemId::Function(function_id)) = res {
435                    Some(hir.function(*function_id).state_mutability <= StateMutability::View)
436                } else {
437                    None
438                }
439            })
440            .unwrap_or(false),
441        _ => match &expr_type(hir, callee).map(|ty| &ty.kind) {
442            Some(TypeKind::Function(function)) => {
443                function.state_mutability <= StateMutability::View
444            }
445            _ => false,
446        },
447    }
448}
449
450const fn is_comparison(op: BinOpKind) -> bool {
451    matches!(
452        op,
453        BinOpKind::Lt
454            | BinOpKind::Le
455            | BinOpKind::Gt
456            | BinOpKind::Ge
457            | BinOpKind::Eq
458            | BinOpKind::Ne
459    )
460}
461
462fn is_array_like(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
463    let Some(ty) = expr_type(hir, expr) else { return false };
464    match &ty.kind {
465        TypeKind::Array(array) => array.size.is_none(),
466        TypeKind::Elementary(ElementaryType::Bytes) => true,
467        _ => false,
468    }
469}
470
471fn expr_type<'hir>(
472    hir: &'hir hir::Hir<'hir>,
473    expr: &'hir hir::Expr<'hir>,
474) -> Option<&'hir hir::Type<'hir>> {
475    match &expr.peel_parens().kind {
476        ExprKind::Ident(resolutions) => {
477            let var_id = variable_resolution(resolutions)?;
478            Some(&hir.variable(var_id).ty)
479        }
480        ExprKind::Index(base, _) => match &expr_type(hir, base)?.kind {
481            TypeKind::Array(array) => Some(&array.element),
482            TypeKind::Mapping(mapping) => Some(&mapping.value),
483            _ => None,
484        },
485        ExprKind::Member(base, member) => {
486            struct_field_type(hir, expr_type(hir, base)?, member.name)
487        }
488        ExprKind::Call(callee, _, _) => call_return_type(hir, callee),
489        _ => None,
490    }
491}
492
493fn call_return_type<'hir>(
494    hir: &'hir hir::Hir<'hir>,
495    callee: &'hir hir::Expr<'hir>,
496) -> Option<&'hir hir::Type<'hir>> {
497    match &callee.peel_parens().kind {
498        ExprKind::Type(ty) => Some(ty),
499        ExprKind::Ident(resolutions) => {
500            let function_id = resolutions.iter().find_map(|res| {
501                if let Res::Item(ItemId::Function(function_id)) = res {
502                    Some(*function_id)
503                } else {
504                    None
505                }
506            })?;
507            function_return_type(hir, function_id)
508        }
509        _ => match &expr_type(hir, callee)?.kind {
510            TypeKind::Function(function) => {
511                let [return_id] = function.returns else { return None };
512                Some(&hir.variable(*return_id).ty)
513            }
514            _ => None,
515        },
516    }
517}
518
519fn function_return_type<'hir>(
520    hir: &'hir hir::Hir<'hir>,
521    function_id: hir::FunctionId,
522) -> Option<&'hir hir::Type<'hir>> {
523    let [return_id] = hir.function(function_id).returns else { return None };
524    Some(&hir.variable(*return_id).ty)
525}
526
527fn struct_field_type<'hir>(
528    hir: &'hir hir::Hir<'hir>,
529    ty: &'hir hir::Type<'hir>,
530    member: Symbol,
531) -> Option<&'hir hir::Type<'hir>> {
532    let TypeKind::Custom(ItemId::Struct(struct_id)) = &ty.kind else { return None };
533    hir.strukt(*struct_id)
534        .fields
535        .iter()
536        .map(|&field_id| hir.variable(field_id))
537        .find(|field| field.name.is_some_and(|name| name.name == member))
538        .map(|field| &field.ty)
539}
540
541fn variable_resolution(resolutions: &[Res]) -> Option<VariableId> {
542    resolutions.iter().find_map(|res| {
543        if let Res::Item(ItemId::Variable(var_id)) = res { Some(*var_id) } else { None }
544    })
545}