Skip to main content

forge_lint/sol/low/
missing_zero_check.rs

1use super::MissingZeroCheck;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{
5        Severity, SolLint,
6        analysis::primitives::{
7            address_call_receiver, branch_always_exits, is_address_type, is_require_or_assert,
8        },
9    },
10};
11use solar::{
12    ast,
13    interface::data_structures::Never,
14    sema::hir::{self, ExprKind, ItemId, Res, StmtKind, Visit},
15};
16use std::{
17    collections::{HashMap, HashSet},
18    ops::ControlFlow,
19};
20
21declare_forge_lint!(
22    MISSING_ZERO_CHECK,
23    Severity::Low,
24    "missing-zero-check",
25    "address parameter is used in a state write or value transfer without a zero-address check"
26);
27
28impl<'hir> LateLintPass<'hir> for MissingZeroCheck {
29    fn check_function(
30        &mut self,
31        ctx: &LintContext,
32        _gcx: solar::sema::Gcx<'hir>,
33        hir: &'hir hir::Hir<'hir>,
34        func: &'hir hir::Function<'hir>,
35    ) {
36        if !is_entry_point(func) {
37            return;
38        }
39
40        let params: HashSet<hir::VariableId> =
41            func.parameters.iter().copied().filter(|id| is_address_type(hir, *id)).collect();
42
43        if params.is_empty() {
44            return;
45        }
46
47        let Some(body) = func.body else { return };
48
49        let mut a = Analyzer::new(hir, &params);
50
51        for m in func.modifiers {
52            collect_modifier_guards(hir, m, &params, &mut a.guarded);
53        }
54
55        for stmt in body.stmts {
56            let _ = a.visit_stmt(stmt);
57        }
58
59        for &p in &params {
60            if a.sinks.contains(&p) {
61                ctx.emit(&MISSING_ZERO_CHECK, hir.variable(p).span);
62            }
63        }
64    }
65}
66
67/// Externally callable, state-mutating functions and constructors.
68fn is_entry_point(func: &hir::Function<'_>) -> bool {
69    if matches!(func.state_mutability, ast::StateMutability::Pure | ast::StateMutability::View) {
70        return false;
71    }
72    if func.is_constructor() {
73        return true;
74    }
75    func.kind.is_function()
76        && matches!(func.visibility, ast::Visibility::Public | ast::Visibility::External)
77}
78
79/// Tracks address-parameter taint, sinks reached, and guards observed in a function body.
80struct Analyzer<'hir> {
81    hir: &'hir hir::Hir<'hir>,
82    /// Variables transitively derived from candidate parameters, mapped to their sources.
83    /// Each parameter is initially mapped to itself.
84    taint: HashMap<hir::VariableId, HashSet<hir::VariableId>>,
85    /// Source parameters that reached a sink.
86    sinks: HashSet<hir::VariableId>,
87    /// Source parameters read inside an `if`/`require`/`assert` predicate.
88    guarded: HashSet<hir::VariableId>,
89    guard_depth: u32,
90    sink_depth: u32,
91}
92
93impl<'hir> Analyzer<'hir> {
94    fn new(hir: &'hir hir::Hir<'hir>, params: &HashSet<hir::VariableId>) -> Self {
95        let mut taint = HashMap::with_capacity(params.len());
96        for &p in params {
97            taint.insert(p, HashSet::from([p]));
98        }
99        Self {
100            hir,
101            taint,
102            sinks: HashSet::new(),
103            guarded: HashSet::new(),
104            guard_depth: 0,
105            sink_depth: 0,
106        }
107    }
108
109    fn taint_sources(&self, expr: &hir::Expr<'_>) -> HashSet<hir::VariableId> {
110        let mut out = HashSet::new();
111        collect_taint_sources(&self.taint, expr, &mut out);
112        out
113    }
114}
115
116fn collect_taint_sources(
117    taint: &HashMap<hir::VariableId, HashSet<hir::VariableId>>,
118    expr: &hir::Expr<'_>,
119    out: &mut HashSet<hir::VariableId>,
120) {
121    match &expr.kind {
122        ExprKind::Ident(reses) => {
123            for res in *reses {
124                if let Res::Item(ItemId::Variable(vid)) = res
125                    && let Some(srcs) = taint.get(vid)
126                {
127                    out.extend(srcs.iter().copied());
128                }
129            }
130        }
131        ExprKind::Assign(_, _, rhs) => collect_taint_sources(taint, rhs, out),
132        ExprKind::Binary(lhs, _, rhs) => {
133            collect_taint_sources(taint, lhs, out);
134            collect_taint_sources(taint, rhs, out);
135        }
136        ExprKind::Unary(_, e)
137        | ExprKind::Delete(e)
138        | ExprKind::Member(e, _)
139        | ExprKind::Payable(e) => collect_taint_sources(taint, e, out),
140        ExprKind::Ternary(_, t, f) => {
141            collect_taint_sources(taint, t, out);
142            collect_taint_sources(taint, f, out);
143        }
144        ExprKind::Tuple(elems) => {
145            for e in elems.iter().copied().flatten() {
146                collect_taint_sources(taint, e, out);
147            }
148        }
149        ExprKind::Array(elems) => {
150            for e in *elems {
151                collect_taint_sources(taint, e, out);
152            }
153        }
154        ExprKind::Index(base, idx) => {
155            collect_taint_sources(taint, base, out);
156            if let Some(i) = idx {
157                collect_taint_sources(taint, i, out);
158            }
159        }
160        // Covers type casts (`address(x)`), method calls, and ordinary calls; conservative.
161        ExprKind::Call(callee, args, _) => {
162            collect_taint_sources(taint, callee, out);
163            for a in args.exprs() {
164                collect_taint_sources(taint, a, out);
165            }
166        }
167        _ => {}
168    }
169}
170
171/// Returns the underlying local `VariableId` if `lhs` is a direct identifier reference to a
172/// non-state variable.
173fn lhs_local_var(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> Option<hir::VariableId> {
174    if let ExprKind::Ident(reses) = &lhs.kind {
175        for res in *reses {
176            if let Res::Item(ItemId::Variable(vid)) = res
177                && !hir.variable(*vid).kind.is_state()
178            {
179                return Some(*vid);
180            }
181        }
182    }
183    None
184}
185
186impl<'hir> Visit<'hir> for Analyzer<'hir> {
187    type BreakValue = Never;
188
189    fn hir(&self) -> &'hir hir::Hir<'hir> {
190        self.hir
191    }
192
193    fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
194        match stmt.kind {
195            StmtKind::If(cond, then, else_) => {
196                self.guard_depth += 1;
197                let _ = self.visit_expr(cond);
198                self.guard_depth -= 1;
199
200                let baseline = self.guarded.clone();
201                let _ = self.visit_stmt(then);
202                let then_added: HashSet<hir::VariableId> =
203                    self.guarded.difference(&baseline).copied().collect();
204                let then_exits = branch_always_exits(then);
205
206                let (else_added, else_exits) = if let Some(e) = else_ {
207                    self.guarded = baseline.clone();
208                    let _ = self.visit_stmt(e);
209                    let added: HashSet<hir::VariableId> =
210                        self.guarded.difference(&baseline).copied().collect();
211                    (added, branch_always_exits(e))
212                } else {
213                    (HashSet::new(), false)
214                };
215
216                self.guarded = baseline;
217                let to_add: HashSet<hir::VariableId> = match (then_exits, else_exits) {
218                    (true, true) => then_added.union(&else_added).copied().collect(),
219                    (true, false) => else_added,
220                    (false, true) => then_added,
221                    (false, false) => then_added.intersection(&else_added).copied().collect(),
222                };
223                self.guarded.extend(to_add);
224
225                return ControlFlow::Continue(());
226            }
227            // Loop bodies may execute zero times, so guards inside must not persist.
228            StmtKind::Loop(block, _) => {
229                let baseline = self.guarded.clone();
230                for s in block.stmts {
231                    let _ = self.visit_stmt(s);
232                }
233                self.guarded = baseline;
234                return ControlFlow::Continue(());
235            }
236            // Each try/catch clause is taken on a single path; discard clause-local guards.
237            StmtKind::Try(t) => {
238                let _ = self.visit_expr(&t.expr);
239                for clause in t.clauses {
240                    let baseline = self.guarded.clone();
241                    for s in clause.block.stmts {
242                        let _ = self.visit_stmt(s);
243                    }
244                    self.guarded = baseline;
245                }
246                return ControlFlow::Continue(());
247            }
248            // Propagate taint through address-typed local declarations only; this avoids
249            // marking unrelated values (e.g. `bool ok = a.send(1)`) as derived from `a`.
250            StmtKind::DeclSingle(var_id) => {
251                let v = self.hir.variable(var_id);
252                if let Some(init) = v.initializer
253                    && is_address_type(self.hir, var_id)
254                {
255                    let srcs = self.taint_sources(init);
256                    if !srcs.is_empty() {
257                        self.taint.entry(var_id).or_default().extend(srcs);
258                    }
259                }
260            }
261            _ => {}
262        }
263        self.walk_stmt(stmt)
264    }
265
266    fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
267        match &expr.kind {
268            // `require(cond, ..)` / `assert(cond)`: only the first arg is a guard predicate.
269            ExprKind::Call(callee, args, _) if is_require_or_assert(callee) => {
270                let mut iter = args.exprs();
271                if let Some(cond) = iter.next() {
272                    self.guard_depth += 1;
273                    let _ = self.visit_expr(cond);
274                    self.guard_depth -= 1;
275                }
276                for rest in iter {
277                    let _ = self.visit_expr(rest);
278                }
279                return ControlFlow::Continue(());
280            }
281
282            // `<addr>.call/.delegatecall/.transfer/.send(..)`: receiver is the sink.
283            ExprKind::Call(callee, args, _) => {
284                if let Some(receiver) = address_call_receiver(callee) {
285                    self.sink_depth += 1;
286                    let _ = self.visit_expr(receiver);
287                    self.sink_depth -= 1;
288                    let _ = self.visit_call_args(args);
289                    return ControlFlow::Continue(());
290                }
291            }
292
293            ExprKind::Assign(lhs, _, rhs) => {
294                // Sink: assignment to an address state variable.
295                if is_address_state_var_lhs(self.hir, lhs) {
296                    let _ = self.visit_expr(lhs);
297                    self.sink_depth += 1;
298                    let _ = self.visit_expr(rhs);
299                    self.sink_depth -= 1;
300                    return ControlFlow::Continue(());
301                }
302                // Taint propagation: assignment to an address local.
303                if let Some(local) = lhs_local_var(self.hir, lhs)
304                    && is_address_type(self.hir, local)
305                {
306                    let srcs = self.taint_sources(rhs);
307                    if !srcs.is_empty() {
308                        self.taint.entry(local).or_default().extend(srcs);
309                    }
310                }
311            }
312
313            // Identifier reads contribute to whichever contexts are currently active.
314            ExprKind::Ident(reses) => {
315                for res in *reses {
316                    if let Res::Item(ItemId::Variable(vid)) = res
317                        && let Some(srcs) = self.taint.get(vid)
318                    {
319                        if self.guard_depth > 0 {
320                            self.guarded.extend(srcs.iter().copied());
321                        }
322                        if self.sink_depth > 0 {
323                            for &src in srcs {
324                                if !self.guarded.contains(&src) {
325                                    self.sinks.insert(src);
326                                }
327                            }
328                        }
329                    }
330                }
331            }
332
333            _ => {}
334        }
335        self.walk_expr(expr)
336    }
337}
338
339fn is_address_state_var_lhs(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> bool {
340    if let ExprKind::Ident(reses) = &lhs.kind {
341        for res in *reses {
342            if let Res::Item(ItemId::Variable(vid)) = res
343                && hir.variable(*vid).kind.is_state()
344                && is_address_type(hir, *vid)
345            {
346                return true;
347            }
348        }
349    }
350    false
351}
352
353/// Maps each direct-ident modifier argument back to its caller-side parameter, runs the same guard
354/// analysis on the modifier body, and records any caller params whose mapped modifier parameter is
355/// guarded.
356fn collect_modifier_guards<'hir>(
357    hir: &'hir hir::Hir<'hir>,
358    invocation: &hir::Modifier<'hir>,
359    caller_params: &HashSet<hir::VariableId>,
360    guarded: &mut HashSet<hir::VariableId>,
361) {
362    let ItemId::Function(fid) = invocation.id else { return };
363    let modifier = hir.function(fid);
364    if !matches!(modifier.kind, hir::FunctionKind::Modifier) {
365        return;
366    }
367
368    let mod_params = modifier.parameters;
369    let mut mapping: HashSet<hir::VariableId> = HashSet::new();
370    let mut caller_for_modparam: HashMap<hir::VariableId, hir::VariableId> = HashMap::new();
371    for (i, arg_expr) in invocation.args.exprs().enumerate() {
372        if let ExprKind::Ident(reses) = &arg_expr.kind {
373            for res in *reses {
374                if let Res::Item(ItemId::Variable(vid)) = res
375                    && caller_params.contains(vid)
376                    && let Some(&mp) = mod_params.get(i)
377                {
378                    caller_for_modparam.insert(mp, *vid);
379                    mapping.insert(mp);
380                }
381            }
382        }
383    }
384    if mapping.is_empty() {
385        return;
386    }
387
388    let Some(body) = modifier.body else { return };
389    let mut a = Analyzer::new(hir, &mapping);
390    for stmt in body.stmts {
391        let _ = a.visit_stmt(stmt);
392    }
393
394    for mp in a.guarded {
395        if let Some(caller_vid) = caller_for_modparam.get(&mp) {
396            guarded.insert(*caller_vid);
397        }
398    }
399}