Skip to main content

forge_lint/sol/med/
assert_state_change.rs

1use super::AssertStateChange;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    ast::{DataLocation, ElementaryType, UnOpKind},
8    interface::{Span, Symbol, kw, sym},
9    sema::{
10        Gcx, Hir, Ty,
11        hir::{ContractId, Expr, ExprKind, FunctionId, ItemId, Res, Type, TypeKind},
12        ty::TyKind,
13    },
14};
15use std::{cell::RefCell, collections::HashMap, rc::Rc};
16
17declare_forge_lint!(
18    ASSERT_STATE_CHANGE,
19    Severity::Med,
20    "assert-state-change",
21    "assert() should not contain state-modifying expressions"
22);
23
24thread_local! {
25    static CURRENT_CONTRACT: RefCell<Option<ContractId>> = const { RefCell::new(None) };
26}
27
28impl<'hir> LateLintPass<'hir> for AssertStateChange {
29    fn check_nested_contract(
30        &mut self,
31        _ctx: &LintContext,
32        _gcx: solar::sema::Gcx<'hir>,
33        _hir: &'hir Hir<'hir>,
34        id: ContractId,
35    ) {
36        set_current_contract(Some(id));
37    }
38
39    fn check_function(
40        &mut self,
41        _ctx: &LintContext,
42        _gcx: solar::sema::Gcx<'hir>,
43        _hir: &'hir Hir<'hir>,
44        func: &'hir solar::sema::hir::Function<'hir>,
45    ) {
46        set_current_contract(func.contract);
47    }
48
49    fn check_expr(
50        &mut self,
51        ctx: &LintContext,
52        gcx: Gcx<'hir>,
53        hir: &'hir Hir<'hir>,
54        expr: &'hir Expr<'hir>,
55    ) {
56        let ExprKind::Call(callee, args, _) = &expr.kind else { return };
57        if !is_assert(callee) {
58            return;
59        }
60
61        let current_contract = current_contract();
62        for arg in args.exprs() {
63            if let Some(span) = find_state_change(gcx, hir, current_contract, arg) {
64                ctx.emit_with_msg(
65                    &ASSERT_STATE_CHANGE,
66                    span,
67                    "assert() argument contains a state-modifying expression; \
68                     assert() is for invariants, hoist the mutation before the assert, \
69                     or use require() for validation",
70                );
71            }
72        }
73    }
74}
75
76fn set_current_contract(id: Option<ContractId>) {
77    CURRENT_CONTRACT.with(|cell| *cell.borrow_mut() = id);
78}
79
80fn current_contract() -> Option<ContractId> {
81    CURRENT_CONTRACT.with(|cell| *cell.borrow())
82}
83
84fn is_assert(callee: &Expr<'_>) -> bool {
85    let ExprKind::Ident(reses) = &callee.kind else { return false };
86    reses.iter().any(|r| matches!(r, Res::Builtin(b) if b.name() == sym::assert))
87}
88
89/// Recursively searches `expr` for the first sub-expression that modifies state.
90/// Returns its span so the diagnostic points at exactly where the mutation occurs.
91fn find_state_change<'hir>(
92    gcx: Gcx<'hir>,
93    hir: &'hir Hir<'hir>,
94    current_contract: Option<ContractId>,
95    expr: &'hir Expr<'hir>,
96) -> Option<Span> {
97    match &expr.kind {
98        // x = y, x += y, etc., only when the lvalue targets a state variable
99        ExprKind::Assign(lhs, _, rhs) => {
100            if lvalue_is_state_var(hir, lhs) {
101                return Some(expr.span);
102            }
103            find_state_change(gcx, hir, current_contract, lhs)
104                .or_else(|| find_state_change(gcx, hir, current_contract, rhs))
105        }
106
107        // delete x, only when x is a state variable
108        ExprKind::Delete(inner) => {
109            if lvalue_is_state_var(hir, inner) {
110                return Some(expr.span);
111            }
112            find_state_change(gcx, hir, current_contract, inner)
113        }
114
115        // ++x, x++, --x, x--, only when x is a state variable
116        ExprKind::Unary(op, inner)
117            if matches!(
118                op.kind,
119                UnOpKind::PreInc | UnOpKind::PostInc | UnOpKind::PreDec | UnOpKind::PostDec
120            ) =>
121        {
122            if lvalue_is_state_var(hir, inner) {
123                return Some(expr.span);
124            }
125            find_state_change(gcx, hir, current_contract, inner)
126        }
127
128        ExprKind::Call(callee, args, named_args) => {
129            // arr.push(...) / arr.pop() on a storage array/bytes are mutations.
130            // Positive type check (`is_dynamic_array_or_bytes`) avoids FPs on interface/contract
131            // methods named push/pop, even when the receiver is not a simple Ident (e.g.
132            // `stateStruct.queue.push(x)` where `queue: IQueue`).
133            if let ExprKind::Member(base, method) = &callee.kind
134                && (method.name == sym::push || method.name.as_str() == "pop")
135                && is_dynamic_array_or_bytes(gcx, base)
136                && lvalue_is_state_var(hir, base)
137            {
138                return Some(expr.span);
139            }
140
141            // Low-level address calls (.call/.delegatecall/.send/.transfer) are always mutating.
142            // Only apply this name-based heuristic when the receiver is syntactically address-like.
143            // Using a positive address check rather than "not a known contract" avoids FPs on
144            // non-Ident receivers (function-call results, member chains, `this`) whose contract
145            // type `contract_id_of` cannot resolve syntactically.
146            if let ExprKind::Member(base, method) = &callee.kind {
147                let n = method.name;
148                if (n == kw::Call || n == kw::Delegatecall || n == sym::send || n == sym::transfer)
149                    && is_address_like(gcx, base)
150                {
151                    return Some(expr.span);
152                }
153            }
154
155            // Resolvable contract member calls: check mutates_state() via HIR.
156            // We collect all overloads with the same name and arity, then flag when
157            // any candidate mutates state. Using `any` avoids FNs where a mutating
158            // overload coexists with a view overload of the same arity.
159            let candidates =
160                resolve_member_overloads(gcx, hir, current_contract, callee, args.len());
161            if !candidates.is_empty()
162                && candidates.iter().any(|&fid| hir.function(fid).mutates_state())
163            {
164                return Some(expr.span);
165            }
166
167            if candidates.is_empty()
168                && let ExprKind::Member(base, method) = &callee.kind
169                && lvalue_is_state_var(hir, base)
170                && let Some(recv_ty) = expr_ty(gcx, base)
171            {
172                let lib_candidates =
173                    resolve_library_extension(gcx, hir, method.name, args.len(), recv_ty);
174                if !lib_candidates.is_empty()
175                    && lib_candidates.iter().all(|&fid| hir.function(fid).mutates_state())
176                {
177                    return Some(expr.span);
178                }
179            }
180
181            // Bare-identifier internal function calls: same any-mutates policy as member calls,
182            // since Solar does not resolve which specific overload was selected.
183            let reses = match &callee.peel_parens().kind {
184                ExprKind::Ident(r) => *r,
185                _ => &[],
186            };
187            let fn_reses: Vec<FunctionId> = reses
188                .iter()
189                .filter_map(|res| {
190                    if let Res::Item(ItemId::Function(fid)) = res { Some(*fid) } else { None }
191                })
192                .filter(|&fid| hir.function(fid).parameters.len() == args.len())
193                .collect();
194            if !fn_reses.is_empty() && fn_reses.iter().any(|&fid| hir.function(fid).mutates_state())
195            {
196                return Some(expr.span);
197            }
198
199            // Recurse into callee, positional args, and named args
200            find_state_change(gcx, hir, current_contract, callee)
201                .or_else(|| {
202                    args.exprs().find_map(|a| find_state_change(gcx, hir, current_contract, a))
203                })
204                .or_else(|| {
205                    named_args
206                        .iter()
207                        .flat_map(|opts| opts.args.iter())
208                        .find_map(|na| find_state_change(gcx, hir, current_contract, &na.value))
209                })
210        }
211
212        ExprKind::Unary(_, inner) | ExprKind::Member(inner, _) | ExprKind::Payable(inner) => {
213            find_state_change(gcx, hir, current_contract, inner)
214        }
215        ExprKind::Binary(lhs, _, rhs) => find_state_change(gcx, hir, current_contract, lhs)
216            .or_else(|| find_state_change(gcx, hir, current_contract, rhs)),
217        ExprKind::Ternary(cond, t, f) => find_state_change(gcx, hir, current_contract, cond)
218            .or_else(|| find_state_change(gcx, hir, current_contract, t))
219            .or_else(|| find_state_change(gcx, hir, current_contract, f)),
220        ExprKind::Index(base, idx) => find_state_change(gcx, hir, current_contract, base)
221            .or_else(|| idx.and_then(|i| find_state_change(gcx, hir, current_contract, i))),
222        ExprKind::Slice(base, start, end) => find_state_change(gcx, hir, current_contract, base)
223            .or_else(|| start.and_then(|s| find_state_change(gcx, hir, current_contract, s)))
224            .or_else(|| end.and_then(|e| find_state_change(gcx, hir, current_contract, e))),
225        ExprKind::Array(exprs) => {
226            exprs.iter().find_map(|e| find_state_change(gcx, hir, current_contract, e))
227        }
228        ExprKind::Tuple(exprs) => exprs
229            .iter()
230            .copied()
231            .flatten()
232            .find_map(|e| find_state_change(gcx, hir, current_contract, e)),
233        ExprKind::Ident(_)
234        | ExprKind::Lit(_)
235        | ExprKind::New(_)
236        | ExprKind::TypeCall(_)
237        | ExprKind::Type(_)
238        | ExprKind::YulMember(..)
239        | ExprKind::Err(_) => None,
240    }
241}
242
243/// Returns all overloads of the called member function that match the call's argument count.
244/// Matching by arity narrows overload candidates; the caller flags the call if any candidate
245/// mutates state, since Solar does not resolve which specific overload was selected.
246fn resolve_member_overloads<'hir>(
247    gcx: Gcx<'hir>,
248    hir: &'hir Hir<'hir>,
249    current_contract: Option<ContractId>,
250    callee: &'hir Expr<'hir>,
251    arg_count: usize,
252) -> Vec<FunctionId> {
253    let ExprKind::Member(base, method) = &callee.peel_parens().kind else { return vec![] };
254    let Some(cid) = contract_id_of(gcx, hir, current_contract, base) else { return vec![] };
255    hir.contract_item_ids(cid)
256        .filter_map(|item| {
257            let fid = item.as_function()?;
258            let f = hir.function(fid);
259            (f.name.is_some_and(|n| n.name == method.name) && f.parameters.len() == arg_count)
260                .then_some(fid)
261        })
262        .collect()
263}
264
265/// Extracts the contract ID from an expression whose static type is a contract or interface.
266fn contract_id_of<'hir>(
267    gcx: Gcx<'hir>,
268    _hir: &'hir Hir<'hir>,
269    current_contract: Option<ContractId>,
270    expr: &'hir Expr<'hir>,
271) -> Option<ContractId> {
272    if is_this_or_super(expr) {
273        return current_contract;
274    }
275    // `IToken(addr).foo()`, explicit interface cast; the callee Ident resolves to the contract
276    // itself rather than a function, so receiver_type's Call arm would not match it.
277    if let ExprKind::Call(
278        Expr { kind: ExprKind::Ident([Res::Item(ItemId::Contract(cid))]), .. },
279        ..,
280    ) = &expr.peel_parens().kind
281    {
282        return Some(*cid);
283    }
284    type_contract_id(expr_ty(gcx, expr)?)
285}
286
287fn is_this_or_super(expr: &Expr<'_>) -> bool {
288    let ExprKind::Ident(reses) = &expr.peel_parens().kind else { return false };
289    reses
290        .iter()
291        .any(|r| matches!(r, Res::Builtin(b) if b.name() == sym::this || b.name() == sym::super_))
292}
293
294/// Finds library functions in the HIR that could be a `using for` extension matching the given
295/// method name, call arity, **and** receiver type. A library extension function has
296/// `arg_count + 1` parameters (the extra one being the receiver passed implicitly) with the
297/// first parameter in storage, and that first parameter's type must structurally match the
298/// receiver's static type, otherwise an unrelated library with a same-named function would
299/// false-positive on a contract/interface call.
300///
301/// Solar does not yet embed resolution info on `ExprKind::Member` for extension methods, so this
302/// is a best-effort fallback. The per-name lookup table is memoized per HIR (see
303/// `library_extensions_by_name`) to avoid a full `function_ids()` scan on every eligible call.
304fn resolve_library_extension<'hir>(
305    gcx: Gcx<'hir>,
306    hir: &Hir<'hir>,
307    method_name: Symbol,
308    arg_count: usize,
309    receiver_ty: Ty<'hir>,
310) -> Vec<FunctionId> {
311    let expected_params = arg_count + 1; // +1 for the implicit storage receiver
312    let by_name = library_extensions_by_name(hir);
313    let Some(fids) = by_name.get(&method_name) else { return Vec::new() };
314    fids.iter()
315        .copied()
316        .filter(|&fid| {
317            let f = hir.function(fid);
318            if f.parameters.len() != expected_params {
319                return false;
320            }
321            // First param must be a storage reference of a type matching the receiver.
322            let Some(first_id) = f.parameters.first().copied() else {
323                return false;
324            };
325            let first = hir.variable(first_id);
326            if first.data_location != Some(DataLocation::Storage) {
327                return false;
328            }
329            receiver_ty.convert_implicit_to(gcx.type_of_item(first_id.into()), gcx)
330        })
331        .collect()
332}
333
334/// Memoized per-HIR map of library function names to candidate `FunctionId`s. Building the map
335/// requires a full `hir.function_ids()` scan; without memoization that scan would run on every
336/// eligible call site in the program and scale poorly.
337///
338/// Identity is keyed on the `Hir<'_>` raw pointer. A given lint run sees a single HIR with a
339/// stable address, so pointer comparison is safe; we never deref the pointer beyond identity
340/// checking. The cache is `thread_local`, so concurrent project lint workers each maintain
341/// their own.
342fn library_extensions_by_name(hir: &Hir<'_>) -> Rc<HashMap<Symbol, Vec<FunctionId>>> {
343    type Cache = (usize, Rc<HashMap<Symbol, Vec<FunctionId>>>);
344    thread_local! {
345        static CACHE: RefCell<Option<Cache>> = const { RefCell::new(None) };
346    }
347    let key = hir as *const Hir<'_> as usize;
348    CACHE.with(|cell| {
349        if let Some((cached_key, map)) = &*cell.borrow()
350            && *cached_key == key
351        {
352            return map.clone();
353        }
354        let mut map: HashMap<Symbol, Vec<FunctionId>> = HashMap::new();
355        for fid in hir.function_ids() {
356            let f = hir.function(fid);
357            let Some(cid) = f.contract else { continue };
358            if !hir.contract(cid).kind.is_library() {
359                continue;
360            }
361            let Some(name) = f.name else { continue };
362            map.entry(name.name).or_default().push(fid);
363        }
364        let rc = Rc::new(map);
365        *cell.borrow_mut() = Some((key, rc.clone()));
366        rc
367    })
368}
369
370fn expr_ty<'hir>(gcx: Gcx<'hir>, expr: &'hir Expr<'hir>) -> Option<Ty<'hir>> {
371    gcx.type_of_expr(expr.peel_parens().id)
372}
373
374fn type_contract_id(ty: Ty<'_>) -> Option<ContractId> {
375    match ty.peel_refs().kind {
376        TyKind::Contract(id) => Some(id),
377        _ => None,
378    }
379}
380
381/// Returns `true` when `expr` is a dynamic array or `bytes`
382fn is_dynamic_array_or_bytes<'hir>(gcx: Gcx<'hir>, expr: &'hir Expr<'hir>) -> bool {
383    expr_ty(gcx, expr).is_some_and(|ty| {
384        matches!(
385            ty.peel_refs().kind,
386            TyKind::DynArray(_) | TyKind::Array(..) | TyKind::Elementary(ElementaryType::Bytes)
387        )
388    })
389}
390
391fn is_address_like<'hir>(gcx: Gcx<'hir>, expr: &'hir Expr<'hir>) -> bool {
392    if expr_ty(gcx, expr).is_some_and(ty_is_address) {
393        return true;
394    }
395
396    match &expr.peel_parens().kind {
397        ExprKind::Payable(_) => true,
398        // `address(x)` / `address payable(x)` casts parse as `Call(Type(Address), [x])`.
399        ExprKind::Call(callee, _, _) => matches!(
400            &callee.peel_parens().kind,
401            ExprKind::Type(Type { kind: TypeKind::Elementary(ElementaryType::Address(_)), .. })
402        ),
403        // `msg.sender`, `tx.origin`, `block.coinbase`.
404        ExprKind::Member(base, member) => is_address_builtin_member(base, member.name),
405        ExprKind::Tuple(exprs) => {
406            let mut iter = exprs.iter().flatten();
407            match (iter.next(), iter.next()) {
408                (Some(inner), None) => is_address_like(gcx, inner),
409                _ => false,
410            }
411        }
412        _ => false,
413    }
414}
415
416fn ty_is_address(ty: Ty<'_>) -> bool {
417    matches!(ty.peel_refs().kind, TyKind::Elementary(ElementaryType::Address(_)))
418}
419
420fn is_address_builtin_member(base: &Expr<'_>, member: Symbol) -> bool {
421    let ExprKind::Ident(reses) = &base.peel_parens().kind else { return false };
422    reses.iter().any(|res| {
423        let Res::Builtin(builtin) = res else { return false };
424        matches!(
425            (builtin.name(), member),
426            (sym::msg, sym::sender) | (sym::tx, kw::Origin) | (sym::block, kw::Coinbase)
427        )
428    })
429}
430
431/// Returns `true` if the lvalue expression ultimately targets a storage variable.
432/// Peels through index, slice, member, and payable wrappers to find the root identifier.
433/// Locals declared `storage` are aliases into contract storage and count as state mutations.
434fn lvalue_is_state_var(hir: &Hir<'_>, expr: &Expr<'_>) -> bool {
435    match &expr.peel_parens().kind {
436        ExprKind::Ident([Res::Item(ItemId::Variable(id)), ..]) => {
437            let v = hir.variable(*id);
438            v.is_state_variable() || v.data_location == Some(DataLocation::Storage)
439        }
440        ExprKind::Index(base, _)
441        | ExprKind::Slice(base, _, _)
442        | ExprKind::Member(base, _)
443        | ExprKind::Payable(base) => lvalue_is_state_var(hir, base),
444        ExprKind::Tuple(exprs) => exprs.iter().flatten().any(|e| lvalue_is_state_var(hir, e)),
445        _ => false,
446    }
447}