Skip to main content

forge_lint/sol/low/
missing_zero_check.rs

1use super::MissingZeroCheck;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{
5        Severity, SolLint,
6        analysis::primitives::{
7            address_call_receiver, branch_always_exits, is_address_type, is_require_or_assert,
8        },
9    },
10};
11use solar::{
12    ast,
13    interface::data_structures::Never,
14    sema::hir::{self, ExprKind, ItemId, Res, StmtKind, Visit},
15};
16use std::{
17    collections::{HashMap, HashSet},
18    ops::ControlFlow,
19};
20
21declare_forge_lint!(
22    MISSING_ZERO_CHECK,
23    Severity::Low,
24    "missing-zero-check",
25    "address parameter is used in a state write or value transfer without a zero-address check"
26);
27
28impl<'hir> LateLintPass<'hir> for MissingZeroCheck {
29    fn check_function(
30        &mut self,
31        ctx: &LintContext,
32        hir: &'hir hir::Hir<'hir>,
33        func: &'hir hir::Function<'hir>,
34    ) {
35        if !is_entry_point(func) {
36            return;
37        }
38
39        let params: HashSet<hir::VariableId> =
40            func.parameters.iter().copied().filter(|id| is_address_type(hir, *id)).collect();
41
42        if params.is_empty() {
43            return;
44        }
45
46        let Some(body) = func.body else { return };
47
48        let mut a = Analyzer::new(hir, &params);
49
50        for m in func.modifiers {
51            collect_modifier_guards(hir, m, &params, &mut a.guarded);
52        }
53
54        for stmt in body.stmts {
55            let _ = a.visit_stmt(stmt);
56        }
57
58        for &p in &params {
59            if a.sinks.contains(&p) {
60                ctx.emit(&MISSING_ZERO_CHECK, hir.variable(p).span);
61            }
62        }
63    }
64}
65
66/// Externally callable, state-mutating functions and constructors.
67fn is_entry_point(func: &hir::Function<'_>) -> bool {
68    if matches!(func.state_mutability, ast::StateMutability::Pure | ast::StateMutability::View) {
69        return false;
70    }
71    if func.is_constructor() {
72        return true;
73    }
74    func.kind.is_function()
75        && matches!(func.visibility, ast::Visibility::Public | ast::Visibility::External)
76}
77
78/// Tracks address-parameter taint, sinks reached, and guards observed in a function body.
79struct Analyzer<'hir> {
80    hir: &'hir hir::Hir<'hir>,
81    /// Variables transitively derived from candidate parameters, mapped to their sources.
82    /// Each parameter is initially mapped to itself.
83    taint: HashMap<hir::VariableId, HashSet<hir::VariableId>>,
84    /// Source parameters that reached a sink.
85    sinks: HashSet<hir::VariableId>,
86    /// Source parameters read inside an `if`/`require`/`assert` predicate.
87    guarded: HashSet<hir::VariableId>,
88    guard_depth: u32,
89    sink_depth: u32,
90}
91
92impl<'hir> Analyzer<'hir> {
93    fn new(hir: &'hir hir::Hir<'hir>, params: &HashSet<hir::VariableId>) -> Self {
94        let mut taint = HashMap::with_capacity(params.len());
95        for &p in params {
96            taint.insert(p, HashSet::from([p]));
97        }
98        Self {
99            hir,
100            taint,
101            sinks: HashSet::new(),
102            guarded: HashSet::new(),
103            guard_depth: 0,
104            sink_depth: 0,
105        }
106    }
107
108    fn taint_sources(&self, expr: &hir::Expr<'_>) -> HashSet<hir::VariableId> {
109        let mut out = HashSet::new();
110        collect_taint_sources(&self.taint, expr, &mut out);
111        out
112    }
113}
114
115fn collect_taint_sources(
116    taint: &HashMap<hir::VariableId, HashSet<hir::VariableId>>,
117    expr: &hir::Expr<'_>,
118    out: &mut HashSet<hir::VariableId>,
119) {
120    match &expr.kind {
121        ExprKind::Ident(reses) => {
122            for res in *reses {
123                if let Res::Item(ItemId::Variable(vid)) = res
124                    && let Some(srcs) = taint.get(vid)
125                {
126                    out.extend(srcs.iter().copied());
127                }
128            }
129        }
130        ExprKind::Assign(_, _, rhs) => collect_taint_sources(taint, rhs, out),
131        ExprKind::Binary(lhs, _, rhs) => {
132            collect_taint_sources(taint, lhs, out);
133            collect_taint_sources(taint, rhs, out);
134        }
135        ExprKind::Unary(_, e)
136        | ExprKind::Delete(e)
137        | ExprKind::Member(e, _)
138        | ExprKind::Payable(e) => collect_taint_sources(taint, e, out),
139        ExprKind::Ternary(_, t, f) => {
140            collect_taint_sources(taint, t, out);
141            collect_taint_sources(taint, f, out);
142        }
143        ExprKind::Tuple(elems) => {
144            for e in elems.iter().copied().flatten() {
145                collect_taint_sources(taint, e, out);
146            }
147        }
148        ExprKind::Array(elems) => {
149            for e in *elems {
150                collect_taint_sources(taint, e, out);
151            }
152        }
153        ExprKind::Index(base, idx) => {
154            collect_taint_sources(taint, base, out);
155            if let Some(i) = idx {
156                collect_taint_sources(taint, i, out);
157            }
158        }
159        // Covers type casts (`address(x)`), method calls, and ordinary calls; conservative.
160        ExprKind::Call(callee, args, _) => {
161            collect_taint_sources(taint, callee, out);
162            for a in args.exprs() {
163                collect_taint_sources(taint, a, out);
164            }
165        }
166        _ => {}
167    }
168}
169
170/// Returns the underlying local `VariableId` if `lhs` is a direct identifier reference to a
171/// non-state variable.
172fn lhs_local_var(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> Option<hir::VariableId> {
173    if let ExprKind::Ident(reses) = &lhs.kind {
174        for res in *reses {
175            if let Res::Item(ItemId::Variable(vid)) = res
176                && !hir.variable(*vid).kind.is_state()
177            {
178                return Some(*vid);
179            }
180        }
181    }
182    None
183}
184
185impl<'hir> Visit<'hir> for Analyzer<'hir> {
186    type BreakValue = Never;
187
188    fn hir(&self) -> &'hir hir::Hir<'hir> {
189        self.hir
190    }
191
192    fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
193        match stmt.kind {
194            StmtKind::If(cond, then, else_) => {
195                self.guard_depth += 1;
196                let _ = self.visit_expr(cond);
197                self.guard_depth -= 1;
198
199                let baseline = self.guarded.clone();
200                let _ = self.visit_stmt(then);
201                let then_added: HashSet<hir::VariableId> =
202                    self.guarded.difference(&baseline).copied().collect();
203                let then_exits = branch_always_exits(then);
204
205                let (else_added, else_exits) = if let Some(e) = else_ {
206                    self.guarded = baseline.clone();
207                    let _ = self.visit_stmt(e);
208                    let added: HashSet<hir::VariableId> =
209                        self.guarded.difference(&baseline).copied().collect();
210                    (added, branch_always_exits(e))
211                } else {
212                    (HashSet::new(), false)
213                };
214
215                self.guarded = baseline;
216                let to_add: HashSet<hir::VariableId> = match (then_exits, else_exits) {
217                    (true, true) => then_added.union(&else_added).copied().collect(),
218                    (true, false) => else_added,
219                    (false, true) => then_added,
220                    (false, false) => then_added.intersection(&else_added).copied().collect(),
221                };
222                self.guarded.extend(to_add);
223
224                return ControlFlow::Continue(());
225            }
226            // Loop bodies may execute zero times, so guards inside must not persist.
227            StmtKind::Loop(block, _) => {
228                let baseline = self.guarded.clone();
229                for s in block.stmts {
230                    let _ = self.visit_stmt(s);
231                }
232                self.guarded = baseline;
233                return ControlFlow::Continue(());
234            }
235            // Each try/catch clause is taken on a single path; discard clause-local guards.
236            StmtKind::Try(t) => {
237                let _ = self.visit_expr(&t.expr);
238                for clause in t.clauses {
239                    let baseline = self.guarded.clone();
240                    for s in clause.block.stmts {
241                        let _ = self.visit_stmt(s);
242                    }
243                    self.guarded = baseline;
244                }
245                return ControlFlow::Continue(());
246            }
247            // Propagate taint through address-typed local declarations only; this avoids
248            // marking unrelated values (e.g. `bool ok = a.send(1)`) as derived from `a`.
249            StmtKind::DeclSingle(var_id) => {
250                let v = self.hir.variable(var_id);
251                if let Some(init) = v.initializer
252                    && is_address_type(self.hir, var_id)
253                {
254                    let srcs = self.taint_sources(init);
255                    if !srcs.is_empty() {
256                        self.taint.entry(var_id).or_default().extend(srcs);
257                    }
258                }
259            }
260            _ => {}
261        }
262        self.walk_stmt(stmt)
263    }
264
265    fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
266        match &expr.kind {
267            // `require(cond, ..)` / `assert(cond)`: only the first arg is a guard predicate.
268            ExprKind::Call(callee, args, _) if is_require_or_assert(callee) => {
269                let mut iter = args.exprs();
270                if let Some(cond) = iter.next() {
271                    self.guard_depth += 1;
272                    let _ = self.visit_expr(cond);
273                    self.guard_depth -= 1;
274                }
275                for rest in iter {
276                    let _ = self.visit_expr(rest);
277                }
278                return ControlFlow::Continue(());
279            }
280
281            // `<addr>.call/.delegatecall/.transfer/.send(..)`: receiver is the sink.
282            ExprKind::Call(callee, args, _) => {
283                if let Some(receiver) = address_call_receiver(callee) {
284                    self.sink_depth += 1;
285                    let _ = self.visit_expr(receiver);
286                    self.sink_depth -= 1;
287                    let _ = self.visit_call_args(args);
288                    return ControlFlow::Continue(());
289                }
290            }
291
292            ExprKind::Assign(lhs, _, rhs) => {
293                // Sink: assignment to an address state variable.
294                if is_address_state_var_lhs(self.hir, lhs) {
295                    let _ = self.visit_expr(lhs);
296                    self.sink_depth += 1;
297                    let _ = self.visit_expr(rhs);
298                    self.sink_depth -= 1;
299                    return ControlFlow::Continue(());
300                }
301                // Taint propagation: assignment to an address local.
302                if let Some(local) = lhs_local_var(self.hir, lhs)
303                    && is_address_type(self.hir, local)
304                {
305                    let srcs = self.taint_sources(rhs);
306                    if !srcs.is_empty() {
307                        self.taint.entry(local).or_default().extend(srcs);
308                    }
309                }
310            }
311
312            // Identifier reads contribute to whichever contexts are currently active.
313            ExprKind::Ident(reses) => {
314                for res in *reses {
315                    if let Res::Item(ItemId::Variable(vid)) = res
316                        && let Some(srcs) = self.taint.get(vid)
317                    {
318                        if self.guard_depth > 0 {
319                            self.guarded.extend(srcs.iter().copied());
320                        }
321                        if self.sink_depth > 0 {
322                            for &src in srcs {
323                                if !self.guarded.contains(&src) {
324                                    self.sinks.insert(src);
325                                }
326                            }
327                        }
328                    }
329                }
330            }
331
332            _ => {}
333        }
334        self.walk_expr(expr)
335    }
336}
337
338fn is_address_state_var_lhs(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> bool {
339    if let ExprKind::Ident(reses) = &lhs.kind {
340        for res in *reses {
341            if let Res::Item(ItemId::Variable(vid)) = res
342                && hir.variable(*vid).kind.is_state()
343                && is_address_type(hir, *vid)
344            {
345                return true;
346            }
347        }
348    }
349    false
350}
351
352/// Maps each direct-ident modifier argument back to its caller-side parameter, runs the same guard
353/// analysis on the modifier body, and records any caller params whose mapped modifier parameter is
354/// guarded.
355fn collect_modifier_guards(
356    hir: &hir::Hir<'_>,
357    invocation: &hir::Modifier<'_>,
358    caller_params: &HashSet<hir::VariableId>,
359    guarded: &mut HashSet<hir::VariableId>,
360) {
361    let ItemId::Function(fid) = invocation.id else { return };
362    let modifier = hir.function(fid);
363    if !matches!(modifier.kind, hir::FunctionKind::Modifier) {
364        return;
365    }
366
367    let mod_params = modifier.parameters;
368    let mut mapping: HashSet<hir::VariableId> = HashSet::new();
369    let mut caller_for_modparam: HashMap<hir::VariableId, hir::VariableId> = HashMap::new();
370    for (i, arg_expr) in invocation.args.exprs().enumerate() {
371        if let ExprKind::Ident(reses) = &arg_expr.kind {
372            for res in *reses {
373                if let Res::Item(ItemId::Variable(vid)) = res
374                    && caller_params.contains(vid)
375                    && let Some(&mp) = mod_params.get(i)
376                {
377                    caller_for_modparam.insert(mp, *vid);
378                    mapping.insert(mp);
379                }
380            }
381        }
382    }
383    if mapping.is_empty() {
384        return;
385    }
386
387    let Some(body) = modifier.body else { return };
388    let mut a = Analyzer::new(hir, &mapping);
389    for stmt in body.stmts {
390        let _ = a.visit_stmt(stmt);
391    }
392
393    for mp in a.guarded {
394        if let Some(caller_vid) = caller_for_modparam.get(&mp) {
395            guarded.insert(*caller_vid);
396        }
397    }
398}