1use super::MissingZeroCheck;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{Severity, SolLint},
5};
6use solar::{
7 ast,
8 interface::{data_structures::Never, kw, sym},
9 sema::hir::{self, ElementaryType, ExprKind, ItemId, Res, StmtKind, TypeKind, Visit},
10};
11use std::{
12 collections::{HashMap, HashSet},
13 ops::ControlFlow,
14};
15
16declare_forge_lint!(
17 MISSING_ZERO_CHECK,
18 Severity::Low,
19 "missing-zero-check",
20 "address parameter is used in a state write or value transfer without a zero-address check"
21);
22
23impl<'hir> LateLintPass<'hir> for MissingZeroCheck {
24 fn check_function(
25 &mut self,
26 ctx: &LintContext,
27 hir: &'hir hir::Hir<'hir>,
28 func: &'hir hir::Function<'hir>,
29 ) {
30 if !is_entry_point(func) {
31 return;
32 }
33
34 let params: HashSet<hir::VariableId> =
35 func.parameters.iter().copied().filter(|id| is_address(hir, *id)).collect();
36
37 if params.is_empty() {
38 return;
39 }
40
41 let Some(body) = func.body else { return };
42
43 let mut a = Analyzer::new(hir, ¶ms);
44
45 for m in func.modifiers {
46 collect_modifier_guards(hir, m, ¶ms, &mut a.guarded);
47 }
48
49 for stmt in body.stmts {
50 let _ = a.visit_stmt(stmt);
51 }
52
53 for &p in ¶ms {
54 if a.sinks.contains(&p) {
55 ctx.emit(&MISSING_ZERO_CHECK, hir.variable(p).span);
56 }
57 }
58 }
59}
60
61fn is_entry_point(func: &hir::Function<'_>) -> bool {
63 if matches!(func.state_mutability, ast::StateMutability::Pure | ast::StateMutability::View) {
64 return false;
65 }
66 if func.is_constructor() {
67 return true;
68 }
69 func.kind.is_function()
70 && matches!(func.visibility, ast::Visibility::Public | ast::Visibility::External)
71}
72
73fn is_address(hir: &hir::Hir<'_>, id: hir::VariableId) -> bool {
74 matches!(hir.variable(id).ty.kind, TypeKind::Elementary(ElementaryType::Address(_)))
75}
76
77struct Analyzer<'hir> {
79 hir: &'hir hir::Hir<'hir>,
80 taint: HashMap<hir::VariableId, HashSet<hir::VariableId>>,
83 sinks: HashSet<hir::VariableId>,
85 guarded: HashSet<hir::VariableId>,
87 guard_depth: u32,
88 sink_depth: u32,
89}
90
91impl<'hir> Analyzer<'hir> {
92 fn new(hir: &'hir hir::Hir<'hir>, params: &HashSet<hir::VariableId>) -> Self {
93 let mut taint = HashMap::with_capacity(params.len());
94 for &p in params {
95 taint.insert(p, HashSet::from([p]));
96 }
97 Self {
98 hir,
99 taint,
100 sinks: HashSet::new(),
101 guarded: HashSet::new(),
102 guard_depth: 0,
103 sink_depth: 0,
104 }
105 }
106
107 fn taint_sources(&self, expr: &hir::Expr<'_>) -> HashSet<hir::VariableId> {
108 let mut out = HashSet::new();
109 collect_taint_sources(&self.taint, expr, &mut out);
110 out
111 }
112}
113
114fn collect_taint_sources(
115 taint: &HashMap<hir::VariableId, HashSet<hir::VariableId>>,
116 expr: &hir::Expr<'_>,
117 out: &mut HashSet<hir::VariableId>,
118) {
119 match &expr.kind {
120 ExprKind::Ident(reses) => {
121 for res in *reses {
122 if let Res::Item(ItemId::Variable(vid)) = res
123 && let Some(srcs) = taint.get(vid)
124 {
125 out.extend(srcs.iter().copied());
126 }
127 }
128 }
129 ExprKind::Assign(_, _, rhs) => collect_taint_sources(taint, rhs, out),
130 ExprKind::Binary(lhs, _, rhs) => {
131 collect_taint_sources(taint, lhs, out);
132 collect_taint_sources(taint, rhs, out);
133 }
134 ExprKind::Unary(_, e)
135 | ExprKind::Delete(e)
136 | ExprKind::Member(e, _)
137 | ExprKind::Payable(e) => collect_taint_sources(taint, e, out),
138 ExprKind::Ternary(_, t, f) => {
139 collect_taint_sources(taint, t, out);
140 collect_taint_sources(taint, f, out);
141 }
142 ExprKind::Tuple(elems) => {
143 for e in elems.iter().copied().flatten() {
144 collect_taint_sources(taint, e, out);
145 }
146 }
147 ExprKind::Array(elems) => {
148 for e in *elems {
149 collect_taint_sources(taint, e, out);
150 }
151 }
152 ExprKind::Index(base, idx) => {
153 collect_taint_sources(taint, base, out);
154 if let Some(i) = idx {
155 collect_taint_sources(taint, i, out);
156 }
157 }
158 ExprKind::Call(callee, args, _) => {
160 collect_taint_sources(taint, callee, out);
161 for a in args.exprs() {
162 collect_taint_sources(taint, a, out);
163 }
164 }
165 _ => {}
166 }
167}
168
169fn lhs_local_var(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> Option<hir::VariableId> {
172 if let ExprKind::Ident(reses) = &lhs.kind {
173 for res in *reses {
174 if let Res::Item(ItemId::Variable(vid)) = res
175 && !hir.variable(*vid).kind.is_state()
176 {
177 return Some(*vid);
178 }
179 }
180 }
181 None
182}
183
184impl<'hir> Visit<'hir> for Analyzer<'hir> {
185 type BreakValue = Never;
186
187 fn hir(&self) -> &'hir hir::Hir<'hir> {
188 self.hir
189 }
190
191 fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
192 match stmt.kind {
193 StmtKind::If(cond, then, else_) => {
194 self.guard_depth += 1;
195 let _ = self.visit_expr(cond);
196 self.guard_depth -= 1;
197
198 let baseline = self.guarded.clone();
199 let _ = self.visit_stmt(then);
200 let then_added: HashSet<hir::VariableId> =
201 self.guarded.difference(&baseline).copied().collect();
202 let then_exits = branch_always_exits(then);
203
204 let (else_added, else_exits) = if let Some(e) = else_ {
205 self.guarded = baseline.clone();
206 let _ = self.visit_stmt(e);
207 let added: HashSet<hir::VariableId> =
208 self.guarded.difference(&baseline).copied().collect();
209 (added, branch_always_exits(e))
210 } else {
211 (HashSet::new(), false)
212 };
213
214 self.guarded = baseline;
215 let to_add: HashSet<hir::VariableId> = match (then_exits, else_exits) {
216 (true, true) => then_added.union(&else_added).copied().collect(),
217 (true, false) => else_added,
218 (false, true) => then_added,
219 (false, false) => then_added.intersection(&else_added).copied().collect(),
220 };
221 self.guarded.extend(to_add);
222
223 return ControlFlow::Continue(());
224 }
225 StmtKind::Loop(block, _) => {
227 let baseline = self.guarded.clone();
228 for s in block.stmts {
229 let _ = self.visit_stmt(s);
230 }
231 self.guarded = baseline;
232 return ControlFlow::Continue(());
233 }
234 StmtKind::Try(t) => {
236 let _ = self.visit_expr(&t.expr);
237 for clause in t.clauses {
238 let baseline = self.guarded.clone();
239 for s in clause.block.stmts {
240 let _ = self.visit_stmt(s);
241 }
242 self.guarded = baseline;
243 }
244 return ControlFlow::Continue(());
245 }
246 StmtKind::DeclSingle(var_id) => {
249 let v = self.hir.variable(var_id);
250 if let Some(init) = v.initializer
251 && is_address(self.hir, var_id)
252 {
253 let srcs = self.taint_sources(init);
254 if !srcs.is_empty() {
255 self.taint.entry(var_id).or_default().extend(srcs);
256 }
257 }
258 }
259 _ => {}
260 }
261 self.walk_stmt(stmt)
262 }
263
264 fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
265 match &expr.kind {
266 ExprKind::Call(callee, args, _) if is_require_or_assert(callee) => {
268 let mut iter = args.exprs();
269 if let Some(cond) = iter.next() {
270 self.guard_depth += 1;
271 let _ = self.visit_expr(cond);
272 self.guard_depth -= 1;
273 }
274 for rest in iter {
275 let _ = self.visit_expr(rest);
276 }
277 return ControlFlow::Continue(());
278 }
279
280 ExprKind::Call(callee, args, _) => {
282 if let Some(receiver) = address_call_receiver(callee) {
283 self.sink_depth += 1;
284 let _ = self.visit_expr(receiver);
285 self.sink_depth -= 1;
286 let _ = self.visit_call_args(args);
287 return ControlFlow::Continue(());
288 }
289 }
290
291 ExprKind::Assign(lhs, _, rhs) => {
292 if is_address_state_var_lhs(self.hir, lhs) {
294 let _ = self.visit_expr(lhs);
295 self.sink_depth += 1;
296 let _ = self.visit_expr(rhs);
297 self.sink_depth -= 1;
298 return ControlFlow::Continue(());
299 }
300 if let Some(local) = lhs_local_var(self.hir, lhs)
302 && is_address(self.hir, local)
303 {
304 let srcs = self.taint_sources(rhs);
305 if !srcs.is_empty() {
306 self.taint.entry(local).or_default().extend(srcs);
307 }
308 }
309 }
310
311 ExprKind::Ident(reses) => {
313 for res in *reses {
314 if let Res::Item(ItemId::Variable(vid)) = res
315 && let Some(srcs) = self.taint.get(vid)
316 {
317 if self.guard_depth > 0 {
318 self.guarded.extend(srcs.iter().copied());
319 }
320 if self.sink_depth > 0 {
321 for &src in srcs {
322 if !self.guarded.contains(&src) {
323 self.sinks.insert(src);
324 }
325 }
326 }
327 }
328 }
329 }
330
331 _ => {}
332 }
333 self.walk_expr(expr)
334 }
335}
336
337fn is_require_or_assert(callee: &hir::Expr<'_>) -> bool {
338 if let ExprKind::Ident(reses) = &callee.kind {
339 return reses.iter().any(|r| {
340 if let Res::Builtin(b) = r {
341 let n = b.name();
342 n == sym::require || n == sym::assert
343 } else {
344 false
345 }
346 });
347 }
348 false
349}
350
351fn address_call_receiver<'hir>(callee: &'hir hir::Expr<'hir>) -> Option<&'hir hir::Expr<'hir>> {
354 let inner = match &callee.kind {
357 ExprKind::Call(inner, ..) => inner,
358 _ => callee,
359 };
360 let target = if matches!(inner.kind, ExprKind::Member(..)) { inner } else { callee };
361 if let ExprKind::Member(receiver, name) = &target.kind {
362 let n = name.name;
363 if n == kw::Call || n == kw::Delegatecall || n == sym::transfer || n == sym::send {
364 return Some(receiver);
365 }
366 }
367 None
368}
369
370fn branch_always_exits(stmt: &hir::Stmt<'_>) -> bool {
371 match &stmt.kind {
372 StmtKind::Return(_) | StmtKind::Revert(_) => true,
373 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
374 block.stmts.last().is_some_and(branch_always_exits)
375 }
376 StmtKind::If(_, t, Some(e)) => branch_always_exits(t) && branch_always_exits(e),
377 _ => false,
378 }
379}
380
381fn is_address_state_var_lhs(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> bool {
382 if let ExprKind::Ident(reses) = &lhs.kind {
383 for res in *reses {
384 if let Res::Item(ItemId::Variable(vid)) = res {
385 let v = hir.variable(*vid);
386 if v.kind.is_state()
387 && matches!(v.ty.kind, TypeKind::Elementary(ElementaryType::Address(_)))
388 {
389 return true;
390 }
391 }
392 }
393 }
394 false
395}
396
397fn collect_modifier_guards(
401 hir: &hir::Hir<'_>,
402 invocation: &hir::Modifier<'_>,
403 caller_params: &HashSet<hir::VariableId>,
404 guarded: &mut HashSet<hir::VariableId>,
405) {
406 let ItemId::Function(fid) = invocation.id else { return };
407 let modifier = hir.function(fid);
408 if !matches!(modifier.kind, hir::FunctionKind::Modifier) {
409 return;
410 }
411
412 let mod_params = modifier.parameters;
413 let mut mapping: HashSet<hir::VariableId> = HashSet::new();
414 let mut caller_for_modparam: HashMap<hir::VariableId, hir::VariableId> = HashMap::new();
415 for (i, arg_expr) in invocation.args.exprs().enumerate() {
416 if let ExprKind::Ident(reses) = &arg_expr.kind {
417 for res in *reses {
418 if let Res::Item(ItemId::Variable(vid)) = res
419 && caller_params.contains(vid)
420 && let Some(&mp) = mod_params.get(i)
421 {
422 caller_for_modparam.insert(mp, *vid);
423 mapping.insert(mp);
424 }
425 }
426 }
427 }
428 if mapping.is_empty() {
429 return;
430 }
431
432 let Some(body) = modifier.body else { return };
433 let mut a = Analyzer::new(hir, &mapping);
434 for stmt in body.stmts {
435 let _ = a.visit_stmt(stmt);
436 }
437
438 for mp in a.guarded {
439 if let Some(caller_vid) = caller_for_modparam.get(&mp) {
440 guarded.insert(*caller_vid);
441 }
442 }
443}