Skip to main content

forge_lint/linter/
late.rs

1use solar::{
2    interface::data_structures::Never,
3    sema::{Gcx, hir},
4};
5use std::ops::ControlFlow;
6
7use super::LintContext;
8
9/// Trait for lints that operate on the HIR (High-level Intermediate Representation).
10/// Its methods mirror `hir::visit::Visit`, with the addition of `LintContext`.
11pub trait LateLintPass<'hir>: Send + Sync {
12    fn check_nested_source(
13        &mut self,
14        _ctx: &LintContext,
15        _hir: &'hir hir::Hir<'hir>,
16        _id: hir::SourceId,
17    ) {
18    }
19    fn check_nested_item(
20        &mut self,
21        _ctx: &LintContext,
22        _hir: &'hir hir::Hir<'hir>,
23        _id: hir::ItemId,
24    ) {
25    }
26    fn check_nested_contract(
27        &mut self,
28        _ctx: &LintContext,
29        _gcx: Gcx<'hir>,
30        _hir: &'hir hir::Hir<'hir>,
31        _id: hir::ContractId,
32    ) {
33    }
34    fn check_nested_function(
35        &mut self,
36        _ctx: &LintContext,
37        _hir: &'hir hir::Hir<'hir>,
38        _id: hir::FunctionId,
39    ) {
40    }
41    fn check_nested_var(
42        &mut self,
43        _ctx: &LintContext,
44        _hir: &'hir hir::Hir<'hir>,
45        _id: hir::VariableId,
46    ) {
47    }
48    fn check_item(
49        &mut self,
50        _ctx: &LintContext,
51        _hir: &'hir hir::Hir<'hir>,
52        _item: hir::Item<'hir, 'hir>,
53    ) {
54    }
55    fn check_contract(
56        &mut self,
57        _ctx: &LintContext,
58        _gcx: Gcx<'hir>,
59        _hir: &'hir hir::Hir<'hir>,
60        _contract: &'hir hir::Contract<'hir>,
61    ) {
62    }
63    fn check_function(
64        &mut self,
65        _ctx: &LintContext,
66        _gcx: Gcx<'hir>,
67        _hir: &'hir hir::Hir<'hir>,
68        _func: &'hir hir::Function<'hir>,
69    ) {
70    }
71    fn check_modifier(
72        &mut self,
73        _ctx: &LintContext,
74        _hir: &'hir hir::Hir<'hir>,
75        _mod: &'hir hir::Modifier<'hir>,
76    ) {
77    }
78    fn check_var(
79        &mut self,
80        _ctx: &LintContext,
81        _hir: &'hir hir::Hir<'hir>,
82        _var: &'hir hir::Variable<'hir>,
83    ) {
84    }
85    fn check_expr(
86        &mut self,
87        _ctx: &LintContext,
88        _gcx: Gcx<'hir>,
89        _hir: &'hir hir::Hir<'hir>,
90        _expr: &'hir hir::Expr<'hir>,
91    ) {
92    }
93    fn check_call_args(
94        &mut self,
95        _ctx: &LintContext,
96        _hir: &'hir hir::Hir<'hir>,
97        _args: &'hir hir::CallArgs<'hir>,
98    ) {
99    }
100    fn check_stmt(
101        &mut self,
102        _ctx: &LintContext,
103        _gcx: Gcx<'hir>,
104        _hir: &'hir hir::Hir<'hir>,
105        _stmt: &'hir hir::Stmt<'hir>,
106    ) {
107    }
108    fn check_ty(
109        &mut self,
110        _ctx: &LintContext,
111        _hir: &'hir hir::Hir<'hir>,
112        _ty: &'hir hir::Type<'hir>,
113    ) {
114    }
115}
116
117/// Visitor struct for `LateLintPass`es
118pub struct LateLintVisitor<'a, 's, 'hir> {
119    ctx: &'a LintContext<'s, 'a>,
120    passes: &'a mut [Box<dyn LateLintPass<'hir> + 's>],
121    gcx: Gcx<'hir>,
122    hir: &'hir hir::Hir<'hir>,
123}
124
125impl<'a, 's, 'hir> LateLintVisitor<'a, 's, 'hir>
126where
127    's: 'hir,
128{
129    pub fn new(
130        ctx: &'a LintContext<'s, 'a>,
131        passes: &'a mut [Box<dyn LateLintPass<'hir> + 's>],
132        gcx: Gcx<'hir>,
133        hir: &'hir hir::Hir<'hir>,
134    ) -> Self {
135        Self { ctx, passes, gcx, hir }
136    }
137}
138
139impl<'s, 'hir> hir::Visit<'hir> for LateLintVisitor<'_, 's, 'hir>
140where
141    's: 'hir,
142{
143    type BreakValue = Never;
144
145    fn hir(&self) -> &'hir hir::Hir<'hir> {
146        self.hir
147    }
148
149    fn visit_nested_source(&mut self, id: hir::SourceId) -> ControlFlow<Self::BreakValue> {
150        for pass in self.passes.iter_mut() {
151            pass.check_nested_source(self.ctx, self.hir, id);
152        }
153        self.walk_nested_source(id)
154    }
155
156    fn visit_nested_item(&mut self, id: hir::ItemId) -> ControlFlow<Self::BreakValue> {
157        for pass in self.passes.iter_mut() {
158            pass.check_nested_item(self.ctx, self.hir, id);
159        }
160        self.walk_nested_item(id)
161    }
162
163    fn visit_nested_contract(&mut self, id: hir::ContractId) -> ControlFlow<Self::BreakValue> {
164        for pass in self.passes.iter_mut() {
165            pass.check_nested_contract(self.ctx, self.gcx, self.hir, id);
166        }
167        self.walk_nested_contract(id)
168    }
169
170    fn visit_nested_function(&mut self, id: hir::FunctionId) -> ControlFlow<Self::BreakValue> {
171        for pass in self.passes.iter_mut() {
172            pass.check_nested_function(self.ctx, self.hir, id);
173        }
174        self.walk_nested_function(id)
175    }
176
177    fn visit_nested_var(&mut self, id: hir::VariableId) -> ControlFlow<Self::BreakValue> {
178        for pass in self.passes.iter_mut() {
179            pass.check_nested_var(self.ctx, self.hir, id);
180        }
181        self.walk_nested_var(id)
182    }
183
184    fn visit_contract(
185        &mut self,
186        contract: &'hir hir::Contract<'hir>,
187    ) -> ControlFlow<Self::BreakValue> {
188        for pass in self.passes.iter_mut() {
189            pass.check_contract(self.ctx, self.gcx, self.hir, contract);
190        }
191        self.walk_contract(contract)
192    }
193
194    fn visit_function(&mut self, func: &'hir hir::Function<'hir>) -> ControlFlow<Self::BreakValue> {
195        for pass in self.passes.iter_mut() {
196            pass.check_function(self.ctx, self.gcx, self.hir, func);
197        }
198        self.walk_function(func)
199    }
200
201    fn visit_modifier(
202        &mut self,
203        modifier: &'hir hir::Modifier<'hir>,
204    ) -> ControlFlow<Self::BreakValue> {
205        for pass in self.passes.iter_mut() {
206            pass.check_modifier(self.ctx, self.hir, modifier);
207        }
208        self.walk_modifier(modifier)
209    }
210
211    fn visit_item(&mut self, item: hir::Item<'hir, 'hir>) -> ControlFlow<Self::BreakValue> {
212        for pass in self.passes.iter_mut() {
213            pass.check_item(self.ctx, self.hir, item);
214        }
215        self.walk_item(item)
216    }
217
218    fn visit_var(&mut self, var: &'hir hir::Variable<'hir>) -> ControlFlow<Self::BreakValue> {
219        for pass in self.passes.iter_mut() {
220            pass.check_var(self.ctx, self.hir, var);
221        }
222        self.walk_var(var)
223    }
224
225    fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
226        for pass in self.passes.iter_mut() {
227            pass.check_expr(self.ctx, self.gcx, self.hir, expr);
228        }
229        self.walk_expr(expr)
230    }
231
232    fn visit_call_args(
233        &mut self,
234        args: &'hir hir::CallArgs<'hir>,
235    ) -> ControlFlow<Self::BreakValue> {
236        for pass in self.passes.iter_mut() {
237            pass.check_call_args(self.ctx, self.hir, args);
238        }
239        self.walk_call_args(args)
240    }
241
242    fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
243        for pass in self.passes.iter_mut() {
244            pass.check_stmt(self.ctx, self.gcx, self.hir, stmt);
245        }
246        self.walk_stmt(stmt)
247    }
248
249    fn visit_ty(&mut self, ty: &'hir hir::Type<'hir>) -> ControlFlow<Self::BreakValue> {
250        for pass in self.passes.iter_mut() {
251            pass.check_ty(self.ctx, self.hir, ty);
252        }
253        self.walk_ty(ty)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::linter::LinterConfig;
261    use foundry_common::comments::inline_config::InlineConfig;
262    use foundry_config::lint::LintSpecificConfig;
263    use solar::{
264        interface::{Session, source_map::FileName},
265        sema::Compiler,
266    };
267    use std::sync::{Arc, Mutex};
268
269    #[derive(Debug, Default)]
270    struct HookCounts {
271        nested_item: usize,
272        nested_contract: usize,
273        nested_function: usize,
274        nested_var: usize,
275        modifier: usize,
276        call_args: usize,
277    }
278
279    struct RecordingPass {
280        counts: Arc<Mutex<HookCounts>>,
281    }
282
283    impl RecordingPass {
284        fn record(&self, update: impl FnOnce(&mut HookCounts)) {
285            update(&mut self.counts.lock().unwrap());
286        }
287    }
288
289    impl<'hir> LateLintPass<'hir> for RecordingPass {
290        fn check_nested_item(
291            &mut self,
292            _ctx: &LintContext,
293            _hir: &'hir hir::Hir<'hir>,
294            _id: hir::ItemId,
295        ) {
296            self.record(|counts| counts.nested_item += 1);
297        }
298
299        fn check_nested_contract(
300            &mut self,
301            _ctx: &LintContext,
302            _gcx: solar::sema::Gcx<'hir>,
303            _hir: &'hir hir::Hir<'hir>,
304            _id: hir::ContractId,
305        ) {
306            self.record(|counts| counts.nested_contract += 1);
307        }
308
309        fn check_nested_function(
310            &mut self,
311            _ctx: &LintContext,
312            _hir: &'hir hir::Hir<'hir>,
313            _id: hir::FunctionId,
314        ) {
315            self.record(|counts| counts.nested_function += 1);
316        }
317
318        fn check_nested_var(
319            &mut self,
320            _ctx: &LintContext,
321            _hir: &'hir hir::Hir<'hir>,
322            _id: hir::VariableId,
323        ) {
324            self.record(|counts| counts.nested_var += 1);
325        }
326
327        fn check_modifier(
328            &mut self,
329            _ctx: &LintContext,
330            _hir: &'hir hir::Hir<'hir>,
331            _modifier: &'hir hir::Modifier<'hir>,
332        ) {
333            self.record(|counts| counts.modifier += 1);
334        }
335
336        fn check_call_args(
337            &mut self,
338            _ctx: &LintContext,
339            _hir: &'hir hir::Hir<'hir>,
340            _args: &'hir hir::CallArgs<'hir>,
341        ) {
342            self.record(|counts| counts.call_args += 1);
343        }
344    }
345
346    #[test]
347    fn calls_hooks_for_nested_items_modifiers_and_call_args() {
348        let counts = Arc::new(Mutex::new(HookCounts::default()));
349        let inline = InlineConfig::default();
350        let lint_specific = LintSpecificConfig::default();
351        let source = r#"
352            pragma solidity ^0.8.20;
353
354            contract Base {
355                function hook(uint256 value) internal pure returns (uint256) {
356                    return value;
357                }
358            }
359
360            contract Test is Base {
361                uint256 stored;
362
363                modifier gated(uint256 amount) {
364                    _;
365                }
366
367                function run(uint256 amount) public gated(amount) returns (uint256) {
368                    return hook(amount + stored);
369                }
370            }
371        "#;
372
373        let mut compiler =
374            Compiler::new(Session::builder().with_buffer_emitter(Default::default()).build());
375        compiler
376            .enter_mut(|compiler| -> solar::interface::Result<()> {
377                let mut pcx = compiler.parse();
378                pcx.set_resolve_imports(false);
379                let file = compiler
380                    .sess()
381                    .source_map()
382                    .new_source_file(FileName::Stdin, source)
383                    .expect("failed to create source file");
384                pcx.add_file(file);
385                pcx.parse();
386
387                let ControlFlow::Continue(()) = compiler.lower_asts()? else {
388                    panic!("expected HIR lowering to continue");
389                };
390
391                let gcx = compiler.gcx();
392                let source_id = gcx.hir.source_ids().next().expect("expected one lowered source");
393                let ctx = LintContext::new(
394                    gcx.sess,
395                    false,
396                    false,
397                    LinterConfig { inline: &inline, lint_specific: &lint_specific },
398                    Vec::new(),
399                    None,
400                );
401                let mut passes: Vec<Box<dyn LateLintPass<'_>>> =
402                    vec![Box::new(RecordingPass { counts: counts.clone() })];
403                let mut visitor = LateLintVisitor::new(&ctx, &mut passes, gcx, &gcx.hir);
404                let _ = hir::Visit::visit_nested_source(&mut visitor, source_id);
405                Ok(())
406            })
407            .expect("failed to lower test source");
408
409        let counts = counts.lock().unwrap();
410        assert!(counts.nested_item > 0, "expected nested item hook to run");
411        assert!(counts.nested_contract > 0, "expected nested contract hook to run");
412        assert!(counts.nested_function > 0, "expected nested function hook to run");
413        assert!(counts.nested_var > 0, "expected nested var hook to run");
414        assert!(counts.modifier > 0, "expected modifier hook to run");
415        assert!(counts.call_args > 0, "expected call args hook to run");
416    }
417}