1use solar::{
2 interface::data_structures::Never,
3 sema::{Gcx, hir},
4};
5use std::ops::ControlFlow;
6
7use super::LintContext;
8
9pub trait LateLintPass<'hir>: Send + Sync {
12 fn check_nested_source(
13 &mut self,
14 _ctx: &LintContext,
15 _hir: &'hir hir::Hir<'hir>,
16 _id: hir::SourceId,
17 ) {
18 }
19 fn check_nested_item(
20 &mut self,
21 _ctx: &LintContext,
22 _hir: &'hir hir::Hir<'hir>,
23 _id: hir::ItemId,
24 ) {
25 }
26 fn check_nested_contract(
27 &mut self,
28 _ctx: &LintContext,
29 _gcx: Gcx<'hir>,
30 _hir: &'hir hir::Hir<'hir>,
31 _id: hir::ContractId,
32 ) {
33 }
34 fn check_nested_function(
35 &mut self,
36 _ctx: &LintContext,
37 _hir: &'hir hir::Hir<'hir>,
38 _id: hir::FunctionId,
39 ) {
40 }
41 fn check_nested_var(
42 &mut self,
43 _ctx: &LintContext,
44 _hir: &'hir hir::Hir<'hir>,
45 _id: hir::VariableId,
46 ) {
47 }
48 fn check_item(
49 &mut self,
50 _ctx: &LintContext,
51 _hir: &'hir hir::Hir<'hir>,
52 _item: hir::Item<'hir, 'hir>,
53 ) {
54 }
55 fn check_contract(
56 &mut self,
57 _ctx: &LintContext,
58 _gcx: Gcx<'hir>,
59 _hir: &'hir hir::Hir<'hir>,
60 _contract: &'hir hir::Contract<'hir>,
61 ) {
62 }
63 fn check_function(
64 &mut self,
65 _ctx: &LintContext,
66 _gcx: Gcx<'hir>,
67 _hir: &'hir hir::Hir<'hir>,
68 _func: &'hir hir::Function<'hir>,
69 ) {
70 }
71 fn check_modifier(
72 &mut self,
73 _ctx: &LintContext,
74 _hir: &'hir hir::Hir<'hir>,
75 _mod: &'hir hir::Modifier<'hir>,
76 ) {
77 }
78 fn check_var(
79 &mut self,
80 _ctx: &LintContext,
81 _hir: &'hir hir::Hir<'hir>,
82 _var: &'hir hir::Variable<'hir>,
83 ) {
84 }
85 fn check_expr(
86 &mut self,
87 _ctx: &LintContext,
88 _gcx: Gcx<'hir>,
89 _hir: &'hir hir::Hir<'hir>,
90 _expr: &'hir hir::Expr<'hir>,
91 ) {
92 }
93 fn check_call_args(
94 &mut self,
95 _ctx: &LintContext,
96 _hir: &'hir hir::Hir<'hir>,
97 _args: &'hir hir::CallArgs<'hir>,
98 ) {
99 }
100 fn check_stmt(
101 &mut self,
102 _ctx: &LintContext,
103 _gcx: Gcx<'hir>,
104 _hir: &'hir hir::Hir<'hir>,
105 _stmt: &'hir hir::Stmt<'hir>,
106 ) {
107 }
108 fn check_ty(
109 &mut self,
110 _ctx: &LintContext,
111 _hir: &'hir hir::Hir<'hir>,
112 _ty: &'hir hir::Type<'hir>,
113 ) {
114 }
115}
116
117pub struct LateLintVisitor<'a, 's, 'hir> {
119 ctx: &'a LintContext<'s, 'a>,
120 passes: &'a mut [Box<dyn LateLintPass<'hir> + 's>],
121 gcx: Gcx<'hir>,
122 hir: &'hir hir::Hir<'hir>,
123}
124
125impl<'a, 's, 'hir> LateLintVisitor<'a, 's, 'hir>
126where
127 's: 'hir,
128{
129 pub fn new(
130 ctx: &'a LintContext<'s, 'a>,
131 passes: &'a mut [Box<dyn LateLintPass<'hir> + 's>],
132 gcx: Gcx<'hir>,
133 hir: &'hir hir::Hir<'hir>,
134 ) -> Self {
135 Self { ctx, passes, gcx, hir }
136 }
137}
138
139impl<'s, 'hir> hir::Visit<'hir> for LateLintVisitor<'_, 's, 'hir>
140where
141 's: 'hir,
142{
143 type BreakValue = Never;
144
145 fn hir(&self) -> &'hir hir::Hir<'hir> {
146 self.hir
147 }
148
149 fn visit_nested_source(&mut self, id: hir::SourceId) -> ControlFlow<Self::BreakValue> {
150 for pass in self.passes.iter_mut() {
151 pass.check_nested_source(self.ctx, self.hir, id);
152 }
153 self.walk_nested_source(id)
154 }
155
156 fn visit_nested_item(&mut self, id: hir::ItemId) -> ControlFlow<Self::BreakValue> {
157 for pass in self.passes.iter_mut() {
158 pass.check_nested_item(self.ctx, self.hir, id);
159 }
160 self.walk_nested_item(id)
161 }
162
163 fn visit_nested_contract(&mut self, id: hir::ContractId) -> ControlFlow<Self::BreakValue> {
164 for pass in self.passes.iter_mut() {
165 pass.check_nested_contract(self.ctx, self.gcx, self.hir, id);
166 }
167 self.walk_nested_contract(id)
168 }
169
170 fn visit_nested_function(&mut self, id: hir::FunctionId) -> ControlFlow<Self::BreakValue> {
171 for pass in self.passes.iter_mut() {
172 pass.check_nested_function(self.ctx, self.hir, id);
173 }
174 self.walk_nested_function(id)
175 }
176
177 fn visit_nested_var(&mut self, id: hir::VariableId) -> ControlFlow<Self::BreakValue> {
178 for pass in self.passes.iter_mut() {
179 pass.check_nested_var(self.ctx, self.hir, id);
180 }
181 self.walk_nested_var(id)
182 }
183
184 fn visit_contract(
185 &mut self,
186 contract: &'hir hir::Contract<'hir>,
187 ) -> ControlFlow<Self::BreakValue> {
188 for pass in self.passes.iter_mut() {
189 pass.check_contract(self.ctx, self.gcx, self.hir, contract);
190 }
191 self.walk_contract(contract)
192 }
193
194 fn visit_function(&mut self, func: &'hir hir::Function<'hir>) -> ControlFlow<Self::BreakValue> {
195 for pass in self.passes.iter_mut() {
196 pass.check_function(self.ctx, self.gcx, self.hir, func);
197 }
198 self.walk_function(func)
199 }
200
201 fn visit_modifier(
202 &mut self,
203 modifier: &'hir hir::Modifier<'hir>,
204 ) -> ControlFlow<Self::BreakValue> {
205 for pass in self.passes.iter_mut() {
206 pass.check_modifier(self.ctx, self.hir, modifier);
207 }
208 self.walk_modifier(modifier)
209 }
210
211 fn visit_item(&mut self, item: hir::Item<'hir, 'hir>) -> ControlFlow<Self::BreakValue> {
212 for pass in self.passes.iter_mut() {
213 pass.check_item(self.ctx, self.hir, item);
214 }
215 self.walk_item(item)
216 }
217
218 fn visit_var(&mut self, var: &'hir hir::Variable<'hir>) -> ControlFlow<Self::BreakValue> {
219 for pass in self.passes.iter_mut() {
220 pass.check_var(self.ctx, self.hir, var);
221 }
222 self.walk_var(var)
223 }
224
225 fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
226 for pass in self.passes.iter_mut() {
227 pass.check_expr(self.ctx, self.gcx, self.hir, expr);
228 }
229 self.walk_expr(expr)
230 }
231
232 fn visit_call_args(
233 &mut self,
234 args: &'hir hir::CallArgs<'hir>,
235 ) -> ControlFlow<Self::BreakValue> {
236 for pass in self.passes.iter_mut() {
237 pass.check_call_args(self.ctx, self.hir, args);
238 }
239 self.walk_call_args(args)
240 }
241
242 fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
243 for pass in self.passes.iter_mut() {
244 pass.check_stmt(self.ctx, self.gcx, self.hir, stmt);
245 }
246 self.walk_stmt(stmt)
247 }
248
249 fn visit_ty(&mut self, ty: &'hir hir::Type<'hir>) -> ControlFlow<Self::BreakValue> {
250 for pass in self.passes.iter_mut() {
251 pass.check_ty(self.ctx, self.hir, ty);
252 }
253 self.walk_ty(ty)
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::linter::LinterConfig;
261 use foundry_common::comments::inline_config::InlineConfig;
262 use foundry_config::lint::LintSpecificConfig;
263 use solar::{
264 interface::{Session, source_map::FileName},
265 sema::Compiler,
266 };
267 use std::sync::{Arc, Mutex};
268
269 #[derive(Debug, Default)]
270 struct HookCounts {
271 nested_item: usize,
272 nested_contract: usize,
273 nested_function: usize,
274 nested_var: usize,
275 modifier: usize,
276 call_args: usize,
277 }
278
279 struct RecordingPass {
280 counts: Arc<Mutex<HookCounts>>,
281 }
282
283 impl RecordingPass {
284 fn record(&self, update: impl FnOnce(&mut HookCounts)) {
285 update(&mut self.counts.lock().unwrap());
286 }
287 }
288
289 impl<'hir> LateLintPass<'hir> for RecordingPass {
290 fn check_nested_item(
291 &mut self,
292 _ctx: &LintContext,
293 _hir: &'hir hir::Hir<'hir>,
294 _id: hir::ItemId,
295 ) {
296 self.record(|counts| counts.nested_item += 1);
297 }
298
299 fn check_nested_contract(
300 &mut self,
301 _ctx: &LintContext,
302 _gcx: solar::sema::Gcx<'hir>,
303 _hir: &'hir hir::Hir<'hir>,
304 _id: hir::ContractId,
305 ) {
306 self.record(|counts| counts.nested_contract += 1);
307 }
308
309 fn check_nested_function(
310 &mut self,
311 _ctx: &LintContext,
312 _hir: &'hir hir::Hir<'hir>,
313 _id: hir::FunctionId,
314 ) {
315 self.record(|counts| counts.nested_function += 1);
316 }
317
318 fn check_nested_var(
319 &mut self,
320 _ctx: &LintContext,
321 _hir: &'hir hir::Hir<'hir>,
322 _id: hir::VariableId,
323 ) {
324 self.record(|counts| counts.nested_var += 1);
325 }
326
327 fn check_modifier(
328 &mut self,
329 _ctx: &LintContext,
330 _hir: &'hir hir::Hir<'hir>,
331 _modifier: &'hir hir::Modifier<'hir>,
332 ) {
333 self.record(|counts| counts.modifier += 1);
334 }
335
336 fn check_call_args(
337 &mut self,
338 _ctx: &LintContext,
339 _hir: &'hir hir::Hir<'hir>,
340 _args: &'hir hir::CallArgs<'hir>,
341 ) {
342 self.record(|counts| counts.call_args += 1);
343 }
344 }
345
346 #[test]
347 fn calls_hooks_for_nested_items_modifiers_and_call_args() {
348 let counts = Arc::new(Mutex::new(HookCounts::default()));
349 let inline = InlineConfig::default();
350 let lint_specific = LintSpecificConfig::default();
351 let source = r#"
352 pragma solidity ^0.8.20;
353
354 contract Base {
355 function hook(uint256 value) internal pure returns (uint256) {
356 return value;
357 }
358 }
359
360 contract Test is Base {
361 uint256 stored;
362
363 modifier gated(uint256 amount) {
364 _;
365 }
366
367 function run(uint256 amount) public gated(amount) returns (uint256) {
368 return hook(amount + stored);
369 }
370 }
371 "#;
372
373 let mut compiler =
374 Compiler::new(Session::builder().with_buffer_emitter(Default::default()).build());
375 compiler
376 .enter_mut(|compiler| -> solar::interface::Result<()> {
377 let mut pcx = compiler.parse();
378 pcx.set_resolve_imports(false);
379 let file = compiler
380 .sess()
381 .source_map()
382 .new_source_file(FileName::Stdin, source)
383 .expect("failed to create source file");
384 pcx.add_file(file);
385 pcx.parse();
386
387 let ControlFlow::Continue(()) = compiler.lower_asts()? else {
388 panic!("expected HIR lowering to continue");
389 };
390
391 let gcx = compiler.gcx();
392 let source_id = gcx.hir.source_ids().next().expect("expected one lowered source");
393 let ctx = LintContext::new(
394 gcx.sess,
395 false,
396 false,
397 LinterConfig { inline: &inline, lint_specific: &lint_specific },
398 Vec::new(),
399 None,
400 );
401 let mut passes: Vec<Box<dyn LateLintPass<'_>>> =
402 vec![Box::new(RecordingPass { counts: counts.clone() })];
403 let mut visitor = LateLintVisitor::new(&ctx, &mut passes, gcx, &gcx.hir);
404 let _ = hir::Visit::visit_nested_source(&mut visitor, source_id);
405 Ok(())
406 })
407 .expect("failed to lower test source");
408
409 let counts = counts.lock().unwrap();
410 assert!(counts.nested_item > 0, "expected nested item hook to run");
411 assert!(counts.nested_contract > 0, "expected nested contract hook to run");
412 assert!(counts.nested_function > 0, "expected nested function hook to run");
413 assert!(counts.nested_var > 0, "expected nested var hook to run");
414 assert!(counts.modifier > 0, "expected modifier hook to run");
415 assert!(counts.call_args > 0, "expected call args hook to run");
416 }
417}