Skip to main content

forge_lint/linter/
late.rs

1use solar::{
2    interface::data_structures::Never,
3    sema::{Gcx, hir},
4};
5use std::ops::ControlFlow;
6
7use super::LintContext;
8
9/// Trait for lints that operate on the HIR (High-level Intermediate Representation).
10/// Its methods mirror `hir::visit::Visit`, with the addition of `LintContext`.
11pub trait LateLintPass<'hir>: Send + Sync {
12    fn check_nested_source(
13        &mut self,
14        _ctx: &LintContext,
15        _hir: &'hir hir::Hir<'hir>,
16        _id: hir::SourceId,
17    ) {
18    }
19    fn check_nested_item(
20        &mut self,
21        _ctx: &LintContext,
22        _hir: &'hir hir::Hir<'hir>,
23        _id: hir::ItemId,
24    ) {
25    }
26    fn check_nested_contract(
27        &mut self,
28        _ctx: &LintContext,
29        _hir: &'hir hir::Hir<'hir>,
30        _id: hir::ContractId,
31    ) {
32    }
33    fn check_nested_function(
34        &mut self,
35        _ctx: &LintContext,
36        _hir: &'hir hir::Hir<'hir>,
37        _id: hir::FunctionId,
38    ) {
39    }
40    fn check_nested_var(
41        &mut self,
42        _ctx: &LintContext,
43        _hir: &'hir hir::Hir<'hir>,
44        _id: hir::VariableId,
45    ) {
46    }
47    fn check_item(
48        &mut self,
49        _ctx: &LintContext,
50        _hir: &'hir hir::Hir<'hir>,
51        _item: hir::Item<'hir, 'hir>,
52    ) {
53    }
54    fn check_contract(
55        &mut self,
56        _ctx: &LintContext,
57        _hir: &'hir hir::Hir<'hir>,
58        _contract: &'hir hir::Contract<'hir>,
59    ) {
60    }
61    fn check_function(
62        &mut self,
63        _ctx: &LintContext,
64        _hir: &'hir hir::Hir<'hir>,
65        _func: &'hir hir::Function<'hir>,
66    ) {
67    }
68    fn check_function_with_gcx(
69        &mut self,
70        ctx: &LintContext,
71        _gcx: Gcx<'hir>,
72        hir: &'hir hir::Hir<'hir>,
73        func: &'hir hir::Function<'hir>,
74    ) {
75        self.check_function(ctx, hir, func);
76    }
77    fn check_modifier(
78        &mut self,
79        _ctx: &LintContext,
80        _hir: &'hir hir::Hir<'hir>,
81        _mod: &'hir hir::Modifier<'hir>,
82    ) {
83    }
84    fn check_var(
85        &mut self,
86        _ctx: &LintContext,
87        _hir: &'hir hir::Hir<'hir>,
88        _var: &'hir hir::Variable<'hir>,
89    ) {
90    }
91    fn check_expr(
92        &mut self,
93        _ctx: &LintContext,
94        _hir: &'hir hir::Hir<'hir>,
95        _expr: &'hir hir::Expr<'hir>,
96    ) {
97    }
98    fn check_call_args(
99        &mut self,
100        _ctx: &LintContext,
101        _hir: &'hir hir::Hir<'hir>,
102        _args: &'hir hir::CallArgs<'hir>,
103    ) {
104    }
105    fn check_stmt(
106        &mut self,
107        _ctx: &LintContext,
108        _hir: &'hir hir::Hir<'hir>,
109        _stmt: &'hir hir::Stmt<'hir>,
110    ) {
111    }
112    fn check_ty(
113        &mut self,
114        _ctx: &LintContext,
115        _hir: &'hir hir::Hir<'hir>,
116        _ty: &'hir hir::Type<'hir>,
117    ) {
118    }
119}
120
121/// Visitor struct for `LateLintPass`es
122pub struct LateLintVisitor<'a, 's, 'hir> {
123    ctx: &'a LintContext<'s, 'a>,
124    passes: &'a mut [Box<dyn LateLintPass<'hir> + 's>],
125    gcx: Gcx<'hir>,
126    hir: &'hir hir::Hir<'hir>,
127}
128
129impl<'a, 's, 'hir> LateLintVisitor<'a, 's, 'hir>
130where
131    's: 'hir,
132{
133    pub fn new(
134        ctx: &'a LintContext<'s, 'a>,
135        passes: &'a mut [Box<dyn LateLintPass<'hir> + 's>],
136        gcx: Gcx<'hir>,
137        hir: &'hir hir::Hir<'hir>,
138    ) -> Self {
139        Self { ctx, passes, gcx, hir }
140    }
141}
142
143impl<'s, 'hir> hir::Visit<'hir> for LateLintVisitor<'_, 's, 'hir>
144where
145    's: 'hir,
146{
147    type BreakValue = Never;
148
149    fn hir(&self) -> &'hir hir::Hir<'hir> {
150        self.hir
151    }
152
153    fn visit_nested_source(&mut self, id: hir::SourceId) -> ControlFlow<Self::BreakValue> {
154        for pass in self.passes.iter_mut() {
155            pass.check_nested_source(self.ctx, self.hir, id);
156        }
157        self.walk_nested_source(id)
158    }
159
160    fn visit_nested_item(&mut self, id: hir::ItemId) -> ControlFlow<Self::BreakValue> {
161        for pass in self.passes.iter_mut() {
162            pass.check_nested_item(self.ctx, self.hir, id);
163        }
164        self.walk_nested_item(id)
165    }
166
167    fn visit_nested_contract(&mut self, id: hir::ContractId) -> ControlFlow<Self::BreakValue> {
168        for pass in self.passes.iter_mut() {
169            pass.check_nested_contract(self.ctx, self.hir, id);
170        }
171        self.walk_nested_contract(id)
172    }
173
174    fn visit_nested_function(&mut self, id: hir::FunctionId) -> ControlFlow<Self::BreakValue> {
175        for pass in self.passes.iter_mut() {
176            pass.check_nested_function(self.ctx, self.hir, id);
177        }
178        self.walk_nested_function(id)
179    }
180
181    fn visit_nested_var(&mut self, id: hir::VariableId) -> ControlFlow<Self::BreakValue> {
182        for pass in self.passes.iter_mut() {
183            pass.check_nested_var(self.ctx, self.hir, id);
184        }
185        self.walk_nested_var(id)
186    }
187
188    fn visit_contract(
189        &mut self,
190        contract: &'hir hir::Contract<'hir>,
191    ) -> ControlFlow<Self::BreakValue> {
192        for pass in self.passes.iter_mut() {
193            pass.check_contract(self.ctx, self.hir, contract);
194        }
195        self.walk_contract(contract)
196    }
197
198    fn visit_function(&mut self, func: &'hir hir::Function<'hir>) -> ControlFlow<Self::BreakValue> {
199        for pass in self.passes.iter_mut() {
200            pass.check_function_with_gcx(self.ctx, self.gcx, self.hir, func);
201        }
202        self.walk_function(func)
203    }
204
205    fn visit_modifier(
206        &mut self,
207        modifier: &'hir hir::Modifier<'hir>,
208    ) -> ControlFlow<Self::BreakValue> {
209        for pass in self.passes.iter_mut() {
210            pass.check_modifier(self.ctx, self.hir, modifier);
211        }
212        self.walk_modifier(modifier)
213    }
214
215    fn visit_item(&mut self, item: hir::Item<'hir, 'hir>) -> ControlFlow<Self::BreakValue> {
216        for pass in self.passes.iter_mut() {
217            pass.check_item(self.ctx, self.hir, item);
218        }
219        self.walk_item(item)
220    }
221
222    fn visit_var(&mut self, var: &'hir hir::Variable<'hir>) -> ControlFlow<Self::BreakValue> {
223        for pass in self.passes.iter_mut() {
224            pass.check_var(self.ctx, self.hir, var);
225        }
226        self.walk_var(var)
227    }
228
229    fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
230        for pass in self.passes.iter_mut() {
231            pass.check_expr(self.ctx, self.hir, expr);
232        }
233        self.walk_expr(expr)
234    }
235
236    fn visit_call_args(
237        &mut self,
238        args: &'hir hir::CallArgs<'hir>,
239    ) -> ControlFlow<Self::BreakValue> {
240        for pass in self.passes.iter_mut() {
241            pass.check_call_args(self.ctx, self.hir, args);
242        }
243        self.walk_call_args(args)
244    }
245
246    fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
247        for pass in self.passes.iter_mut() {
248            pass.check_stmt(self.ctx, self.hir, stmt);
249        }
250        self.walk_stmt(stmt)
251    }
252
253    fn visit_ty(&mut self, ty: &'hir hir::Type<'hir>) -> ControlFlow<Self::BreakValue> {
254        for pass in self.passes.iter_mut() {
255            pass.check_ty(self.ctx, self.hir, ty);
256        }
257        self.walk_ty(ty)
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use crate::linter::LinterConfig;
265    use foundry_common::comments::inline_config::InlineConfig;
266    use foundry_config::lint::LintSpecificConfig;
267    use solar::{
268        interface::{Session, source_map::FileName},
269        sema::Compiler,
270    };
271    use std::sync::{Arc, Mutex};
272
273    #[derive(Debug, Default)]
274    struct HookCounts {
275        nested_item: usize,
276        nested_contract: usize,
277        nested_function: usize,
278        nested_var: usize,
279        modifier: usize,
280        call_args: usize,
281    }
282
283    struct RecordingPass {
284        counts: Arc<Mutex<HookCounts>>,
285    }
286
287    impl RecordingPass {
288        fn record(&self, update: impl FnOnce(&mut HookCounts)) {
289            update(&mut self.counts.lock().unwrap());
290        }
291    }
292
293    impl<'hir> LateLintPass<'hir> for RecordingPass {
294        fn check_nested_item(
295            &mut self,
296            _ctx: &LintContext,
297            _hir: &'hir hir::Hir<'hir>,
298            _id: hir::ItemId,
299        ) {
300            self.record(|counts| counts.nested_item += 1);
301        }
302
303        fn check_nested_contract(
304            &mut self,
305            _ctx: &LintContext,
306            _hir: &'hir hir::Hir<'hir>,
307            _id: hir::ContractId,
308        ) {
309            self.record(|counts| counts.nested_contract += 1);
310        }
311
312        fn check_nested_function(
313            &mut self,
314            _ctx: &LintContext,
315            _hir: &'hir hir::Hir<'hir>,
316            _id: hir::FunctionId,
317        ) {
318            self.record(|counts| counts.nested_function += 1);
319        }
320
321        fn check_nested_var(
322            &mut self,
323            _ctx: &LintContext,
324            _hir: &'hir hir::Hir<'hir>,
325            _id: hir::VariableId,
326        ) {
327            self.record(|counts| counts.nested_var += 1);
328        }
329
330        fn check_modifier(
331            &mut self,
332            _ctx: &LintContext,
333            _hir: &'hir hir::Hir<'hir>,
334            _modifier: &'hir hir::Modifier<'hir>,
335        ) {
336            self.record(|counts| counts.modifier += 1);
337        }
338
339        fn check_call_args(
340            &mut self,
341            _ctx: &LintContext,
342            _hir: &'hir hir::Hir<'hir>,
343            _args: &'hir hir::CallArgs<'hir>,
344        ) {
345            self.record(|counts| counts.call_args += 1);
346        }
347    }
348
349    #[test]
350    fn calls_hooks_for_nested_items_modifiers_and_call_args() {
351        let counts = Arc::new(Mutex::new(HookCounts::default()));
352        let inline = InlineConfig::default();
353        let lint_specific = LintSpecificConfig::default();
354        let source = r#"
355            pragma solidity ^0.8.20;
356
357            contract Base {
358                function hook(uint256 value) internal pure returns (uint256) {
359                    return value;
360                }
361            }
362
363            contract Test is Base {
364                uint256 stored;
365
366                modifier gated(uint256 amount) {
367                    _;
368                }
369
370                function run(uint256 amount) public gated(amount) returns (uint256) {
371                    return hook(amount + stored);
372                }
373            }
374        "#;
375
376        let mut compiler =
377            Compiler::new(Session::builder().with_buffer_emitter(Default::default()).build());
378        compiler
379            .enter_mut(|compiler| -> solar::interface::Result<()> {
380                let mut pcx = compiler.parse();
381                pcx.set_resolve_imports(false);
382                let file = compiler
383                    .sess()
384                    .source_map()
385                    .new_source_file(FileName::Stdin, source)
386                    .expect("failed to create source file");
387                pcx.add_file(file);
388                pcx.parse();
389
390                let ControlFlow::Continue(()) = compiler.lower_asts()? else {
391                    panic!("expected HIR lowering to continue");
392                };
393
394                let gcx = compiler.gcx();
395                let source_id = gcx.hir.source_ids().next().expect("expected one lowered source");
396                let ctx = LintContext::new(
397                    gcx.sess,
398                    false,
399                    false,
400                    LinterConfig { inline: &inline, lint_specific: &lint_specific },
401                    Vec::new(),
402                    None,
403                );
404                let mut passes: Vec<Box<dyn LateLintPass<'_>>> =
405                    vec![Box::new(RecordingPass { counts: counts.clone() })];
406                let mut visitor = LateLintVisitor::new(&ctx, &mut passes, gcx, &gcx.hir);
407                let _ = hir::Visit::visit_nested_source(&mut visitor, source_id);
408                Ok(())
409            })
410            .expect("failed to lower test source");
411
412        let counts = counts.lock().unwrap();
413        assert!(counts.nested_item > 0, "expected nested item hook to run");
414        assert!(counts.nested_contract > 0, "expected nested contract hook to run");
415        assert!(counts.nested_function > 0, "expected nested function hook to run");
416        assert!(counts.nested_var > 0, "expected nested var hook to run");
417        assert!(counts.modifier > 0, "expected modifier hook to run");
418        assert!(counts.call_args > 0, "expected call args hook to run");
419    }
420}