Skip to main content

forge_lint/sol/gas/
write_after_write.rs

1use super::WriteAfterWrite;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    interface::Span,
8    sema::{
9        Hir,
10        hir::{
11            BinOpKind, Block, Expr, ExprKind, Function, ItemId, Res, Stmt, StmtKind, VariableId,
12        },
13    },
14};
15use std::collections::HashMap;
16
17declare_forge_lint!(
18    WRITE_AFTER_WRITE,
19    Severity::Gas,
20    "write-after-write",
21    "redundant storage write; value overwritten before being read"
22);
23
24impl<'hir> LateLintPass<'hir> for WriteAfterWrite {
25    fn check_function(
26        &mut self,
27        ctx: &LintContext,
28        _gcx: solar::sema::Gcx<'hir>,
29        hir: &'hir Hir<'hir>,
30        func: &'hir Function<'hir>,
31    ) {
32        if let Some(body) = func.body {
33            let mut pending = HashMap::default();
34            check_block(ctx, hir, body, &mut pending);
35        }
36    }
37}
38
39/// Whether control flow continues past this statement/block.
40#[derive(PartialEq, Eq)]
41enum Flow {
42    Continue,
43    Stop,
44}
45
46fn check_block<'hir>(
47    ctx: &LintContext,
48    hir: &'hir Hir<'hir>,
49    block: Block<'hir>,
50    pending: &mut HashMap<VariableId, Span>,
51) -> Flow {
52    for stmt in block.stmts {
53        if check_stmt(ctx, hir, stmt, pending) == Flow::Stop {
54            return Flow::Stop;
55        }
56    }
57    Flow::Continue
58}
59
60fn check_stmt<'hir>(
61    ctx: &LintContext,
62    hir: &'hir Hir<'hir>,
63    stmt: &'hir Stmt<'hir>,
64    pending: &mut HashMap<VariableId, Span>,
65) -> Flow {
66    match &stmt.kind {
67        StmtKind::Expr(expr) => {
68            process_expr(ctx, hir, expr.peel_parens(), pending);
69        }
70        StmtKind::DeclSingle(var_id) => {
71            if let Some(init) = hir.variable(*var_id).initializer {
72                collect_reads(ctx, hir, init, pending);
73            }
74        }
75        StmtKind::DeclMulti(_, expr) => {
76            collect_reads(ctx, hir, expr, pending);
77        }
78        // return/revert/break/continue are terminal; code after them is unreachable,
79        // so the "pending" writes will never be overwritten. Reads still matter for the
80        // value carried by return/revert, but after processing we must clear because no
81        // subsequent statement in the same block can execute.
82        StmtKind::Return(Some(expr)) => {
83            collect_reads(ctx, hir, expr, pending);
84            pending.clear();
85            return Flow::Stop;
86        }
87        StmtKind::Return(None) | StmtKind::Break | StmtKind::Continue => {
88            pending.clear();
89            return Flow::Stop;
90        }
91        StmtKind::Emit(expr) => {
92            // Emit only logs; it doesn't invoke external code that could observe state.
93            // Walk the call args directly so pending is not cleared.
94            if let ExprKind::Call(callee, args, named_args) = &expr.peel_parens().kind {
95                collect_reads(ctx, hir, callee, pending);
96                for arg in args.exprs() {
97                    collect_reads(ctx, hir, arg, pending);
98                }
99                walk_named_args(ctx, hir, named_args, pending);
100            } else {
101                collect_reads(ctx, hir, expr, pending);
102            }
103        }
104        StmtKind::Revert(expr) => {
105            collect_reads(ctx, hir, expr, pending);
106            pending.clear();
107            return Flow::Stop;
108        }
109        // Branches and loops: recurse with a fresh map so intra-body pairs are still
110        // caught, then clear the outer pending conservatively since any branch may
111        // observe or skip the outer write.
112        // Propagate Stop only when both branches unconditionally stop (no else = Continue).
113        StmtKind::If(cond, then_stmt, else_stmt) => {
114            collect_reads(ctx, hir, cond, pending);
115            pending.clear();
116            let mut branch_pending = HashMap::default();
117            let then_flow = check_stmt(ctx, hir, then_stmt, &mut branch_pending);
118            if let Some(else_stmt) = else_stmt {
119                let mut else_pending = HashMap::default();
120                let else_flow = check_stmt(ctx, hir, else_stmt, &mut else_pending);
121                if then_flow == Flow::Stop && else_flow == Flow::Stop {
122                    return Flow::Stop;
123                }
124            }
125        }
126        StmtKind::Loop(block, _) => {
127            pending.clear();
128            let mut loop_pending = HashMap::default();
129            check_block(ctx, hir, *block, &mut loop_pending);
130            // A loop may execute zero times, so it never guarantees Stop for outer flow.
131        }
132        StmtKind::Try(try_stmt) => {
133            collect_reads(ctx, hir, &try_stmt.expr, pending);
134            pending.clear();
135            for clause in try_stmt.clauses {
136                let mut clause_pending = HashMap::default();
137                check_block(ctx, hir, clause.block, &mut clause_pending);
138            }
139        }
140        // Nested blocks are sequential; share the same pending map so reads inside
141        // them properly invalidate outer writes. Propagate terminal flow outward.
142        StmtKind::Block(block) | StmtKind::UncheckedBlock(block) => {
143            return check_block(ctx, hir, *block, pending);
144        }
145        // The placeholder `_` in a modifier body invokes the modified function, which
146        // can freely read any storage variable. Conservatively clear everything.
147        StmtKind::Placeholder => {
148            pending.clear();
149        }
150        // Inline assembly or parse errors: we can't reason about what is read or
151        // written, so clear conservatively (same approach as unprotected_initializer).
152        StmtKind::AssemblyBlock(_) | StmtKind::Switch(_) | StmtKind::Err(_) => {
153            pending.clear();
154        }
155    }
156    Flow::Continue
157}
158
159/// Process an expression that appears as a statement, tracking writes and reads.
160fn process_expr<'hir>(
161    ctx: &LintContext,
162    hir: &'hir Hir<'hir>,
163    expr: &'hir Expr<'hir>,
164    pending: &mut HashMap<VariableId, Span>,
165) {
166    match &expr.kind {
167        ExprKind::Assign(lhs, op, rhs) => {
168            // RHS is always evaluated (read) before the assignment takes effect.
169            collect_reads(ctx, hir, rhs, pending);
170
171            if op.is_none() {
172                // Plain `=`: recursively handle the LHS as a write target.
173                process_assignment_lhs(ctx, hir, lhs, expr.span, pending);
174            } else {
175                // Compound assignment (+=, etc.) reads the current value of LHS first.
176                collect_reads(ctx, hir, lhs, pending);
177            }
178        }
179        ExprKind::Unary(op, inner) if op.kind.has_side_effects() => {
180            // Pre/post inc/dec: reads the variable, then writes it.
181            collect_reads(ctx, hir, inner, pending);
182            if let Some(var_id) = simple_state_var_id(hir, inner) {
183                pending.insert(var_id, expr.span);
184            }
185        }
186        ExprKind::Delete(inner) => {
187            // `delete x` is a pure write with no read of the previous value.
188            if let Some(var_id) = simple_state_var_id(hir, inner) {
189                if let Some(&prev_span) = pending.get(&var_id) {
190                    ctx.emit(&WRITE_AFTER_WRITE, prev_span);
191                }
192                pending.insert(var_id, expr.span);
193            } else {
194                collect_reads(ctx, hir, inner, pending);
195            }
196        }
197        // Any function/method call can observe state through re-entrancy or view
198        // calls, so conservatively treat it as reading everything pending.
199        // Walk callee, arguments, and call options first — all are evaluated before the call.
200        ExprKind::Call(callee, args, named_args) => {
201            collect_reads(ctx, hir, callee, pending);
202            for arg in args.exprs() {
203                collect_reads(ctx, hir, arg, pending);
204            }
205            walk_named_args(ctx, hir, named_args, pending);
206            pending.clear();
207        }
208        // For any other expression used as a statement, scan for reads.
209        _ => collect_reads(ctx, hir, expr, pending),
210    }
211}
212
213/// Recursively handle a plain-`=` assignment LHS, tracking each component as a write.
214/// For tuple destructuring `(x, y) = ...`, each element is processed independently.
215fn process_assignment_lhs<'hir>(
216    ctx: &LintContext,
217    hir: &'hir Hir<'hir>,
218    lhs: &'hir Expr<'hir>,
219    assign_span: Span,
220    pending: &mut HashMap<VariableId, Span>,
221) {
222    match &lhs.peel_parens().kind {
223        ExprKind::Tuple(exprs) => {
224            for e in exprs.iter().flatten() {
225                process_assignment_lhs(ctx, hir, e, e.span, pending);
226            }
227        }
228        _ => {
229            if let Some(var_id) = simple_state_var_id(hir, lhs) {
230                if let Some(&prev_span) = pending.get(&var_id) {
231                    ctx.emit(&WRITE_AFTER_WRITE, prev_span);
232                }
233                pending.insert(var_id, assign_span);
234            } else {
235                // Non-simple LHS (index/member access): slot computation reads the base.
236                collect_reads(ctx, hir, lhs, pending);
237            }
238        }
239    }
240}
241
242/// Remove any state variable mentioned in `expr` from `pending` (it was read).
243/// For nested assignments, delegates to `process_expr` so writes are handled correctly.
244fn collect_reads<'hir>(
245    ctx: &LintContext,
246    hir: &'hir Hir<'hir>,
247    expr: &'hir Expr<'hir>,
248    pending: &mut HashMap<VariableId, Span>,
249) {
250    match &expr.peel_parens().kind {
251        ExprKind::Ident(resolutions) => {
252            for res in *resolutions {
253                if let Res::Item(ItemId::Variable(id)) = res
254                    && hir.variable(*id).is_state_variable()
255                {
256                    pending.remove(id);
257                }
258            }
259        }
260        ExprKind::Assign(_, _, _) => {
261            // A nested assignment (e.g. `uint256 z = (x = v)`) writes to its LHS, not
262            // reads it. Delegate to process_expr so the write is tracked correctly.
263            process_expr(ctx, hir, expr.peel_parens(), pending);
264        }
265        // Short-circuit operators: LHS always evaluates, RHS may not.
266        // Clear outer pending before RHS to avoid false-positive WAW in the conditional path.
267        ExprKind::Binary(lhs, op, rhs) if matches!(op.kind, BinOpKind::And | BinOpKind::Or) => {
268            collect_reads(ctx, hir, lhs, pending);
269            pending.clear();
270            let mut rhs_pending = HashMap::default();
271            collect_reads(ctx, hir, rhs, &mut rhs_pending);
272        }
273        ExprKind::Binary(lhs, _, rhs) => {
274            collect_reads(ctx, hir, lhs, pending);
275            collect_reads(ctx, hir, rhs, pending);
276        }
277        ExprKind::Unary(_, inner) | ExprKind::Payable(inner) => {
278            collect_reads(ctx, hir, inner, pending);
279        }
280        // Ternary arms are mutually exclusive; analyze each independently with a fresh
281        // pending to avoid false-positive WAW between branches.
282        ExprKind::Ternary(cond, t, f) => {
283            collect_reads(ctx, hir, cond, pending);
284            pending.clear();
285            let mut then_pending = HashMap::default();
286            let mut else_pending = HashMap::default();
287            collect_reads(ctx, hir, t, &mut then_pending);
288            collect_reads(ctx, hir, f, &mut else_pending);
289        }
290        // Any call may observe storage through re-entrancy or view semantics.
291        // Walk callee, arguments, and call options first — all evaluated before the call.
292        ExprKind::Call(callee, args, named_args) => {
293            collect_reads(ctx, hir, callee, pending);
294            for arg in args.exprs() {
295                collect_reads(ctx, hir, arg, pending);
296            }
297            walk_named_args(ctx, hir, named_args, pending);
298            pending.clear();
299        }
300        ExprKind::Index(base, index) => {
301            collect_reads(ctx, hir, base, pending);
302            if let Some(idx) = index {
303                collect_reads(ctx, hir, idx, pending);
304            }
305        }
306        ExprKind::Slice(base, start, end) => {
307            collect_reads(ctx, hir, base, pending);
308            if let Some(s) = start {
309                collect_reads(ctx, hir, s, pending);
310            }
311            if let Some(e) = end {
312                collect_reads(ctx, hir, e, pending);
313            }
314        }
315        ExprKind::Member(base, _) => collect_reads(ctx, hir, base, pending),
316        ExprKind::Tuple(exprs) => {
317            for e in exprs.iter().flatten() {
318                collect_reads(ctx, hir, e, pending);
319            }
320        }
321        ExprKind::Array(exprs) => {
322            for e in *exprs {
323                collect_reads(ctx, hir, e, pending);
324            }
325        }
326        // A nested `delete` is a write; delegate to process_expr for correct tracking.
327        ExprKind::Delete(_) => {
328            process_expr(ctx, hir, expr.peel_parens(), pending);
329        }
330        ExprKind::Lit(_) | ExprKind::New(_) | ExprKind::TypeCall(_) | ExprKind::Type(_) => {}
331        ExprKind::YulMember(..) | ExprKind::Err(_) => {
332            pending.clear();
333        }
334    }
335}
336
337/// Walk named call arguments (e.g. `{value: expr, gas: expr}`) for reads and writes.
338fn walk_named_args<'hir>(
339    ctx: &LintContext,
340    hir: &'hir Hir<'hir>,
341    named_args: &Option<&'hir solar::sema::hir::CallOptions<'hir>>,
342    pending: &mut HashMap<VariableId, Span>,
343) {
344    if let Some(named) = named_args {
345        for na in named.args {
346            collect_reads(ctx, hir, &na.value, pending);
347        }
348    }
349}
350
351/// Returns `Some(id)` if the expression is a bare state variable identifier (no indexing/member).
352fn simple_state_var_id(hir: &Hir<'_>, expr: &Expr<'_>) -> Option<VariableId> {
353    match &expr.peel_parens().kind {
354        ExprKind::Ident(resolutions) => resolutions.iter().find_map(|res| match res {
355            Res::Item(ItemId::Variable(id)) if hir.variable(*id).is_state_variable() => Some(*id),
356            _ => None,
357        }),
358        _ => None,
359    }
360}