1use super::MissingZeroCheck;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{
5 Severity, SolLint,
6 analysis::primitives::{
7 address_call_receiver, branch_always_exits, is_address_type, is_require_or_assert,
8 },
9 },
10};
11use solar::{
12 ast,
13 interface::data_structures::Never,
14 sema::hir::{self, ExprKind, ItemId, Res, StmtKind, Visit},
15};
16use std::{
17 collections::{HashMap, HashSet},
18 ops::ControlFlow,
19};
20
21declare_forge_lint!(
22 MISSING_ZERO_CHECK,
23 Severity::Low,
24 "missing-zero-check",
25 "address parameter is used in a state write or value transfer without a zero-address check"
26);
27
28impl<'hir> LateLintPass<'hir> for MissingZeroCheck {
29 fn check_function(
30 &mut self,
31 ctx: &LintContext,
32 _gcx: solar::sema::Gcx<'hir>,
33 hir: &'hir hir::Hir<'hir>,
34 func: &'hir hir::Function<'hir>,
35 ) {
36 if !is_entry_point(func) {
37 return;
38 }
39
40 let params: HashSet<hir::VariableId> =
41 func.parameters.iter().copied().filter(|id| is_address_type(hir, *id)).collect();
42
43 if params.is_empty() {
44 return;
45 }
46
47 let Some(body) = func.body else { return };
48
49 let mut a = Analyzer::new(hir, ¶ms);
50
51 for m in func.modifiers {
52 collect_modifier_guards(hir, m, ¶ms, &mut a.guarded);
53 }
54
55 for stmt in body.stmts {
56 let _ = a.visit_stmt(stmt);
57 }
58
59 for &p in ¶ms {
60 if a.sinks.contains(&p) {
61 ctx.emit(&MISSING_ZERO_CHECK, hir.variable(p).span);
62 }
63 }
64 }
65}
66
67fn is_entry_point(func: &hir::Function<'_>) -> bool {
69 if matches!(func.state_mutability, ast::StateMutability::Pure | ast::StateMutability::View) {
70 return false;
71 }
72 if func.is_constructor() {
73 return true;
74 }
75 func.kind.is_function()
76 && matches!(func.visibility, ast::Visibility::Public | ast::Visibility::External)
77}
78
79struct Analyzer<'hir> {
81 hir: &'hir hir::Hir<'hir>,
82 taint: HashMap<hir::VariableId, HashSet<hir::VariableId>>,
85 sinks: HashSet<hir::VariableId>,
87 guarded: HashSet<hir::VariableId>,
89 guard_depth: u32,
90 sink_depth: u32,
91}
92
93impl<'hir> Analyzer<'hir> {
94 fn new(hir: &'hir hir::Hir<'hir>, params: &HashSet<hir::VariableId>) -> Self {
95 let mut taint = HashMap::with_capacity(params.len());
96 for &p in params {
97 taint.insert(p, HashSet::from([p]));
98 }
99 Self {
100 hir,
101 taint,
102 sinks: HashSet::new(),
103 guarded: HashSet::new(),
104 guard_depth: 0,
105 sink_depth: 0,
106 }
107 }
108
109 fn taint_sources(&self, expr: &hir::Expr<'_>) -> HashSet<hir::VariableId> {
110 let mut out = HashSet::new();
111 collect_taint_sources(&self.taint, expr, &mut out);
112 out
113 }
114}
115
116fn collect_taint_sources(
117 taint: &HashMap<hir::VariableId, HashSet<hir::VariableId>>,
118 expr: &hir::Expr<'_>,
119 out: &mut HashSet<hir::VariableId>,
120) {
121 match &expr.kind {
122 ExprKind::Ident(reses) => {
123 for res in *reses {
124 if let Res::Item(ItemId::Variable(vid)) = res
125 && let Some(srcs) = taint.get(vid)
126 {
127 out.extend(srcs.iter().copied());
128 }
129 }
130 }
131 ExprKind::Assign(_, _, rhs) => collect_taint_sources(taint, rhs, out),
132 ExprKind::Binary(lhs, _, rhs) => {
133 collect_taint_sources(taint, lhs, out);
134 collect_taint_sources(taint, rhs, out);
135 }
136 ExprKind::Unary(_, e)
137 | ExprKind::Delete(e)
138 | ExprKind::Member(e, _)
139 | ExprKind::Payable(e) => collect_taint_sources(taint, e, out),
140 ExprKind::Ternary(_, t, f) => {
141 collect_taint_sources(taint, t, out);
142 collect_taint_sources(taint, f, out);
143 }
144 ExprKind::Tuple(elems) => {
145 for e in elems.iter().copied().flatten() {
146 collect_taint_sources(taint, e, out);
147 }
148 }
149 ExprKind::Array(elems) => {
150 for e in *elems {
151 collect_taint_sources(taint, e, out);
152 }
153 }
154 ExprKind::Index(base, idx) => {
155 collect_taint_sources(taint, base, out);
156 if let Some(i) = idx {
157 collect_taint_sources(taint, i, out);
158 }
159 }
160 ExprKind::Call(callee, args, _) => {
162 collect_taint_sources(taint, callee, out);
163 for a in args.exprs() {
164 collect_taint_sources(taint, a, out);
165 }
166 }
167 _ => {}
168 }
169}
170
171fn lhs_local_var(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> Option<hir::VariableId> {
174 if let ExprKind::Ident(reses) = &lhs.kind {
175 for res in *reses {
176 if let Res::Item(ItemId::Variable(vid)) = res
177 && !hir.variable(*vid).kind.is_state()
178 {
179 return Some(*vid);
180 }
181 }
182 }
183 None
184}
185
186impl<'hir> Visit<'hir> for Analyzer<'hir> {
187 type BreakValue = Never;
188
189 fn hir(&self) -> &'hir hir::Hir<'hir> {
190 self.hir
191 }
192
193 fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
194 match stmt.kind {
195 StmtKind::If(cond, then, else_) => {
196 self.guard_depth += 1;
197 let _ = self.visit_expr(cond);
198 self.guard_depth -= 1;
199
200 let baseline = self.guarded.clone();
201 let _ = self.visit_stmt(then);
202 let then_added: HashSet<hir::VariableId> =
203 self.guarded.difference(&baseline).copied().collect();
204 let then_exits = branch_always_exits(then);
205
206 let (else_added, else_exits) = if let Some(e) = else_ {
207 self.guarded = baseline.clone();
208 let _ = self.visit_stmt(e);
209 let added: HashSet<hir::VariableId> =
210 self.guarded.difference(&baseline).copied().collect();
211 (added, branch_always_exits(e))
212 } else {
213 (HashSet::new(), false)
214 };
215
216 self.guarded = baseline;
217 let to_add: HashSet<hir::VariableId> = match (then_exits, else_exits) {
218 (true, true) => then_added.union(&else_added).copied().collect(),
219 (true, false) => else_added,
220 (false, true) => then_added,
221 (false, false) => then_added.intersection(&else_added).copied().collect(),
222 };
223 self.guarded.extend(to_add);
224
225 return ControlFlow::Continue(());
226 }
227 StmtKind::Loop(block, _) => {
229 let baseline = self.guarded.clone();
230 for s in block.stmts {
231 let _ = self.visit_stmt(s);
232 }
233 self.guarded = baseline;
234 return ControlFlow::Continue(());
235 }
236 StmtKind::Try(t) => {
238 let _ = self.visit_expr(&t.expr);
239 for clause in t.clauses {
240 let baseline = self.guarded.clone();
241 for s in clause.block.stmts {
242 let _ = self.visit_stmt(s);
243 }
244 self.guarded = baseline;
245 }
246 return ControlFlow::Continue(());
247 }
248 StmtKind::DeclSingle(var_id) => {
251 let v = self.hir.variable(var_id);
252 if let Some(init) = v.initializer
253 && is_address_type(self.hir, var_id)
254 {
255 let srcs = self.taint_sources(init);
256 if !srcs.is_empty() {
257 self.taint.entry(var_id).or_default().extend(srcs);
258 }
259 }
260 }
261 _ => {}
262 }
263 self.walk_stmt(stmt)
264 }
265
266 fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
267 match &expr.kind {
268 ExprKind::Call(callee, args, _) if is_require_or_assert(callee) => {
270 let mut iter = args.exprs();
271 if let Some(cond) = iter.next() {
272 self.guard_depth += 1;
273 let _ = self.visit_expr(cond);
274 self.guard_depth -= 1;
275 }
276 for rest in iter {
277 let _ = self.visit_expr(rest);
278 }
279 return ControlFlow::Continue(());
280 }
281
282 ExprKind::Call(callee, args, _) => {
284 if let Some(receiver) = address_call_receiver(callee) {
285 self.sink_depth += 1;
286 let _ = self.visit_expr(receiver);
287 self.sink_depth -= 1;
288 let _ = self.visit_call_args(args);
289 return ControlFlow::Continue(());
290 }
291 }
292
293 ExprKind::Assign(lhs, _, rhs) => {
294 if is_address_state_var_lhs(self.hir, lhs) {
296 let _ = self.visit_expr(lhs);
297 self.sink_depth += 1;
298 let _ = self.visit_expr(rhs);
299 self.sink_depth -= 1;
300 return ControlFlow::Continue(());
301 }
302 if let Some(local) = lhs_local_var(self.hir, lhs)
304 && is_address_type(self.hir, local)
305 {
306 let srcs = self.taint_sources(rhs);
307 if !srcs.is_empty() {
308 self.taint.entry(local).or_default().extend(srcs);
309 }
310 }
311 }
312
313 ExprKind::Ident(reses) => {
315 for res in *reses {
316 if let Res::Item(ItemId::Variable(vid)) = res
317 && let Some(srcs) = self.taint.get(vid)
318 {
319 if self.guard_depth > 0 {
320 self.guarded.extend(srcs.iter().copied());
321 }
322 if self.sink_depth > 0 {
323 for &src in srcs {
324 if !self.guarded.contains(&src) {
325 self.sinks.insert(src);
326 }
327 }
328 }
329 }
330 }
331 }
332
333 _ => {}
334 }
335 self.walk_expr(expr)
336 }
337}
338
339fn is_address_state_var_lhs(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> bool {
340 if let ExprKind::Ident(reses) = &lhs.kind {
341 for res in *reses {
342 if let Res::Item(ItemId::Variable(vid)) = res
343 && hir.variable(*vid).kind.is_state()
344 && is_address_type(hir, *vid)
345 {
346 return true;
347 }
348 }
349 }
350 false
351}
352
353fn collect_modifier_guards<'hir>(
357 hir: &'hir hir::Hir<'hir>,
358 invocation: &hir::Modifier<'hir>,
359 caller_params: &HashSet<hir::VariableId>,
360 guarded: &mut HashSet<hir::VariableId>,
361) {
362 let ItemId::Function(fid) = invocation.id else { return };
363 let modifier = hir.function(fid);
364 if !matches!(modifier.kind, hir::FunctionKind::Modifier) {
365 return;
366 }
367
368 let mod_params = modifier.parameters;
369 let mut mapping: HashSet<hir::VariableId> = HashSet::new();
370 let mut caller_for_modparam: HashMap<hir::VariableId, hir::VariableId> = HashMap::new();
371 for (i, arg_expr) in invocation.args.exprs().enumerate() {
372 if let ExprKind::Ident(reses) = &arg_expr.kind {
373 for res in *reses {
374 if let Res::Item(ItemId::Variable(vid)) = res
375 && caller_params.contains(vid)
376 && let Some(&mp) = mod_params.get(i)
377 {
378 caller_for_modparam.insert(mp, *vid);
379 mapping.insert(mp);
380 }
381 }
382 }
383 }
384 if mapping.is_empty() {
385 return;
386 }
387
388 let Some(body) = modifier.body else { return };
389 let mut a = Analyzer::new(hir, &mapping);
390 for stmt in body.stmts {
391 let _ = a.visit_stmt(stmt);
392 }
393
394 for mp in a.guarded {
395 if let Some(caller_vid) = caller_for_modparam.get(&mp) {
396 guarded.insert(*caller_vid);
397 }
398 }
399}