1use super::UnchangedStateVariables;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{Severity, SolLint},
5};
6use solar::{
7 ast::{self, UnOpKind},
8 interface::{kw, sym},
9 sema::hir::{self, ExprKind, Res, StmtKind, TypeKind},
10};
11use std::collections::HashSet;
12
13declare_forge_lint!(
14 COULD_BE_IMMUTABLE,
15 Severity::Gas,
16 "could-be-immutable",
17 "state variable could be declared immutable"
18);
19
20declare_forge_lint!(
21 COULD_BE_CONSTANT,
22 Severity::Gas,
23 "could-be-constant",
24 "state variable could be declared constant"
25);
26
27impl<'hir> LateLintPass<'hir> for UnchangedStateVariables {
28 fn check_nested_contract(
29 &mut self,
30 ctx: &LintContext,
31 _gcx: solar::sema::Gcx<'hir>,
32 hir: &'hir hir::Hir<'hir>,
33 contract_id: hir::ContractId,
34 ) {
35 let contract = hir.contract(contract_id);
36 if contract.kind == ast::ContractKind::Interface {
37 return;
38 }
39 if !is_most_derived_contract(hir, contract_id) {
40 return;
41 }
42
43 let candidates: Vec<_> = contract
45 .linearized_bases
46 .iter()
47 .flat_map(|&contract_id| hir.contract(contract_id).variables())
48 .filter(|&id| is_constant_candidate_type(hir.variable(id)))
49 .collect();
50
51 if candidates.is_empty() {
52 return;
53 }
54 let candidate_set: HashSet<_> = candidates.iter().copied().collect();
55
56 if contract_contains_assembly_or_unknown_stmt(hir, contract) {
57 return;
58 }
59
60 let mut constructor_body_writes = HashSet::new();
61 let mut initializer_side_effect_writes = HashSet::new();
62 let mut runtime_writes = HashSet::new();
63 let mut non_constant_initializer = HashSet::new();
64
65 for &var_id in &candidates {
66 let var = hir.variable(var_id);
67 if let Some(expr) = var.initializer {
68 if !is_compile_time_constant(hir, expr) {
69 non_constant_initializer.insert(var_id);
70 }
71 collect_expr_writes(expr, &candidate_set, &mut initializer_side_effect_writes);
73 }
74 }
75
76 for &contract_id in contract.linearized_bases {
77 for function_id in hir.contract(contract_id).all_functions() {
78 let function = hir.function(function_id);
79 if function.is_constructor() {
80 collect_modifier_writes(
81 hir,
82 function,
83 &candidate_set,
84 &mut constructor_body_writes,
85 &mut runtime_writes,
86 &mut HashSet::new(),
87 );
88
89 if let Some(body) = function.body {
90 collect_state_writes(
91 hir,
92 body,
93 &candidate_set,
94 &mut constructor_body_writes,
95 );
96 }
97 } else {
98 let mut modifier_argument_writes = HashSet::new();
101 collect_modifier_writes(
102 hir,
103 function,
104 &candidate_set,
105 &mut modifier_argument_writes,
106 &mut runtime_writes,
107 &mut HashSet::new(),
108 );
109 runtime_writes.extend(modifier_argument_writes);
110
111 if let Some(body) = function.body {
112 collect_state_writes(hir, body, &candidate_set, &mut runtime_writes);
113 }
114 }
115 }
116 }
117
118 for &var_id in &candidates {
119 if runtime_writes.contains(&var_id) {
120 continue;
121 }
122 let var = hir.variable(var_id);
123 let span = var.name.map_or(var.span, |name| name.span);
124
125 let has_constant_initializer =
128 var.initializer.is_some_and(|expr| is_compile_time_constant(hir, expr));
129 if has_constant_initializer
130 && !constructor_body_writes.contains(&var_id)
131 && !initializer_side_effect_writes.contains(&var_id)
132 {
133 ctx.emit(&COULD_BE_CONSTANT, span);
134 continue;
135 }
136
137 if !is_immutable_candidate_type(var) {
140 continue;
141 }
142 if non_constant_initializer.contains(&var_id)
143 || constructor_body_writes.contains(&var_id)
144 {
145 ctx.emit(&COULD_BE_IMMUTABLE, span);
146 }
147 }
148 }
149}
150
151fn is_most_derived_contract(hir: &hir::Hir<'_>, contract_id: hir::ContractId) -> bool {
152 !hir.contracts()
153 .any(|contract| contract.linearized_bases.iter().skip(1).any(|&id| id == contract_id))
154}
155
156fn collect_modifier_writes<'hir>(
157 hir: &'hir hir::Hir<'hir>,
158 function: &'hir hir::Function<'hir>,
159 candidates: &HashSet<hir::VariableId>,
160 argument_writes: &mut HashSet<hir::VariableId>,
161 body_writes: &mut HashSet<hir::VariableId>,
162 visited_modifiers: &mut HashSet<hir::FunctionId>,
163) {
164 for modifier in function.modifiers {
165 for expr in modifier.args.exprs() {
166 collect_expr_writes(expr, candidates, argument_writes);
167 }
168
169 let Some(modifier_id) = modifier.id.as_function() else { continue };
170 if !visited_modifiers.insert(modifier_id) {
171 continue;
172 }
173
174 let modifier = hir.function(modifier_id);
175 let mut nested_argument_writes = HashSet::new();
176 collect_modifier_writes(
177 hir,
178 modifier,
179 candidates,
180 &mut nested_argument_writes,
181 body_writes,
182 visited_modifiers,
183 );
184 body_writes.extend(nested_argument_writes);
185 if let Some(body) = modifier.body {
186 collect_state_writes(hir, body, candidates, body_writes);
187 }
188 }
189}
190
191fn is_immutable_candidate_type(var: &hir::Variable<'_>) -> bool {
192 var.is_state_variable()
193 && var.mutability.is_none()
194 && match var.ty.kind {
195 TypeKind::Elementary(ty) => ty.is_value_type(),
196 TypeKind::Custom(hir::ItemId::Contract(_)) => true,
197 _ => false,
198 }
199}
200
201fn is_constant_candidate_type(var: &hir::Variable<'_>) -> bool {
203 var.is_state_variable()
204 && var.mutability.is_none()
205 && matches!(
206 var.ty.kind,
207 TypeKind::Elementary(_) | TypeKind::Custom(hir::ItemId::Contract(_))
208 )
209}
210
211fn contract_contains_assembly_or_unknown_stmt<'hir>(
212 hir: &'hir hir::Hir<'hir>,
213 contract: &'hir hir::Contract<'hir>,
214) -> bool {
215 contract.linearized_bases.iter().any(|&contract_id| {
216 hir.contract(contract_id).all_functions().any(|function_id| {
217 hir.function(function_id)
218 .body
219 .is_some_and(|body| block_contains_assembly_or_unknown_stmt(body))
220 })
221 })
222}
223
224fn block_contains_assembly_or_unknown_stmt(block: hir::Block<'_>) -> bool {
225 block.stmts.iter().any(stmt_contains_assembly_or_unknown_stmt)
226}
227
228fn stmt_contains_assembly_or_unknown_stmt(stmt: &hir::Stmt<'_>) -> bool {
229 match &stmt.kind {
230 StmtKind::AssemblyBlock(_) | StmtKind::Switch(_) | StmtKind::Err(_) => true,
231 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
232 block_contains_assembly_or_unknown_stmt(*block)
233 }
234 StmtKind::If(_, then_stmt, else_stmt) => {
235 stmt_contains_assembly_or_unknown_stmt(then_stmt)
236 || else_stmt.is_some_and(stmt_contains_assembly_or_unknown_stmt)
237 }
238 StmtKind::Try(stmt_try) => stmt_try
239 .clauses
240 .iter()
241 .any(|clause| block_contains_assembly_or_unknown_stmt(clause.block)),
242 StmtKind::DeclSingle(_)
243 | StmtKind::DeclMulti(_, _)
244 | StmtKind::Emit(_)
245 | StmtKind::Revert(_)
246 | StmtKind::Return(_)
247 | StmtKind::Break
248 | StmtKind::Continue
249 | StmtKind::Expr(_)
250 | StmtKind::Placeholder => false,
251 }
252}
253
254fn collect_state_writes<'hir>(
255 hir: &'hir hir::Hir<'hir>,
256 block: hir::Block<'hir>,
257 candidates: &HashSet<hir::VariableId>,
258 writes: &mut HashSet<hir::VariableId>,
259) {
260 for stmt in block.stmts {
261 collect_stmt_writes(hir, stmt, candidates, writes);
262 }
263}
264
265fn collect_stmt_writes<'hir>(
266 hir: &'hir hir::Hir<'hir>,
267 stmt: &'hir hir::Stmt<'hir>,
268 candidates: &HashSet<hir::VariableId>,
269 writes: &mut HashSet<hir::VariableId>,
270) {
271 match &stmt.kind {
272 StmtKind::Block(block) | StmtKind::UncheckedBlock(block) | StmtKind::Loop(block, _) => {
273 collect_state_writes(hir, *block, candidates, writes);
274 }
275 StmtKind::If(condition, then_stmt, else_stmt) => {
276 collect_expr_writes(condition, candidates, writes);
277 collect_stmt_writes(hir, then_stmt, candidates, writes);
278 if let Some(else_stmt) = else_stmt {
279 collect_stmt_writes(hir, else_stmt, candidates, writes);
280 }
281 }
282 StmtKind::Try(stmt_try) => {
283 collect_expr_writes(&stmt_try.expr, candidates, writes);
284 for clause in stmt_try.clauses {
285 collect_state_writes(hir, clause.block, candidates, writes);
286 }
287 }
288 StmtKind::DeclSingle(var_id) => {
289 if let Some(initializer) = hir.variable(*var_id).initializer {
290 collect_expr_writes(initializer, candidates, writes);
291 }
292 }
293 StmtKind::DeclMulti(_, expr)
294 | StmtKind::Emit(expr)
295 | StmtKind::Revert(expr)
296 | StmtKind::Return(Some(expr))
297 | StmtKind::Expr(expr) => collect_expr_writes(expr, candidates, writes),
298 StmtKind::Return(None)
299 | StmtKind::Break
300 | StmtKind::Continue
301 | StmtKind::Placeholder
302 | StmtKind::AssemblyBlock(_)
303 | StmtKind::Switch(_)
304 | StmtKind::Err(_) => {}
305 }
306}
307
308fn collect_expr_writes<'hir>(
309 expr: &'hir hir::Expr<'hir>,
310 candidates: &HashSet<hir::VariableId>,
311 writes: &mut HashSet<hir::VariableId>,
312) {
313 match &expr.kind {
314 ExprKind::Assign(lhs, _, rhs) => {
315 collect_lvalue_writes(lhs, candidates, writes);
316 collect_expr_writes(lhs, candidates, writes);
317 collect_expr_writes(rhs, candidates, writes);
318 }
319 ExprKind::Delete(inner) => {
320 collect_lvalue_writes(inner, candidates, writes);
321 collect_expr_writes(inner, candidates, writes);
322 }
323 ExprKind::Unary(op, inner) => {
324 if op.kind.has_side_effects() {
325 collect_lvalue_writes(inner, candidates, writes);
326 }
327 collect_expr_writes(inner, candidates, writes);
328 }
329 ExprKind::Array(exprs) => {
330 for expr in *exprs {
331 collect_expr_writes(expr, candidates, writes);
332 }
333 }
334 ExprKind::Binary(lhs, _, rhs) => {
335 collect_expr_writes(lhs, candidates, writes);
336 collect_expr_writes(rhs, candidates, writes);
337 }
338 ExprKind::Call(callee, args, named_args) => {
339 collect_expr_writes(callee, candidates, writes);
340 for expr in args.exprs() {
341 collect_expr_writes(expr, candidates, writes);
342 }
343 if let Some(named_args) = named_args {
344 for arg in named_args.args {
345 collect_expr_writes(&arg.value, candidates, writes);
346 }
347 }
348 }
349 ExprKind::Index(base, index) => {
350 collect_expr_writes(base, candidates, writes);
351 if let Some(index) = index {
352 collect_expr_writes(index, candidates, writes);
353 }
354 }
355 ExprKind::Slice(base, start, end) => {
356 collect_expr_writes(base, candidates, writes);
357 if let Some(start) = start {
358 collect_expr_writes(start, candidates, writes);
359 }
360 if let Some(end) = end {
361 collect_expr_writes(end, candidates, writes);
362 }
363 }
364 ExprKind::Member(base, _) | ExprKind::Payable(base) => {
365 collect_expr_writes(base, candidates, writes);
366 }
367 ExprKind::Ternary(condition, then_expr, else_expr) => {
368 collect_expr_writes(condition, candidates, writes);
369 collect_expr_writes(then_expr, candidates, writes);
370 collect_expr_writes(else_expr, candidates, writes);
371 }
372 ExprKind::Tuple(exprs) => {
373 for expr in exprs.iter().flatten() {
374 collect_expr_writes(expr, candidates, writes);
375 }
376 }
377 ExprKind::Ident(_)
378 | ExprKind::Lit(_)
379 | ExprKind::New(_)
380 | ExprKind::TypeCall(_)
381 | ExprKind::Type(_)
382 | ExprKind::YulMember(..)
383 | ExprKind::Err(_) => {}
384 }
385}
386
387fn collect_lvalue_writes(
388 expr: &hir::Expr<'_>,
389 candidates: &HashSet<hir::VariableId>,
390 writes: &mut HashSet<hir::VariableId>,
391) {
392 match &expr.peel_parens().kind {
393 ExprKind::Ident([Res::Item(hir::ItemId::Variable(id)), ..]) if candidates.contains(id) => {
394 writes.insert(*id);
395 }
396 ExprKind::Tuple(exprs) => {
397 for expr in exprs.iter().flatten() {
398 collect_lvalue_writes(expr, candidates, writes);
399 }
400 }
401 ExprKind::Index(base, _)
402 | ExprKind::Slice(base, _, _)
403 | ExprKind::Member(base, _)
404 | ExprKind::Payable(base) => collect_lvalue_writes(base, candidates, writes),
405 _ => {}
406 }
407}
408
409fn is_compile_time_constant(hir: &hir::Hir<'_>, expr: &hir::Expr<'_>) -> bool {
410 match &expr.kind {
411 ExprKind::Lit(_) | ExprKind::Type(_) | ExprKind::TypeCall(_) => true,
412 ExprKind::Ident(resolutions) => {
413 let mut has_const_var = false;
414 let all_safe = resolutions.iter().all(|res| match res {
415 Res::Item(hir::ItemId::Variable(var_id)) => {
416 let is_const = hir.variable(*var_id).is_constant();
417 has_const_var |= is_const;
418 is_const
419 }
420 Res::Item(hir::ItemId::Function(_)) => true,
421 _ => false,
422 });
423 all_safe && has_const_var
424 }
425 ExprKind::Unary(op, inner) => {
426 !matches!(
427 op.kind,
428 UnOpKind::PreInc | UnOpKind::PreDec | UnOpKind::PostInc | UnOpKind::PostDec
429 ) && is_compile_time_constant(hir, inner)
430 }
431 ExprKind::Binary(lhs, _, rhs) => {
432 is_compile_time_constant(hir, lhs) && is_compile_time_constant(hir, rhs)
433 }
434 ExprKind::Call(callee, args, named_args) => {
435 is_allowed_constant_call(callee)
436 && args.exprs().all(|expr| is_compile_time_constant(hir, expr))
437 && named_args.is_none_or(|args| {
438 args.args.iter().all(|arg| is_compile_time_constant(hir, &arg.value))
439 })
440 }
441 ExprKind::Ternary(condition, then_expr, else_expr) => {
442 is_compile_time_constant(hir, condition)
443 && is_compile_time_constant(hir, then_expr)
444 && is_compile_time_constant(hir, else_expr)
445 }
446 ExprKind::Tuple(exprs) => {
447 exprs.iter().flatten().all(|expr| is_compile_time_constant(hir, expr))
448 }
449 ExprKind::Member(base, member) => match (&base.kind, member.as_str()) {
452 (ExprKind::TypeCall(ty), "min" | "max") => matches!(
453 ty.kind,
454 TypeKind::Elementary(ast::ElementaryType::Int(_) | ast::ElementaryType::UInt(_))
455 | TypeKind::Custom(hir::ItemId::Enum(_))
456 ),
457 (ExprKind::TypeCall(ty), "interfaceId") => matches!(
458 ty.kind,
459 TypeKind::Custom(hir::ItemId::Contract(cid))
460 if hir.contract(cid).kind == ast::ContractKind::Interface
461 ),
462 _ => false,
463 },
464 ExprKind::Array(_)
465 | ExprKind::Assign(_, _, _)
466 | ExprKind::Delete(_)
467 | ExprKind::Index(_, _)
468 | ExprKind::Slice(_, _, _)
469 | ExprKind::New(_)
470 | ExprKind::Payable(_)
471 | ExprKind::YulMember(..)
472 | ExprKind::Err(_) => false,
473 }
474}
475
476fn is_allowed_constant_call(callee: &hir::Expr<'_>) -> bool {
477 match &callee.kind {
478 ExprKind::Type(_) => true,
480 ExprKind::Ident([Res::Item(hir::ItemId::Contract(_)), ..]) => true,
482 ExprKind::Ident([Res::Builtin(builtin), ..]) => {
483 let name = builtin.name();
484 name == kw::Keccak256
485 || name == kw::Addmod
486 || name == kw::Mulmod
487 || name == sym::sha256
488 || name == sym::ripemd160
489 || name == sym::ecrecover
490 }
491 _ => false,
492 }
493}