1use super::DivideBeforeMultiply;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{Severity, SolLint},
5};
6use solar::{
7 ast::UnOpKind,
8 sema::{
9 Gcx, Hir,
10 builtins::Builtin,
11 hir::{
12 BinOpKind, Block, Expr, ExprKind, Function, ItemId, Res, Stmt, StmtKind, VariableId,
13 },
14 },
15};
16use std::collections::HashSet;
17
18declare_forge_lint!(
19 DIVIDE_BEFORE_MULTIPLY,
20 Severity::Med,
21 "divide-before-multiply",
22 "multiplication should occur before division to avoid loss of precision"
23);
24
25impl<'hir> LateLintPass<'hir> for DivideBeforeMultiply {
26 fn check_function(
27 &mut self,
28 ctx: &LintContext,
29 _gcx: Gcx<'hir>,
30 hir: &'hir Hir<'hir>,
31 func: &'hir Function<'hir>,
32 ) {
33 if let Some(body) = func.body {
34 let mut tainted = HashSet::default();
35 check_block(ctx, hir, body, &mut tainted);
36 }
37 }
38}
39
40fn check_block<'hir>(
41 ctx: &LintContext,
42 hir: &'hir Hir<'hir>,
43 block: Block<'hir>,
44 tainted: &mut HashSet<VariableId>,
45) -> bool {
46 for stmt in block.stmts {
47 if !check_stmt(ctx, hir, stmt, tainted) {
48 return false;
49 }
50 }
51 true
52}
53
54fn check_stmt<'hir>(
55 ctx: &LintContext,
56 hir: &'hir Hir<'hir>,
57 stmt: &'hir Stmt<'hir>,
58 tainted: &mut HashSet<VariableId>,
59) -> bool {
60 match &stmt.kind {
61 StmtKind::DeclSingle(var_id) => {
62 if let Some(init) = hir.variable(*var_id).initializer {
63 check_expr(ctx, hir, init, tainted);
64 update_taint(
65 hir,
66 *var_id,
67 expr_value_is_division_or_tainted(init, tainted),
68 tainted,
69 );
70 }
71 true
72 }
73 StmtKind::DeclMulti(vars, expr) => {
74 check_expr(ctx, hir, expr, tainted);
75 update_multi_decl_taint(hir, vars, expr, tainted);
76 true
77 }
78 StmtKind::Expr(expr) => {
79 check_expr(ctx, hir, expr, tainted);
80 !is_revert_call(expr)
81 }
82 StmtKind::Emit(expr) => {
83 check_expr(ctx, hir, expr, tainted);
84 true
85 }
86 StmtKind::Revert(expr) | StmtKind::Return(Some(expr)) => {
87 check_expr(ctx, hir, expr, tainted);
88 false
89 }
90 StmtKind::If(cond, then_stmt, else_stmt) => {
91 check_expr(ctx, hir, cond, tainted);
92
93 let baseline = tainted.clone();
94 let mut merged_taint = HashSet::default();
95 let mut falls_through = false;
96
97 let mut then_tainted = baseline.clone();
98 if check_stmt(ctx, hir, then_stmt, &mut then_tainted) {
99 merged_taint = union_taints(&merged_taint, &then_tainted);
100 falls_through = true;
101 }
102
103 if let Some(else_stmt) = else_stmt {
104 let mut else_tainted = baseline;
105 if check_stmt(ctx, hir, else_stmt, &mut else_tainted) {
106 merged_taint = union_taints(&merged_taint, &else_tainted);
107 falls_through = true;
108 }
109 } else {
110 merged_taint = union_taints(&merged_taint, &baseline);
111 falls_through = true;
112 }
113
114 if falls_through {
115 *tainted = merged_taint;
116 }
117 falls_through
118 }
119 StmtKind::Loop(block, _) => {
120 let baseline = tainted.clone();
121 let mut loop_tainted = baseline.clone();
122 *tainted = if check_block(ctx, hir, *block, &mut loop_tainted) {
123 union_taints(&baseline, &loop_tainted)
124 } else {
125 baseline
126 };
127 true
128 }
129 StmtKind::Try(try_stmt) => {
130 check_expr(ctx, hir, &try_stmt.expr, tainted);
131 let mut merged_taint = tainted.clone();
132 for clause in try_stmt.clauses {
133 let mut clause_tainted = tainted.clone();
134 if check_block(ctx, hir, clause.block, &mut clause_tainted) {
135 merged_taint = union_taints(&merged_taint, &clause_tainted);
136 }
137 }
138 *tainted = merged_taint;
139 true
140 }
141 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
142 check_block(ctx, hir, *block, tainted)
143 }
144 StmtKind::AssemblyBlock(block) => check_block(ctx, hir, *block, tainted),
145 StmtKind::Switch(switch) => {
146 check_expr(ctx, hir, switch.selector, tainted);
147 let mut merged_taint = tainted.clone();
148 for case in switch.cases {
149 let mut case_tainted = tainted.clone();
150 if check_block(ctx, hir, case.body, &mut case_tainted) {
151 merged_taint = union_taints(&merged_taint, &case_tainted);
152 }
153 }
154 *tainted = merged_taint;
155 true
156 }
157 StmtKind::Return(None) => false,
158 StmtKind::Break | StmtKind::Continue | StmtKind::Placeholder | StmtKind::Err(_) => true,
159 }
160}
161
162fn check_expr<'hir>(
163 ctx: &LintContext,
164 hir: &'hir Hir<'hir>,
165 expr: &'hir Expr<'hir>,
166 tainted: &mut HashSet<VariableId>,
167) {
168 match &expr.peel_parens().kind {
169 ExprKind::Assign(lhs, op, rhs) => {
170 check_expr(ctx, hir, rhs, tainted);
171 check_expr(ctx, hir, lhs, tainted);
172
173 match op {
174 None => {
175 update_assignment_taint(hir, lhs, rhs, tainted);
176 }
177 Some(op) if op.kind == BinOpKind::Mul => {
178 let lhs_tainted = expr_is_division_result_or_tainted(lhs, tainted);
179 let rhs_tainted = expr_is_division_result_or_tainted(rhs, tainted);
180 if lhs_tainted || rhs_tainted {
181 ctx.emit(&DIVIDE_BEFORE_MULTIPLY, expr.span);
182 }
183 update_lhs_taint(hir, lhs, lhs_tainted || rhs_tainted, tainted);
184 }
185 Some(op) if op.kind == BinOpKind::Div => {
186 update_lhs_taint(hir, lhs, true, tainted);
187 }
188 Some(_) => update_lhs_taint(hir, lhs, false, tainted),
189 }
190 }
191 ExprKind::Binary(left, op, right) => {
192 check_expr(ctx, hir, left, tainted);
193 check_expr(ctx, hir, right, tainted);
194
195 if op.kind == BinOpKind::Mul
196 && (expr_is_division_result_or_tainted(left, tainted)
197 || expr_is_division_result_or_tainted(right, tainted))
198 {
199 ctx.emit(&DIVIDE_BEFORE_MULTIPLY, expr.span);
200 }
201 }
202 ExprKind::Array(exprs) => {
203 for expr in *exprs {
204 check_expr(ctx, hir, expr, tainted);
205 }
206 }
207 ExprKind::Call(callee, args, named_args) => {
208 check_expr(ctx, hir, callee, tainted);
209 for arg in args.exprs() {
210 check_expr(ctx, hir, arg, tainted);
211 }
212 if let Some(named_args) = named_args {
213 for arg in named_args.args {
214 check_expr(ctx, hir, &arg.value, tainted);
215 }
216 }
217
218 if is_yul_multiplication_call(expr)
219 && args.exprs().any(|arg| expr_is_division_result_or_tainted(arg, tainted))
220 {
221 ctx.emit(&DIVIDE_BEFORE_MULTIPLY, expr.span);
222 }
223 }
224 ExprKind::Delete(inner)
225 | ExprKind::Index(inner, None)
226 | ExprKind::Member(inner, _)
227 | ExprKind::Payable(inner) => check_expr(ctx, hir, inner, tainted),
228 ExprKind::Index(base, Some(index)) => {
229 check_expr(ctx, hir, base, tainted);
230 check_expr(ctx, hir, index, tainted);
231 }
232 ExprKind::Slice(base, start, end) => {
233 check_expr(ctx, hir, base, tainted);
234 if let Some(start) = start {
235 check_expr(ctx, hir, start, tainted);
236 }
237 if let Some(end) = end {
238 check_expr(ctx, hir, end, tainted);
239 }
240 }
241 ExprKind::Ternary(cond, then_expr, else_expr) => {
242 check_expr(ctx, hir, cond, tainted);
243 let mut then_tainted = tainted.clone();
244 check_expr(ctx, hir, then_expr, &mut then_tainted);
245 let mut else_tainted = tainted.clone();
246 check_expr(ctx, hir, else_expr, &mut else_tainted);
247 *tainted = union_taints(&then_tainted, &else_tainted);
248 }
249 ExprKind::Tuple(exprs) => {
250 for expr in exprs.iter().flatten() {
251 check_expr(ctx, hir, expr, tainted);
252 }
253 }
254 ExprKind::Unary(op, inner) => {
255 check_expr(ctx, hir, inner, tainted);
256 if is_inc_dec_op(op.kind) {
257 update_lhs_taint(hir, inner, false, tainted);
258 }
259 }
260 ExprKind::Ident(_)
261 | ExprKind::Lit(_)
262 | ExprKind::New(_)
263 | ExprKind::TypeCall(_)
264 | ExprKind::Type(_) => {}
265 ExprKind::YulMember(inner, _) => check_expr(ctx, hir, inner, tainted),
266 ExprKind::Err(_) => {}
267 }
268}
269
270fn update_multi_decl_taint(
271 hir: &Hir<'_>,
272 vars: &[Option<VariableId>],
273 expr: &Expr<'_>,
274 tainted: &mut HashSet<VariableId>,
275) {
276 if let ExprKind::Tuple(exprs) = &expr.peel_parens().kind
277 && exprs.len() == vars.len()
278 {
279 let rhs_taints: Vec<_> = exprs
280 .iter()
281 .map(|expr| expr.is_some_and(|expr| expr_value_is_division_or_tainted(expr, tainted)))
282 .collect();
283 for (var_id, rhs_tainted) in vars.iter().zip(rhs_taints) {
284 if let Some(var_id) = var_id {
285 update_taint(hir, *var_id, rhs_tainted, tainted);
286 }
287 }
288 return;
289 }
290
291 let rhs_tainted = expr_value_is_division_or_tainted(expr, tainted);
292 for var_id in vars.iter().flatten() {
293 update_taint(hir, *var_id, rhs_tainted, tainted);
294 }
295}
296
297fn update_assignment_taint(
298 hir: &Hir<'_>,
299 lhs: &Expr<'_>,
300 rhs: &Expr<'_>,
301 tainted: &mut HashSet<VariableId>,
302) {
303 if let (ExprKind::Tuple(lhs_exprs), ExprKind::Tuple(rhs_exprs)) =
304 (&lhs.peel_parens().kind, &rhs.peel_parens().kind)
305 && lhs_exprs.len() == rhs_exprs.len()
306 {
307 let rhs_taints: Vec<_> = rhs_exprs
308 .iter()
309 .map(|rhs| rhs.is_some_and(|rhs| expr_value_is_division_or_tainted(rhs, tainted)))
310 .collect();
311 for (lhs, rhs_tainted) in lhs_exprs.iter().zip(rhs_taints) {
312 if let Some(lhs) = lhs {
313 update_lhs_taint(hir, lhs, rhs_tainted, tainted);
314 }
315 }
316 return;
317 }
318
319 update_lhs_taint(hir, lhs, expr_value_is_division_or_tainted(rhs, tainted), tainted);
320}
321
322fn union_taints(left: &HashSet<VariableId>, right: &HashSet<VariableId>) -> HashSet<VariableId> {
323 left.union(right).copied().collect()
324}
325
326fn update_lhs_taint(
327 hir: &Hir<'_>,
328 lhs: &Expr<'_>,
329 is_tainted: bool,
330 tainted: &mut HashSet<VariableId>,
331) {
332 match &lhs.peel_parens().kind {
333 ExprKind::Ident(resolutions) => {
334 for res in *resolutions {
335 if let Res::Item(ItemId::Variable(var_id)) = res {
336 update_taint(hir, *var_id, is_tainted, tainted);
337 }
338 }
339 }
340 ExprKind::Tuple(exprs) => {
341 for expr in exprs.iter().flatten() {
342 update_lhs_taint(hir, expr, is_tainted, tainted);
343 }
344 }
345 _ => {}
346 }
347}
348
349fn update_taint(
350 hir: &Hir<'_>,
351 var_id: VariableId,
352 is_tainted: bool,
353 tainted: &mut HashSet<VariableId>,
354) {
355 if !hir.variable(var_id).is_local_or_return() {
356 return;
357 }
358 if is_tainted {
359 tainted.insert(var_id);
360 } else {
361 tainted.remove(&var_id);
362 }
363}
364
365fn expr_value_is_division_or_tainted(expr: &Expr<'_>, tainted: &HashSet<VariableId>) -> bool {
366 match &expr.peel_parens().kind {
367 ExprKind::Binary(_, op, _) => op.kind == BinOpKind::Div,
368 ExprKind::Ident(resolutions) => resolutions.iter().any(
369 |res| matches!(res, Res::Item(ItemId::Variable(var_id)) if tainted.contains(var_id)),
370 ),
371 ExprKind::Call(..) => is_yul_division_call(expr),
372 ExprKind::Tuple([Some(inner)]) => expr_value_is_division_or_tainted(inner, tainted),
373 ExprKind::YulMember(inner, _) => expr_value_is_division_or_tainted(inner, tainted),
374 ExprKind::Array(_)
375 | ExprKind::Assign(..)
376 | ExprKind::Delete(_)
377 | ExprKind::Index(..)
378 | ExprKind::Lit(_)
379 | ExprKind::Member(_, _)
380 | ExprKind::New(_)
381 | ExprKind::Payable(_)
382 | ExprKind::Slice(..)
383 | ExprKind::Ternary(..)
384 | ExprKind::TypeCall(_)
385 | ExprKind::Type(_)
386 | ExprKind::Unary(_, _)
387 | ExprKind::Tuple(_) => false,
388 ExprKind::Err(_) => false,
389 }
390}
391
392fn expr_is_division_result_or_tainted(expr: &Expr<'_>, tainted: &HashSet<VariableId>) -> bool {
393 match &expr.peel_parens().kind {
394 ExprKind::Binary(_, op, _) => op.kind == BinOpKind::Div,
395 ExprKind::Call(..) => is_yul_division_call(expr),
396 ExprKind::Ident(resolutions) => resolutions.iter().any(
397 |res| matches!(res, Res::Item(ItemId::Variable(var_id)) if tainted.contains(var_id)),
398 ),
399 ExprKind::Tuple([Some(inner)]) => expr_is_division_result_or_tainted(inner, tainted),
400 _ => false,
401 }
402}
403
404fn is_yul_division_call(expr: &Expr<'_>) -> bool {
405 is_yul_builtin_call(expr, |builtin| matches!(builtin, Builtin::YulDiv | Builtin::YulSdiv))
406}
407
408fn is_yul_multiplication_call(expr: &Expr<'_>) -> bool {
409 is_yul_builtin_call(expr, |builtin| matches!(builtin, Builtin::YulMul))
410}
411
412fn is_revert_call(expr: &Expr<'_>) -> bool {
413 let ExprKind::Call(callee, _, _) = &expr.peel_parens().kind else { return false };
414 let ExprKind::Ident(resolutions) = &callee.peel_parens().kind else { return false };
415 resolutions.iter().any(|res| matches!(res, Res::Builtin(Builtin::Revert | Builtin::RevertMsg)))
416}
417
418const fn is_inc_dec_op(kind: UnOpKind) -> bool {
419 matches!(kind, UnOpKind::PreInc | UnOpKind::PostInc | UnOpKind::PreDec | UnOpKind::PostDec)
420}
421
422fn is_yul_builtin_call(expr: &Expr<'_>, predicate: impl Fn(Builtin) -> bool) -> bool {
423 let ExprKind::Call(callee, args, _) = &expr.peel_parens().kind else { return false };
424 if args.len() != 2 {
425 return false;
426 }
427 let ExprKind::Ident(resolutions) = &callee.peel_parens().kind else { return false };
428 resolutions.iter().any(|res| matches!(res, Res::Builtin(builtin) if predicate(*builtin)))
429}