1use solar::{
2 interface::data_structures::Never,
3 sema::{Gcx, hir},
4};
5use std::ops::ControlFlow;
6
7use super::LintContext;
8
9pub trait LateLintPass<'hir>: Send + Sync {
12 fn check_nested_source(
13 &mut self,
14 _ctx: &LintContext,
15 _hir: &'hir hir::Hir<'hir>,
16 _id: hir::SourceId,
17 ) {
18 }
19 fn check_nested_item(
20 &mut self,
21 _ctx: &LintContext,
22 _hir: &'hir hir::Hir<'hir>,
23 _id: hir::ItemId,
24 ) {
25 }
26 fn check_nested_contract(
27 &mut self,
28 _ctx: &LintContext,
29 _hir: &'hir hir::Hir<'hir>,
30 _id: hir::ContractId,
31 ) {
32 }
33 fn check_nested_function(
34 &mut self,
35 _ctx: &LintContext,
36 _hir: &'hir hir::Hir<'hir>,
37 _id: hir::FunctionId,
38 ) {
39 }
40 fn check_nested_var(
41 &mut self,
42 _ctx: &LintContext,
43 _hir: &'hir hir::Hir<'hir>,
44 _id: hir::VariableId,
45 ) {
46 }
47 fn check_item(
48 &mut self,
49 _ctx: &LintContext,
50 _hir: &'hir hir::Hir<'hir>,
51 _item: hir::Item<'hir, 'hir>,
52 ) {
53 }
54 fn check_contract(
55 &mut self,
56 _ctx: &LintContext,
57 _hir: &'hir hir::Hir<'hir>,
58 _contract: &'hir hir::Contract<'hir>,
59 ) {
60 }
61 fn check_function(
62 &mut self,
63 _ctx: &LintContext,
64 _hir: &'hir hir::Hir<'hir>,
65 _func: &'hir hir::Function<'hir>,
66 ) {
67 }
68 fn check_function_with_gcx(
69 &mut self,
70 ctx: &LintContext,
71 _gcx: Gcx<'hir>,
72 hir: &'hir hir::Hir<'hir>,
73 func: &'hir hir::Function<'hir>,
74 ) {
75 self.check_function(ctx, hir, func);
76 }
77 fn check_modifier(
78 &mut self,
79 _ctx: &LintContext,
80 _hir: &'hir hir::Hir<'hir>,
81 _mod: &'hir hir::Modifier<'hir>,
82 ) {
83 }
84 fn check_var(
85 &mut self,
86 _ctx: &LintContext,
87 _hir: &'hir hir::Hir<'hir>,
88 _var: &'hir hir::Variable<'hir>,
89 ) {
90 }
91 fn check_expr(
92 &mut self,
93 _ctx: &LintContext,
94 _hir: &'hir hir::Hir<'hir>,
95 _expr: &'hir hir::Expr<'hir>,
96 ) {
97 }
98 fn check_call_args(
99 &mut self,
100 _ctx: &LintContext,
101 _hir: &'hir hir::Hir<'hir>,
102 _args: &'hir hir::CallArgs<'hir>,
103 ) {
104 }
105 fn check_stmt(
106 &mut self,
107 _ctx: &LintContext,
108 _hir: &'hir hir::Hir<'hir>,
109 _stmt: &'hir hir::Stmt<'hir>,
110 ) {
111 }
112 fn check_ty(
113 &mut self,
114 _ctx: &LintContext,
115 _hir: &'hir hir::Hir<'hir>,
116 _ty: &'hir hir::Type<'hir>,
117 ) {
118 }
119}
120
121pub struct LateLintVisitor<'a, 's, 'hir> {
123 ctx: &'a LintContext<'s, 'a>,
124 passes: &'a mut [Box<dyn LateLintPass<'hir> + 's>],
125 gcx: Gcx<'hir>,
126 hir: &'hir hir::Hir<'hir>,
127}
128
129impl<'a, 's, 'hir> LateLintVisitor<'a, 's, 'hir>
130where
131 's: 'hir,
132{
133 pub fn new(
134 ctx: &'a LintContext<'s, 'a>,
135 passes: &'a mut [Box<dyn LateLintPass<'hir> + 's>],
136 gcx: Gcx<'hir>,
137 hir: &'hir hir::Hir<'hir>,
138 ) -> Self {
139 Self { ctx, passes, gcx, hir }
140 }
141}
142
143impl<'s, 'hir> hir::Visit<'hir> for LateLintVisitor<'_, 's, 'hir>
144where
145 's: 'hir,
146{
147 type BreakValue = Never;
148
149 fn hir(&self) -> &'hir hir::Hir<'hir> {
150 self.hir
151 }
152
153 fn visit_nested_source(&mut self, id: hir::SourceId) -> ControlFlow<Self::BreakValue> {
154 for pass in self.passes.iter_mut() {
155 pass.check_nested_source(self.ctx, self.hir, id);
156 }
157 self.walk_nested_source(id)
158 }
159
160 fn visit_nested_item(&mut self, id: hir::ItemId) -> ControlFlow<Self::BreakValue> {
161 for pass in self.passes.iter_mut() {
162 pass.check_nested_item(self.ctx, self.hir, id);
163 }
164 self.walk_nested_item(id)
165 }
166
167 fn visit_nested_contract(&mut self, id: hir::ContractId) -> ControlFlow<Self::BreakValue> {
168 for pass in self.passes.iter_mut() {
169 pass.check_nested_contract(self.ctx, self.hir, id);
170 }
171 self.walk_nested_contract(id)
172 }
173
174 fn visit_nested_function(&mut self, id: hir::FunctionId) -> ControlFlow<Self::BreakValue> {
175 for pass in self.passes.iter_mut() {
176 pass.check_nested_function(self.ctx, self.hir, id);
177 }
178 self.walk_nested_function(id)
179 }
180
181 fn visit_nested_var(&mut self, id: hir::VariableId) -> ControlFlow<Self::BreakValue> {
182 for pass in self.passes.iter_mut() {
183 pass.check_nested_var(self.ctx, self.hir, id);
184 }
185 self.walk_nested_var(id)
186 }
187
188 fn visit_contract(
189 &mut self,
190 contract: &'hir hir::Contract<'hir>,
191 ) -> ControlFlow<Self::BreakValue> {
192 for pass in self.passes.iter_mut() {
193 pass.check_contract(self.ctx, self.hir, contract);
194 }
195 self.walk_contract(contract)
196 }
197
198 fn visit_function(&mut self, func: &'hir hir::Function<'hir>) -> ControlFlow<Self::BreakValue> {
199 for pass in self.passes.iter_mut() {
200 pass.check_function_with_gcx(self.ctx, self.gcx, self.hir, func);
201 }
202 self.walk_function(func)
203 }
204
205 fn visit_modifier(
206 &mut self,
207 modifier: &'hir hir::Modifier<'hir>,
208 ) -> ControlFlow<Self::BreakValue> {
209 for pass in self.passes.iter_mut() {
210 pass.check_modifier(self.ctx, self.hir, modifier);
211 }
212 self.walk_modifier(modifier)
213 }
214
215 fn visit_item(&mut self, item: hir::Item<'hir, 'hir>) -> ControlFlow<Self::BreakValue> {
216 for pass in self.passes.iter_mut() {
217 pass.check_item(self.ctx, self.hir, item);
218 }
219 self.walk_item(item)
220 }
221
222 fn visit_var(&mut self, var: &'hir hir::Variable<'hir>) -> ControlFlow<Self::BreakValue> {
223 for pass in self.passes.iter_mut() {
224 pass.check_var(self.ctx, self.hir, var);
225 }
226 self.walk_var(var)
227 }
228
229 fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
230 for pass in self.passes.iter_mut() {
231 pass.check_expr(self.ctx, self.hir, expr);
232 }
233 self.walk_expr(expr)
234 }
235
236 fn visit_call_args(
237 &mut self,
238 args: &'hir hir::CallArgs<'hir>,
239 ) -> ControlFlow<Self::BreakValue> {
240 for pass in self.passes.iter_mut() {
241 pass.check_call_args(self.ctx, self.hir, args);
242 }
243 self.walk_call_args(args)
244 }
245
246 fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
247 for pass in self.passes.iter_mut() {
248 pass.check_stmt(self.ctx, self.hir, stmt);
249 }
250 self.walk_stmt(stmt)
251 }
252
253 fn visit_ty(&mut self, ty: &'hir hir::Type<'hir>) -> ControlFlow<Self::BreakValue> {
254 for pass in self.passes.iter_mut() {
255 pass.check_ty(self.ctx, self.hir, ty);
256 }
257 self.walk_ty(ty)
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use crate::linter::LinterConfig;
265 use foundry_common::comments::inline_config::InlineConfig;
266 use foundry_config::lint::LintSpecificConfig;
267 use solar::{
268 interface::{Session, source_map::FileName},
269 sema::Compiler,
270 };
271 use std::sync::{Arc, Mutex};
272
273 #[derive(Debug, Default)]
274 struct HookCounts {
275 nested_item: usize,
276 nested_contract: usize,
277 nested_function: usize,
278 nested_var: usize,
279 modifier: usize,
280 call_args: usize,
281 }
282
283 struct RecordingPass {
284 counts: Arc<Mutex<HookCounts>>,
285 }
286
287 impl RecordingPass {
288 fn record(&self, update: impl FnOnce(&mut HookCounts)) {
289 update(&mut self.counts.lock().unwrap());
290 }
291 }
292
293 impl<'hir> LateLintPass<'hir> for RecordingPass {
294 fn check_nested_item(
295 &mut self,
296 _ctx: &LintContext,
297 _hir: &'hir hir::Hir<'hir>,
298 _id: hir::ItemId,
299 ) {
300 self.record(|counts| counts.nested_item += 1);
301 }
302
303 fn check_nested_contract(
304 &mut self,
305 _ctx: &LintContext,
306 _hir: &'hir hir::Hir<'hir>,
307 _id: hir::ContractId,
308 ) {
309 self.record(|counts| counts.nested_contract += 1);
310 }
311
312 fn check_nested_function(
313 &mut self,
314 _ctx: &LintContext,
315 _hir: &'hir hir::Hir<'hir>,
316 _id: hir::FunctionId,
317 ) {
318 self.record(|counts| counts.nested_function += 1);
319 }
320
321 fn check_nested_var(
322 &mut self,
323 _ctx: &LintContext,
324 _hir: &'hir hir::Hir<'hir>,
325 _id: hir::VariableId,
326 ) {
327 self.record(|counts| counts.nested_var += 1);
328 }
329
330 fn check_modifier(
331 &mut self,
332 _ctx: &LintContext,
333 _hir: &'hir hir::Hir<'hir>,
334 _modifier: &'hir hir::Modifier<'hir>,
335 ) {
336 self.record(|counts| counts.modifier += 1);
337 }
338
339 fn check_call_args(
340 &mut self,
341 _ctx: &LintContext,
342 _hir: &'hir hir::Hir<'hir>,
343 _args: &'hir hir::CallArgs<'hir>,
344 ) {
345 self.record(|counts| counts.call_args += 1);
346 }
347 }
348
349 #[test]
350 fn calls_hooks_for_nested_items_modifiers_and_call_args() {
351 let counts = Arc::new(Mutex::new(HookCounts::default()));
352 let inline = InlineConfig::default();
353 let lint_specific = LintSpecificConfig::default();
354 let source = r#"
355 pragma solidity ^0.8.20;
356
357 contract Base {
358 function hook(uint256 value) internal pure returns (uint256) {
359 return value;
360 }
361 }
362
363 contract Test is Base {
364 uint256 stored;
365
366 modifier gated(uint256 amount) {
367 _;
368 }
369
370 function run(uint256 amount) public gated(amount) returns (uint256) {
371 return hook(amount + stored);
372 }
373 }
374 "#;
375
376 let mut compiler =
377 Compiler::new(Session::builder().with_buffer_emitter(Default::default()).build());
378 compiler
379 .enter_mut(|compiler| -> solar::interface::Result<()> {
380 let mut pcx = compiler.parse();
381 pcx.set_resolve_imports(false);
382 let file = compiler
383 .sess()
384 .source_map()
385 .new_source_file(FileName::Stdin, source)
386 .expect("failed to create source file");
387 pcx.add_file(file);
388 pcx.parse();
389
390 let ControlFlow::Continue(()) = compiler.lower_asts()? else {
391 panic!("expected HIR lowering to continue");
392 };
393
394 let gcx = compiler.gcx();
395 let source_id = gcx.hir.source_ids().next().expect("expected one lowered source");
396 let ctx = LintContext::new(
397 gcx.sess,
398 false,
399 false,
400 LinterConfig { inline: &inline, lint_specific: &lint_specific },
401 Vec::new(),
402 None,
403 );
404 let mut passes: Vec<Box<dyn LateLintPass<'_>>> =
405 vec![Box::new(RecordingPass { counts: counts.clone() })];
406 let mut visitor = LateLintVisitor::new(&ctx, &mut passes, gcx, &gcx.hir);
407 let _ = hir::Visit::visit_nested_source(&mut visitor, source_id);
408 Ok(())
409 })
410 .expect("failed to lower test source");
411
412 let counts = counts.lock().unwrap();
413 assert!(counts.nested_item > 0, "expected nested item hook to run");
414 assert!(counts.nested_contract > 0, "expected nested contract hook to run");
415 assert!(counts.nested_function > 0, "expected nested function hook to run");
416 assert!(counts.nested_var > 0, "expected nested var hook to run");
417 assert!(counts.modifier > 0, "expected modifier hook to run");
418 assert!(counts.call_args > 0, "expected call args hook to run");
419 }
420}