1use super::MissingZeroCheck;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{
5 Severity, SolLint,
6 analysis::primitives::{
7 address_call_receiver, branch_always_exits, is_address_type, is_require_or_assert,
8 },
9 },
10};
11use solar::{
12 ast,
13 interface::data_structures::Never,
14 sema::hir::{self, ExprKind, ItemId, Res, StmtKind, Visit},
15};
16use std::{
17 collections::{HashMap, HashSet},
18 ops::ControlFlow,
19};
20
21declare_forge_lint!(
22 MISSING_ZERO_CHECK,
23 Severity::Low,
24 "missing-zero-check",
25 "address parameter is used in a state write or value transfer without a zero-address check"
26);
27
28impl<'hir> LateLintPass<'hir> for MissingZeroCheck {
29 fn check_function(
30 &mut self,
31 ctx: &LintContext,
32 hir: &'hir hir::Hir<'hir>,
33 func: &'hir hir::Function<'hir>,
34 ) {
35 if !is_entry_point(func) {
36 return;
37 }
38
39 let params: HashSet<hir::VariableId> =
40 func.parameters.iter().copied().filter(|id| is_address_type(hir, *id)).collect();
41
42 if params.is_empty() {
43 return;
44 }
45
46 let Some(body) = func.body else { return };
47
48 let mut a = Analyzer::new(hir, ¶ms);
49
50 for m in func.modifiers {
51 collect_modifier_guards(hir, m, ¶ms, &mut a.guarded);
52 }
53
54 for stmt in body.stmts {
55 let _ = a.visit_stmt(stmt);
56 }
57
58 for &p in ¶ms {
59 if a.sinks.contains(&p) {
60 ctx.emit(&MISSING_ZERO_CHECK, hir.variable(p).span);
61 }
62 }
63 }
64}
65
66fn is_entry_point(func: &hir::Function<'_>) -> bool {
68 if matches!(func.state_mutability, ast::StateMutability::Pure | ast::StateMutability::View) {
69 return false;
70 }
71 if func.is_constructor() {
72 return true;
73 }
74 func.kind.is_function()
75 && matches!(func.visibility, ast::Visibility::Public | ast::Visibility::External)
76}
77
78struct Analyzer<'hir> {
80 hir: &'hir hir::Hir<'hir>,
81 taint: HashMap<hir::VariableId, HashSet<hir::VariableId>>,
84 sinks: HashSet<hir::VariableId>,
86 guarded: HashSet<hir::VariableId>,
88 guard_depth: u32,
89 sink_depth: u32,
90}
91
92impl<'hir> Analyzer<'hir> {
93 fn new(hir: &'hir hir::Hir<'hir>, params: &HashSet<hir::VariableId>) -> Self {
94 let mut taint = HashMap::with_capacity(params.len());
95 for &p in params {
96 taint.insert(p, HashSet::from([p]));
97 }
98 Self {
99 hir,
100 taint,
101 sinks: HashSet::new(),
102 guarded: HashSet::new(),
103 guard_depth: 0,
104 sink_depth: 0,
105 }
106 }
107
108 fn taint_sources(&self, expr: &hir::Expr<'_>) -> HashSet<hir::VariableId> {
109 let mut out = HashSet::new();
110 collect_taint_sources(&self.taint, expr, &mut out);
111 out
112 }
113}
114
115fn collect_taint_sources(
116 taint: &HashMap<hir::VariableId, HashSet<hir::VariableId>>,
117 expr: &hir::Expr<'_>,
118 out: &mut HashSet<hir::VariableId>,
119) {
120 match &expr.kind {
121 ExprKind::Ident(reses) => {
122 for res in *reses {
123 if let Res::Item(ItemId::Variable(vid)) = res
124 && let Some(srcs) = taint.get(vid)
125 {
126 out.extend(srcs.iter().copied());
127 }
128 }
129 }
130 ExprKind::Assign(_, _, rhs) => collect_taint_sources(taint, rhs, out),
131 ExprKind::Binary(lhs, _, rhs) => {
132 collect_taint_sources(taint, lhs, out);
133 collect_taint_sources(taint, rhs, out);
134 }
135 ExprKind::Unary(_, e)
136 | ExprKind::Delete(e)
137 | ExprKind::Member(e, _)
138 | ExprKind::Payable(e) => collect_taint_sources(taint, e, out),
139 ExprKind::Ternary(_, t, f) => {
140 collect_taint_sources(taint, t, out);
141 collect_taint_sources(taint, f, out);
142 }
143 ExprKind::Tuple(elems) => {
144 for e in elems.iter().copied().flatten() {
145 collect_taint_sources(taint, e, out);
146 }
147 }
148 ExprKind::Array(elems) => {
149 for e in *elems {
150 collect_taint_sources(taint, e, out);
151 }
152 }
153 ExprKind::Index(base, idx) => {
154 collect_taint_sources(taint, base, out);
155 if let Some(i) = idx {
156 collect_taint_sources(taint, i, out);
157 }
158 }
159 ExprKind::Call(callee, args, _) => {
161 collect_taint_sources(taint, callee, out);
162 for a in args.exprs() {
163 collect_taint_sources(taint, a, out);
164 }
165 }
166 _ => {}
167 }
168}
169
170fn lhs_local_var(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> Option<hir::VariableId> {
173 if let ExprKind::Ident(reses) = &lhs.kind {
174 for res in *reses {
175 if let Res::Item(ItemId::Variable(vid)) = res
176 && !hir.variable(*vid).kind.is_state()
177 {
178 return Some(*vid);
179 }
180 }
181 }
182 None
183}
184
185impl<'hir> Visit<'hir> for Analyzer<'hir> {
186 type BreakValue = Never;
187
188 fn hir(&self) -> &'hir hir::Hir<'hir> {
189 self.hir
190 }
191
192 fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
193 match stmt.kind {
194 StmtKind::If(cond, then, else_) => {
195 self.guard_depth += 1;
196 let _ = self.visit_expr(cond);
197 self.guard_depth -= 1;
198
199 let baseline = self.guarded.clone();
200 let _ = self.visit_stmt(then);
201 let then_added: HashSet<hir::VariableId> =
202 self.guarded.difference(&baseline).copied().collect();
203 let then_exits = branch_always_exits(then);
204
205 let (else_added, else_exits) = if let Some(e) = else_ {
206 self.guarded = baseline.clone();
207 let _ = self.visit_stmt(e);
208 let added: HashSet<hir::VariableId> =
209 self.guarded.difference(&baseline).copied().collect();
210 (added, branch_always_exits(e))
211 } else {
212 (HashSet::new(), false)
213 };
214
215 self.guarded = baseline;
216 let to_add: HashSet<hir::VariableId> = match (then_exits, else_exits) {
217 (true, true) => then_added.union(&else_added).copied().collect(),
218 (true, false) => else_added,
219 (false, true) => then_added,
220 (false, false) => then_added.intersection(&else_added).copied().collect(),
221 };
222 self.guarded.extend(to_add);
223
224 return ControlFlow::Continue(());
225 }
226 StmtKind::Loop(block, _) => {
228 let baseline = self.guarded.clone();
229 for s in block.stmts {
230 let _ = self.visit_stmt(s);
231 }
232 self.guarded = baseline;
233 return ControlFlow::Continue(());
234 }
235 StmtKind::Try(t) => {
237 let _ = self.visit_expr(&t.expr);
238 for clause in t.clauses {
239 let baseline = self.guarded.clone();
240 for s in clause.block.stmts {
241 let _ = self.visit_stmt(s);
242 }
243 self.guarded = baseline;
244 }
245 return ControlFlow::Continue(());
246 }
247 StmtKind::DeclSingle(var_id) => {
250 let v = self.hir.variable(var_id);
251 if let Some(init) = v.initializer
252 && is_address_type(self.hir, var_id)
253 {
254 let srcs = self.taint_sources(init);
255 if !srcs.is_empty() {
256 self.taint.entry(var_id).or_default().extend(srcs);
257 }
258 }
259 }
260 _ => {}
261 }
262 self.walk_stmt(stmt)
263 }
264
265 fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
266 match &expr.kind {
267 ExprKind::Call(callee, args, _) if is_require_or_assert(callee) => {
269 let mut iter = args.exprs();
270 if let Some(cond) = iter.next() {
271 self.guard_depth += 1;
272 let _ = self.visit_expr(cond);
273 self.guard_depth -= 1;
274 }
275 for rest in iter {
276 let _ = self.visit_expr(rest);
277 }
278 return ControlFlow::Continue(());
279 }
280
281 ExprKind::Call(callee, args, _) => {
283 if let Some(receiver) = address_call_receiver(callee) {
284 self.sink_depth += 1;
285 let _ = self.visit_expr(receiver);
286 self.sink_depth -= 1;
287 let _ = self.visit_call_args(args);
288 return ControlFlow::Continue(());
289 }
290 }
291
292 ExprKind::Assign(lhs, _, rhs) => {
293 if is_address_state_var_lhs(self.hir, lhs) {
295 let _ = self.visit_expr(lhs);
296 self.sink_depth += 1;
297 let _ = self.visit_expr(rhs);
298 self.sink_depth -= 1;
299 return ControlFlow::Continue(());
300 }
301 if let Some(local) = lhs_local_var(self.hir, lhs)
303 && is_address_type(self.hir, local)
304 {
305 let srcs = self.taint_sources(rhs);
306 if !srcs.is_empty() {
307 self.taint.entry(local).or_default().extend(srcs);
308 }
309 }
310 }
311
312 ExprKind::Ident(reses) => {
314 for res in *reses {
315 if let Res::Item(ItemId::Variable(vid)) = res
316 && let Some(srcs) = self.taint.get(vid)
317 {
318 if self.guard_depth > 0 {
319 self.guarded.extend(srcs.iter().copied());
320 }
321 if self.sink_depth > 0 {
322 for &src in srcs {
323 if !self.guarded.contains(&src) {
324 self.sinks.insert(src);
325 }
326 }
327 }
328 }
329 }
330 }
331
332 _ => {}
333 }
334 self.walk_expr(expr)
335 }
336}
337
338fn is_address_state_var_lhs(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> bool {
339 if let ExprKind::Ident(reses) = &lhs.kind {
340 for res in *reses {
341 if let Res::Item(ItemId::Variable(vid)) = res
342 && hir.variable(*vid).kind.is_state()
343 && is_address_type(hir, *vid)
344 {
345 return true;
346 }
347 }
348 }
349 false
350}
351
352fn collect_modifier_guards(
356 hir: &hir::Hir<'_>,
357 invocation: &hir::Modifier<'_>,
358 caller_params: &HashSet<hir::VariableId>,
359 guarded: &mut HashSet<hir::VariableId>,
360) {
361 let ItemId::Function(fid) = invocation.id else { return };
362 let modifier = hir.function(fid);
363 if !matches!(modifier.kind, hir::FunctionKind::Modifier) {
364 return;
365 }
366
367 let mod_params = modifier.parameters;
368 let mut mapping: HashSet<hir::VariableId> = HashSet::new();
369 let mut caller_for_modparam: HashMap<hir::VariableId, hir::VariableId> = HashMap::new();
370 for (i, arg_expr) in invocation.args.exprs().enumerate() {
371 if let ExprKind::Ident(reses) = &arg_expr.kind {
372 for res in *reses {
373 if let Res::Item(ItemId::Variable(vid)) = res
374 && caller_params.contains(vid)
375 && let Some(&mp) = mod_params.get(i)
376 {
377 caller_for_modparam.insert(mp, *vid);
378 mapping.insert(mp);
379 }
380 }
381 }
382 }
383 if mapping.is_empty() {
384 return;
385 }
386
387 let Some(body) = modifier.body else { return };
388 let mut a = Analyzer::new(hir, &mapping);
389 for stmt in body.stmts {
390 let _ = a.visit_stmt(stmt);
391 }
392
393 for mp in a.guarded {
394 if let Some(caller_vid) = caller_for_modparam.get(&mp) {
395 guarded.insert(*caller_vid);
396 }
397 }
398}