Skip to main content

forge_lint/sol/low/
missing_zero_check.rs

1use super::MissingZeroCheck;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    ast,
8    interface::{data_structures::Never, kw, sym},
9    sema::hir::{self, ElementaryType, ExprKind, ItemId, Res, StmtKind, TypeKind, Visit},
10};
11use std::{
12    collections::{HashMap, HashSet},
13    ops::ControlFlow,
14};
15
16declare_forge_lint!(
17    MISSING_ZERO_CHECK,
18    Severity::Low,
19    "missing-zero-check",
20    "address parameter is used in a state write or value transfer without a zero-address check"
21);
22
23impl<'hir> LateLintPass<'hir> for MissingZeroCheck {
24    fn check_function(
25        &mut self,
26        ctx: &LintContext,
27        hir: &'hir hir::Hir<'hir>,
28        func: &'hir hir::Function<'hir>,
29    ) {
30        if !is_entry_point(func) {
31            return;
32        }
33
34        let params: HashSet<hir::VariableId> =
35            func.parameters.iter().copied().filter(|id| is_address(hir, *id)).collect();
36
37        if params.is_empty() {
38            return;
39        }
40
41        let Some(body) = func.body else { return };
42
43        let mut a = Analyzer::new(hir, &params);
44
45        for m in func.modifiers {
46            collect_modifier_guards(hir, m, &params, &mut a.guarded);
47        }
48
49        for stmt in body.stmts {
50            let _ = a.visit_stmt(stmt);
51        }
52
53        for &p in &params {
54            if a.sinks.contains(&p) {
55                ctx.emit(&MISSING_ZERO_CHECK, hir.variable(p).span);
56            }
57        }
58    }
59}
60
61/// Externally callable, state-mutating functions and constructors.
62fn is_entry_point(func: &hir::Function<'_>) -> bool {
63    if matches!(func.state_mutability, ast::StateMutability::Pure | ast::StateMutability::View) {
64        return false;
65    }
66    if func.is_constructor() {
67        return true;
68    }
69    func.kind.is_function()
70        && matches!(func.visibility, ast::Visibility::Public | ast::Visibility::External)
71}
72
73fn is_address(hir: &hir::Hir<'_>, id: hir::VariableId) -> bool {
74    matches!(hir.variable(id).ty.kind, TypeKind::Elementary(ElementaryType::Address(_)))
75}
76
77/// Tracks address-parameter taint, sinks reached, and guards observed in a function body.
78struct Analyzer<'hir> {
79    hir: &'hir hir::Hir<'hir>,
80    /// Variables transitively derived from candidate parameters, mapped to their sources.
81    /// Each parameter is initially mapped to itself.
82    taint: HashMap<hir::VariableId, HashSet<hir::VariableId>>,
83    /// Source parameters that reached a sink.
84    sinks: HashSet<hir::VariableId>,
85    /// Source parameters read inside an `if`/`require`/`assert` predicate.
86    guarded: HashSet<hir::VariableId>,
87    guard_depth: u32,
88    sink_depth: u32,
89}
90
91impl<'hir> Analyzer<'hir> {
92    fn new(hir: &'hir hir::Hir<'hir>, params: &HashSet<hir::VariableId>) -> Self {
93        let mut taint = HashMap::with_capacity(params.len());
94        for &p in params {
95            taint.insert(p, HashSet::from([p]));
96        }
97        Self {
98            hir,
99            taint,
100            sinks: HashSet::new(),
101            guarded: HashSet::new(),
102            guard_depth: 0,
103            sink_depth: 0,
104        }
105    }
106
107    fn taint_sources(&self, expr: &hir::Expr<'_>) -> HashSet<hir::VariableId> {
108        let mut out = HashSet::new();
109        collect_taint_sources(&self.taint, expr, &mut out);
110        out
111    }
112}
113
114fn collect_taint_sources(
115    taint: &HashMap<hir::VariableId, HashSet<hir::VariableId>>,
116    expr: &hir::Expr<'_>,
117    out: &mut HashSet<hir::VariableId>,
118) {
119    match &expr.kind {
120        ExprKind::Ident(reses) => {
121            for res in *reses {
122                if let Res::Item(ItemId::Variable(vid)) = res
123                    && let Some(srcs) = taint.get(vid)
124                {
125                    out.extend(srcs.iter().copied());
126                }
127            }
128        }
129        ExprKind::Assign(_, _, rhs) => collect_taint_sources(taint, rhs, out),
130        ExprKind::Binary(lhs, _, rhs) => {
131            collect_taint_sources(taint, lhs, out);
132            collect_taint_sources(taint, rhs, out);
133        }
134        ExprKind::Unary(_, e)
135        | ExprKind::Delete(e)
136        | ExprKind::Member(e, _)
137        | ExprKind::Payable(e) => collect_taint_sources(taint, e, out),
138        ExprKind::Ternary(_, t, f) => {
139            collect_taint_sources(taint, t, out);
140            collect_taint_sources(taint, f, out);
141        }
142        ExprKind::Tuple(elems) => {
143            for e in elems.iter().copied().flatten() {
144                collect_taint_sources(taint, e, out);
145            }
146        }
147        ExprKind::Array(elems) => {
148            for e in *elems {
149                collect_taint_sources(taint, e, out);
150            }
151        }
152        ExprKind::Index(base, idx) => {
153            collect_taint_sources(taint, base, out);
154            if let Some(i) = idx {
155                collect_taint_sources(taint, i, out);
156            }
157        }
158        // Covers type casts (`address(x)`), method calls, and ordinary calls; conservative.
159        ExprKind::Call(callee, args, _) => {
160            collect_taint_sources(taint, callee, out);
161            for a in args.exprs() {
162                collect_taint_sources(taint, a, out);
163            }
164        }
165        _ => {}
166    }
167}
168
169/// Returns the underlying local `VariableId` if `lhs` is a direct identifier reference to a
170/// non-state variable.
171fn lhs_local_var(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> Option<hir::VariableId> {
172    if let ExprKind::Ident(reses) = &lhs.kind {
173        for res in *reses {
174            if let Res::Item(ItemId::Variable(vid)) = res
175                && !hir.variable(*vid).kind.is_state()
176            {
177                return Some(*vid);
178            }
179        }
180    }
181    None
182}
183
184impl<'hir> Visit<'hir> for Analyzer<'hir> {
185    type BreakValue = Never;
186
187    fn hir(&self) -> &'hir hir::Hir<'hir> {
188        self.hir
189    }
190
191    fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
192        match stmt.kind {
193            StmtKind::If(cond, then, else_) => {
194                self.guard_depth += 1;
195                let _ = self.visit_expr(cond);
196                self.guard_depth -= 1;
197
198                let baseline = self.guarded.clone();
199                let _ = self.visit_stmt(then);
200                let then_added: HashSet<hir::VariableId> =
201                    self.guarded.difference(&baseline).copied().collect();
202                let then_exits = branch_always_exits(then);
203
204                let (else_added, else_exits) = if let Some(e) = else_ {
205                    self.guarded = baseline.clone();
206                    let _ = self.visit_stmt(e);
207                    let added: HashSet<hir::VariableId> =
208                        self.guarded.difference(&baseline).copied().collect();
209                    (added, branch_always_exits(e))
210                } else {
211                    (HashSet::new(), false)
212                };
213
214                self.guarded = baseline;
215                let to_add: HashSet<hir::VariableId> = match (then_exits, else_exits) {
216                    (true, true) => then_added.union(&else_added).copied().collect(),
217                    (true, false) => else_added,
218                    (false, true) => then_added,
219                    (false, false) => then_added.intersection(&else_added).copied().collect(),
220                };
221                self.guarded.extend(to_add);
222
223                return ControlFlow::Continue(());
224            }
225            // Loop bodies may execute zero times, so guards inside must not persist.
226            StmtKind::Loop(block, _) => {
227                let baseline = self.guarded.clone();
228                for s in block.stmts {
229                    let _ = self.visit_stmt(s);
230                }
231                self.guarded = baseline;
232                return ControlFlow::Continue(());
233            }
234            // Each try/catch clause is taken on a single path; discard clause-local guards.
235            StmtKind::Try(t) => {
236                let _ = self.visit_expr(&t.expr);
237                for clause in t.clauses {
238                    let baseline = self.guarded.clone();
239                    for s in clause.block.stmts {
240                        let _ = self.visit_stmt(s);
241                    }
242                    self.guarded = baseline;
243                }
244                return ControlFlow::Continue(());
245            }
246            // Propagate taint through address-typed local declarations only; this avoids
247            // marking unrelated values (e.g. `bool ok = a.send(1)`) as derived from `a`.
248            StmtKind::DeclSingle(var_id) => {
249                let v = self.hir.variable(var_id);
250                if let Some(init) = v.initializer
251                    && is_address(self.hir, var_id)
252                {
253                    let srcs = self.taint_sources(init);
254                    if !srcs.is_empty() {
255                        self.taint.entry(var_id).or_default().extend(srcs);
256                    }
257                }
258            }
259            _ => {}
260        }
261        self.walk_stmt(stmt)
262    }
263
264    fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
265        match &expr.kind {
266            // `require(cond, ..)` / `assert(cond)`: only the first arg is a guard predicate.
267            ExprKind::Call(callee, args, _) if is_require_or_assert(callee) => {
268                let mut iter = args.exprs();
269                if let Some(cond) = iter.next() {
270                    self.guard_depth += 1;
271                    let _ = self.visit_expr(cond);
272                    self.guard_depth -= 1;
273                }
274                for rest in iter {
275                    let _ = self.visit_expr(rest);
276                }
277                return ControlFlow::Continue(());
278            }
279
280            // `<addr>.call/.delegatecall/.transfer/.send(..)`: receiver is the sink.
281            ExprKind::Call(callee, args, _) => {
282                if let Some(receiver) = address_call_receiver(callee) {
283                    self.sink_depth += 1;
284                    let _ = self.visit_expr(receiver);
285                    self.sink_depth -= 1;
286                    let _ = self.visit_call_args(args);
287                    return ControlFlow::Continue(());
288                }
289            }
290
291            ExprKind::Assign(lhs, _, rhs) => {
292                // Sink: assignment to an address state variable.
293                if is_address_state_var_lhs(self.hir, lhs) {
294                    let _ = self.visit_expr(lhs);
295                    self.sink_depth += 1;
296                    let _ = self.visit_expr(rhs);
297                    self.sink_depth -= 1;
298                    return ControlFlow::Continue(());
299                }
300                // Taint propagation: assignment to an address local.
301                if let Some(local) = lhs_local_var(self.hir, lhs)
302                    && is_address(self.hir, local)
303                {
304                    let srcs = self.taint_sources(rhs);
305                    if !srcs.is_empty() {
306                        self.taint.entry(local).or_default().extend(srcs);
307                    }
308                }
309            }
310
311            // Identifier reads contribute to whichever contexts are currently active.
312            ExprKind::Ident(reses) => {
313                for res in *reses {
314                    if let Res::Item(ItemId::Variable(vid)) = res
315                        && let Some(srcs) = self.taint.get(vid)
316                    {
317                        if self.guard_depth > 0 {
318                            self.guarded.extend(srcs.iter().copied());
319                        }
320                        if self.sink_depth > 0 {
321                            for &src in srcs {
322                                if !self.guarded.contains(&src) {
323                                    self.sinks.insert(src);
324                                }
325                            }
326                        }
327                    }
328                }
329            }
330
331            _ => {}
332        }
333        self.walk_expr(expr)
334    }
335}
336
337fn is_require_or_assert(callee: &hir::Expr<'_>) -> bool {
338    if let ExprKind::Ident(reses) = &callee.kind {
339        return reses.iter().any(|r| {
340            if let Res::Builtin(b) = r {
341                let n = b.name();
342                n == sym::require || n == sym::assert
343            } else {
344                false
345            }
346        });
347    }
348    false
349}
350
351/// If `callee` is `<receiver>.{call,delegatecall,transfer,send}` (with or without
352/// call options), returns the `<receiver>` expression.
353fn address_call_receiver<'hir>(callee: &'hir hir::Expr<'hir>) -> Option<&'hir hir::Expr<'hir>> {
354    // `addr.call{value: x}(..)` lowers as `Call(Member(receiver, "call"), ..)` — peel an
355    // outer call layer so the inner Member is reachable.
356    let inner = match &callee.kind {
357        ExprKind::Call(inner, ..) => inner,
358        _ => callee,
359    };
360    let target = if matches!(inner.kind, ExprKind::Member(..)) { inner } else { callee };
361    if let ExprKind::Member(receiver, name) = &target.kind {
362        let n = name.name;
363        if n == kw::Call || n == kw::Delegatecall || n == sym::transfer || n == sym::send {
364            return Some(receiver);
365        }
366    }
367    None
368}
369
370fn branch_always_exits(stmt: &hir::Stmt<'_>) -> bool {
371    match &stmt.kind {
372        StmtKind::Return(_) | StmtKind::Revert(_) => true,
373        StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
374            block.stmts.last().is_some_and(branch_always_exits)
375        }
376        StmtKind::If(_, t, Some(e)) => branch_always_exits(t) && branch_always_exits(e),
377        _ => false,
378    }
379}
380
381fn is_address_state_var_lhs(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> bool {
382    if let ExprKind::Ident(reses) = &lhs.kind {
383        for res in *reses {
384            if let Res::Item(ItemId::Variable(vid)) = res {
385                let v = hir.variable(*vid);
386                if v.kind.is_state()
387                    && matches!(v.ty.kind, TypeKind::Elementary(ElementaryType::Address(_)))
388                {
389                    return true;
390                }
391            }
392        }
393    }
394    false
395}
396
397/// Maps each direct-ident modifier argument back to its caller-side parameter, runs the same guard
398/// analysis on the modifier body, and records any caller params whose mapped modifier parameter is
399/// guarded.
400fn collect_modifier_guards(
401    hir: &hir::Hir<'_>,
402    invocation: &hir::Modifier<'_>,
403    caller_params: &HashSet<hir::VariableId>,
404    guarded: &mut HashSet<hir::VariableId>,
405) {
406    let ItemId::Function(fid) = invocation.id else { return };
407    let modifier = hir.function(fid);
408    if !matches!(modifier.kind, hir::FunctionKind::Modifier) {
409        return;
410    }
411
412    let mod_params = modifier.parameters;
413    let mut mapping: HashSet<hir::VariableId> = HashSet::new();
414    let mut caller_for_modparam: HashMap<hir::VariableId, hir::VariableId> = HashMap::new();
415    for (i, arg_expr) in invocation.args.exprs().enumerate() {
416        if let ExprKind::Ident(reses) = &arg_expr.kind {
417            for res in *reses {
418                if let Res::Item(ItemId::Variable(vid)) = res
419                    && caller_params.contains(vid)
420                    && let Some(&mp) = mod_params.get(i)
421                {
422                    caller_for_modparam.insert(mp, *vid);
423                    mapping.insert(mp);
424                }
425            }
426        }
427    }
428    if mapping.is_empty() {
429        return;
430    }
431
432    let Some(body) = modifier.body else { return };
433    let mut a = Analyzer::new(hir, &mapping);
434    for stmt in body.stmts {
435        let _ = a.visit_stmt(stmt);
436    }
437
438    for mp in a.guarded {
439        if let Some(caller_vid) = caller_for_modparam.get(&mp) {
440            guarded.insert(*caller_vid);
441        }
442    }
443}