Skip to main content

forge_lint/sol/med/
div_mul.rs

1use super::DivideBeforeMultiply;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    ast::UnOpKind,
8    sema::{
9        Gcx, Hir,
10        builtins::Builtin,
11        hir::{
12            BinOpKind, Block, Expr, ExprKind, Function, ItemId, Res, Stmt, StmtKind, VariableId,
13        },
14    },
15};
16use std::collections::HashSet;
17
18declare_forge_lint!(
19    DIVIDE_BEFORE_MULTIPLY,
20    Severity::Med,
21    "divide-before-multiply",
22    "multiplication should occur before division to avoid loss of precision"
23);
24
25impl<'hir> LateLintPass<'hir> for DivideBeforeMultiply {
26    fn check_function(
27        &mut self,
28        ctx: &LintContext,
29        _gcx: Gcx<'hir>,
30        hir: &'hir Hir<'hir>,
31        func: &'hir Function<'hir>,
32    ) {
33        if let Some(body) = func.body {
34            let mut tainted = HashSet::default();
35            check_block(ctx, hir, body, &mut tainted);
36        }
37    }
38}
39
40fn check_block<'hir>(
41    ctx: &LintContext,
42    hir: &'hir Hir<'hir>,
43    block: Block<'hir>,
44    tainted: &mut HashSet<VariableId>,
45) -> bool {
46    for stmt in block.stmts {
47        if !check_stmt(ctx, hir, stmt, tainted) {
48            return false;
49        }
50    }
51    true
52}
53
54fn check_stmt<'hir>(
55    ctx: &LintContext,
56    hir: &'hir Hir<'hir>,
57    stmt: &'hir Stmt<'hir>,
58    tainted: &mut HashSet<VariableId>,
59) -> bool {
60    match &stmt.kind {
61        StmtKind::DeclSingle(var_id) => {
62            if let Some(init) = hir.variable(*var_id).initializer {
63                check_expr(ctx, hir, init, tainted);
64                update_taint(
65                    hir,
66                    *var_id,
67                    expr_value_is_division_or_tainted(init, tainted),
68                    tainted,
69                );
70            }
71            true
72        }
73        StmtKind::DeclMulti(vars, expr) => {
74            check_expr(ctx, hir, expr, tainted);
75            update_multi_decl_taint(hir, vars, expr, tainted);
76            true
77        }
78        StmtKind::Expr(expr) => {
79            check_expr(ctx, hir, expr, tainted);
80            !is_revert_call(expr)
81        }
82        StmtKind::Emit(expr) => {
83            check_expr(ctx, hir, expr, tainted);
84            true
85        }
86        StmtKind::Revert(expr) | StmtKind::Return(Some(expr)) => {
87            check_expr(ctx, hir, expr, tainted);
88            false
89        }
90        StmtKind::If(cond, then_stmt, else_stmt) => {
91            check_expr(ctx, hir, cond, tainted);
92
93            let baseline = tainted.clone();
94            let mut merged_taint = HashSet::default();
95            let mut falls_through = false;
96
97            let mut then_tainted = baseline.clone();
98            if check_stmt(ctx, hir, then_stmt, &mut then_tainted) {
99                merged_taint = union_taints(&merged_taint, &then_tainted);
100                falls_through = true;
101            }
102
103            if let Some(else_stmt) = else_stmt {
104                let mut else_tainted = baseline;
105                if check_stmt(ctx, hir, else_stmt, &mut else_tainted) {
106                    merged_taint = union_taints(&merged_taint, &else_tainted);
107                    falls_through = true;
108                }
109            } else {
110                merged_taint = union_taints(&merged_taint, &baseline);
111                falls_through = true;
112            }
113
114            if falls_through {
115                *tainted = merged_taint;
116            }
117            falls_through
118        }
119        StmtKind::Loop(block, _) => {
120            let baseline = tainted.clone();
121            let mut loop_tainted = baseline.clone();
122            *tainted = if check_block(ctx, hir, *block, &mut loop_tainted) {
123                union_taints(&baseline, &loop_tainted)
124            } else {
125                baseline
126            };
127            true
128        }
129        StmtKind::Try(try_stmt) => {
130            check_expr(ctx, hir, &try_stmt.expr, tainted);
131            let mut merged_taint = tainted.clone();
132            for clause in try_stmt.clauses {
133                let mut clause_tainted = tainted.clone();
134                if check_block(ctx, hir, clause.block, &mut clause_tainted) {
135                    merged_taint = union_taints(&merged_taint, &clause_tainted);
136                }
137            }
138            *tainted = merged_taint;
139            true
140        }
141        StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
142            check_block(ctx, hir, *block, tainted)
143        }
144        StmtKind::AssemblyBlock(block) => check_block(ctx, hir, *block, tainted),
145        StmtKind::Switch(switch) => {
146            check_expr(ctx, hir, switch.selector, tainted);
147            let mut merged_taint = tainted.clone();
148            for case in switch.cases {
149                let mut case_tainted = tainted.clone();
150                if check_block(ctx, hir, case.body, &mut case_tainted) {
151                    merged_taint = union_taints(&merged_taint, &case_tainted);
152                }
153            }
154            *tainted = merged_taint;
155            true
156        }
157        StmtKind::Return(None) => false,
158        StmtKind::Break | StmtKind::Continue | StmtKind::Placeholder | StmtKind::Err(_) => true,
159    }
160}
161
162fn check_expr<'hir>(
163    ctx: &LintContext,
164    hir: &'hir Hir<'hir>,
165    expr: &'hir Expr<'hir>,
166    tainted: &mut HashSet<VariableId>,
167) {
168    match &expr.peel_parens().kind {
169        ExprKind::Assign(lhs, op, rhs) => {
170            check_expr(ctx, hir, rhs, tainted);
171            check_expr(ctx, hir, lhs, tainted);
172
173            match op {
174                None => {
175                    update_assignment_taint(hir, lhs, rhs, tainted);
176                }
177                Some(op) if op.kind == BinOpKind::Mul => {
178                    let lhs_tainted = expr_is_division_result_or_tainted(lhs, tainted);
179                    let rhs_tainted = expr_is_division_result_or_tainted(rhs, tainted);
180                    if lhs_tainted || rhs_tainted {
181                        ctx.emit(&DIVIDE_BEFORE_MULTIPLY, expr.span);
182                    }
183                    update_lhs_taint(hir, lhs, lhs_tainted || rhs_tainted, tainted);
184                }
185                Some(op) if op.kind == BinOpKind::Div => {
186                    update_lhs_taint(hir, lhs, true, tainted);
187                }
188                Some(_) => update_lhs_taint(hir, lhs, false, tainted),
189            }
190        }
191        ExprKind::Binary(left, op, right) => {
192            check_expr(ctx, hir, left, tainted);
193            check_expr(ctx, hir, right, tainted);
194
195            if op.kind == BinOpKind::Mul
196                && (expr_is_division_result_or_tainted(left, tainted)
197                    || expr_is_division_result_or_tainted(right, tainted))
198            {
199                ctx.emit(&DIVIDE_BEFORE_MULTIPLY, expr.span);
200            }
201        }
202        ExprKind::Array(exprs) => {
203            for expr in *exprs {
204                check_expr(ctx, hir, expr, tainted);
205            }
206        }
207        ExprKind::Call(callee, args, named_args) => {
208            check_expr(ctx, hir, callee, tainted);
209            for arg in args.exprs() {
210                check_expr(ctx, hir, arg, tainted);
211            }
212            if let Some(named_args) = named_args {
213                for arg in named_args.args {
214                    check_expr(ctx, hir, &arg.value, tainted);
215                }
216            }
217
218            if is_yul_multiplication_call(expr)
219                && args.exprs().any(|arg| expr_is_division_result_or_tainted(arg, tainted))
220            {
221                ctx.emit(&DIVIDE_BEFORE_MULTIPLY, expr.span);
222            }
223        }
224        ExprKind::Delete(inner)
225        | ExprKind::Index(inner, None)
226        | ExprKind::Member(inner, _)
227        | ExprKind::Payable(inner) => check_expr(ctx, hir, inner, tainted),
228        ExprKind::Index(base, Some(index)) => {
229            check_expr(ctx, hir, base, tainted);
230            check_expr(ctx, hir, index, tainted);
231        }
232        ExprKind::Slice(base, start, end) => {
233            check_expr(ctx, hir, base, tainted);
234            if let Some(start) = start {
235                check_expr(ctx, hir, start, tainted);
236            }
237            if let Some(end) = end {
238                check_expr(ctx, hir, end, tainted);
239            }
240        }
241        ExprKind::Ternary(cond, then_expr, else_expr) => {
242            check_expr(ctx, hir, cond, tainted);
243            let mut then_tainted = tainted.clone();
244            check_expr(ctx, hir, then_expr, &mut then_tainted);
245            let mut else_tainted = tainted.clone();
246            check_expr(ctx, hir, else_expr, &mut else_tainted);
247            *tainted = union_taints(&then_tainted, &else_tainted);
248        }
249        ExprKind::Tuple(exprs) => {
250            for expr in exprs.iter().flatten() {
251                check_expr(ctx, hir, expr, tainted);
252            }
253        }
254        ExprKind::Unary(op, inner) => {
255            check_expr(ctx, hir, inner, tainted);
256            if is_inc_dec_op(op.kind) {
257                update_lhs_taint(hir, inner, false, tainted);
258            }
259        }
260        ExprKind::Ident(_)
261        | ExprKind::Lit(_)
262        | ExprKind::New(_)
263        | ExprKind::TypeCall(_)
264        | ExprKind::Type(_) => {}
265        ExprKind::YulMember(inner, _) => check_expr(ctx, hir, inner, tainted),
266        ExprKind::Err(_) => {}
267    }
268}
269
270fn update_multi_decl_taint(
271    hir: &Hir<'_>,
272    vars: &[Option<VariableId>],
273    expr: &Expr<'_>,
274    tainted: &mut HashSet<VariableId>,
275) {
276    if let ExprKind::Tuple(exprs) = &expr.peel_parens().kind
277        && exprs.len() == vars.len()
278    {
279        let rhs_taints: Vec<_> = exprs
280            .iter()
281            .map(|expr| expr.is_some_and(|expr| expr_value_is_division_or_tainted(expr, tainted)))
282            .collect();
283        for (var_id, rhs_tainted) in vars.iter().zip(rhs_taints) {
284            if let Some(var_id) = var_id {
285                update_taint(hir, *var_id, rhs_tainted, tainted);
286            }
287        }
288        return;
289    }
290
291    let rhs_tainted = expr_value_is_division_or_tainted(expr, tainted);
292    for var_id in vars.iter().flatten() {
293        update_taint(hir, *var_id, rhs_tainted, tainted);
294    }
295}
296
297fn update_assignment_taint(
298    hir: &Hir<'_>,
299    lhs: &Expr<'_>,
300    rhs: &Expr<'_>,
301    tainted: &mut HashSet<VariableId>,
302) {
303    if let (ExprKind::Tuple(lhs_exprs), ExprKind::Tuple(rhs_exprs)) =
304        (&lhs.peel_parens().kind, &rhs.peel_parens().kind)
305        && lhs_exprs.len() == rhs_exprs.len()
306    {
307        let rhs_taints: Vec<_> = rhs_exprs
308            .iter()
309            .map(|rhs| rhs.is_some_and(|rhs| expr_value_is_division_or_tainted(rhs, tainted)))
310            .collect();
311        for (lhs, rhs_tainted) in lhs_exprs.iter().zip(rhs_taints) {
312            if let Some(lhs) = lhs {
313                update_lhs_taint(hir, lhs, rhs_tainted, tainted);
314            }
315        }
316        return;
317    }
318
319    update_lhs_taint(hir, lhs, expr_value_is_division_or_tainted(rhs, tainted), tainted);
320}
321
322fn union_taints(left: &HashSet<VariableId>, right: &HashSet<VariableId>) -> HashSet<VariableId> {
323    left.union(right).copied().collect()
324}
325
326fn update_lhs_taint(
327    hir: &Hir<'_>,
328    lhs: &Expr<'_>,
329    is_tainted: bool,
330    tainted: &mut HashSet<VariableId>,
331) {
332    match &lhs.peel_parens().kind {
333        ExprKind::Ident(resolutions) => {
334            for res in *resolutions {
335                if let Res::Item(ItemId::Variable(var_id)) = res {
336                    update_taint(hir, *var_id, is_tainted, tainted);
337                }
338            }
339        }
340        ExprKind::Tuple(exprs) => {
341            for expr in exprs.iter().flatten() {
342                update_lhs_taint(hir, expr, is_tainted, tainted);
343            }
344        }
345        _ => {}
346    }
347}
348
349fn update_taint(
350    hir: &Hir<'_>,
351    var_id: VariableId,
352    is_tainted: bool,
353    tainted: &mut HashSet<VariableId>,
354) {
355    if !hir.variable(var_id).is_local_or_return() {
356        return;
357    }
358    if is_tainted {
359        tainted.insert(var_id);
360    } else {
361        tainted.remove(&var_id);
362    }
363}
364
365fn expr_value_is_division_or_tainted(expr: &Expr<'_>, tainted: &HashSet<VariableId>) -> bool {
366    match &expr.peel_parens().kind {
367        ExprKind::Binary(_, op, _) => op.kind == BinOpKind::Div,
368        ExprKind::Ident(resolutions) => resolutions.iter().any(
369            |res| matches!(res, Res::Item(ItemId::Variable(var_id)) if tainted.contains(var_id)),
370        ),
371        ExprKind::Call(..) => is_yul_division_call(expr),
372        ExprKind::Tuple([Some(inner)]) => expr_value_is_division_or_tainted(inner, tainted),
373        ExprKind::YulMember(inner, _) => expr_value_is_division_or_tainted(inner, tainted),
374        ExprKind::Array(_)
375        | ExprKind::Assign(..)
376        | ExprKind::Delete(_)
377        | ExprKind::Index(..)
378        | ExprKind::Lit(_)
379        | ExprKind::Member(_, _)
380        | ExprKind::New(_)
381        | ExprKind::Payable(_)
382        | ExprKind::Slice(..)
383        | ExprKind::Ternary(..)
384        | ExprKind::TypeCall(_)
385        | ExprKind::Type(_)
386        | ExprKind::Unary(_, _)
387        | ExprKind::Tuple(_) => false,
388        ExprKind::Err(_) => false,
389    }
390}
391
392fn expr_is_division_result_or_tainted(expr: &Expr<'_>, tainted: &HashSet<VariableId>) -> bool {
393    match &expr.peel_parens().kind {
394        ExprKind::Binary(_, op, _) => op.kind == BinOpKind::Div,
395        ExprKind::Call(..) => is_yul_division_call(expr),
396        ExprKind::Ident(resolutions) => resolutions.iter().any(
397            |res| matches!(res, Res::Item(ItemId::Variable(var_id)) if tainted.contains(var_id)),
398        ),
399        ExprKind::Tuple([Some(inner)]) => expr_is_division_result_or_tainted(inner, tainted),
400        _ => false,
401    }
402}
403
404fn is_yul_division_call(expr: &Expr<'_>) -> bool {
405    is_yul_builtin_call(expr, |builtin| matches!(builtin, Builtin::YulDiv | Builtin::YulSdiv))
406}
407
408fn is_yul_multiplication_call(expr: &Expr<'_>) -> bool {
409    is_yul_builtin_call(expr, |builtin| matches!(builtin, Builtin::YulMul))
410}
411
412fn is_revert_call(expr: &Expr<'_>) -> bool {
413    let ExprKind::Call(callee, _, _) = &expr.peel_parens().kind else { return false };
414    let ExprKind::Ident(resolutions) = &callee.peel_parens().kind else { return false };
415    resolutions.iter().any(|res| matches!(res, Res::Builtin(Builtin::Revert | Builtin::RevertMsg)))
416}
417
418const fn is_inc_dec_op(kind: UnOpKind) -> bool {
419    matches!(kind, UnOpKind::PreInc | UnOpKind::PostInc | UnOpKind::PreDec | UnOpKind::PostDec)
420}
421
422fn is_yul_builtin_call(expr: &Expr<'_>, predicate: impl Fn(Builtin) -> bool) -> bool {
423    let ExprKind::Call(callee, args, _) = &expr.peel_parens().kind else { return false };
424    if args.len() != 2 {
425        return false;
426    }
427    let ExprKind::Ident(resolutions) = &callee.peel_parens().kind else { return false };
428    resolutions.iter().any(|res| matches!(res, Res::Builtin(builtin) if predicate(*builtin)))
429}