Skip to main content

forge_lint/sol/high/
reentrancy.rs

1use super::ReentrancyUnlimitedGas;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    ast::{LitKind, StateMutability, UnOpKind, Visibility},
8    interface::{Span, kw, sym},
9    sema::hir::{self, ExprKind, FunctionId, ItemId, Res, StmtKind, VariableId},
10};
11use std::collections::{BTreeSet, HashSet};
12
13declare_forge_lint!(
14    REENTRANCY_UNLIMITED_GAS,
15    Severity::High,
16    "reentrancy-unlimited-gas",
17    "state read before uncapped ETH transfer is written after the transfer"
18);
19
20impl<'hir> LateLintPass<'hir> for ReentrancyUnlimitedGas {
21    fn check_function(
22        &mut self,
23        ctx: &LintContext,
24        hir: &'hir hir::Hir<'hir>,
25        func: &'hir hir::Function<'hir>,
26    ) {
27        if !is_entry_point(func) {
28            return;
29        }
30
31        let Some(body) = func.body else { return };
32
33        let mut analyzer = Analyzer::new(ctx, hir);
34        let mut state = FlowState::default();
35        analyzer.analyze_callable(func, body, &mut state);
36    }
37}
38
39fn is_entry_point(func: &hir::Function<'_>) -> bool {
40    if matches!(func.state_mutability, StateMutability::Pure | StateMutability::View) {
41        return false;
42    }
43    if func.is_constructor() {
44        return false;
45    }
46    if func.is_special() {
47        return true;
48    }
49    func.kind.is_function() && matches!(func.visibility, Visibility::Public | Visibility::External)
50}
51
52#[derive(Clone, Debug, Default)]
53struct FlowState {
54    state_reads: BTreeSet<VariableId>,
55    pending_value_calls: Vec<PendingValueCall>,
56}
57
58#[derive(Clone, Debug)]
59struct PendingValueCall {
60    span: Span,
61    state_reads: BTreeSet<VariableId>,
62}
63
64impl FlowState {
65    fn push_read(&mut self, var_id: VariableId) {
66        self.state_reads.insert(var_id);
67    }
68
69    fn push_call(&mut self, span: Span) {
70        if self.state_reads.is_empty() {
71            return;
72        }
73
74        if let Some(existing) = self.pending_value_calls.iter_mut().find(|call| call.span == span) {
75            existing.state_reads.extend(self.state_reads.iter().copied());
76        } else {
77            self.pending_value_calls
78                .push(PendingValueCall { span, state_reads: self.state_reads.clone() });
79        }
80    }
81}
82
83struct Analyzer<'ctx, 's, 'c, 'hir> {
84    ctx: &'ctx LintContext<'s, 'c>,
85    hir: &'hir hir::Hir<'hir>,
86    emitted: HashSet<Span>,
87    call_stack: Vec<FunctionId>,
88}
89
90impl<'ctx, 's, 'c, 'hir> Analyzer<'ctx, 's, 'c, 'hir> {
91    fn new(ctx: &'ctx LintContext<'s, 'c>, hir: &'hir hir::Hir<'hir>) -> Self {
92        Self { ctx, hir, emitted: HashSet::new(), call_stack: Vec::new() }
93    }
94
95    fn analyze_callable(
96        &mut self,
97        func: &'hir hir::Function<'hir>,
98        body: hir::Block<'hir>,
99        state: &mut FlowState,
100    ) -> bool {
101        self.analyze_modifier_chain(func.modifiers, 0, body, state)
102    }
103
104    fn analyze_modifier_chain(
105        &mut self,
106        modifiers: &'hir [hir::Modifier<'hir>],
107        index: usize,
108        body: hir::Block<'hir>,
109        state: &mut FlowState,
110    ) -> bool {
111        let Some(modifier) = modifiers.get(index) else {
112            return self.analyze_block(body, None, state);
113        };
114
115        for arg in modifier.args.exprs() {
116            self.analyze_expr(arg, state);
117        }
118
119        let Some(modifier_id) = modifier.id.as_function() else {
120            return self.analyze_modifier_chain(modifiers, index + 1, body, state);
121        };
122
123        if self.call_stack.contains(&modifier_id) {
124            return self.analyze_modifier_chain(modifiers, index + 1, body, state);
125        }
126
127        let modifier_func = self.hir.function(modifier_id);
128        let Some(modifier_body) = modifier_func.body else {
129            return self.analyze_modifier_chain(modifiers, index + 1, body, state);
130        };
131
132        self.call_stack.push(modifier_id);
133        let falls_through =
134            self.analyze_block(modifier_body, Some((modifiers, index + 1, body)), state);
135        self.call_stack.pop();
136        falls_through
137    }
138
139    fn analyze_block(
140        &mut self,
141        block: hir::Block<'hir>,
142        placeholder: Option<(&'hir [hir::Modifier<'hir>], usize, hir::Block<'hir>)>,
143        state: &mut FlowState,
144    ) -> bool {
145        for stmt in block.stmts {
146            if !self.analyze_stmt(stmt, placeholder, state) {
147                return false;
148            }
149        }
150        true
151    }
152
153    fn analyze_stmt(
154        &mut self,
155        stmt: &'hir hir::Stmt<'hir>,
156        placeholder: Option<(&'hir [hir::Modifier<'hir>], usize, hir::Block<'hir>)>,
157        state: &mut FlowState,
158    ) -> bool {
159        match stmt.kind {
160            StmtKind::DeclSingle(var_id) => {
161                if let Some(init) = self.hir.variable(var_id).initializer {
162                    self.analyze_expr(init, state);
163                }
164                true
165            }
166            StmtKind::DeclMulti(_, expr) | StmtKind::Expr(expr) => {
167                self.analyze_expr(expr, state);
168                true
169            }
170            StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
171                self.analyze_block(block, placeholder, state)
172            }
173            StmtKind::Emit(expr) => {
174                self.analyze_expr(expr, state);
175                true
176            }
177            StmtKind::Revert(expr) => {
178                self.analyze_expr(expr, state);
179                false
180            }
181            StmtKind::Return(expr) => {
182                if let Some(expr) = expr {
183                    self.analyze_expr(expr, state);
184                }
185                false
186            }
187            StmtKind::Break | StmtKind::Continue => false,
188            StmtKind::Loop(block, _) => {
189                let before_loop = state.clone();
190                let mut body_state = state.clone();
191                self.analyze_block(block, placeholder, &mut body_state);
192                state.clear();
193                state.merge(&before_loop);
194                state.merge(&body_state);
195                true
196            }
197            StmtKind::If(cond, then_stmt, else_stmt) => {
198                self.analyze_expr(cond, state);
199
200                let mut then_state = state.clone();
201                let then_falls_through = self.analyze_stmt(then_stmt, placeholder, &mut then_state);
202
203                let mut else_state = state.clone();
204                let else_falls_through = if let Some(else_stmt) = else_stmt {
205                    self.analyze_stmt(else_stmt, placeholder, &mut else_state)
206                } else {
207                    true
208                };
209
210                state.clear();
211                if then_falls_through {
212                    state.merge(&then_state);
213                }
214                if else_falls_through {
215                    state.merge(&else_state);
216                }
217
218                then_falls_through || else_falls_through
219            }
220            StmtKind::Try(try_stmt) => {
221                self.analyze_expr(&try_stmt.expr, state);
222
223                let mut merged = FlowState::default();
224                let mut any_falls_through = false;
225                for clause in try_stmt.clauses {
226                    let mut clause_state = state.clone();
227                    let falls_through =
228                        self.analyze_block(clause.block, placeholder, &mut clause_state);
229                    if falls_through {
230                        merged.merge(&clause_state);
231                        any_falls_through = true;
232                    }
233                }
234
235                *state = merged;
236                any_falls_through
237            }
238            StmtKind::Placeholder => {
239                if let Some((modifiers, index, body)) = placeholder {
240                    self.analyze_modifier_chain(modifiers, index, body, state)
241                } else {
242                    true
243                }
244            }
245            StmtKind::Err(_) => true,
246        }
247    }
248
249    fn analyze_expr(&mut self, expr: &'hir hir::Expr<'hir>, state: &mut FlowState) {
250        match &expr.kind {
251            ExprKind::Assign(lhs, op, rhs) => {
252                if op.is_some() {
253                    self.analyze_expr(lhs, state);
254                }
255                self.analyze_expr(rhs, state);
256                let written_vars = state_write_lhs_vars(self.hir, lhs);
257                if !written_vars.is_empty() {
258                    self.emit_pending_calls(state, &written_vars);
259                }
260                self.analyze_lhs_indices(lhs, state);
261            }
262            ExprKind::Delete(inner) => {
263                let written_vars = state_write_lhs_vars(self.hir, inner);
264                if !written_vars.is_empty() {
265                    self.emit_pending_calls(state, &written_vars);
266                }
267                self.analyze_lhs_indices(inner, state);
268            }
269            ExprKind::Unary(op, inner)
270                if matches!(
271                    op.kind,
272                    UnOpKind::PreInc | UnOpKind::PreDec | UnOpKind::PostInc | UnOpKind::PostDec
273                ) =>
274            {
275                self.analyze_expr(inner, state);
276                let written_vars = state_write_lhs_vars(self.hir, inner);
277                if !written_vars.is_empty() {
278                    self.emit_pending_calls(state, &written_vars);
279                }
280            }
281            ExprKind::Unary(_, inner) => {
282                self.analyze_expr(inner, state);
283            }
284            ExprKind::Call(callee, args, opts) => {
285                self.analyze_expr(callee, state);
286                if let Some(opts) = opts {
287                    for opt in *opts {
288                        self.analyze_expr(&opt.value, state);
289                    }
290                }
291                for arg in args.exprs() {
292                    self.analyze_expr(arg, state);
293                }
294
295                for func_id in resolved_function_ids(callee) {
296                    self.analyze_internal_call(func_id, state);
297                }
298                if is_uncapped_value_call(callee, *opts) {
299                    state.push_call(expr.span);
300                }
301            }
302            ExprKind::Binary(lhs, _, rhs) => {
303                self.analyze_expr(lhs, state);
304                self.analyze_expr(rhs, state);
305            }
306            ExprKind::Index(base, index) => {
307                self.analyze_expr(base, state);
308                if let Some(index) = index {
309                    self.analyze_expr(index, state);
310                }
311            }
312            ExprKind::Slice(base, start, end) => {
313                self.analyze_expr(base, state);
314                if let Some(start) = start {
315                    self.analyze_expr(start, state);
316                }
317                if let Some(end) = end {
318                    self.analyze_expr(end, state);
319                }
320            }
321            ExprKind::Ternary(cond, true_expr, false_expr) => {
322                self.analyze_expr(cond, state);
323
324                let mut true_state = state.clone();
325                self.analyze_expr(true_expr, &mut true_state);
326
327                let mut false_state = state.clone();
328                self.analyze_expr(false_expr, &mut false_state);
329
330                state.clear();
331                state.merge(&true_state);
332                state.merge(&false_state);
333            }
334            ExprKind::Array(exprs) => {
335                for expr in *exprs {
336                    self.analyze_expr(expr, state);
337                }
338            }
339            ExprKind::Tuple(exprs) => {
340                for expr in exprs.iter().copied().flatten() {
341                    self.analyze_expr(expr, state);
342                }
343            }
344            ExprKind::Member(base, _) | ExprKind::Payable(base) => {
345                self.analyze_expr(base, state);
346            }
347            ExprKind::New(_) | ExprKind::TypeCall(_) | ExprKind::Type(_) => {}
348            ExprKind::Ident(reses) => {
349                for &res in *reses {
350                    if let Res::Item(ItemId::Variable(var_id)) = res
351                        && self.hir.variable(var_id).kind.is_state()
352                    {
353                        state.push_read(var_id);
354                    }
355                }
356            }
357            ExprKind::Lit(_) | ExprKind::Err(_) => {}
358        }
359    }
360
361    fn analyze_internal_call(&mut self, func_id: FunctionId, state: &mut FlowState) {
362        if self.call_stack.contains(&func_id) {
363            return;
364        }
365
366        let func = self.hir.function(func_id);
367        let Some(body) = func.body else { return };
368
369        self.call_stack.push(func_id);
370        self.analyze_callable(func, body, state);
371        self.call_stack.pop();
372    }
373
374    fn analyze_lhs_indices(&mut self, expr: &'hir hir::Expr<'hir>, state: &mut FlowState) {
375        match &expr.kind {
376            ExprKind::Index(base, index) => {
377                self.analyze_lhs_indices(base, state);
378                if let Some(index) = index {
379                    self.analyze_expr(index, state);
380                }
381            }
382            ExprKind::Slice(base, start, end) => {
383                self.analyze_lhs_indices(base, state);
384                if let Some(start) = start {
385                    self.analyze_expr(start, state);
386                }
387                if let Some(end) = end {
388                    self.analyze_expr(end, state);
389                }
390            }
391            ExprKind::Member(base, _) | ExprKind::Payable(base) => {
392                self.analyze_lhs_indices(base, state);
393            }
394            ExprKind::Tuple(exprs) => {
395                for expr in exprs.iter().copied().flatten() {
396                    self.analyze_lhs_indices(expr, state);
397                }
398            }
399            _ => {}
400        }
401    }
402
403    fn emit_pending_calls(&mut self, state: &FlowState, written_vars: &[VariableId]) {
404        for call in &state.pending_value_calls {
405            if self.emitted.contains(&call.span) {
406                continue;
407            }
408
409            if let Some(var_id) =
410                written_vars.iter().find(|&&var_id| call.state_reads.contains(&var_id))
411            {
412                let name = self
413                    .hir
414                    .variable(*var_id)
415                    .name
416                    .map(|name| name.as_str().to_string())
417                    .unwrap_or_else(|| "state".to_string());
418                self.ctx.emit_with_msg(
419                    &REENTRANCY_UNLIMITED_GAS,
420                    call.span,
421                    format!("uncapped ETH transfer can be reentered before `{name}` is updated"),
422                );
423                self.emitted.insert(call.span);
424            }
425        }
426    }
427}
428
429impl FlowState {
430    fn clear(&mut self) {
431        self.state_reads.clear();
432        self.pending_value_calls.clear();
433    }
434
435    fn merge(&mut self, other: &Self) {
436        self.state_reads.extend(other.state_reads.iter().copied());
437        for call in &other.pending_value_calls {
438            if let Some(existing) =
439                self.pending_value_calls.iter_mut().find(|existing| existing.span == call.span)
440            {
441                existing.state_reads.extend(call.state_reads.iter().copied());
442            } else {
443                self.pending_value_calls.push(call.clone());
444            }
445        }
446    }
447}
448
449fn is_uncapped_value_call(callee: &hir::Expr<'_>, opts: Option<&[hir::NamedArg<'_>]>) -> bool {
450    let Some(opts) = opts else { return false };
451    let ExprKind::Member(_, member) = &callee.kind else { return false };
452    if member.name != kw::Call {
453        return false;
454    }
455
456    let mut value = None;
457    let mut gas = None;
458    for opt in opts {
459        if opt.name.name == sym::value {
460            value = Some(&opt.value);
461        } else if opt.name.name == kw::Gas {
462            gas = Some(&opt.value);
463        }
464    }
465
466    value.is_some_and(|value| !is_zero_literal(value)) && gas.is_none_or(gas_option_forwards_all)
467}
468
469fn is_zero_literal(expr: &hir::Expr<'_>) -> bool {
470    matches!(
471        &expr.peel_parens().kind,
472        ExprKind::Lit(lit) if matches!(lit.kind, LitKind::Number(value) if value.is_zero())
473    )
474}
475
476fn gas_option_forwards_all(expr: &hir::Expr<'_>) -> bool {
477    let ExprKind::Call(callee, args, opts) = &expr.peel_parens().kind else {
478        return false;
479    };
480    if opts.is_some() || args.exprs().next().is_some() {
481        return false;
482    }
483    matches!(
484        &callee.peel_parens().kind,
485        ExprKind::Ident(reses)
486            if reses.iter().any(|res| {
487                matches!(res, Res::Builtin(builtin) if builtin.name() == sym::gasleft)
488            })
489    )
490}
491
492fn resolved_function_ids<'hir>(
493    callee: &'hir hir::Expr<'hir>,
494) -> impl Iterator<Item = FunctionId> + 'hir {
495    let reses = match &callee.peel_parens().kind {
496        ExprKind::Ident(reses) => *reses,
497        _ => &[],
498    };
499    reses.iter().filter_map(|res| match res {
500        Res::Item(ItemId::Function(func_id)) => Some(*func_id),
501        _ => None,
502    })
503}
504
505fn state_write_lhs_vars(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> Vec<VariableId> {
506    let mut vars = Vec::new();
507    collect_state_write_lhs_vars(hir, expr, &mut vars);
508    vars
509}
510
511fn collect_state_write_lhs_vars(
512    hir: &hir::Hir<'_>,
513    expr: &hir::Expr<'_>,
514    vars: &mut Vec<VariableId>,
515) {
516    match &expr.kind {
517        ExprKind::Ident(reses) => {
518            for &res in *reses {
519                if let Res::Item(ItemId::Variable(var_id)) = res
520                    && hir.variable(var_id).kind.is_state()
521                {
522                    push_unique(vars, var_id);
523                }
524            }
525        }
526        ExprKind::Index(base, _) | ExprKind::Slice(base, ..) => {
527            collect_state_write_lhs_vars(hir, base, vars);
528        }
529        ExprKind::Member(base, _)
530        | ExprKind::Payable(base)
531        | ExprKind::Unary(_, base)
532        | ExprKind::Delete(base) => collect_state_write_lhs_vars(hir, base, vars),
533        ExprKind::Tuple(exprs) => {
534            for expr in exprs.iter().copied().flatten() {
535                collect_state_write_lhs_vars(hir, expr, vars);
536            }
537        }
538        _ => {}
539    }
540}
541
542fn push_unique<T: Copy + Eq>(items: &mut Vec<T>, item: T) {
543    if !items.contains(&item) {
544        items.push(item);
545    }
546}