1use super::UnchangedStateVariables;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{Severity, SolLint},
5};
6use solar::{
7 ast::{self, UnOpKind},
8 interface::{kw, sym},
9 sema::hir::{self, ExprKind, Res, StmtKind, TypeKind},
10};
11use std::collections::HashSet;
12
13declare_forge_lint!(
14 COULD_BE_IMMUTABLE,
15 Severity::Gas,
16 "could-be-immutable",
17 "state variable could be declared immutable"
18);
19
20declare_forge_lint!(
21 COULD_BE_CONSTANT,
22 Severity::Gas,
23 "could-be-constant",
24 "state variable could be declared constant"
25);
26
27impl<'hir> LateLintPass<'hir> for UnchangedStateVariables {
28 fn check_nested_contract(
29 &mut self,
30 ctx: &LintContext,
31 hir: &'hir hir::Hir<'hir>,
32 contract_id: hir::ContractId,
33 ) {
34 let contract = hir.contract(contract_id);
35 if contract.kind == ast::ContractKind::Interface {
36 return;
37 }
38 if !is_most_derived_contract(hir, contract_id) {
39 return;
40 }
41
42 let candidates: Vec<_> = contract
44 .linearized_bases
45 .iter()
46 .flat_map(|&contract_id| hir.contract(contract_id).variables())
47 .filter(|&id| is_constant_candidate_type(hir.variable(id)))
48 .collect();
49
50 if candidates.is_empty() {
51 return;
52 }
53 let candidate_set: HashSet<_> = candidates.iter().copied().collect();
54
55 if contract_contains_unlowered_stmt(hir, contract) {
56 return;
57 }
58
59 let mut constructor_body_writes = HashSet::new();
60 let mut initializer_side_effect_writes = HashSet::new();
61 let mut runtime_writes = HashSet::new();
62 let mut non_constant_initializer = HashSet::new();
63
64 for &var_id in &candidates {
65 let var = hir.variable(var_id);
66 if let Some(expr) = var.initializer {
67 if !is_compile_time_constant(hir, expr) {
68 non_constant_initializer.insert(var_id);
69 }
70 collect_expr_writes(expr, &candidate_set, &mut initializer_side_effect_writes);
72 }
73 }
74
75 for &contract_id in contract.linearized_bases {
76 for function_id in hir.contract(contract_id).all_functions() {
77 let function = hir.function(function_id);
78 if function.is_constructor() {
79 collect_modifier_writes(
80 hir,
81 function,
82 &candidate_set,
83 &mut constructor_body_writes,
84 &mut runtime_writes,
85 &mut HashSet::new(),
86 );
87
88 if let Some(body) = function.body {
89 collect_state_writes(
90 hir,
91 body,
92 &candidate_set,
93 &mut constructor_body_writes,
94 );
95 }
96 } else {
97 let mut modifier_argument_writes = HashSet::new();
100 collect_modifier_writes(
101 hir,
102 function,
103 &candidate_set,
104 &mut modifier_argument_writes,
105 &mut runtime_writes,
106 &mut HashSet::new(),
107 );
108 runtime_writes.extend(modifier_argument_writes);
109
110 if let Some(body) = function.body {
111 collect_state_writes(hir, body, &candidate_set, &mut runtime_writes);
112 }
113 }
114 }
115 }
116
117 for &var_id in &candidates {
118 if runtime_writes.contains(&var_id) {
119 continue;
120 }
121 let var = hir.variable(var_id);
122 let span = var.name.map_or(var.span, |name| name.span);
123
124 let has_constant_initializer =
127 var.initializer.is_some_and(|expr| is_compile_time_constant(hir, expr));
128 if has_constant_initializer
129 && !constructor_body_writes.contains(&var_id)
130 && !initializer_side_effect_writes.contains(&var_id)
131 {
132 ctx.emit(&COULD_BE_CONSTANT, span);
133 continue;
134 }
135
136 if !is_immutable_candidate_type(var) {
139 continue;
140 }
141 if non_constant_initializer.contains(&var_id)
142 || constructor_body_writes.contains(&var_id)
143 {
144 ctx.emit(&COULD_BE_IMMUTABLE, span);
145 }
146 }
147 }
148}
149
150fn is_most_derived_contract(hir: &hir::Hir<'_>, contract_id: hir::ContractId) -> bool {
151 !hir.contracts()
152 .any(|contract| contract.linearized_bases.iter().skip(1).any(|&id| id == contract_id))
153}
154
155fn collect_modifier_writes<'hir>(
156 hir: &'hir hir::Hir<'hir>,
157 function: &'hir hir::Function<'hir>,
158 candidates: &HashSet<hir::VariableId>,
159 argument_writes: &mut HashSet<hir::VariableId>,
160 body_writes: &mut HashSet<hir::VariableId>,
161 visited_modifiers: &mut HashSet<hir::FunctionId>,
162) {
163 for modifier in function.modifiers {
164 for expr in modifier.args.exprs() {
165 collect_expr_writes(expr, candidates, argument_writes);
166 }
167
168 let Some(modifier_id) = modifier.id.as_function() else { continue };
169 if !visited_modifiers.insert(modifier_id) {
170 continue;
171 }
172
173 let modifier = hir.function(modifier_id);
174 let mut nested_argument_writes = HashSet::new();
175 collect_modifier_writes(
176 hir,
177 modifier,
178 candidates,
179 &mut nested_argument_writes,
180 body_writes,
181 visited_modifiers,
182 );
183 body_writes.extend(nested_argument_writes);
184 if let Some(body) = modifier.body {
185 collect_state_writes(hir, body, candidates, body_writes);
186 }
187 }
188}
189
190fn is_immutable_candidate_type(var: &hir::Variable<'_>) -> bool {
191 var.is_state_variable()
192 && var.mutability.is_none()
193 && match var.ty.kind {
194 TypeKind::Elementary(ty) => ty.is_value_type(),
195 TypeKind::Custom(hir::ItemId::Contract(_)) => true,
196 _ => false,
197 }
198}
199
200fn is_constant_candidate_type(var: &hir::Variable<'_>) -> bool {
202 var.is_state_variable()
203 && var.mutability.is_none()
204 && matches!(
205 var.ty.kind,
206 TypeKind::Elementary(_) | TypeKind::Custom(hir::ItemId::Contract(_))
207 )
208}
209
210fn contract_contains_unlowered_stmt<'hir>(
211 hir: &'hir hir::Hir<'hir>,
212 contract: &'hir hir::Contract<'hir>,
213) -> bool {
214 contract.linearized_bases.iter().any(|&contract_id| {
215 hir.contract(contract_id).all_functions().any(|function_id| {
216 hir.function(function_id).body.is_some_and(|body| block_contains_unlowered_stmt(body))
217 })
218 })
219}
220
221fn block_contains_unlowered_stmt(block: hir::Block<'_>) -> bool {
222 block.stmts.iter().any(stmt_contains_unlowered_stmt)
223}
224
225fn stmt_contains_unlowered_stmt(stmt: &hir::Stmt<'_>) -> bool {
226 match &stmt.kind {
227 StmtKind::Err(_) => true,
228 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
229 block_contains_unlowered_stmt(*block)
230 }
231 StmtKind::If(_, then_stmt, else_stmt) => {
232 stmt_contains_unlowered_stmt(then_stmt)
233 || else_stmt.is_some_and(stmt_contains_unlowered_stmt)
234 }
235 StmtKind::Try(stmt_try) => {
236 stmt_try.clauses.iter().any(|clause| block_contains_unlowered_stmt(clause.block))
237 }
238 StmtKind::DeclSingle(_)
239 | StmtKind::DeclMulti(_, _)
240 | StmtKind::Emit(_)
241 | StmtKind::Revert(_)
242 | StmtKind::Return(_)
243 | StmtKind::Break
244 | StmtKind::Continue
245 | StmtKind::Expr(_)
246 | StmtKind::Placeholder => false,
247 }
248}
249
250fn collect_state_writes<'hir>(
251 hir: &'hir hir::Hir<'hir>,
252 block: hir::Block<'hir>,
253 candidates: &HashSet<hir::VariableId>,
254 writes: &mut HashSet<hir::VariableId>,
255) {
256 for stmt in block.stmts {
257 collect_stmt_writes(hir, stmt, candidates, writes);
258 }
259}
260
261fn collect_stmt_writes<'hir>(
262 hir: &'hir hir::Hir<'hir>,
263 stmt: &'hir hir::Stmt<'hir>,
264 candidates: &HashSet<hir::VariableId>,
265 writes: &mut HashSet<hir::VariableId>,
266) {
267 match &stmt.kind {
268 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
269 collect_state_writes(hir, *block, candidates, writes);
270 }
271 StmtKind::If(condition, then_stmt, else_stmt) => {
272 collect_expr_writes(condition, candidates, writes);
273 collect_stmt_writes(hir, then_stmt, candidates, writes);
274 if let Some(else_stmt) = else_stmt {
275 collect_stmt_writes(hir, else_stmt, candidates, writes);
276 }
277 }
278 StmtKind::Try(stmt_try) => {
279 collect_expr_writes(&stmt_try.expr, candidates, writes);
280 for clause in stmt_try.clauses {
281 collect_state_writes(hir, clause.block, candidates, writes);
282 }
283 }
284 StmtKind::DeclSingle(var_id) => {
285 if let Some(initializer) = hir.variable(*var_id).initializer {
286 collect_expr_writes(initializer, candidates, writes);
287 }
288 }
289 StmtKind::DeclMulti(_, expr)
290 | StmtKind::Emit(expr)
291 | StmtKind::Revert(expr)
292 | StmtKind::Return(Some(expr))
293 | StmtKind::Expr(expr) => collect_expr_writes(expr, candidates, writes),
294 StmtKind::Return(None)
295 | StmtKind::Break
296 | StmtKind::Continue
297 | StmtKind::Placeholder
298 | StmtKind::Err(_) => {}
299 }
300}
301
302fn collect_expr_writes<'hir>(
303 expr: &'hir hir::Expr<'hir>,
304 candidates: &HashSet<hir::VariableId>,
305 writes: &mut HashSet<hir::VariableId>,
306) {
307 match &expr.kind {
308 ExprKind::Assign(lhs, _, rhs) => {
309 collect_lvalue_writes(lhs, candidates, writes);
310 collect_expr_writes(lhs, candidates, writes);
311 collect_expr_writes(rhs, candidates, writes);
312 }
313 ExprKind::Delete(inner) => {
314 collect_lvalue_writes(inner, candidates, writes);
315 collect_expr_writes(inner, candidates, writes);
316 }
317 ExprKind::Unary(op, inner) => {
318 if op.kind.has_side_effects() {
319 collect_lvalue_writes(inner, candidates, writes);
320 }
321 collect_expr_writes(inner, candidates, writes);
322 }
323 ExprKind::Array(exprs) => {
324 for expr in *exprs {
325 collect_expr_writes(expr, candidates, writes);
326 }
327 }
328 ExprKind::Binary(lhs, _, rhs) => {
329 collect_expr_writes(lhs, candidates, writes);
330 collect_expr_writes(rhs, candidates, writes);
331 }
332 ExprKind::Call(callee, args, named_args) => {
333 collect_expr_writes(callee, candidates, writes);
334 for expr in args.exprs() {
335 collect_expr_writes(expr, candidates, writes);
336 }
337 if let Some(named_args) = named_args {
338 for arg in *named_args {
339 collect_expr_writes(&arg.value, candidates, writes);
340 }
341 }
342 }
343 ExprKind::Index(base, index) => {
344 collect_expr_writes(base, candidates, writes);
345 if let Some(index) = index {
346 collect_expr_writes(index, candidates, writes);
347 }
348 }
349 ExprKind::Slice(base, start, end) => {
350 collect_expr_writes(base, candidates, writes);
351 if let Some(start) = start {
352 collect_expr_writes(start, candidates, writes);
353 }
354 if let Some(end) = end {
355 collect_expr_writes(end, candidates, writes);
356 }
357 }
358 ExprKind::Member(base, _) | ExprKind::Payable(base) => {
359 collect_expr_writes(base, candidates, writes);
360 }
361 ExprKind::Ternary(condition, then_expr, else_expr) => {
362 collect_expr_writes(condition, candidates, writes);
363 collect_expr_writes(then_expr, candidates, writes);
364 collect_expr_writes(else_expr, candidates, writes);
365 }
366 ExprKind::Tuple(exprs) => {
367 for expr in exprs.iter().flatten() {
368 collect_expr_writes(expr, candidates, writes);
369 }
370 }
371 ExprKind::Ident(_)
372 | ExprKind::Lit(_)
373 | ExprKind::New(_)
374 | ExprKind::TypeCall(_)
375 | ExprKind::Type(_)
376 | ExprKind::Err(_) => {}
377 }
378}
379
380fn collect_lvalue_writes(
381 expr: &hir::Expr<'_>,
382 candidates: &HashSet<hir::VariableId>,
383 writes: &mut HashSet<hir::VariableId>,
384) {
385 match &expr.peel_parens().kind {
386 ExprKind::Ident([Res::Item(hir::ItemId::Variable(id)), ..]) if candidates.contains(id) => {
387 writes.insert(*id);
388 }
389 ExprKind::Tuple(exprs) => {
390 for expr in exprs.iter().flatten() {
391 collect_lvalue_writes(expr, candidates, writes);
392 }
393 }
394 ExprKind::Index(base, _)
395 | ExprKind::Slice(base, _, _)
396 | ExprKind::Member(base, _)
397 | ExprKind::Payable(base) => collect_lvalue_writes(base, candidates, writes),
398 _ => {}
399 }
400}
401
402fn is_compile_time_constant(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
403 match &expr.kind {
404 ExprKind::Lit(_) | ExprKind::Type(_) | ExprKind::TypeCall(_) => true,
405 ExprKind::Ident(resolutions) => {
406 let mut has_const_var = false;
407 let all_safe = resolutions.iter().all(|res| match res {
408 Res::Item(hir::ItemId::Variable(var_id)) => {
409 let is_const = hir.variable(*var_id).is_constant();
410 has_const_var |= is_const;
411 is_const
412 }
413 Res::Item(hir::ItemId::Function(_)) => true,
414 _ => false,
415 });
416 all_safe && has_const_var
417 }
418 ExprKind::Unary(op, inner) => {
419 !matches!(
420 op.kind,
421 UnOpKind::PreInc | UnOpKind::PreDec | UnOpKind::PostInc | UnOpKind::PostDec
422 ) && is_compile_time_constant(hir, inner)
423 }
424 ExprKind::Binary(lhs, _, rhs) => {
425 is_compile_time_constant(hir, lhs) && is_compile_time_constant(hir, rhs)
426 }
427 ExprKind::Call(callee, args, named_args) => {
428 is_allowed_constant_call(callee)
429 && args.exprs().all(|expr| is_compile_time_constant(hir, expr))
430 && named_args.is_none_or(|args| {
431 args.iter().all(|arg| is_compile_time_constant(hir, &arg.value))
432 })
433 }
434 ExprKind::Ternary(condition, then_expr, else_expr) => {
435 is_compile_time_constant(hir, condition)
436 && is_compile_time_constant(hir, then_expr)
437 && is_compile_time_constant(hir, else_expr)
438 }
439 ExprKind::Tuple(exprs) => {
440 exprs.iter().flatten().all(|expr| is_compile_time_constant(hir, expr))
441 }
442 ExprKind::Member(base, member) => match (&base.kind, member.as_str()) {
445 (ExprKind::TypeCall(ty), "min" | "max") => matches!(
446 ty.kind,
447 TypeKind::Elementary(ast::ElementaryType::Int(_) | ast::ElementaryType::UInt(_))
448 | TypeKind::Custom(hir::ItemId::Enum(_))
449 ),
450 (ExprKind::TypeCall(ty), "interfaceId") => matches!(
451 ty.kind,
452 TypeKind::Custom(hir::ItemId::Contract(cid))
453 if hir.contract(cid).kind == ast::ContractKind::Interface
454 ),
455 _ => false,
456 },
457 ExprKind::Array(_)
458 | ExprKind::Assign(_, _, _)
459 | ExprKind::Delete(_)
460 | ExprKind::Index(_, _)
461 | ExprKind::Slice(_, _, _)
462 | ExprKind::New(_)
463 | ExprKind::Payable(_)
464 | ExprKind::Err(_) => false,
465 }
466}
467
468fn is_allowed_constant_call(callee: &hir::Expr<'_>) -> bool {
469 match &callee.kind {
470 ExprKind::Type(_) => true,
472 ExprKind::Ident([Res::Item(hir::ItemId::Contract(_)), ..]) => true,
474 ExprKind::Ident([Res::Builtin(builtin), ..]) => {
475 let name = builtin.name();
476 name == kw::Keccak256
477 || name == kw::Addmod
478 || name == kw::Mulmod
479 || name == sym::sha256
480 || name == sym::ripemd160
481 || name == sym::ecrecover
482 }
483 _ => false,
484 }
485}