Skip to main content

forge_lint/sol/gas/
immutable.rs

1use super::UnchangedStateVariables;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    ast::{self, UnOpKind},
8    interface::{kw, sym},
9    sema::hir::{self, ExprKind, Res, StmtKind, TypeKind},
10};
11use std::collections::HashSet;
12
13declare_forge_lint!(
14    COULD_BE_IMMUTABLE,
15    Severity::Gas,
16    "could-be-immutable",
17    "state variable could be declared immutable"
18);
19
20declare_forge_lint!(
21    COULD_BE_CONSTANT,
22    Severity::Gas,
23    "could-be-constant",
24    "state variable could be declared constant"
25);
26
27impl<'hir> LateLintPass<'hir> for UnchangedStateVariables {
28    fn check_nested_contract(
29        &mut self,
30        ctx: &LintContext,
31        hir: &'hir hir::Hir<'hir>,
32        contract_id: hir::ContractId,
33    ) {
34        let contract = hir.contract(contract_id);
35        if contract.kind == ast::ContractKind::Interface {
36            return;
37        }
38        if !is_most_derived_contract(hir, contract_id) {
39            return;
40        }
41
42        // Use the broader `constant`-eligible filter so the candidate set covers both lints.
43        let candidates: Vec<_> = contract
44            .linearized_bases
45            .iter()
46            .flat_map(|&contract_id| hir.contract(contract_id).variables())
47            .filter(|&id| is_constant_candidate_type(hir.variable(id)))
48            .collect();
49
50        if candidates.is_empty() {
51            return;
52        }
53        let candidate_set: HashSet<_> = candidates.iter().copied().collect();
54
55        if contract_contains_unlowered_stmt(hir, contract) {
56            return;
57        }
58
59        let mut constructor_body_writes = HashSet::new();
60        let mut initializer_side_effect_writes = HashSet::new();
61        let mut runtime_writes = HashSet::new();
62        let mut non_constant_initializer = HashSet::new();
63
64        for &var_id in &candidates {
65            let var = hir.variable(var_id);
66            if let Some(expr) = var.initializer {
67                if !is_compile_time_constant(hir, expr) {
68                    non_constant_initializer.insert(var_id);
69                }
70                // Tracked separately: blocks `could-be-constant` but not `could-be-immutable`.
71                collect_expr_writes(expr, &candidate_set, &mut initializer_side_effect_writes);
72            }
73        }
74
75        for &contract_id in contract.linearized_bases {
76            for function_id in hir.contract(contract_id).all_functions() {
77                let function = hir.function(function_id);
78                if function.is_constructor() {
79                    collect_modifier_writes(
80                        hir,
81                        function,
82                        &candidate_set,
83                        &mut constructor_body_writes,
84                        &mut runtime_writes,
85                        &mut HashSet::new(),
86                    );
87
88                    if let Some(body) = function.body {
89                        collect_state_writes(
90                            hir,
91                            body,
92                            &candidate_set,
93                            &mut constructor_body_writes,
94                        );
95                    }
96                } else {
97                    // Immutable variables can only be assigned inline or directly in constructor
98                    // bodies, so writes hidden behind internal helpers are not valid candidates.
99                    let mut modifier_argument_writes = HashSet::new();
100                    collect_modifier_writes(
101                        hir,
102                        function,
103                        &candidate_set,
104                        &mut modifier_argument_writes,
105                        &mut runtime_writes,
106                        &mut HashSet::new(),
107                    );
108                    runtime_writes.extend(modifier_argument_writes);
109
110                    if let Some(body) = function.body {
111                        collect_state_writes(hir, body, &candidate_set, &mut runtime_writes);
112                    }
113                }
114            }
115        }
116
117        for &var_id in &candidates {
118            if runtime_writes.contains(&var_id) {
119                continue;
120            }
121            let var = hir.variable(var_id);
122            let span = var.name.map_or(var.span, |name| name.span);
123
124            // `could-be-constant`: requires a compile-time-constant inline initializer and no
125            // writes anywhere (constructor body, other state-var initializers, or runtime).
126            let has_constant_initializer =
127                var.initializer.is_some_and(|expr| is_compile_time_constant(hir, expr));
128            if has_constant_initializer
129                && !constructor_body_writes.contains(&var_id)
130                && !initializer_side_effect_writes.contains(&var_id)
131            {
132                ctx.emit(&COULD_BE_CONSTANT, span);
133                continue;
134            }
135
136            // `could-be-immutable`: written in the constructor (body or non-constant initializer)
137            // and never at runtime; type must be immutable-compatible.
138            if !is_immutable_candidate_type(var) {
139                continue;
140            }
141            if non_constant_initializer.contains(&var_id)
142                || constructor_body_writes.contains(&var_id)
143            {
144                ctx.emit(&COULD_BE_IMMUTABLE, span);
145            }
146        }
147    }
148}
149
150fn is_most_derived_contract(hir: &hir::Hir<'_>, contract_id: hir::ContractId) -> bool {
151    !hir.contracts()
152        .any(|contract| contract.linearized_bases.iter().skip(1).any(|&id| id == contract_id))
153}
154
155fn collect_modifier_writes<'hir>(
156    hir: &'hir hir::Hir<'hir>,
157    function: &'hir hir::Function<'hir>,
158    candidates: &HashSet<hir::VariableId>,
159    argument_writes: &mut HashSet<hir::VariableId>,
160    body_writes: &mut HashSet<hir::VariableId>,
161    visited_modifiers: &mut HashSet<hir::FunctionId>,
162) {
163    for modifier in function.modifiers {
164        for expr in modifier.args.exprs() {
165            collect_expr_writes(expr, candidates, argument_writes);
166        }
167
168        let Some(modifier_id) = modifier.id.as_function() else { continue };
169        if !visited_modifiers.insert(modifier_id) {
170            continue;
171        }
172
173        let modifier = hir.function(modifier_id);
174        let mut nested_argument_writes = HashSet::new();
175        collect_modifier_writes(
176            hir,
177            modifier,
178            candidates,
179            &mut nested_argument_writes,
180            body_writes,
181            visited_modifiers,
182        );
183        body_writes.extend(nested_argument_writes);
184        if let Some(body) = modifier.body {
185            collect_state_writes(hir, body, candidates, body_writes);
186        }
187    }
188}
189
190fn is_immutable_candidate_type(var: &hir::Variable<'_>) -> bool {
191    var.is_state_variable()
192        && var.mutability.is_none()
193        && match var.ty.kind {
194            TypeKind::Elementary(ty) => ty.is_value_type(),
195            TypeKind::Custom(hir::ItemId::Contract(_)) => true,
196            _ => false,
197        }
198}
199
200/// Constants accept any elementary type (value types plus `string`/`bytes`) and contract types.
201fn is_constant_candidate_type(var: &hir::Variable<'_>) -> bool {
202    var.is_state_variable()
203        && var.mutability.is_none()
204        && matches!(
205            var.ty.kind,
206            TypeKind::Elementary(_) | TypeKind::Custom(hir::ItemId::Contract(_))
207        )
208}
209
210fn contract_contains_unlowered_stmt<'hir>(
211    hir: &'hir hir::Hir<'hir>,
212    contract: &'hir hir::Contract<'hir>,
213) -> bool {
214    contract.linearized_bases.iter().any(|&contract_id| {
215        hir.contract(contract_id).all_functions().any(|function_id| {
216            hir.function(function_id).body.is_some_and(|body| block_contains_unlowered_stmt(body))
217        })
218    })
219}
220
221fn block_contains_unlowered_stmt(block: hir::Block<'_>) -> bool {
222    block.stmts.iter().any(stmt_contains_unlowered_stmt)
223}
224
225fn stmt_contains_unlowered_stmt(stmt: &hir::Stmt<'_>) -> bool {
226    match &stmt.kind {
227        StmtKind::Err(_) => true,
228        StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
229            block_contains_unlowered_stmt(*block)
230        }
231        StmtKind::If(_, then_stmt, else_stmt) => {
232            stmt_contains_unlowered_stmt(then_stmt)
233                || else_stmt.is_some_and(stmt_contains_unlowered_stmt)
234        }
235        StmtKind::Try(stmt_try) => {
236            stmt_try.clauses.iter().any(|clause| block_contains_unlowered_stmt(clause.block))
237        }
238        StmtKind::DeclSingle(_)
239        | StmtKind::DeclMulti(_, _)
240        | StmtKind::Emit(_)
241        | StmtKind::Revert(_)
242        | StmtKind::Return(_)
243        | StmtKind::Break
244        | StmtKind::Continue
245        | StmtKind::Expr(_)
246        | StmtKind::Placeholder => false,
247    }
248}
249
250fn collect_state_writes<'hir>(
251    hir: &'hir hir::Hir<'hir>,
252    block: hir::Block<'hir>,
253    candidates: &HashSet<hir::VariableId>,
254    writes: &mut HashSet<hir::VariableId>,
255) {
256    for stmt in block.stmts {
257        collect_stmt_writes(hir, stmt, candidates, writes);
258    }
259}
260
261fn collect_stmt_writes<'hir>(
262    hir: &'hir hir::Hir<'hir>,
263    stmt: &'hir hir::Stmt<'hir>,
264    candidates: &HashSet<hir::VariableId>,
265    writes: &mut HashSet<hir::VariableId>,
266) {
267    match &stmt.kind {
268        StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
269            collect_state_writes(hir, *block, candidates, writes);
270        }
271        StmtKind::If(condition, then_stmt, else_stmt) => {
272            collect_expr_writes(condition, candidates, writes);
273            collect_stmt_writes(hir, then_stmt, candidates, writes);
274            if let Some(else_stmt) = else_stmt {
275                collect_stmt_writes(hir, else_stmt, candidates, writes);
276            }
277        }
278        StmtKind::Try(stmt_try) => {
279            collect_expr_writes(&stmt_try.expr, candidates, writes);
280            for clause in stmt_try.clauses {
281                collect_state_writes(hir, clause.block, candidates, writes);
282            }
283        }
284        StmtKind::DeclSingle(var_id) => {
285            if let Some(initializer) = hir.variable(*var_id).initializer {
286                collect_expr_writes(initializer, candidates, writes);
287            }
288        }
289        StmtKind::DeclMulti(_, expr)
290        | StmtKind::Emit(expr)
291        | StmtKind::Revert(expr)
292        | StmtKind::Return(Some(expr))
293        | StmtKind::Expr(expr) => collect_expr_writes(expr, candidates, writes),
294        StmtKind::Return(None)
295        | StmtKind::Break
296        | StmtKind::Continue
297        | StmtKind::Placeholder
298        | StmtKind::Err(_) => {}
299    }
300}
301
302fn collect_expr_writes<'hir>(
303    expr: &'hir hir::Expr<'hir>,
304    candidates: &HashSet<hir::VariableId>,
305    writes: &mut HashSet<hir::VariableId>,
306) {
307    match &expr.kind {
308        ExprKind::Assign(lhs, _, rhs) => {
309            collect_lvalue_writes(lhs, candidates, writes);
310            collect_expr_writes(lhs, candidates, writes);
311            collect_expr_writes(rhs, candidates, writes);
312        }
313        ExprKind::Delete(inner) => {
314            collect_lvalue_writes(inner, candidates, writes);
315            collect_expr_writes(inner, candidates, writes);
316        }
317        ExprKind::Unary(op, inner) => {
318            if op.kind.has_side_effects() {
319                collect_lvalue_writes(inner, candidates, writes);
320            }
321            collect_expr_writes(inner, candidates, writes);
322        }
323        ExprKind::Array(exprs) => {
324            for expr in *exprs {
325                collect_expr_writes(expr, candidates, writes);
326            }
327        }
328        ExprKind::Binary(lhs, _, rhs) => {
329            collect_expr_writes(lhs, candidates, writes);
330            collect_expr_writes(rhs, candidates, writes);
331        }
332        ExprKind::Call(callee, args, named_args) => {
333            collect_expr_writes(callee, candidates, writes);
334            for expr in args.exprs() {
335                collect_expr_writes(expr, candidates, writes);
336            }
337            if let Some(named_args) = named_args {
338                for arg in *named_args {
339                    collect_expr_writes(&arg.value, candidates, writes);
340                }
341            }
342        }
343        ExprKind::Index(base, index) => {
344            collect_expr_writes(base, candidates, writes);
345            if let Some(index) = index {
346                collect_expr_writes(index, candidates, writes);
347            }
348        }
349        ExprKind::Slice(base, start, end) => {
350            collect_expr_writes(base, candidates, writes);
351            if let Some(start) = start {
352                collect_expr_writes(start, candidates, writes);
353            }
354            if let Some(end) = end {
355                collect_expr_writes(end, candidates, writes);
356            }
357        }
358        ExprKind::Member(base, _) | ExprKind::Payable(base) => {
359            collect_expr_writes(base, candidates, writes);
360        }
361        ExprKind::Ternary(condition, then_expr, else_expr) => {
362            collect_expr_writes(condition, candidates, writes);
363            collect_expr_writes(then_expr, candidates, writes);
364            collect_expr_writes(else_expr, candidates, writes);
365        }
366        ExprKind::Tuple(exprs) => {
367            for expr in exprs.iter().flatten() {
368                collect_expr_writes(expr, candidates, writes);
369            }
370        }
371        ExprKind::Ident(_)
372        | ExprKind::Lit(_)
373        | ExprKind::New(_)
374        | ExprKind::TypeCall(_)
375        | ExprKind::Type(_)
376        | ExprKind::Err(_) => {}
377    }
378}
379
380fn collect_lvalue_writes(
381    expr: &hir::Expr<'_>,
382    candidates: &HashSet<hir::VariableId>,
383    writes: &mut HashSet<hir::VariableId>,
384) {
385    match &expr.peel_parens().kind {
386        ExprKind::Ident([Res::Item(hir::ItemId::Variable(id)), ..]) if candidates.contains(id) => {
387            writes.insert(*id);
388        }
389        ExprKind::Tuple(exprs) => {
390            for expr in exprs.iter().flatten() {
391                collect_lvalue_writes(expr, candidates, writes);
392            }
393        }
394        ExprKind::Index(base, _)
395        | ExprKind::Slice(base, _, _)
396        | ExprKind::Member(base, _)
397        | ExprKind::Payable(base) => collect_lvalue_writes(base, candidates, writes),
398        _ => {}
399    }
400}
401
402fn is_compile_time_constant(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
403    match &expr.kind {
404        ExprKind::Lit(_) | ExprKind::Type(_) | ExprKind::TypeCall(_) => true,
405        ExprKind::Ident(resolutions) => {
406            let mut has_const_var = false;
407            let all_safe = resolutions.iter().all(|res| match res {
408                Res::Item(hir::ItemId::Variable(var_id)) => {
409                    let is_const = hir.variable(*var_id).is_constant();
410                    has_const_var |= is_const;
411                    is_const
412                }
413                Res::Item(hir::ItemId::Function(_)) => true,
414                _ => false,
415            });
416            all_safe && has_const_var
417        }
418        ExprKind::Unary(op, inner) => {
419            !matches!(
420                op.kind,
421                UnOpKind::PreInc | UnOpKind::PreDec | UnOpKind::PostInc | UnOpKind::PostDec
422            ) && is_compile_time_constant(hir, inner)
423        }
424        ExprKind::Binary(lhs, _, rhs) => {
425            is_compile_time_constant(hir, lhs) && is_compile_time_constant(hir, rhs)
426        }
427        ExprKind::Call(callee, args, named_args) => {
428            is_allowed_constant_call(callee)
429                && args.exprs().all(|expr| is_compile_time_constant(hir, expr))
430                && named_args.is_none_or(|args| {
431                    args.iter().all(|arg| is_compile_time_constant(hir, &arg.value))
432                })
433        }
434        ExprKind::Ternary(condition, then_expr, else_expr) => {
435            is_compile_time_constant(hir, condition)
436                && is_compile_time_constant(hir, then_expr)
437                && is_compile_time_constant(hir, else_expr)
438        }
439        ExprKind::Tuple(exprs) => {
440            exprs.iter().flatten().all(|expr| is_compile_time_constant(hir, expr))
441        }
442        // `type(T).min`/`type(T).max` for integer/enum types; `type(T).interfaceId` for
443        // interface types.
444        ExprKind::Member(base, member) => match (&base.kind, member.as_str()) {
445            (ExprKind::TypeCall(ty), "min" | "max") => matches!(
446                ty.kind,
447                TypeKind::Elementary(ast::ElementaryType::Int(_) | ast::ElementaryType::UInt(_))
448                    | TypeKind::Custom(hir::ItemId::Enum(_))
449            ),
450            (ExprKind::TypeCall(ty), "interfaceId") => matches!(
451                ty.kind,
452                TypeKind::Custom(hir::ItemId::Contract(cid))
453                    if hir.contract(cid).kind == ast::ContractKind::Interface
454            ),
455            _ => false,
456        },
457        ExprKind::Array(_)
458        | ExprKind::Assign(_, _, _)
459        | ExprKind::Delete(_)
460        | ExprKind::Index(_, _)
461        | ExprKind::Slice(_, _, _)
462        | ExprKind::New(_)
463        | ExprKind::Payable(_)
464        | ExprKind::Err(_) => false,
465    }
466}
467
468fn is_allowed_constant_call(callee: &hir::Expr<'_>) -> bool {
469    match &callee.kind {
470        // Type casts: `address(0xCAFE)`, `uint256(x)`, etc.
471        ExprKind::Type(_) => true,
472        // Contract/interface casts: `IToken(address(0xCAFE))`.
473        ExprKind::Ident([Res::Item(hir::ItemId::Contract(_)), ..]) => true,
474        ExprKind::Ident([Res::Builtin(builtin), ..]) => {
475            let name = builtin.name();
476            name == kw::Keccak256
477                || name == kw::Addmod
478                || name == kw::Mulmod
479                || name == sym::sha256
480                || name == sym::ripemd160
481                || name == sym::ecrecover
482        }
483        _ => false,
484    }
485}