1use super::ReentrancyEth;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{
5 Severity, SolLint,
6 analysis::helper_cache::{DEFAULT_HELPER_ANALYSIS_CACHE_LIMIT, HelperAnalysisCache},
7 },
8};
9use solar::{
10 ast::{
11 BinOpKind, DataLocation, ElementaryType, LitKind, StateMutability, StrKind, TypeSize,
12 UnOpKind, Visibility,
13 },
14 interface::{Span, kw, sym},
15 sema::{
16 Gcx, Ty,
17 hir::{
18 self, CallArgs, CallArgsKind, ExprKind, FunctionId, ItemId, Res, StmtKind, VariableId,
19 },
20 ty::{TyFnKind, TyKind},
21 },
22};
23use std::collections::{BTreeSet, HashMap, HashSet};
24
25declare_forge_lint!(
26 REENTRANCY_ETH,
27 Severity::High,
28 "reentrancy-eth",
29 "state read before ETH transfer is written after the transfer"
30);
31
32declare_forge_lint!(
33 REENTRANCY_NO_ETH,
34 Severity::Med,
35 "reentrancy-no-eth",
36 "state read before external call is written after the call"
37);
38
39impl<'hir> LateLintPass<'hir> for ReentrancyEth {
40 fn check_function(
41 &mut self,
42 ctx: &LintContext,
43 gcx: Gcx<'hir>,
44 hir: &'hir hir::Hir<'hir>,
45 func: &'hir hir::Function<'hir>,
46 ) {
47 if !is_entry_point(func) {
48 return;
49 }
50
51 let Some(body) = func.body else { return };
52
53 let mut analyzer = Analyzer::new(ctx, gcx, hir);
54 if !analyzer.has_enabled_lints() {
55 return;
56 }
57 let mut state = FlowState::default();
58 analyzer.analyze_callable(func, body, &mut state);
59 }
60}
61
62fn is_entry_point(func: &hir::Function<'_>) -> bool {
63 if matches!(func.state_mutability, StateMutability::Pure | StateMutability::View) {
64 return false;
65 }
66 if func.is_constructor() {
67 return false;
68 }
69 if func.is_special() {
70 return true;
71 }
72 func.kind.is_function() && matches!(func.visibility, Visibility::Public | Visibility::External)
73}
74
75#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
76struct FlowState {
77 state_reads: BTreeSet<VariableId>,
78 pending_calls: Vec<PendingCall>,
79}
80
81#[derive(Clone, Debug, PartialEq, Eq, Hash)]
82struct PendingCall {
83 span: Span,
84 kind: ReentrantCallKind,
85 state_reads: BTreeSet<VariableId>,
86}
87
88#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
89enum ReentrantCallKind {
90 Eth,
91 NoEth,
92}
93
94impl FlowState {
95 fn push_read(&mut self, var_id: VariableId) {
96 self.state_reads.insert(var_id);
97 }
98
99 fn push_call(&mut self, span: Span, kind: ReentrantCallKind) {
100 if self.state_reads.is_empty() {
101 return;
102 }
103
104 if let Some(existing) =
105 self.pending_calls.iter_mut().find(|call| call.span == span && call.kind == kind)
106 {
107 existing.state_reads.extend(self.state_reads.iter().copied());
108 } else {
109 self.pending_calls.push(PendingCall {
110 span,
111 kind,
112 state_reads: self.state_reads.clone(),
113 });
114 }
115 }
116}
117
118struct Analyzer<'ctx, 's, 'c, 'hir> {
119 ctx: &'ctx LintContext<'s, 'c>,
120 gcx: Gcx<'hir>,
121 hir: &'hir hir::Hir<'hir>,
122 emitted: HashSet<Span>,
123 call_stack: Vec<FunctionId>,
124 inline_cache: HelperAnalysisCache<InlineCallKey, FlowState>,
125 recursive_cut_frontiers: HashMap<RecursiveFrontierKey, Vec<FunctionId>>,
126 direct_internal_calls: HashMap<FunctionId, Vec<FunctionId>>,
127 reentrancy_eth_enabled: bool,
128 reentrancy_no_eth_enabled: bool,
129}
130
131#[derive(Clone, Debug, PartialEq, Eq, Hash)]
132struct InlineCallKey {
133 func_id: FunctionId,
134 recursive_cut: Option<FunctionId>,
136 state: FlowState,
137}
138
139#[derive(Clone, Debug, PartialEq, Eq, Hash)]
140struct RecursiveFrontierKey {
141 func_id: FunctionId,
142 active_call_stack: Vec<FunctionId>,
143}
144
145impl<'ctx, 's, 'c, 'hir> Analyzer<'ctx, 's, 'c, 'hir> {
146 fn new(ctx: &'ctx LintContext<'s, 'c>, gcx: Gcx<'hir>, hir: &'hir hir::Hir<'hir>) -> Self {
147 Self {
148 ctx,
149 gcx,
150 hir,
151 emitted: HashSet::new(),
152 call_stack: Vec::new(),
153 inline_cache: HelperAnalysisCache::new(DEFAULT_HELPER_ANALYSIS_CACHE_LIMIT),
154 recursive_cut_frontiers: HashMap::new(),
155 direct_internal_calls: HashMap::new(),
156 reentrancy_eth_enabled: ctx.is_lint_enabled(REENTRANCY_ETH.id),
157 reentrancy_no_eth_enabled: ctx.is_lint_enabled(REENTRANCY_NO_ETH.id),
158 }
159 }
160
161 const fn has_enabled_lints(&self) -> bool {
162 self.reentrancy_eth_enabled || self.reentrancy_no_eth_enabled
163 }
164
165 fn analyze_callable(
166 &mut self,
167 func: &'hir hir::Function<'hir>,
168 body: hir::Block<'hir>,
169 state: &mut FlowState,
170 ) -> bool {
171 self.analyze_modifier_chain(func.modifiers, 0, body, state)
172 }
173
174 fn analyze_modifier_chain(
175 &mut self,
176 modifiers: &'hir [hir::Modifier<'hir>],
177 index: usize,
178 body: hir::Block<'hir>,
179 state: &mut FlowState,
180 ) -> bool {
181 let Some(modifier) = modifiers.get(index) else {
182 return self.analyze_block(body, None, state);
183 };
184
185 for arg in modifier.args.exprs() {
186 self.analyze_expr(arg, state);
187 }
188
189 let Some(modifier_id) = modifier.id.as_function() else {
190 return self.analyze_modifier_chain(modifiers, index + 1, body, state);
191 };
192
193 if self.call_stack.contains(&modifier_id) {
194 return self.analyze_modifier_chain(modifiers, index + 1, body, state);
195 }
196
197 let modifier_func = self.hir.function(modifier_id);
198 let Some(modifier_body) = modifier_func.body else {
199 return self.analyze_modifier_chain(modifiers, index + 1, body, state);
200 };
201
202 self.call_stack.push(modifier_id);
203 let falls_through =
204 self.analyze_block(modifier_body, Some((modifiers, index + 1, body)), state);
205 self.call_stack.pop();
206 falls_through
207 }
208
209 fn analyze_block(
210 &mut self,
211 block: hir::Block<'hir>,
212 placeholder: Option<(&'hir [hir::Modifier<'hir>], usize, hir::Block<'hir>)>,
213 state: &mut FlowState,
214 ) -> bool {
215 for stmt in block.stmts {
216 if !self.analyze_stmt(stmt, placeholder, state) {
217 return false;
218 }
219 }
220 true
221 }
222
223 fn analyze_stmt(
224 &mut self,
225 stmt: &'hir hir::Stmt<'hir>,
226 placeholder: Option<(&'hir [hir::Modifier<'hir>], usize, hir::Block<'hir>)>,
227 state: &mut FlowState,
228 ) -> bool {
229 match stmt.kind {
230 StmtKind::DeclSingle(var_id) => {
231 if let Some(init) = self.hir.variable(var_id).initializer {
232 self.analyze_expr(init, state);
233 }
234 true
235 }
236 StmtKind::DeclMulti(_, expr) | StmtKind::Expr(expr) => {
237 self.analyze_expr(expr, state);
238 true
239 }
240 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
241 self.analyze_block(block, placeholder, state)
242 }
243 StmtKind::Emit(expr) => {
244 self.analyze_expr(expr, state);
245 true
246 }
247 StmtKind::Revert(expr) => {
248 self.analyze_expr(expr, state);
249 false
250 }
251 StmtKind::Return(expr) => {
252 if let Some(expr) = expr {
253 self.analyze_expr(expr, state);
254 }
255 false
256 }
257 StmtKind::Break | StmtKind::Continue => false,
258 StmtKind::Loop(block, _) => {
259 let before_loop = state.clone();
260 let mut body_state = state.clone();
261 self.analyze_block(block, placeholder, &mut body_state);
262 state.clear();
263 state.merge(&before_loop);
264 state.merge(&body_state);
265 true
266 }
267 StmtKind::If(cond, then_stmt, else_stmt) => {
268 self.analyze_expr(cond, state);
269
270 let mut then_state = state.clone();
271 let then_falls_through = self.analyze_stmt(then_stmt, placeholder, &mut then_state);
272
273 let mut else_state = state.clone();
274 let else_falls_through = if let Some(else_stmt) = else_stmt {
275 self.analyze_stmt(else_stmt, placeholder, &mut else_state)
276 } else {
277 true
278 };
279
280 state.clear();
281 if then_falls_through {
282 state.merge(&then_state);
283 }
284 if else_falls_through {
285 state.merge(&else_state);
286 }
287
288 then_falls_through || else_falls_through
289 }
290 StmtKind::Try(try_stmt) => {
291 self.analyze_expr(&try_stmt.expr, state);
292
293 let mut merged = FlowState::default();
294 let mut any_falls_through = false;
295 for clause in try_stmt.clauses {
296 let mut clause_state = state.clone();
297 let falls_through =
298 self.analyze_block(clause.block, placeholder, &mut clause_state);
299 if falls_through {
300 merged.merge(&clause_state);
301 any_falls_through = true;
302 }
303 }
304
305 *state = merged;
306 any_falls_through
307 }
308 StmtKind::Placeholder => {
309 if let Some((modifiers, index, body)) = placeholder {
310 self.analyze_modifier_chain(modifiers, index, body, state)
311 } else {
312 true
313 }
314 }
315 StmtKind::AssemblyBlock(_) | StmtKind::Switch(_) | StmtKind::Err(_) => true,
316 }
317 }
318
319 fn analyze_expr(&mut self, expr: &'hir hir::Expr<'hir>, state: &mut FlowState) {
320 match &expr.kind {
321 ExprKind::Assign(lhs, op, rhs) => {
322 if op.is_some() {
323 self.analyze_expr(lhs, state);
324 }
325 self.analyze_expr(rhs, state);
326 self.analyze_lhs_indices(lhs, state);
327 let written_vars = state_write_lhs_vars(self.hir, lhs);
328 if !written_vars.is_empty() {
329 self.emit_pending_calls(state, &written_vars);
330 }
331 }
332 ExprKind::Delete(inner) => {
333 self.analyze_lhs_indices(inner, state);
334 let written_vars = state_write_lhs_vars(self.hir, inner);
335 if !written_vars.is_empty() {
336 self.emit_pending_calls(state, &written_vars);
337 }
338 }
339 ExprKind::Unary(op, inner)
340 if matches!(
341 op.kind,
342 UnOpKind::PreInc | UnOpKind::PreDec | UnOpKind::PostInc | UnOpKind::PostDec
343 ) =>
344 {
345 self.analyze_expr(inner, state);
346 let written_vars = state_write_lhs_vars(self.hir, inner);
347 if !written_vars.is_empty() {
348 self.emit_pending_calls(state, &written_vars);
349 }
350 }
351 ExprKind::Unary(_, inner) => {
352 self.analyze_expr(inner, state);
353 }
354 ExprKind::Call(callee, args, opts) => {
355 self.analyze_expr(callee, state);
356 if let Some(opts) = opts {
357 for opt in opts.args {
358 self.analyze_expr(&opt.value, state);
359 }
360 }
361 for arg in args.exprs() {
362 self.analyze_expr(arg, state);
363 }
364
365 for func_id in resolved_function_ids(callee) {
366 self.analyze_internal_call(func_id, state);
367 }
368 if !state.state_reads.is_empty()
369 && let Some(kind) = self.reentrant_call_kind(callee, args, *opts)
370 {
371 state.push_call(expr.span, kind);
372 }
373 }
374 ExprKind::Binary(lhs, _, rhs) => {
375 self.analyze_expr(lhs, state);
376 self.analyze_expr(rhs, state);
377 }
378 ExprKind::Index(base, index) => {
379 self.analyze_expr(base, state);
380 if let Some(index) = index {
381 self.analyze_expr(index, state);
382 }
383 }
384 ExprKind::Slice(base, start, end) => {
385 self.analyze_expr(base, state);
386 if let Some(start) = start {
387 self.analyze_expr(start, state);
388 }
389 if let Some(end) = end {
390 self.analyze_expr(end, state);
391 }
392 }
393 ExprKind::Ternary(cond, true_expr, false_expr) => {
394 self.analyze_expr(cond, state);
395
396 let mut true_state = state.clone();
397 self.analyze_expr(true_expr, &mut true_state);
398
399 let mut false_state = state.clone();
400 self.analyze_expr(false_expr, &mut false_state);
401
402 state.clear();
403 state.merge(&true_state);
404 state.merge(&false_state);
405 }
406 ExprKind::Array(exprs) => {
407 for expr in *exprs {
408 self.analyze_expr(expr, state);
409 }
410 }
411 ExprKind::Tuple(exprs) => {
412 for expr in exprs.iter().copied().flatten() {
413 self.analyze_expr(expr, state);
414 }
415 }
416 ExprKind::Member(base, _) | ExprKind::Payable(base) => {
417 self.analyze_expr(base, state);
418 }
419 ExprKind::New(_) | ExprKind::TypeCall(_) | ExprKind::Type(_) => {}
420 ExprKind::Ident(reses) => {
421 for &res in *reses {
422 if let Res::Item(ItemId::Variable(var_id)) = res
423 && self.hir.variable(var_id).kind.is_state()
424 {
425 state.push_read(var_id);
426 }
427 }
428 }
429 ExprKind::Lit(_) | ExprKind::YulMember(..) | ExprKind::Err(_) => {}
430 }
431 }
432
433 fn analyze_internal_call(&mut self, func_id: FunctionId, state: &mut FlowState) {
434 if self.call_stack.contains(&func_id) {
435 return;
436 }
437
438 let func = self.hir.function(func_id);
439 let Some(body) = func.body else { return };
440
441 let key = InlineCallKey {
442 func_id,
443 recursive_cut: self.first_recursive_cut(func_id),
444 state: state.clone(),
445 };
446 if self.inline_cache.is_in_progress(&key) {
447 return;
448 }
449 if let Some(cached) = self.inline_cache.get(&key) {
450 *state = cached.clone();
451 return;
452 }
453
454 let mut after = state.clone();
455 self.inline_cache.start(key.clone());
456 self.call_stack.push(func_id);
457 self.analyze_callable(func, body, &mut after);
458 self.call_stack.pop();
459
460 self.inline_cache.finish(key, after.clone());
461 *state = after;
462 }
463
464 fn first_recursive_cut(&mut self, func_id: FunctionId) -> Option<FunctionId> {
465 let active_call_stack = self.call_stack.iter().copied().collect::<BTreeSet<_>>();
466 if active_call_stack.is_empty() {
467 return None;
468 }
469
470 let active_call_stack = active_call_stack.into_iter().collect::<Vec<_>>();
471 let key = RecursiveFrontierKey { func_id, active_call_stack };
472 if let Some(frontier) = self.recursive_cut_frontiers.get(&key) {
473 return frontier.first().copied();
474 }
475
476 let active_call_stack = key.active_call_stack.iter().copied().collect::<BTreeSet<_>>();
477 let mut seen = HashSet::new();
478 let cut = self.first_recursive_cut_function(func_id, &active_call_stack, &mut seen);
479 self.recursive_cut_frontiers.insert(key, cut.into_iter().collect::<Vec<_>>());
480 cut
481 }
482
483 fn first_recursive_cut_function(
484 &mut self,
485 func_id: FunctionId,
486 active_call_stack: &BTreeSet<FunctionId>,
487 seen: &mut HashSet<FunctionId>,
488 ) -> Option<FunctionId> {
489 if !seen.insert(func_id) {
490 return None;
491 }
492
493 for callee_id in self.direct_internal_calls(func_id) {
494 if active_call_stack.contains(&callee_id) {
495 return Some(callee_id);
496 }
497 if let Some(cut) = self.first_recursive_cut_function(callee_id, active_call_stack, seen)
498 {
499 return Some(cut);
500 }
501 }
502 None
503 }
504
505 fn direct_internal_calls(&mut self, func_id: FunctionId) -> Vec<FunctionId> {
506 if let Some(calls) = self.direct_internal_calls.get(&func_id) {
507 return calls.clone();
508 }
509
510 let mut calls = BTreeSet::new();
511 let func = self.hir.function(func_id);
512 for modifier in func.modifiers {
513 for arg in modifier.args.exprs() {
514 self.collect_direct_internal_calls_expr(arg, &mut calls);
515 }
516 if let Some(modifier_id) = modifier.id.as_function() {
517 calls.insert(modifier_id);
518 }
519 }
520 if let Some(body) = func.body {
521 self.collect_direct_internal_calls_block(body, &mut calls);
522 }
523
524 let calls = calls.into_iter().collect::<Vec<_>>();
525 self.direct_internal_calls.insert(func_id, calls.clone());
526 calls
527 }
528
529 fn collect_direct_internal_calls_block(
530 &mut self,
531 block: hir::Block<'hir>,
532 calls: &mut BTreeSet<FunctionId>,
533 ) {
534 for stmt in block.stmts {
535 self.collect_direct_internal_calls_stmt(stmt, calls);
536 }
537 }
538
539 fn collect_direct_internal_calls_stmt(
540 &mut self,
541 stmt: &'hir hir::Stmt<'hir>,
542 calls: &mut BTreeSet<FunctionId>,
543 ) {
544 match stmt.kind {
545 StmtKind::DeclSingle(var_id) => {
546 if let Some(init) = self.hir.variable(var_id).initializer {
547 self.collect_direct_internal_calls_expr(init, calls);
548 }
549 }
550 StmtKind::DeclMulti(_, expr)
551 | StmtKind::Expr(expr)
552 | StmtKind::Emit(expr)
553 | StmtKind::Revert(expr) => {
554 self.collect_direct_internal_calls_expr(expr, calls);
555 }
556 StmtKind::Return(expr) => {
557 if let Some(expr) = expr {
558 self.collect_direct_internal_calls_expr(expr, calls);
559 }
560 }
561 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
562 self.collect_direct_internal_calls_block(block, calls);
563 }
564 StmtKind::If(cond, then_stmt, else_stmt) => {
565 self.collect_direct_internal_calls_expr(cond, calls);
566 self.collect_direct_internal_calls_stmt(then_stmt, calls);
567 if let Some(else_stmt) = else_stmt {
568 self.collect_direct_internal_calls_stmt(else_stmt, calls);
569 }
570 }
571 StmtKind::Try(try_stmt) => {
572 self.collect_direct_internal_calls_expr(&try_stmt.expr, calls);
573 for clause in try_stmt.clauses {
574 self.collect_direct_internal_calls_block(clause.block, calls);
575 }
576 }
577 StmtKind::Break
578 | StmtKind::Continue
579 | StmtKind::Placeholder
580 | StmtKind::AssemblyBlock(_)
581 | StmtKind::Switch(_)
582 | StmtKind::Err(_) => {}
583 }
584 }
585
586 fn collect_direct_internal_calls_expr(
587 &mut self,
588 expr: &'hir hir::Expr<'hir>,
589 calls: &mut BTreeSet<FunctionId>,
590 ) {
591 match &expr.kind {
592 ExprKind::Assign(lhs, _, rhs) | ExprKind::Binary(lhs, _, rhs) => {
593 self.collect_direct_internal_calls_expr(lhs, calls);
594 self.collect_direct_internal_calls_expr(rhs, calls);
595 }
596 ExprKind::Unary(_, inner)
597 | ExprKind::Delete(inner)
598 | ExprKind::Member(inner, _)
599 | ExprKind::Payable(inner) => {
600 self.collect_direct_internal_calls_expr(inner, calls);
601 }
602 ExprKind::Call(callee, args, opts) => {
603 self.collect_direct_internal_calls_expr(callee, calls);
604 if let Some(opts) = opts {
605 for opt in opts.args {
606 self.collect_direct_internal_calls_expr(&opt.value, calls);
607 }
608 }
609 for arg in args.exprs() {
610 self.collect_direct_internal_calls_expr(arg, calls);
611 }
612 for func_id in resolved_function_ids(callee) {
613 calls.insert(func_id);
614 }
615 }
616 ExprKind::Index(base, index) => {
617 self.collect_direct_internal_calls_expr(base, calls);
618 if let Some(index) = index {
619 self.collect_direct_internal_calls_expr(index, calls);
620 }
621 }
622 ExprKind::Slice(base, start, end) => {
623 self.collect_direct_internal_calls_expr(base, calls);
624 if let Some(start) = start {
625 self.collect_direct_internal_calls_expr(start, calls);
626 }
627 if let Some(end) = end {
628 self.collect_direct_internal_calls_expr(end, calls);
629 }
630 }
631 ExprKind::Ternary(cond, true_expr, false_expr) => {
632 self.collect_direct_internal_calls_expr(cond, calls);
633 self.collect_direct_internal_calls_expr(true_expr, calls);
634 self.collect_direct_internal_calls_expr(false_expr, calls);
635 }
636 ExprKind::Array(exprs) => {
637 for expr in *exprs {
638 self.collect_direct_internal_calls_expr(expr, calls);
639 }
640 }
641 ExprKind::Tuple(exprs) => {
642 for expr in exprs.iter().copied().flatten() {
643 self.collect_direct_internal_calls_expr(expr, calls);
644 }
645 }
646 ExprKind::Ident(_)
647 | ExprKind::Lit(_)
648 | ExprKind::New(_)
649 | ExprKind::TypeCall(_)
650 | ExprKind::Type(_)
651 | ExprKind::YulMember(..)
652 | ExprKind::Err(_) => {}
653 }
654 }
655
656 fn analyze_lhs_indices(&mut self, expr: &'hir hir::Expr<'hir>, state: &mut FlowState) {
657 match &expr.kind {
658 ExprKind::Index(base, index) => {
659 self.analyze_lhs_indices(base, state);
660 if let Some(index) = index {
661 self.analyze_expr(index, state);
662 }
663 }
664 ExprKind::Slice(base, start, end) => {
665 self.analyze_lhs_indices(base, state);
666 if let Some(start) = start {
667 self.analyze_expr(start, state);
668 }
669 if let Some(end) = end {
670 self.analyze_expr(end, state);
671 }
672 }
673 ExprKind::Member(base, _) | ExprKind::Payable(base) => {
674 self.analyze_lhs_indices(base, state);
675 }
676 ExprKind::Tuple(exprs) => {
677 for expr in exprs.iter().copied().flatten() {
678 self.analyze_lhs_indices(expr, state);
679 }
680 }
681 _ => {}
682 }
683 }
684
685 fn emit_pending_calls(&mut self, state: &FlowState, written_vars: &[VariableId]) {
686 for call in &state.pending_calls {
687 let (lint, msg_prefix) = match call.kind {
688 ReentrantCallKind::Eth => {
689 (&REENTRANCY_ETH, "uncapped ETH transfer can be reentered before")
690 }
691 ReentrantCallKind::NoEth => {
692 (&REENTRANCY_NO_ETH, "external call can be reentered before")
693 }
694 };
695 if !self.ctx.is_lint_enabled(lint.id) || self.emitted.contains(&call.span) {
696 continue;
697 }
698
699 if let Some(var_id) =
700 written_vars.iter().find(|&&var_id| call.state_reads.contains(&var_id))
701 {
702 let name = self
703 .hir
704 .variable(*var_id)
705 .name
706 .map(|name| name.as_str().to_string())
707 .unwrap_or_else(|| "state".to_string());
708 self.ctx.emit_with_msg(
709 lint,
710 call.span,
711 format!("{msg_prefix} `{name}` is updated"),
712 );
713 self.emitted.insert(call.span);
714 }
715 }
716 }
717
718 fn reentrant_call_kind(
719 &self,
720 callee: &'hir hir::Expr<'hir>,
721 args: &CallArgs<'hir>,
722 opts: Option<&hir::CallOptions<'hir>>,
723 ) -> Option<ReentrantCallKind> {
724 if self.reentrancy_eth_enabled && is_uncapped_value_call(self.hir, callee, opts) {
725 return Some(ReentrantCallKind::Eth);
726 }
727 if self.reentrancy_no_eth_enabled
728 && is_no_eth_reentrant_call(self.gcx, self.hir, callee, args, opts)
729 {
730 return Some(ReentrantCallKind::NoEth);
731 }
732 None
733 }
734}
735
736impl FlowState {
737 fn clear(&mut self) {
738 self.state_reads.clear();
739 self.pending_calls.clear();
740 }
741
742 fn merge(&mut self, other: &Self) {
743 self.state_reads.extend(other.state_reads.iter().copied());
744 for call in &other.pending_calls {
745 if let Some(existing) = self
746 .pending_calls
747 .iter_mut()
748 .find(|existing| existing.span == call.span && existing.kind == call.kind)
749 {
750 existing.state_reads.extend(call.state_reads.iter().copied());
751 } else {
752 self.pending_calls.push(call.clone());
753 }
754 }
755 }
756}
757
758fn is_uncapped_value_call(
759 hir: &hir::Hir<'_>,
760 callee: &hir::Expr<'_>,
761 opts: Option<&hir::CallOptions<'_>>,
762) -> bool {
763 let Some(opts) = opts else { return false };
764 let ExprKind::Member(_, member) = &callee.peel_parens().kind else { return false };
765 if member.name != kw::Call {
766 return false;
767 }
768
769 let mut value = None;
770 let mut gas = None;
771 for opt in opts.args {
772 if opt.name.name == sym::value {
773 value = Some(&opt.value);
774 } else if opt.name.name == kw::Gas {
775 gas = Some(&opt.value);
776 }
777 }
778
779 value.is_some_and(|value| !is_zero_value(hir, value)) && gas.is_none_or(gas_option_forwards_all)
780}
781
782fn is_no_eth_reentrant_call<'hir>(
783 gcx: Gcx<'hir>,
784 hir: &'hir hir::Hir<'hir>,
785 callee: &'hir hir::Expr<'hir>,
786 args: &CallArgs<'hir>,
787 opts: Option<&hir::CallOptions<'hir>>,
788) -> bool {
789 if call_sends_eth(hir, opts) {
790 return false;
791 }
792
793 match &callee.peel_parens().kind {
794 ExprKind::Member(receiver, member) => match member.name {
795 kw::Call | kw::Callcode | kw::Delegatecall => is_address_like(gcx, hir, receiver),
796 kw::Staticcall => false,
797 _ => external_member_call_can_reenter(gcx, hir, receiver, member.name, args),
798 },
799 _ => external_function_pointer_can_reenter(gcx, hir, callee, args),
800 }
801}
802
803fn call_sends_eth(hir: &hir::Hir<'_>, opts: Option<&hir::CallOptions<'_>>) -> bool {
804 opts.is_some_and(|opts| {
805 opts.args.iter().any(|opt| opt.name.name == sym::value && !is_zero_value(hir, &opt.value))
806 })
807}
808
809fn external_member_call_can_reenter<'hir>(
810 gcx: Gcx<'hir>,
811 hir: &'hir hir::Hir<'hir>,
812 receiver: &'hir hir::Expr<'hir>,
813 member: solar::interface::Symbol,
814 args: &CallArgs<'hir>,
815) -> bool {
816 if is_super(receiver) {
817 return false;
818 }
819
820 let Some(receiver_ty) = expr_ty(gcx, hir, receiver) else { return false };
821 gcx.members_of(receiver_ty, base_item_source(hir, receiver), base_contract(hir, receiver))
822 .filter(|candidate| candidate.name == member)
823 .any(|candidate| match (candidate.res, candidate.ty.kind) {
824 (Some(Res::Item(ItemId::Function(function_id))), _) => {
825 let function = hir.function(function_id);
826 is_externally_callable(function)
827 && args_match_function(gcx, hir, args, function.parameters)
828 && function.mutates_state()
829 }
830 (_, TyKind::Fn(function)) => {
831 is_externally_callable_fn_kind(function.kind)
832 && args_match_types(gcx, hir, args, function.parameters)
833 && !matches!(
834 function.state_mutability,
835 StateMutability::Pure | StateMutability::View
836 )
837 }
838 _ => false,
839 })
840}
841
842fn external_function_pointer_can_reenter<'hir>(
843 gcx: Gcx<'hir>,
844 hir: &'hir hir::Hir<'hir>,
845 callee: &'hir hir::Expr<'hir>,
846 args: &CallArgs<'hir>,
847) -> bool {
848 let Some(ty) = expr_ty(gcx, hir, callee) else { return false };
849 let TyKind::Fn(function) = ty.kind else { return false };
850 function.kind == TyFnKind::External
851 && args_match_types(gcx, hir, args, function.parameters)
852 && !matches!(function.state_mutability, StateMutability::Pure | StateMutability::View)
853}
854
855const fn is_externally_callable(func: &hir::Function<'_>) -> bool {
856 matches!(func.visibility, Visibility::Public | Visibility::External)
857}
858
859const fn is_externally_callable_fn_kind(kind: TyFnKind) -> bool {
860 matches!(kind, TyFnKind::External | TyFnKind::Declaration | TyFnKind::DelegateCall)
861}
862
863fn args_match_function<'hir>(
864 gcx: Gcx<'hir>,
865 hir: &'hir hir::Hir<'hir>,
866 args: &CallArgs<'hir>,
867 params: &'hir [VariableId],
868) -> bool {
869 if args.len() != params.len() {
870 return false;
871 }
872
873 match args.kind {
874 CallArgsKind::Unnamed(exprs) => exprs.iter().zip(params).all(|(arg, ¶m)| {
875 let param = hir.variable(param);
876 let param_ty =
877 gcx.type_of_hir_ty(¶m.ty).with_loc_if_ref_opt(gcx, param.data_location);
878 arg_matches_type(gcx, hir, arg, param_ty)
879 }),
880 CallArgsKind::Named(named_args) => named_args.iter().all(|arg| {
881 params
882 .iter()
883 .copied()
884 .find(|¶m| {
885 hir.variable(param).name.is_some_and(|name| name.name == arg.name.name)
886 })
887 .is_some_and(|param| {
888 let param = hir.variable(param);
889 let param_ty =
890 gcx.type_of_hir_ty(¶m.ty).with_loc_if_ref_opt(gcx, param.data_location);
891 arg_matches_type(gcx, hir, &arg.value, param_ty)
892 })
893 }),
894 }
895}
896
897fn args_match_types<'hir>(
898 gcx: Gcx<'hir>,
899 hir: &'hir hir::Hir<'hir>,
900 args: &CallArgs<'hir>,
901 params: &'hir [Ty<'hir>],
902) -> bool {
903 if args.len() != params.len() {
904 return false;
905 }
906
907 match args.kind {
908 CallArgsKind::Unnamed(exprs) => {
909 exprs.iter().zip(params).all(|(arg, ¶m)| arg_matches_type(gcx, hir, arg, param))
910 }
911 CallArgsKind::Named(_) => false,
912 }
913}
914
915fn arg_matches_type<'hir>(
916 gcx: Gcx<'hir>,
917 hir: &'hir hir::Hir<'hir>,
918 arg: &'hir hir::Expr<'hir>,
919 param_ty: Ty<'hir>,
920) -> bool {
921 expr_ty(gcx, hir, arg).is_some_and(|arg_ty| arg_ty.convert_implicit_to(param_ty, gcx))
922}
923
924fn is_address_like<'hir>(
925 gcx: Gcx<'hir>,
926 hir: &'hir hir::Hir<'hir>,
927 expr: &'hir hir::Expr<'hir>,
928) -> bool {
929 match &expr.peel_parens().kind {
930 ExprKind::Payable(_) => true,
931 ExprKind::Call(callee, _, _) if is_address_type_expr(callee) => true,
932 _ => expr_ty(gcx, hir, expr).is_some_and(type_is_address_like),
933 }
934}
935
936fn is_address_type_expr(expr: &hir::Expr<'_>) -> bool {
937 matches!(
938 &expr.peel_parens().kind,
939 ExprKind::Type(hir::Type {
940 kind: hir::TypeKind::Elementary(ElementaryType::Address(_)),
941 ..
942 })
943 )
944}
945
946fn type_is_address_like(ty: Ty<'_>) -> bool {
947 matches!(ty.peel_refs().kind, TyKind::Elementary(ElementaryType::Address(_)))
948}
949
950fn expr_ty<'hir>(
951 gcx: Gcx<'hir>,
952 hir: &'hir hir::Hir<'hir>,
953 expr: &'hir hir::Expr<'hir>,
954) -> Option<Ty<'hir>> {
955 match &expr.peel_parens().kind {
956 ExprKind::Array(_) | ExprKind::YulMember(..) => None,
957 ExprKind::Call(callee, args, _) => {
958 let callee_ty = expr_ty(gcx, hir, callee)?;
959 match callee_ty.kind {
960 TyKind::Fn(func) => fn_call_return_type(gcx, func.returns),
961 TyKind::Type(to) => Some(explicit_cast_ty(gcx, to, args)),
962 _ => None,
963 }
964 }
965 ExprKind::Ident(reses) => {
966 let res = unique(reses.iter().filter(|res| !matches!(res, Res::Err(_))).copied())?;
967 match res {
968 Res::Builtin(builtin) if matches!(builtin.name(), sym::this | sym::super_) => None,
969 Res::Item(ItemId::Variable(var_id)) => Some(
970 gcx.type_of_res(res)
971 .with_loc_if_ref_opt(gcx, variable_data_location(hir, var_id)),
972 ),
973 _ => Some(gcx.type_of_res(res)),
974 }
975 }
976 ExprKind::Index(lhs, index) => {
977 let lhs_ty = expr_ty(gcx, hir, lhs)?;
978 if let Some(index) = index
979 && !expr_ty(gcx, hir, index)?.convert_implicit_to(gcx.types.uint(256), gcx)
980 {
981 return None;
982 }
983 index_ty(gcx, lhs_ty)
984 }
985 ExprKind::Lit(lit) => Some(match &lit.kind {
986 LitKind::Str(StrKind::Hex, s, _) => {
987 let size = TypeSize::try_new_fb_bytes(s.as_byte_str().len().min(32) as u8)?;
988 gcx.types.fixed_bytes(size.bytes())
989 }
990 LitKind::Str(_, s, _) => gcx.mk_ty_string_literal(s.as_byte_str()),
991 LitKind::Number(int) => gcx.mk_ty_int_literal(false, int.bit_len() as _)?,
992 LitKind::Rational(_) | LitKind::Err(_) => return None,
993 LitKind::Address(_) => gcx.types.address,
994 LitKind::Bool(_) => gcx.types.bool,
995 }),
996 ExprKind::Member(base, member) => member_ty(gcx, hir, base, member.name),
997 ExprKind::New(ty) => {
998 let ty = gcx.type_of_hir_ty(ty);
999 Some(gcx.mk_ty(TyKind::Type(ty)))
1000 }
1001 ExprKind::Payable(inner) => {
1002 let inner_ty = expr_ty(gcx, hir, inner)?;
1003 inner_ty
1004 .convert_explicit_to(gcx.types.address_payable, gcx)
1005 .then_some(gcx.types.address_payable)
1006 }
1007 ExprKind::Slice(lhs, ..) => {
1008 let lhs_ty = expr_ty(gcx, hir, lhs)?;
1009 lhs_ty.is_sliceable().then_some(gcx.mk_ty(TyKind::Slice(lhs_ty)))
1010 }
1011 ExprKind::Tuple(exprs) => {
1012 let tys = exprs
1013 .iter()
1014 .map(|expr| expr.and_then(|expr| expr_ty(gcx, hir, expr)))
1015 .collect::<Option<Vec<_>>>()?;
1016 Some(gcx.mk_ty_tuple(gcx.mk_tys(&tys)))
1017 }
1018 ExprKind::Ternary(_, true_expr, false_expr) => {
1019 let true_ty = expr_ty(gcx, hir, true_expr)?;
1020 let false_ty = expr_ty(gcx, hir, false_expr)?;
1021 common_ty(gcx, true_ty, false_ty)
1022 }
1023 ExprKind::Type(ty) | ExprKind::TypeCall(ty) => {
1024 let ty = gcx.type_of_hir_ty(ty);
1025 Some(gcx.mk_ty(TyKind::Type(ty)))
1026 }
1027 ExprKind::Unary(op, inner) => match op.kind {
1028 UnOpKind::Not => Some(gcx.types.bool),
1029 _ => expr_ty(gcx, hir, inner),
1030 },
1031 ExprKind::Binary(_, op, _) if binary_op_returns_bool(op.kind) => Some(gcx.types.bool),
1032 ExprKind::Assign(..) | ExprKind::Binary(..) | ExprKind::Delete(..) | ExprKind::Err(_) => {
1033 None
1034 }
1035 }
1036}
1037
1038const fn binary_op_returns_bool(op: BinOpKind) -> bool {
1039 matches!(
1040 op,
1041 BinOpKind::Lt
1042 | BinOpKind::Le
1043 | BinOpKind::Gt
1044 | BinOpKind::Ge
1045 | BinOpKind::Eq
1046 | BinOpKind::Ne
1047 | BinOpKind::And
1048 | BinOpKind::Or
1049 )
1050}
1051
1052fn member_ty<'hir>(
1053 gcx: Gcx<'hir>,
1054 hir: &'hir hir::Hir<'hir>,
1055 base: &'hir hir::Expr<'hir>,
1056 member_name: solar::interface::Symbol,
1057) -> Option<Ty<'hir>> {
1058 if is_this(base) || is_super(base) {
1059 return None;
1060 }
1061
1062 let base_ty = expr_ty(gcx, hir, base)?;
1063 unique(
1064 gcx.members_of(base_ty, base_item_source(hir, base), base_contract(hir, base))
1065 .filter(|member| member.name == member_name)
1066 .map(|member| member.ty),
1067 )
1068}
1069
1070fn common_ty<'hir>(gcx: Gcx<'hir>, lhs: Ty<'hir>, rhs: Ty<'hir>) -> Option<Ty<'hir>> {
1071 if lhs.convert_implicit_to(rhs, gcx) {
1072 Some(rhs)
1073 } else {
1074 rhs.convert_implicit_to(lhs, gcx).then_some(lhs)
1075 }
1076}
1077
1078fn fn_call_return_type<'hir>(gcx: Gcx<'hir>, returns: &'hir [Ty<'hir>]) -> Option<Ty<'hir>> {
1079 Some(match returns {
1080 [] => gcx.types.unit,
1081 [ret] => *ret,
1082 _ => gcx.mk_ty_tuple(returns),
1083 })
1084}
1085
1086fn explicit_cast_ty<'hir>(gcx: Gcx<'hir>, to: Ty<'hir>, args: &'hir CallArgs<'hir>) -> Ty<'hir> {
1087 match args.exprs().next().and_then(|arg| expr_ty(gcx, &gcx.hir, arg)) {
1088 Some(from) => from.try_convert_explicit_to(to, gcx).unwrap_or(to),
1089 None => to,
1090 }
1091}
1092
1093fn index_ty<'hir>(gcx: Gcx<'hir>, base_ty: Ty<'hir>) -> Option<Ty<'hir>> {
1094 let loc = indexed_base_data_location(base_ty);
1095 match base_ty.peel_refs().kind {
1096 TyKind::Mapping(_, value) => Some(value.with_loc_if_ref_opt(gcx, loc)),
1097 _ => base_ty.base_type(gcx),
1098 }
1099}
1100
1101fn indexed_base_data_location(ty: Ty<'_>) -> Option<DataLocation> {
1102 ty.loc().or_else(|| matches!(ty.kind, TyKind::Mapping(..)).then_some(DataLocation::Storage))
1103}
1104
1105fn base_item_source(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> hir::SourceId {
1106 referenced_item(expr)
1107 .map(|id| hir.item(id).source())
1108 .unwrap_or_else(|| hir.sources_enumerated().next().expect("HIR has a source").0)
1109}
1110
1111fn base_contract(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> Option<hir::ContractId> {
1112 referenced_item(expr).and_then(|id| hir.item(id).contract())
1113}
1114
1115fn referenced_item(expr: &hir::Expr<'_>) -> Option<ItemId> {
1116 match &expr.peel_parens().kind {
1117 ExprKind::Ident([Res::Item(id), ..]) => Some(*id),
1118 _ => None,
1119 }
1120}
1121
1122fn variable_data_location(hir: &hir::Hir<'_>, var_id: VariableId) -> Option<DataLocation> {
1123 let var = hir.variable(var_id);
1124 var.data_location.or_else(|| var.kind.is_state().then_some(DataLocation::Storage))
1125}
1126
1127fn is_this(expr: &hir::Expr<'_>) -> bool {
1128 matches!(
1129 &expr.peel_parens().kind,
1130 ExprKind::Ident(reses)
1131 if reses.iter().any(|res| {
1132 matches!(res, Res::Builtin(builtin) if builtin.name() == sym::this)
1133 })
1134 )
1135}
1136
1137fn is_super(expr: &hir::Expr<'_>) -> bool {
1138 matches!(
1139 &expr.peel_parens().kind,
1140 ExprKind::Ident(reses)
1141 if reses.iter().any(|res| {
1142 matches!(res, Res::Builtin(builtin) if builtin.name() == sym::super_)
1143 })
1144 )
1145}
1146
1147fn unique<T>(mut iter: impl Iterator<Item = T>) -> Option<T> {
1148 let first = iter.next()?;
1149 iter.next().is_none().then_some(first)
1150}
1151
1152fn is_zero_value(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
1153 let mut seen = BTreeSet::new();
1154 is_zero_value_inner(hir, expr, &mut seen)
1155}
1156
1157fn is_zero_value_inner(
1158 hir: &hir::Hir<'_>,
1159 expr: &hir::Expr<'_>,
1160 seen: &mut BTreeSet<VariableId>,
1161) -> bool {
1162 match &expr.peel_parens().kind {
1163 ExprKind::Lit(lit) => matches!(lit.kind, LitKind::Number(value) if value.is_zero()),
1164 ExprKind::Ident(reses) => {
1165 let mut saw_variable = false;
1166 reses.iter().all(|res| match res {
1167 Res::Item(ItemId::Variable(var_id)) => {
1168 saw_variable = true;
1169 constant_var_is_zero(hir, *var_id, seen)
1170 }
1171 _ => false,
1172 }) && saw_variable
1173 }
1174 ExprKind::Call(callee, args, opts)
1175 if opts.is_none()
1176 && matches!(callee.peel_parens().kind, ExprKind::Type(_))
1177 && args.exprs().count() == 1 =>
1178 {
1179 args.exprs().next().is_some_and(|arg| is_zero_value_inner(hir, arg, seen))
1180 }
1181 _ => false,
1182 }
1183}
1184
1185fn constant_var_is_zero(
1186 hir: &hir::Hir<'_>,
1187 var_id: VariableId,
1188 seen: &mut BTreeSet<VariableId>,
1189) -> bool {
1190 let var = hir.variable(var_id);
1191 if !var.is_constant() || !seen.insert(var_id) {
1192 return false;
1193 }
1194 var.initializer.is_some_and(|init| is_zero_value_inner(hir, init, seen))
1195}
1196
1197fn gas_option_forwards_all(expr: &hir::Expr<'_>) -> bool {
1198 let ExprKind::Call(callee, args, opts) = &expr.peel_parens().kind else {
1199 return false;
1200 };
1201 if opts.is_some() || args.exprs().next().is_some() {
1202 return false;
1203 }
1204 matches!(
1205 &callee.peel_parens().kind,
1206 ExprKind::Ident(reses)
1207 if reses.iter().any(|res| {
1208 matches!(res, Res::Builtin(builtin) if builtin.name() == sym::gasleft)
1209 })
1210 )
1211}
1212
1213fn resolved_function_ids<'hir>(
1214 callee: &'hir hir::Expr<'hir>,
1215) -> impl Iterator<Item = FunctionId> + 'hir {
1216 let reses = match &callee.peel_parens().kind {
1217 ExprKind::Ident(reses) => *reses,
1218 _ => &[],
1219 };
1220 reses.iter().filter_map(|res| match res {
1221 Res::Item(ItemId::Function(func_id)) => Some(*func_id),
1222 _ => None,
1223 })
1224}
1225
1226fn state_write_lhs_vars(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> Vec<VariableId> {
1227 let mut vars = Vec::new();
1228 collect_state_write_lhs_vars(hir, expr, &mut vars);
1229 vars
1230}
1231
1232fn collect_state_write_lhs_vars(
1233 hir: &hir::Hir<'_>,
1234 expr: &hir::Expr<'_>,
1235 vars: &mut Vec<VariableId>,
1236) {
1237 match &expr.kind {
1238 ExprKind::Ident(reses) => {
1239 for &res in *reses {
1240 if let Res::Item(ItemId::Variable(var_id)) = res
1241 && hir.variable(var_id).kind.is_state()
1242 {
1243 push_unique(vars, var_id);
1244 }
1245 }
1246 }
1247 ExprKind::Index(base, _) | ExprKind::Slice(base, ..) => {
1248 collect_state_write_lhs_vars(hir, base, vars);
1249 }
1250 ExprKind::Member(base, _)
1251 | ExprKind::Payable(base)
1252 | ExprKind::Unary(_, base)
1253 | ExprKind::Delete(base) => collect_state_write_lhs_vars(hir, base, vars),
1254 ExprKind::Tuple(exprs) => {
1255 for expr in exprs.iter().copied().flatten() {
1256 collect_state_write_lhs_vars(hir, expr, vars);
1257 }
1258 }
1259 _ => {}
1260 }
1261}
1262
1263fn push_unique<T: Copy + Eq>(items: &mut Vec<T>, item: T) {
1264 if !items.contains(&item) {
1265 items.push(item);
1266 }
1267}