1use super::ReentrancyUnlimitedGas;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{Severity, SolLint},
5};
6use solar::{
7 ast::{LitKind, StateMutability, UnOpKind, Visibility},
8 interface::{Span, kw, sym},
9 sema::hir::{self, ExprKind, FunctionId, ItemId, Res, StmtKind, VariableId},
10};
11use std::collections::{BTreeSet, HashSet};
12
13declare_forge_lint!(
14 REENTRANCY_UNLIMITED_GAS,
15 Severity::High,
16 "reentrancy-unlimited-gas",
17 "state read before uncapped ETH transfer is written after the transfer"
18);
19
20impl<'hir> LateLintPass<'hir> for ReentrancyUnlimitedGas {
21 fn check_function(
22 &mut self,
23 ctx: &LintContext,
24 hir: &'hir hir::Hir<'hir>,
25 func: &'hir hir::Function<'hir>,
26 ) {
27 if !is_entry_point(func) {
28 return;
29 }
30
31 let Some(body) = func.body else { return };
32
33 let mut analyzer = Analyzer::new(ctx, hir);
34 let mut state = FlowState::default();
35 analyzer.analyze_callable(func, body, &mut state);
36 }
37}
38
39fn is_entry_point(func: &hir::Function<'_>) -> bool {
40 if matches!(func.state_mutability, StateMutability::Pure | StateMutability::View) {
41 return false;
42 }
43 if func.is_constructor() {
44 return false;
45 }
46 if func.is_special() {
47 return true;
48 }
49 func.kind.is_function() && matches!(func.visibility, Visibility::Public | Visibility::External)
50}
51
52#[derive(Clone, Debug, Default)]
53struct FlowState {
54 state_reads: BTreeSet<VariableId>,
55 pending_value_calls: Vec<PendingValueCall>,
56}
57
58#[derive(Clone, Debug)]
59struct PendingValueCall {
60 span: Span,
61 state_reads: BTreeSet<VariableId>,
62}
63
64impl FlowState {
65 fn push_read(&mut self, var_id: VariableId) {
66 self.state_reads.insert(var_id);
67 }
68
69 fn push_call(&mut self, span: Span) {
70 if self.state_reads.is_empty() {
71 return;
72 }
73
74 if let Some(existing) = self.pending_value_calls.iter_mut().find(|call| call.span == span) {
75 existing.state_reads.extend(self.state_reads.iter().copied());
76 } else {
77 self.pending_value_calls
78 .push(PendingValueCall { span, state_reads: self.state_reads.clone() });
79 }
80 }
81}
82
83struct Analyzer<'ctx, 's, 'c, 'hir> {
84 ctx: &'ctx LintContext<'s, 'c>,
85 hir: &'hir hir::Hir<'hir>,
86 emitted: HashSet<Span>,
87 call_stack: Vec<FunctionId>,
88}
89
90impl<'ctx, 's, 'c, 'hir> Analyzer<'ctx, 's, 'c, 'hir> {
91 fn new(ctx: &'ctx LintContext<'s, 'c>, hir: &'hir hir::Hir<'hir>) -> Self {
92 Self { ctx, hir, emitted: HashSet::new(), call_stack: Vec::new() }
93 }
94
95 fn analyze_callable(
96 &mut self,
97 func: &'hir hir::Function<'hir>,
98 body: hir::Block<'hir>,
99 state: &mut FlowState,
100 ) -> bool {
101 self.analyze_modifier_chain(func.modifiers, 0, body, state)
102 }
103
104 fn analyze_modifier_chain(
105 &mut self,
106 modifiers: &'hir [hir::Modifier<'hir>],
107 index: usize,
108 body: hir::Block<'hir>,
109 state: &mut FlowState,
110 ) -> bool {
111 let Some(modifier) = modifiers.get(index) else {
112 return self.analyze_block(body, None, state);
113 };
114
115 for arg in modifier.args.exprs() {
116 self.analyze_expr(arg, state);
117 }
118
119 let Some(modifier_id) = modifier.id.as_function() else {
120 return self.analyze_modifier_chain(modifiers, index + 1, body, state);
121 };
122
123 if self.call_stack.contains(&modifier_id) {
124 return self.analyze_modifier_chain(modifiers, index + 1, body, state);
125 }
126
127 let modifier_func = self.hir.function(modifier_id);
128 let Some(modifier_body) = modifier_func.body else {
129 return self.analyze_modifier_chain(modifiers, index + 1, body, state);
130 };
131
132 self.call_stack.push(modifier_id);
133 let falls_through =
134 self.analyze_block(modifier_body, Some((modifiers, index + 1, body)), state);
135 self.call_stack.pop();
136 falls_through
137 }
138
139 fn analyze_block(
140 &mut self,
141 block: hir::Block<'hir>,
142 placeholder: Option<(&'hir [hir::Modifier<'hir>], usize, hir::Block<'hir>)>,
143 state: &mut FlowState,
144 ) -> bool {
145 for stmt in block.stmts {
146 if !self.analyze_stmt(stmt, placeholder, state) {
147 return false;
148 }
149 }
150 true
151 }
152
153 fn analyze_stmt(
154 &mut self,
155 stmt: &'hir hir::Stmt<'hir>,
156 placeholder: Option<(&'hir [hir::Modifier<'hir>], usize, hir::Block<'hir>)>,
157 state: &mut FlowState,
158 ) -> bool {
159 match stmt.kind {
160 StmtKind::DeclSingle(var_id) => {
161 if let Some(init) = self.hir.variable(var_id).initializer {
162 self.analyze_expr(init, state);
163 }
164 true
165 }
166 StmtKind::DeclMulti(_, expr) | StmtKind::Expr(expr) => {
167 self.analyze_expr(expr, state);
168 true
169 }
170 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
171 self.analyze_block(block, placeholder, state)
172 }
173 StmtKind::Emit(expr) => {
174 self.analyze_expr(expr, state);
175 true
176 }
177 StmtKind::Revert(expr) => {
178 self.analyze_expr(expr, state);
179 false
180 }
181 StmtKind::Return(expr) => {
182 if let Some(expr) = expr {
183 self.analyze_expr(expr, state);
184 }
185 false
186 }
187 StmtKind::Break | StmtKind::Continue => false,
188 StmtKind::Loop(block, _) => {
189 let before_loop = state.clone();
190 let mut body_state = state.clone();
191 self.analyze_block(block, placeholder, &mut body_state);
192 state.clear();
193 state.merge(&before_loop);
194 state.merge(&body_state);
195 true
196 }
197 StmtKind::If(cond, then_stmt, else_stmt) => {
198 self.analyze_expr(cond, state);
199
200 let mut then_state = state.clone();
201 let then_falls_through = self.analyze_stmt(then_stmt, placeholder, &mut then_state);
202
203 let mut else_state = state.clone();
204 let else_falls_through = if let Some(else_stmt) = else_stmt {
205 self.analyze_stmt(else_stmt, placeholder, &mut else_state)
206 } else {
207 true
208 };
209
210 state.clear();
211 if then_falls_through {
212 state.merge(&then_state);
213 }
214 if else_falls_through {
215 state.merge(&else_state);
216 }
217
218 then_falls_through || else_falls_through
219 }
220 StmtKind::Try(try_stmt) => {
221 self.analyze_expr(&try_stmt.expr, state);
222
223 let mut merged = FlowState::default();
224 let mut any_falls_through = false;
225 for clause in try_stmt.clauses {
226 let mut clause_state = state.clone();
227 let falls_through =
228 self.analyze_block(clause.block, placeholder, &mut clause_state);
229 if falls_through {
230 merged.merge(&clause_state);
231 any_falls_through = true;
232 }
233 }
234
235 *state = merged;
236 any_falls_through
237 }
238 StmtKind::Placeholder => {
239 if let Some((modifiers, index, body)) = placeholder {
240 self.analyze_modifier_chain(modifiers, index, body, state)
241 } else {
242 true
243 }
244 }
245 StmtKind::Err(_) => true,
246 }
247 }
248
249 fn analyze_expr(&mut self, expr: &'hir hir::Expr<'hir>, state: &mut FlowState) {
250 match &expr.kind {
251 ExprKind::Assign(lhs, op, rhs) => {
252 if op.is_some() {
253 self.analyze_expr(lhs, state);
254 }
255 self.analyze_expr(rhs, state);
256 let written_vars = state_write_lhs_vars(self.hir, lhs);
257 if !written_vars.is_empty() {
258 self.emit_pending_calls(state, &written_vars);
259 }
260 self.analyze_lhs_indices(lhs, state);
261 }
262 ExprKind::Delete(inner) => {
263 let written_vars = state_write_lhs_vars(self.hir, inner);
264 if !written_vars.is_empty() {
265 self.emit_pending_calls(state, &written_vars);
266 }
267 self.analyze_lhs_indices(inner, state);
268 }
269 ExprKind::Unary(op, inner)
270 if matches!(
271 op.kind,
272 UnOpKind::PreInc | UnOpKind::PreDec | UnOpKind::PostInc | UnOpKind::PostDec
273 ) =>
274 {
275 self.analyze_expr(inner, state);
276 let written_vars = state_write_lhs_vars(self.hir, inner);
277 if !written_vars.is_empty() {
278 self.emit_pending_calls(state, &written_vars);
279 }
280 }
281 ExprKind::Unary(_, inner) => {
282 self.analyze_expr(inner, state);
283 }
284 ExprKind::Call(callee, args, opts) => {
285 self.analyze_expr(callee, state);
286 if let Some(opts) = opts {
287 for opt in *opts {
288 self.analyze_expr(&opt.value, state);
289 }
290 }
291 for arg in args.exprs() {
292 self.analyze_expr(arg, state);
293 }
294
295 for func_id in resolved_function_ids(callee) {
296 self.analyze_internal_call(func_id, state);
297 }
298 if is_uncapped_value_call(callee, *opts) {
299 state.push_call(expr.span);
300 }
301 }
302 ExprKind::Binary(lhs, _, rhs) => {
303 self.analyze_expr(lhs, state);
304 self.analyze_expr(rhs, state);
305 }
306 ExprKind::Index(base, index) => {
307 self.analyze_expr(base, state);
308 if let Some(index) = index {
309 self.analyze_expr(index, state);
310 }
311 }
312 ExprKind::Slice(base, start, end) => {
313 self.analyze_expr(base, state);
314 if let Some(start) = start {
315 self.analyze_expr(start, state);
316 }
317 if let Some(end) = end {
318 self.analyze_expr(end, state);
319 }
320 }
321 ExprKind::Ternary(cond, true_expr, false_expr) => {
322 self.analyze_expr(cond, state);
323
324 let mut true_state = state.clone();
325 self.analyze_expr(true_expr, &mut true_state);
326
327 let mut false_state = state.clone();
328 self.analyze_expr(false_expr, &mut false_state);
329
330 state.clear();
331 state.merge(&true_state);
332 state.merge(&false_state);
333 }
334 ExprKind::Array(exprs) => {
335 for expr in *exprs {
336 self.analyze_expr(expr, state);
337 }
338 }
339 ExprKind::Tuple(exprs) => {
340 for expr in exprs.iter().copied().flatten() {
341 self.analyze_expr(expr, state);
342 }
343 }
344 ExprKind::Member(base, _) | ExprKind::Payable(base) => {
345 self.analyze_expr(base, state);
346 }
347 ExprKind::New(_) | ExprKind::TypeCall(_) | ExprKind::Type(_) => {}
348 ExprKind::Ident(reses) => {
349 for &res in *reses {
350 if let Res::Item(ItemId::Variable(var_id)) = res
351 && self.hir.variable(var_id).kind.is_state()
352 {
353 state.push_read(var_id);
354 }
355 }
356 }
357 ExprKind::Lit(_) | ExprKind::Err(_) => {}
358 }
359 }
360
361 fn analyze_internal_call(&mut self, func_id: FunctionId, state: &mut FlowState) {
362 if self.call_stack.contains(&func_id) {
363 return;
364 }
365
366 let func = self.hir.function(func_id);
367 let Some(body) = func.body else { return };
368
369 self.call_stack.push(func_id);
370 self.analyze_callable(func, body, state);
371 self.call_stack.pop();
372 }
373
374 fn analyze_lhs_indices(&mut self, expr: &'hir hir::Expr<'hir>, state: &mut FlowState) {
375 match &expr.kind {
376 ExprKind::Index(base, index) => {
377 self.analyze_lhs_indices(base, state);
378 if let Some(index) = index {
379 self.analyze_expr(index, state);
380 }
381 }
382 ExprKind::Slice(base, start, end) => {
383 self.analyze_lhs_indices(base, state);
384 if let Some(start) = start {
385 self.analyze_expr(start, state);
386 }
387 if let Some(end) = end {
388 self.analyze_expr(end, state);
389 }
390 }
391 ExprKind::Member(base, _) | ExprKind::Payable(base) => {
392 self.analyze_lhs_indices(base, state);
393 }
394 ExprKind::Tuple(exprs) => {
395 for expr in exprs.iter().copied().flatten() {
396 self.analyze_lhs_indices(expr, state);
397 }
398 }
399 _ => {}
400 }
401 }
402
403 fn emit_pending_calls(&mut self, state: &FlowState, written_vars: &[VariableId]) {
404 for call in &state.pending_value_calls {
405 if self.emitted.contains(&call.span) {
406 continue;
407 }
408
409 if let Some(var_id) =
410 written_vars.iter().find(|&&var_id| call.state_reads.contains(&var_id))
411 {
412 let name = self
413 .hir
414 .variable(*var_id)
415 .name
416 .map(|name| name.as_str().to_string())
417 .unwrap_or_else(|| "state".to_string());
418 self.ctx.emit_with_msg(
419 &REENTRANCY_UNLIMITED_GAS,
420 call.span,
421 format!("uncapped ETH transfer can be reentered before `{name}` is updated"),
422 );
423 self.emitted.insert(call.span);
424 }
425 }
426 }
427}
428
429impl FlowState {
430 fn clear(&mut self) {
431 self.state_reads.clear();
432 self.pending_value_calls.clear();
433 }
434
435 fn merge(&mut self, other: &Self) {
436 self.state_reads.extend(other.state_reads.iter().copied());
437 for call in &other.pending_value_calls {
438 if let Some(existing) =
439 self.pending_value_calls.iter_mut().find(|existing| existing.span == call.span)
440 {
441 existing.state_reads.extend(call.state_reads.iter().copied());
442 } else {
443 self.pending_value_calls.push(call.clone());
444 }
445 }
446 }
447}
448
449fn is_uncapped_value_call(callee: &hir::Expr<'_>, opts: Option<&[hir::NamedArg<'_>]>) -> bool {
450 let Some(opts) = opts else { return false };
451 let ExprKind::Member(_, member) = &callee.kind else { return false };
452 if member.name != kw::Call {
453 return false;
454 }
455
456 let mut value = None;
457 let mut gas = None;
458 for opt in opts {
459 if opt.name.name == sym::value {
460 value = Some(&opt.value);
461 } else if opt.name.name == kw::Gas {
462 gas = Some(&opt.value);
463 }
464 }
465
466 value.is_some_and(|value| !is_zero_literal(value)) && gas.is_none_or(gas_option_forwards_all)
467}
468
469fn is_zero_literal(expr: &hir::Expr<'_>) -> bool {
470 matches!(
471 &expr.peel_parens().kind,
472 ExprKind::Lit(lit) if matches!(lit.kind, LitKind::Number(value) if value.is_zero())
473 )
474}
475
476fn gas_option_forwards_all(expr: &hir::Expr<'_>) -> bool {
477 let ExprKind::Call(callee, args, opts) = &expr.peel_parens().kind else {
478 return false;
479 };
480 if opts.is_some() || args.exprs().next().is_some() {
481 return false;
482 }
483 matches!(
484 &callee.peel_parens().kind,
485 ExprKind::Ident(reses)
486 if reses.iter().any(|res| {
487 matches!(res, Res::Builtin(builtin) if builtin.name() == sym::gasleft)
488 })
489 )
490}
491
492fn resolved_function_ids<'hir>(
493 callee: &'hir hir::Expr<'hir>,
494) -> impl Iterator<Item = FunctionId> + 'hir {
495 let reses = match &callee.peel_parens().kind {
496 ExprKind::Ident(reses) => *reses,
497 _ => &[],
498 };
499 reses.iter().filter_map(|res| match res {
500 Res::Item(ItemId::Function(func_id)) => Some(*func_id),
501 _ => None,
502 })
503}
504
505fn state_write_lhs_vars(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> Vec<VariableId> {
506 let mut vars = Vec::new();
507 collect_state_write_lhs_vars(hir, expr, &mut vars);
508 vars
509}
510
511fn collect_state_write_lhs_vars(
512 hir: &hir::Hir<'_>,
513 expr: &hir::Expr<'_>,
514 vars: &mut Vec<VariableId>,
515) {
516 match &expr.kind {
517 ExprKind::Ident(reses) => {
518 for &res in *reses {
519 if let Res::Item(ItemId::Variable(var_id)) = res
520 && hir.variable(var_id).kind.is_state()
521 {
522 push_unique(vars, var_id);
523 }
524 }
525 }
526 ExprKind::Index(base, _) | ExprKind::Slice(base, ..) => {
527 collect_state_write_lhs_vars(hir, base, vars);
528 }
529 ExprKind::Member(base, _)
530 | ExprKind::Payable(base)
531 | ExprKind::Unary(_, base)
532 | ExprKind::Delete(base) => collect_state_write_lhs_vars(hir, base, vars),
533 ExprKind::Tuple(exprs) => {
534 for expr in exprs.iter().copied().flatten() {
535 collect_state_write_lhs_vars(hir, expr, vars);
536 }
537 }
538 _ => {}
539 }
540}
541
542fn push_unique<T: Copy + Eq>(items: &mut Vec<T>, item: T) {
543 if !items.contains(&item) {
544 items.push(item);
545 }
546}