Skip to main content

forge_lint/sol/high/
reentrancy.rs

1use super::ReentrancyEth;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{
5        Severity, SolLint,
6        analysis::helper_cache::{DEFAULT_HELPER_ANALYSIS_CACHE_LIMIT, HelperAnalysisCache},
7    },
8};
9use solar::{
10    ast::{
11        BinOpKind, DataLocation, ElementaryType, LitKind, StateMutability, StrKind, TypeSize,
12        UnOpKind, Visibility,
13    },
14    interface::{Span, kw, sym},
15    sema::{
16        Gcx, Ty,
17        hir::{
18            self, CallArgs, CallArgsKind, ExprKind, FunctionId, ItemId, Res, StmtKind, VariableId,
19        },
20        ty::{TyFnKind, TyKind},
21    },
22};
23use std::collections::{BTreeSet, HashMap, HashSet};
24
25declare_forge_lint!(
26    REENTRANCY_ETH,
27    Severity::High,
28    "reentrancy-eth",
29    "state read before ETH transfer is written after the transfer"
30);
31
32declare_forge_lint!(
33    REENTRANCY_NO_ETH,
34    Severity::Med,
35    "reentrancy-no-eth",
36    "state read before external call is written after the call"
37);
38
39impl<'hir> LateLintPass<'hir> for ReentrancyEth {
40    fn check_function(
41        &mut self,
42        ctx: &LintContext,
43        gcx: Gcx<'hir>,
44        hir: &'hir hir::Hir<'hir>,
45        func: &'hir hir::Function<'hir>,
46    ) {
47        if !is_entry_point(func) {
48            return;
49        }
50
51        let Some(body) = func.body else { return };
52
53        let mut analyzer = Analyzer::new(ctx, gcx, hir);
54        if !analyzer.has_enabled_lints() {
55            return;
56        }
57        let mut state = FlowState::default();
58        analyzer.analyze_callable(func, body, &mut state);
59    }
60}
61
62fn is_entry_point(func: &hir::Function<'_>) -> bool {
63    if matches!(func.state_mutability, StateMutability::Pure | StateMutability::View) {
64        return false;
65    }
66    if func.is_constructor() {
67        return false;
68    }
69    if func.is_special() {
70        return true;
71    }
72    func.kind.is_function() && matches!(func.visibility, Visibility::Public | Visibility::External)
73}
74
75#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
76struct FlowState {
77    state_reads: BTreeSet<VariableId>,
78    pending_calls: Vec<PendingCall>,
79}
80
81#[derive(Clone, Debug, PartialEq, Eq, Hash)]
82struct PendingCall {
83    span: Span,
84    kind: ReentrantCallKind,
85    state_reads: BTreeSet<VariableId>,
86}
87
88#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
89enum ReentrantCallKind {
90    Eth,
91    NoEth,
92}
93
94impl FlowState {
95    fn push_read(&mut self, var_id: VariableId) {
96        self.state_reads.insert(var_id);
97    }
98
99    fn push_call(&mut self, span: Span, kind: ReentrantCallKind) {
100        if self.state_reads.is_empty() {
101            return;
102        }
103
104        if let Some(existing) =
105            self.pending_calls.iter_mut().find(|call| call.span == span && call.kind == kind)
106        {
107            existing.state_reads.extend(self.state_reads.iter().copied());
108        } else {
109            self.pending_calls.push(PendingCall {
110                span,
111                kind,
112                state_reads: self.state_reads.clone(),
113            });
114        }
115    }
116}
117
118struct Analyzer<'ctx, 's, 'c, 'hir> {
119    ctx: &'ctx LintContext<'s, 'c>,
120    gcx: Gcx<'hir>,
121    hir: &'hir hir::Hir<'hir>,
122    emitted: HashSet<Span>,
123    call_stack: Vec<FunctionId>,
124    inline_cache: HelperAnalysisCache<InlineCallKey, FlowState>,
125    recursive_cut_frontiers: HashMap<RecursiveFrontierKey, Vec<FunctionId>>,
126    direct_internal_calls: HashMap<FunctionId, Vec<FunctionId>>,
127    reentrancy_eth_enabled: bool,
128    reentrancy_no_eth_enabled: bool,
129}
130
131#[derive(Clone, Debug, PartialEq, Eq, Hash)]
132struct InlineCallKey {
133    func_id: FunctionId,
134    /// First active function that can cut recursion from this callee.
135    recursive_cut: Option<FunctionId>,
136    state: FlowState,
137}
138
139#[derive(Clone, Debug, PartialEq, Eq, Hash)]
140struct RecursiveFrontierKey {
141    func_id: FunctionId,
142    active_call_stack: Vec<FunctionId>,
143}
144
145impl<'ctx, 's, 'c, 'hir> Analyzer<'ctx, 's, 'c, 'hir> {
146    fn new(ctx: &'ctx LintContext<'s, 'c>, gcx: Gcx<'hir>, hir: &'hir hir::Hir<'hir>) -> Self {
147        Self {
148            ctx,
149            gcx,
150            hir,
151            emitted: HashSet::new(),
152            call_stack: Vec::new(),
153            inline_cache: HelperAnalysisCache::new(DEFAULT_HELPER_ANALYSIS_CACHE_LIMIT),
154            recursive_cut_frontiers: HashMap::new(),
155            direct_internal_calls: HashMap::new(),
156            reentrancy_eth_enabled: ctx.is_lint_enabled(REENTRANCY_ETH.id),
157            reentrancy_no_eth_enabled: ctx.is_lint_enabled(REENTRANCY_NO_ETH.id),
158        }
159    }
160
161    const fn has_enabled_lints(&self) -> bool {
162        self.reentrancy_eth_enabled || self.reentrancy_no_eth_enabled
163    }
164
165    fn analyze_callable(
166        &mut self,
167        func: &'hir hir::Function<'hir>,
168        body: hir::Block<'hir>,
169        state: &mut FlowState,
170    ) -> bool {
171        self.analyze_modifier_chain(func.modifiers, 0, body, state)
172    }
173
174    fn analyze_modifier_chain(
175        &mut self,
176        modifiers: &'hir [hir::Modifier<'hir>],
177        index: usize,
178        body: hir::Block<'hir>,
179        state: &mut FlowState,
180    ) -> bool {
181        let Some(modifier) = modifiers.get(index) else {
182            return self.analyze_block(body, None, state);
183        };
184
185        for arg in modifier.args.exprs() {
186            self.analyze_expr(arg, state);
187        }
188
189        let Some(modifier_id) = modifier.id.as_function() else {
190            return self.analyze_modifier_chain(modifiers, index + 1, body, state);
191        };
192
193        if self.call_stack.contains(&modifier_id) {
194            return self.analyze_modifier_chain(modifiers, index + 1, body, state);
195        }
196
197        let modifier_func = self.hir.function(modifier_id);
198        let Some(modifier_body) = modifier_func.body else {
199            return self.analyze_modifier_chain(modifiers, index + 1, body, state);
200        };
201
202        self.call_stack.push(modifier_id);
203        let falls_through =
204            self.analyze_block(modifier_body, Some((modifiers, index + 1, body)), state);
205        self.call_stack.pop();
206        falls_through
207    }
208
209    fn analyze_block(
210        &mut self,
211        block: hir::Block<'hir>,
212        placeholder: Option<(&'hir [hir::Modifier<'hir>], usize, hir::Block<'hir>)>,
213        state: &mut FlowState,
214    ) -> bool {
215        for stmt in block.stmts {
216            if !self.analyze_stmt(stmt, placeholder, state) {
217                return false;
218            }
219        }
220        true
221    }
222
223    fn analyze_stmt(
224        &mut self,
225        stmt: &'hir hir::Stmt<'hir>,
226        placeholder: Option<(&'hir [hir::Modifier<'hir>], usize, hir::Block<'hir>)>,
227        state: &mut FlowState,
228    ) -> bool {
229        match stmt.kind {
230            StmtKind::DeclSingle(var_id) => {
231                if let Some(init) = self.hir.variable(var_id).initializer {
232                    self.analyze_expr(init, state);
233                }
234                true
235            }
236            StmtKind::DeclMulti(_, expr) | StmtKind::Expr(expr) => {
237                self.analyze_expr(expr, state);
238                true
239            }
240            StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
241                self.analyze_block(block, placeholder, state)
242            }
243            StmtKind::Emit(expr) => {
244                self.analyze_expr(expr, state);
245                true
246            }
247            StmtKind::Revert(expr) => {
248                self.analyze_expr(expr, state);
249                false
250            }
251            StmtKind::Return(expr) => {
252                if let Some(expr) = expr {
253                    self.analyze_expr(expr, state);
254                }
255                false
256            }
257            StmtKind::Break | StmtKind::Continue => false,
258            StmtKind::Loop(block, _) => {
259                let before_loop = state.clone();
260                let mut body_state = state.clone();
261                self.analyze_block(block, placeholder, &mut body_state);
262                state.clear();
263                state.merge(&before_loop);
264                state.merge(&body_state);
265                true
266            }
267            StmtKind::If(cond, then_stmt, else_stmt) => {
268                self.analyze_expr(cond, state);
269
270                let mut then_state = state.clone();
271                let then_falls_through = self.analyze_stmt(then_stmt, placeholder, &mut then_state);
272
273                let mut else_state = state.clone();
274                let else_falls_through = if let Some(else_stmt) = else_stmt {
275                    self.analyze_stmt(else_stmt, placeholder, &mut else_state)
276                } else {
277                    true
278                };
279
280                state.clear();
281                if then_falls_through {
282                    state.merge(&then_state);
283                }
284                if else_falls_through {
285                    state.merge(&else_state);
286                }
287
288                then_falls_through || else_falls_through
289            }
290            StmtKind::Try(try_stmt) => {
291                self.analyze_expr(&try_stmt.expr, state);
292
293                let mut merged = FlowState::default();
294                let mut any_falls_through = false;
295                for clause in try_stmt.clauses {
296                    let mut clause_state = state.clone();
297                    let falls_through =
298                        self.analyze_block(clause.block, placeholder, &mut clause_state);
299                    if falls_through {
300                        merged.merge(&clause_state);
301                        any_falls_through = true;
302                    }
303                }
304
305                *state = merged;
306                any_falls_through
307            }
308            StmtKind::Placeholder => {
309                if let Some((modifiers, index, body)) = placeholder {
310                    self.analyze_modifier_chain(modifiers, index, body, state)
311                } else {
312                    true
313                }
314            }
315            StmtKind::AssemblyBlock(_) | StmtKind::Switch(_) | StmtKind::Err(_) => true,
316        }
317    }
318
319    fn analyze_expr(&mut self, expr: &'hir hir::Expr<'hir>, state: &mut FlowState) {
320        match &expr.kind {
321            ExprKind::Assign(lhs, op, rhs) => {
322                if op.is_some() {
323                    self.analyze_expr(lhs, state);
324                }
325                self.analyze_expr(rhs, state);
326                self.analyze_lhs_indices(lhs, state);
327                let written_vars = state_write_lhs_vars(self.hir, lhs);
328                if !written_vars.is_empty() {
329                    self.emit_pending_calls(state, &written_vars);
330                }
331            }
332            ExprKind::Delete(inner) => {
333                self.analyze_lhs_indices(inner, state);
334                let written_vars = state_write_lhs_vars(self.hir, inner);
335                if !written_vars.is_empty() {
336                    self.emit_pending_calls(state, &written_vars);
337                }
338            }
339            ExprKind::Unary(op, inner)
340                if matches!(
341                    op.kind,
342                    UnOpKind::PreInc | UnOpKind::PreDec | UnOpKind::PostInc | UnOpKind::PostDec
343                ) =>
344            {
345                self.analyze_expr(inner, state);
346                let written_vars = state_write_lhs_vars(self.hir, inner);
347                if !written_vars.is_empty() {
348                    self.emit_pending_calls(state, &written_vars);
349                }
350            }
351            ExprKind::Unary(_, inner) => {
352                self.analyze_expr(inner, state);
353            }
354            ExprKind::Call(callee, args, opts) => {
355                self.analyze_expr(callee, state);
356                if let Some(opts) = opts {
357                    for opt in opts.args {
358                        self.analyze_expr(&opt.value, state);
359                    }
360                }
361                for arg in args.exprs() {
362                    self.analyze_expr(arg, state);
363                }
364
365                for func_id in resolved_function_ids(callee) {
366                    self.analyze_internal_call(func_id, state);
367                }
368                if !state.state_reads.is_empty()
369                    && let Some(kind) = self.reentrant_call_kind(callee, args, *opts)
370                {
371                    state.push_call(expr.span, kind);
372                }
373            }
374            ExprKind::Binary(lhs, _, rhs) => {
375                self.analyze_expr(lhs, state);
376                self.analyze_expr(rhs, state);
377            }
378            ExprKind::Index(base, index) => {
379                self.analyze_expr(base, state);
380                if let Some(index) = index {
381                    self.analyze_expr(index, state);
382                }
383            }
384            ExprKind::Slice(base, start, end) => {
385                self.analyze_expr(base, state);
386                if let Some(start) = start {
387                    self.analyze_expr(start, state);
388                }
389                if let Some(end) = end {
390                    self.analyze_expr(end, state);
391                }
392            }
393            ExprKind::Ternary(cond, true_expr, false_expr) => {
394                self.analyze_expr(cond, state);
395
396                let mut true_state = state.clone();
397                self.analyze_expr(true_expr, &mut true_state);
398
399                let mut false_state = state.clone();
400                self.analyze_expr(false_expr, &mut false_state);
401
402                state.clear();
403                state.merge(&true_state);
404                state.merge(&false_state);
405            }
406            ExprKind::Array(exprs) => {
407                for expr in *exprs {
408                    self.analyze_expr(expr, state);
409                }
410            }
411            ExprKind::Tuple(exprs) => {
412                for expr in exprs.iter().copied().flatten() {
413                    self.analyze_expr(expr, state);
414                }
415            }
416            ExprKind::Member(base, _) | ExprKind::Payable(base) => {
417                self.analyze_expr(base, state);
418            }
419            ExprKind::New(_) | ExprKind::TypeCall(_) | ExprKind::Type(_) => {}
420            ExprKind::Ident(reses) => {
421                for &res in *reses {
422                    if let Res::Item(ItemId::Variable(var_id)) = res
423                        && self.hir.variable(var_id).kind.is_state()
424                    {
425                        state.push_read(var_id);
426                    }
427                }
428            }
429            ExprKind::Lit(_) | ExprKind::YulMember(..) | ExprKind::Err(_) => {}
430        }
431    }
432
433    fn analyze_internal_call(&mut self, func_id: FunctionId, state: &mut FlowState) {
434        if self.call_stack.contains(&func_id) {
435            return;
436        }
437
438        let func = self.hir.function(func_id);
439        let Some(body) = func.body else { return };
440
441        let key = InlineCallKey {
442            func_id,
443            recursive_cut: self.first_recursive_cut(func_id),
444            state: state.clone(),
445        };
446        if self.inline_cache.is_in_progress(&key) {
447            return;
448        }
449        if let Some(cached) = self.inline_cache.get(&key) {
450            *state = cached.clone();
451            return;
452        }
453
454        let mut after = state.clone();
455        self.inline_cache.start(key.clone());
456        self.call_stack.push(func_id);
457        self.analyze_callable(func, body, &mut after);
458        self.call_stack.pop();
459
460        self.inline_cache.finish(key, after.clone());
461        *state = after;
462    }
463
464    fn first_recursive_cut(&mut self, func_id: FunctionId) -> Option<FunctionId> {
465        let active_call_stack = self.call_stack.iter().copied().collect::<BTreeSet<_>>();
466        if active_call_stack.is_empty() {
467            return None;
468        }
469
470        let active_call_stack = active_call_stack.into_iter().collect::<Vec<_>>();
471        let key = RecursiveFrontierKey { func_id, active_call_stack };
472        if let Some(frontier) = self.recursive_cut_frontiers.get(&key) {
473            return frontier.first().copied();
474        }
475
476        let active_call_stack = key.active_call_stack.iter().copied().collect::<BTreeSet<_>>();
477        let mut seen = HashSet::new();
478        let cut = self.first_recursive_cut_function(func_id, &active_call_stack, &mut seen);
479        self.recursive_cut_frontiers.insert(key, cut.into_iter().collect::<Vec<_>>());
480        cut
481    }
482
483    fn first_recursive_cut_function(
484        &mut self,
485        func_id: FunctionId,
486        active_call_stack: &BTreeSet<FunctionId>,
487        seen: &mut HashSet<FunctionId>,
488    ) -> Option<FunctionId> {
489        if !seen.insert(func_id) {
490            return None;
491        }
492
493        for callee_id in self.direct_internal_calls(func_id) {
494            if active_call_stack.contains(&callee_id) {
495                return Some(callee_id);
496            }
497            if let Some(cut) = self.first_recursive_cut_function(callee_id, active_call_stack, seen)
498            {
499                return Some(cut);
500            }
501        }
502        None
503    }
504
505    fn direct_internal_calls(&mut self, func_id: FunctionId) -> Vec<FunctionId> {
506        if let Some(calls) = self.direct_internal_calls.get(&func_id) {
507            return calls.clone();
508        }
509
510        let mut calls = BTreeSet::new();
511        let func = self.hir.function(func_id);
512        for modifier in func.modifiers {
513            for arg in modifier.args.exprs() {
514                self.collect_direct_internal_calls_expr(arg, &mut calls);
515            }
516            if let Some(modifier_id) = modifier.id.as_function() {
517                calls.insert(modifier_id);
518            }
519        }
520        if let Some(body) = func.body {
521            self.collect_direct_internal_calls_block(body, &mut calls);
522        }
523
524        let calls = calls.into_iter().collect::<Vec<_>>();
525        self.direct_internal_calls.insert(func_id, calls.clone());
526        calls
527    }
528
529    fn collect_direct_internal_calls_block(
530        &mut self,
531        block: hir::Block<'hir>,
532        calls: &mut BTreeSet<FunctionId>,
533    ) {
534        for stmt in block.stmts {
535            self.collect_direct_internal_calls_stmt(stmt, calls);
536        }
537    }
538
539    fn collect_direct_internal_calls_stmt(
540        &mut self,
541        stmt: &'hir hir::Stmt<'hir>,
542        calls: &mut BTreeSet<FunctionId>,
543    ) {
544        match stmt.kind {
545            StmtKind::DeclSingle(var_id) => {
546                if let Some(init) = self.hir.variable(var_id).initializer {
547                    self.collect_direct_internal_calls_expr(init, calls);
548                }
549            }
550            StmtKind::DeclMulti(_, expr)
551            | StmtKind::Expr(expr)
552            | StmtKind::Emit(expr)
553            | StmtKind::Revert(expr) => {
554                self.collect_direct_internal_calls_expr(expr, calls);
555            }
556            StmtKind::Return(expr) => {
557                if let Some(expr) = expr {
558                    self.collect_direct_internal_calls_expr(expr, calls);
559                }
560            }
561            StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
562                self.collect_direct_internal_calls_block(block, calls);
563            }
564            StmtKind::If(cond, then_stmt, else_stmt) => {
565                self.collect_direct_internal_calls_expr(cond, calls);
566                self.collect_direct_internal_calls_stmt(then_stmt, calls);
567                if let Some(else_stmt) = else_stmt {
568                    self.collect_direct_internal_calls_stmt(else_stmt, calls);
569                }
570            }
571            StmtKind::Try(try_stmt) => {
572                self.collect_direct_internal_calls_expr(&try_stmt.expr, calls);
573                for clause in try_stmt.clauses {
574                    self.collect_direct_internal_calls_block(clause.block, calls);
575                }
576            }
577            StmtKind::Break
578            | StmtKind::Continue
579            | StmtKind::Placeholder
580            | StmtKind::AssemblyBlock(_)
581            | StmtKind::Switch(_)
582            | StmtKind::Err(_) => {}
583        }
584    }
585
586    fn collect_direct_internal_calls_expr(
587        &mut self,
588        expr: &'hir hir::Expr<'hir>,
589        calls: &mut BTreeSet<FunctionId>,
590    ) {
591        match &expr.kind {
592            ExprKind::Assign(lhs, _, rhs) | ExprKind::Binary(lhs, _, rhs) => {
593                self.collect_direct_internal_calls_expr(lhs, calls);
594                self.collect_direct_internal_calls_expr(rhs, calls);
595            }
596            ExprKind::Unary(_, inner)
597            | ExprKind::Delete(inner)
598            | ExprKind::Member(inner, _)
599            | ExprKind::Payable(inner) => {
600                self.collect_direct_internal_calls_expr(inner, calls);
601            }
602            ExprKind::Call(callee, args, opts) => {
603                self.collect_direct_internal_calls_expr(callee, calls);
604                if let Some(opts) = opts {
605                    for opt in opts.args {
606                        self.collect_direct_internal_calls_expr(&opt.value, calls);
607                    }
608                }
609                for arg in args.exprs() {
610                    self.collect_direct_internal_calls_expr(arg, calls);
611                }
612                for func_id in resolved_function_ids(callee) {
613                    calls.insert(func_id);
614                }
615            }
616            ExprKind::Index(base, index) => {
617                self.collect_direct_internal_calls_expr(base, calls);
618                if let Some(index) = index {
619                    self.collect_direct_internal_calls_expr(index, calls);
620                }
621            }
622            ExprKind::Slice(base, start, end) => {
623                self.collect_direct_internal_calls_expr(base, calls);
624                if let Some(start) = start {
625                    self.collect_direct_internal_calls_expr(start, calls);
626                }
627                if let Some(end) = end {
628                    self.collect_direct_internal_calls_expr(end, calls);
629                }
630            }
631            ExprKind::Ternary(cond, true_expr, false_expr) => {
632                self.collect_direct_internal_calls_expr(cond, calls);
633                self.collect_direct_internal_calls_expr(true_expr, calls);
634                self.collect_direct_internal_calls_expr(false_expr, calls);
635            }
636            ExprKind::Array(exprs) => {
637                for expr in *exprs {
638                    self.collect_direct_internal_calls_expr(expr, calls);
639                }
640            }
641            ExprKind::Tuple(exprs) => {
642                for expr in exprs.iter().copied().flatten() {
643                    self.collect_direct_internal_calls_expr(expr, calls);
644                }
645            }
646            ExprKind::Ident(_)
647            | ExprKind::Lit(_)
648            | ExprKind::New(_)
649            | ExprKind::TypeCall(_)
650            | ExprKind::Type(_)
651            | ExprKind::YulMember(..)
652            | ExprKind::Err(_) => {}
653        }
654    }
655
656    fn analyze_lhs_indices(&mut self, expr: &'hir hir::Expr<'hir>, state: &mut FlowState) {
657        match &expr.kind {
658            ExprKind::Index(base, index) => {
659                self.analyze_lhs_indices(base, state);
660                if let Some(index) = index {
661                    self.analyze_expr(index, state);
662                }
663            }
664            ExprKind::Slice(base, start, end) => {
665                self.analyze_lhs_indices(base, state);
666                if let Some(start) = start {
667                    self.analyze_expr(start, state);
668                }
669                if let Some(end) = end {
670                    self.analyze_expr(end, state);
671                }
672            }
673            ExprKind::Member(base, _) | ExprKind::Payable(base) => {
674                self.analyze_lhs_indices(base, state);
675            }
676            ExprKind::Tuple(exprs) => {
677                for expr in exprs.iter().copied().flatten() {
678                    self.analyze_lhs_indices(expr, state);
679                }
680            }
681            _ => {}
682        }
683    }
684
685    fn emit_pending_calls(&mut self, state: &FlowState, written_vars: &[VariableId]) {
686        for call in &state.pending_calls {
687            let (lint, msg_prefix) = match call.kind {
688                ReentrantCallKind::Eth => {
689                    (&REENTRANCY_ETH, "uncapped ETH transfer can be reentered before")
690                }
691                ReentrantCallKind::NoEth => {
692                    (&REENTRANCY_NO_ETH, "external call can be reentered before")
693                }
694            };
695            if !self.ctx.is_lint_enabled(lint.id) || self.emitted.contains(&call.span) {
696                continue;
697            }
698
699            if let Some(var_id) =
700                written_vars.iter().find(|&&var_id| call.state_reads.contains(&var_id))
701            {
702                let name = self
703                    .hir
704                    .variable(*var_id)
705                    .name
706                    .map(|name| name.as_str().to_string())
707                    .unwrap_or_else(|| "state".to_string());
708                self.ctx.emit_with_msg(
709                    lint,
710                    call.span,
711                    format!("{msg_prefix} `{name}` is updated"),
712                );
713                self.emitted.insert(call.span);
714            }
715        }
716    }
717
718    fn reentrant_call_kind(
719        &self,
720        callee: &'hir hir::Expr<'hir>,
721        args: &CallArgs<'hir>,
722        opts: Option<&hir::CallOptions<'hir>>,
723    ) -> Option<ReentrantCallKind> {
724        if self.reentrancy_eth_enabled && is_uncapped_value_call(self.hir, callee, opts) {
725            return Some(ReentrantCallKind::Eth);
726        }
727        if self.reentrancy_no_eth_enabled
728            && is_no_eth_reentrant_call(self.gcx, self.hir, callee, args, opts)
729        {
730            return Some(ReentrantCallKind::NoEth);
731        }
732        None
733    }
734}
735
736impl FlowState {
737    fn clear(&mut self) {
738        self.state_reads.clear();
739        self.pending_calls.clear();
740    }
741
742    fn merge(&mut self, other: &Self) {
743        self.state_reads.extend(other.state_reads.iter().copied());
744        for call in &other.pending_calls {
745            if let Some(existing) = self
746                .pending_calls
747                .iter_mut()
748                .find(|existing| existing.span == call.span && existing.kind == call.kind)
749            {
750                existing.state_reads.extend(call.state_reads.iter().copied());
751            } else {
752                self.pending_calls.push(call.clone());
753            }
754        }
755    }
756}
757
758fn is_uncapped_value_call(
759    hir: &hir::Hir<'_>,
760    callee: &hir::Expr<'_>,
761    opts: Option<&hir::CallOptions<'_>>,
762) -> bool {
763    let Some(opts) = opts else { return false };
764    let ExprKind::Member(_, member) = &callee.peel_parens().kind else { return false };
765    if member.name != kw::Call {
766        return false;
767    }
768
769    let mut value = None;
770    let mut gas = None;
771    for opt in opts.args {
772        if opt.name.name == sym::value {
773            value = Some(&opt.value);
774        } else if opt.name.name == kw::Gas {
775            gas = Some(&opt.value);
776        }
777    }
778
779    value.is_some_and(|value| !is_zero_value(hir, value)) && gas.is_none_or(gas_option_forwards_all)
780}
781
782fn is_no_eth_reentrant_call<'hir>(
783    gcx: Gcx<'hir>,
784    hir: &'hir hir::Hir<'hir>,
785    callee: &'hir hir::Expr<'hir>,
786    args: &CallArgs<'hir>,
787    opts: Option<&hir::CallOptions<'hir>>,
788) -> bool {
789    if call_sends_eth(hir, opts) {
790        return false;
791    }
792
793    match &callee.peel_parens().kind {
794        ExprKind::Member(receiver, member) => match member.name {
795            kw::Call | kw::Callcode | kw::Delegatecall => is_address_like(gcx, hir, receiver),
796            kw::Staticcall => false,
797            _ => external_member_call_can_reenter(gcx, hir, receiver, member.name, args),
798        },
799        _ => external_function_pointer_can_reenter(gcx, hir, callee, args),
800    }
801}
802
803fn call_sends_eth(hir: &hir::Hir<'_>, opts: Option<&hir::CallOptions<'_>>) -> bool {
804    opts.is_some_and(|opts| {
805        opts.args.iter().any(|opt| opt.name.name == sym::value && !is_zero_value(hir, &opt.value))
806    })
807}
808
809fn external_member_call_can_reenter<'hir>(
810    gcx: Gcx<'hir>,
811    hir: &'hir hir::Hir<'hir>,
812    receiver: &'hir hir::Expr<'hir>,
813    member: solar::interface::Symbol,
814    args: &CallArgs<'hir>,
815) -> bool {
816    if is_super(receiver) {
817        return false;
818    }
819
820    let Some(receiver_ty) = expr_ty(gcx, hir, receiver) else { return false };
821    gcx.members_of(receiver_ty, base_item_source(hir, receiver), base_contract(hir, receiver))
822        .filter(|candidate| candidate.name == member)
823        .any(|candidate| match (candidate.res, candidate.ty.kind) {
824            (Some(Res::Item(ItemId::Function(function_id))), _) => {
825                let function = hir.function(function_id);
826                is_externally_callable(function)
827                    && args_match_function(gcx, hir, args, function.parameters)
828                    && function.mutates_state()
829            }
830            (_, TyKind::Fn(function)) => {
831                is_externally_callable_fn_kind(function.kind)
832                    && args_match_types(gcx, hir, args, function.parameters)
833                    && !matches!(
834                        function.state_mutability,
835                        StateMutability::Pure | StateMutability::View
836                    )
837            }
838            _ => false,
839        })
840}
841
842fn external_function_pointer_can_reenter<'hir>(
843    gcx: Gcx<'hir>,
844    hir: &'hir hir::Hir<'hir>,
845    callee: &'hir hir::Expr<'hir>,
846    args: &CallArgs<'hir>,
847) -> bool {
848    let Some(ty) = expr_ty(gcx, hir, callee) else { return false };
849    let TyKind::Fn(function) = ty.kind else { return false };
850    function.kind == TyFnKind::External
851        && args_match_types(gcx, hir, args, function.parameters)
852        && !matches!(function.state_mutability, StateMutability::Pure | StateMutability::View)
853}
854
855const fn is_externally_callable(func: &hir::Function<'_>) -> bool {
856    matches!(func.visibility, Visibility::Public | Visibility::External)
857}
858
859const fn is_externally_callable_fn_kind(kind: TyFnKind) -> bool {
860    matches!(kind, TyFnKind::External | TyFnKind::Declaration | TyFnKind::DelegateCall)
861}
862
863fn args_match_function<'hir>(
864    gcx: Gcx<'hir>,
865    hir: &'hir hir::Hir<'hir>,
866    args: &CallArgs<'hir>,
867    params: &'hir [VariableId],
868) -> bool {
869    if args.len() != params.len() {
870        return false;
871    }
872
873    match args.kind {
874        CallArgsKind::Unnamed(exprs) => exprs.iter().zip(params).all(|(arg, &param)| {
875            let param = hir.variable(param);
876            let param_ty =
877                gcx.type_of_hir_ty(&param.ty).with_loc_if_ref_opt(gcx, param.data_location);
878            arg_matches_type(gcx, hir, arg, param_ty)
879        }),
880        CallArgsKind::Named(named_args) => named_args.iter().all(|arg| {
881            params
882                .iter()
883                .copied()
884                .find(|&param| {
885                    hir.variable(param).name.is_some_and(|name| name.name == arg.name.name)
886                })
887                .is_some_and(|param| {
888                    let param = hir.variable(param);
889                    let param_ty =
890                        gcx.type_of_hir_ty(&param.ty).with_loc_if_ref_opt(gcx, param.data_location);
891                    arg_matches_type(gcx, hir, &arg.value, param_ty)
892                })
893        }),
894    }
895}
896
897fn args_match_types<'hir>(
898    gcx: Gcx<'hir>,
899    hir: &'hir hir::Hir<'hir>,
900    args: &CallArgs<'hir>,
901    params: &'hir [Ty<'hir>],
902) -> bool {
903    if args.len() != params.len() {
904        return false;
905    }
906
907    match args.kind {
908        CallArgsKind::Unnamed(exprs) => {
909            exprs.iter().zip(params).all(|(arg, &param)| arg_matches_type(gcx, hir, arg, param))
910        }
911        CallArgsKind::Named(_) => false,
912    }
913}
914
915fn arg_matches_type<'hir>(
916    gcx: Gcx<'hir>,
917    hir: &'hir hir::Hir<'hir>,
918    arg: &'hir hir::Expr<'hir>,
919    param_ty: Ty<'hir>,
920) -> bool {
921    expr_ty(gcx, hir, arg).is_some_and(|arg_ty| arg_ty.convert_implicit_to(param_ty, gcx))
922}
923
924fn is_address_like<'hir>(
925    gcx: Gcx<'hir>,
926    hir: &'hir hir::Hir<'hir>,
927    expr: &'hir hir::Expr<'hir>,
928) -> bool {
929    match &expr.peel_parens().kind {
930        ExprKind::Payable(_) => true,
931        ExprKind::Call(callee, _, _) if is_address_type_expr(callee) => true,
932        _ => expr_ty(gcx, hir, expr).is_some_and(type_is_address_like),
933    }
934}
935
936fn is_address_type_expr(expr: &hir::Expr<'_>) -> bool {
937    matches!(
938        &expr.peel_parens().kind,
939        ExprKind::Type(hir::Type {
940            kind: hir::TypeKind::Elementary(ElementaryType::Address(_)),
941            ..
942        })
943    )
944}
945
946fn type_is_address_like(ty: Ty<'_>) -> bool {
947    matches!(ty.peel_refs().kind, TyKind::Elementary(ElementaryType::Address(_)))
948}
949
950fn expr_ty<'hir>(
951    gcx: Gcx<'hir>,
952    hir: &'hir hir::Hir<'hir>,
953    expr: &'hir hir::Expr<'hir>,
954) -> Option<Ty<'hir>> {
955    match &expr.peel_parens().kind {
956        ExprKind::Array(_) | ExprKind::YulMember(..) => None,
957        ExprKind::Call(callee, args, _) => {
958            let callee_ty = expr_ty(gcx, hir, callee)?;
959            match callee_ty.kind {
960                TyKind::Fn(func) => fn_call_return_type(gcx, func.returns),
961                TyKind::Type(to) => Some(explicit_cast_ty(gcx, to, args)),
962                _ => None,
963            }
964        }
965        ExprKind::Ident(reses) => {
966            let res = unique(reses.iter().filter(|res| !matches!(res, Res::Err(_))).copied())?;
967            match res {
968                Res::Builtin(builtin) if matches!(builtin.name(), sym::this | sym::super_) => None,
969                Res::Item(ItemId::Variable(var_id)) => Some(
970                    gcx.type_of_res(res)
971                        .with_loc_if_ref_opt(gcx, variable_data_location(hir, var_id)),
972                ),
973                _ => Some(gcx.type_of_res(res)),
974            }
975        }
976        ExprKind::Index(lhs, index) => {
977            let lhs_ty = expr_ty(gcx, hir, lhs)?;
978            if let Some(index) = index
979                && !expr_ty(gcx, hir, index)?.convert_implicit_to(gcx.types.uint(256), gcx)
980            {
981                return None;
982            }
983            index_ty(gcx, lhs_ty)
984        }
985        ExprKind::Lit(lit) => Some(match &lit.kind {
986            LitKind::Str(StrKind::Hex, s, _) => {
987                let size = TypeSize::try_new_fb_bytes(s.as_byte_str().len().min(32) as u8)?;
988                gcx.types.fixed_bytes(size.bytes())
989            }
990            LitKind::Str(_, s, _) => gcx.mk_ty_string_literal(s.as_byte_str()),
991            LitKind::Number(int) => gcx.mk_ty_int_literal(false, int.bit_len() as _)?,
992            LitKind::Rational(_) | LitKind::Err(_) => return None,
993            LitKind::Address(_) => gcx.types.address,
994            LitKind::Bool(_) => gcx.types.bool,
995        }),
996        ExprKind::Member(base, member) => member_ty(gcx, hir, base, member.name),
997        ExprKind::New(ty) => {
998            let ty = gcx.type_of_hir_ty(ty);
999            Some(gcx.mk_ty(TyKind::Type(ty)))
1000        }
1001        ExprKind::Payable(inner) => {
1002            let inner_ty = expr_ty(gcx, hir, inner)?;
1003            inner_ty
1004                .convert_explicit_to(gcx.types.address_payable, gcx)
1005                .then_some(gcx.types.address_payable)
1006        }
1007        ExprKind::Slice(lhs, ..) => {
1008            let lhs_ty = expr_ty(gcx, hir, lhs)?;
1009            lhs_ty.is_sliceable().then_some(gcx.mk_ty(TyKind::Slice(lhs_ty)))
1010        }
1011        ExprKind::Tuple(exprs) => {
1012            let tys = exprs
1013                .iter()
1014                .map(|expr| expr.and_then(|expr| expr_ty(gcx, hir, expr)))
1015                .collect::<Option<Vec<_>>>()?;
1016            Some(gcx.mk_ty_tuple(gcx.mk_tys(&tys)))
1017        }
1018        ExprKind::Ternary(_, true_expr, false_expr) => {
1019            let true_ty = expr_ty(gcx, hir, true_expr)?;
1020            let false_ty = expr_ty(gcx, hir, false_expr)?;
1021            common_ty(gcx, true_ty, false_ty)
1022        }
1023        ExprKind::Type(ty) | ExprKind::TypeCall(ty) => {
1024            let ty = gcx.type_of_hir_ty(ty);
1025            Some(gcx.mk_ty(TyKind::Type(ty)))
1026        }
1027        ExprKind::Unary(op, inner) => match op.kind {
1028            UnOpKind::Not => Some(gcx.types.bool),
1029            _ => expr_ty(gcx, hir, inner),
1030        },
1031        ExprKind::Binary(_, op, _) if binary_op_returns_bool(op.kind) => Some(gcx.types.bool),
1032        ExprKind::Assign(..) | ExprKind::Binary(..) | ExprKind::Delete(..) | ExprKind::Err(_) => {
1033            None
1034        }
1035    }
1036}
1037
1038const fn binary_op_returns_bool(op: BinOpKind) -> bool {
1039    matches!(
1040        op,
1041        BinOpKind::Lt
1042            | BinOpKind::Le
1043            | BinOpKind::Gt
1044            | BinOpKind::Ge
1045            | BinOpKind::Eq
1046            | BinOpKind::Ne
1047            | BinOpKind::And
1048            | BinOpKind::Or
1049    )
1050}
1051
1052fn member_ty<'hir>(
1053    gcx: Gcx<'hir>,
1054    hir: &'hir hir::Hir<'hir>,
1055    base: &'hir hir::Expr<'hir>,
1056    member_name: solar::interface::Symbol,
1057) -> Option<Ty<'hir>> {
1058    if is_this(base) || is_super(base) {
1059        return None;
1060    }
1061
1062    let base_ty = expr_ty(gcx, hir, base)?;
1063    unique(
1064        gcx.members_of(base_ty, base_item_source(hir, base), base_contract(hir, base))
1065            .filter(|member| member.name == member_name)
1066            .map(|member| member.ty),
1067    )
1068}
1069
1070fn common_ty<'hir>(gcx: Gcx<'hir>, lhs: Ty<'hir>, rhs: Ty<'hir>) -> Option<Ty<'hir>> {
1071    if lhs.convert_implicit_to(rhs, gcx) {
1072        Some(rhs)
1073    } else {
1074        rhs.convert_implicit_to(lhs, gcx).then_some(lhs)
1075    }
1076}
1077
1078fn fn_call_return_type<'hir>(gcx: Gcx<'hir>, returns: &'hir [Ty<'hir>]) -> Option<Ty<'hir>> {
1079    Some(match returns {
1080        [] => gcx.types.unit,
1081        [ret] => *ret,
1082        _ => gcx.mk_ty_tuple(returns),
1083    })
1084}
1085
1086fn explicit_cast_ty<'hir>(gcx: Gcx<'hir>, to: Ty<'hir>, args: &'hir CallArgs<'hir>) -> Ty<'hir> {
1087    match args.exprs().next().and_then(|arg| expr_ty(gcx, &gcx.hir, arg)) {
1088        Some(from) => from.try_convert_explicit_to(to, gcx).unwrap_or(to),
1089        None => to,
1090    }
1091}
1092
1093fn index_ty<'hir>(gcx: Gcx<'hir>, base_ty: Ty<'hir>) -> Option<Ty<'hir>> {
1094    let loc = indexed_base_data_location(base_ty);
1095    match base_ty.peel_refs().kind {
1096        TyKind::Mapping(_, value) => Some(value.with_loc_if_ref_opt(gcx, loc)),
1097        _ => base_ty.base_type(gcx),
1098    }
1099}
1100
1101fn indexed_base_data_location(ty: Ty<'_>) -> Option<DataLocation> {
1102    ty.loc().or_else(|| matches!(ty.kind, TyKind::Mapping(..)).then_some(DataLocation::Storage))
1103}
1104
1105fn base_item_source(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> hir::SourceId {
1106    referenced_item(expr)
1107        .map(|id| hir.item(id).source())
1108        .unwrap_or_else(|| hir.sources_enumerated().next().expect("HIR has a source").0)
1109}
1110
1111fn base_contract(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> Option<hir::ContractId> {
1112    referenced_item(expr).and_then(|id| hir.item(id).contract())
1113}
1114
1115fn referenced_item(expr: &hir::Expr<'_>) -> Option<ItemId> {
1116    match &expr.peel_parens().kind {
1117        ExprKind::Ident([Res::Item(id), ..]) => Some(*id),
1118        _ => None,
1119    }
1120}
1121
1122fn variable_data_location(hir: &hir::Hir<'_>, var_id: VariableId) -> Option<DataLocation> {
1123    let var = hir.variable(var_id);
1124    var.data_location.or_else(|| var.kind.is_state().then_some(DataLocation::Storage))
1125}
1126
1127fn is_this(expr: &hir::Expr<'_>) -> bool {
1128    matches!(
1129        &expr.peel_parens().kind,
1130        ExprKind::Ident(reses)
1131            if reses.iter().any(|res| {
1132                matches!(res, Res::Builtin(builtin) if builtin.name() == sym::this)
1133            })
1134    )
1135}
1136
1137fn is_super(expr: &hir::Expr<'_>) -> bool {
1138    matches!(
1139        &expr.peel_parens().kind,
1140        ExprKind::Ident(reses)
1141            if reses.iter().any(|res| {
1142                matches!(res, Res::Builtin(builtin) if builtin.name() == sym::super_)
1143            })
1144    )
1145}
1146
1147fn unique<T>(mut iter: impl Iterator<Item = T>) -> Option<T> {
1148    let first = iter.next()?;
1149    iter.next().is_none().then_some(first)
1150}
1151
1152fn is_zero_value(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
1153    let mut seen = BTreeSet::new();
1154    is_zero_value_inner(hir, expr, &mut seen)
1155}
1156
1157fn is_zero_value_inner(
1158    hir: &hir::Hir<'_>,
1159    expr: &hir::Expr<'_>,
1160    seen: &mut BTreeSet<VariableId>,
1161) -> bool {
1162    match &expr.peel_parens().kind {
1163        ExprKind::Lit(lit) => matches!(lit.kind, LitKind::Number(value) if value.is_zero()),
1164        ExprKind::Ident(reses) => {
1165            let mut saw_variable = false;
1166            reses.iter().all(|res| match res {
1167                Res::Item(ItemId::Variable(var_id)) => {
1168                    saw_variable = true;
1169                    constant_var_is_zero(hir, *var_id, seen)
1170                }
1171                _ => false,
1172            }) && saw_variable
1173        }
1174        ExprKind::Call(callee, args, opts)
1175            if opts.is_none()
1176                && matches!(callee.peel_parens().kind, ExprKind::Type(_))
1177                && args.exprs().count() == 1 =>
1178        {
1179            args.exprs().next().is_some_and(|arg| is_zero_value_inner(hir, arg, seen))
1180        }
1181        _ => false,
1182    }
1183}
1184
1185fn constant_var_is_zero(
1186    hir: &hir::Hir<'_>,
1187    var_id: VariableId,
1188    seen: &mut BTreeSet<VariableId>,
1189) -> bool {
1190    let var = hir.variable(var_id);
1191    if !var.is_constant() || !seen.insert(var_id) {
1192        return false;
1193    }
1194    var.initializer.is_some_and(|init| is_zero_value_inner(hir, init, seen))
1195}
1196
1197fn gas_option_forwards_all(expr: &hir::Expr<'_>) -> bool {
1198    let ExprKind::Call(callee, args, opts) = &expr.peel_parens().kind else {
1199        return false;
1200    };
1201    if opts.is_some() || args.exprs().next().is_some() {
1202        return false;
1203    }
1204    matches!(
1205        &callee.peel_parens().kind,
1206        ExprKind::Ident(reses)
1207            if reses.iter().any(|res| {
1208                matches!(res, Res::Builtin(builtin) if builtin.name() == sym::gasleft)
1209            })
1210    )
1211}
1212
1213fn resolved_function_ids<'hir>(
1214    callee: &'hir hir::Expr<'hir>,
1215) -> impl Iterator<Item = FunctionId> + 'hir {
1216    let reses = match &callee.peel_parens().kind {
1217        ExprKind::Ident(reses) => *reses,
1218        _ => &[],
1219    };
1220    reses.iter().filter_map(|res| match res {
1221        Res::Item(ItemId::Function(func_id)) => Some(*func_id),
1222        _ => None,
1223    })
1224}
1225
1226fn state_write_lhs_vars(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> Vec<VariableId> {
1227    let mut vars = Vec::new();
1228    collect_state_write_lhs_vars(hir, expr, &mut vars);
1229    vars
1230}
1231
1232fn collect_state_write_lhs_vars(
1233    hir: &hir::Hir<'_>,
1234    expr: &hir::Expr<'_>,
1235    vars: &mut Vec<VariableId>,
1236) {
1237    match &expr.kind {
1238        ExprKind::Ident(reses) => {
1239            for &res in *reses {
1240                if let Res::Item(ItemId::Variable(var_id)) = res
1241                    && hir.variable(var_id).kind.is_state()
1242                {
1243                    push_unique(vars, var_id);
1244                }
1245            }
1246        }
1247        ExprKind::Index(base, _) | ExprKind::Slice(base, ..) => {
1248            collect_state_write_lhs_vars(hir, base, vars);
1249        }
1250        ExprKind::Member(base, _)
1251        | ExprKind::Payable(base)
1252        | ExprKind::Unary(_, base)
1253        | ExprKind::Delete(base) => collect_state_write_lhs_vars(hir, base, vars),
1254        ExprKind::Tuple(exprs) => {
1255            for expr in exprs.iter().copied().flatten() {
1256                collect_state_write_lhs_vars(hir, expr, vars);
1257            }
1258        }
1259        _ => {}
1260    }
1261}
1262
1263fn push_unique<T: Copy + Eq>(items: &mut Vec<T>, item: T) {
1264    if !items.contains(&item) {
1265        items.push(item);
1266    }
1267}