Skip to main content

foundry_common/
traits.rs

1//! Commonly used traits.
2
3use alloy_json_abi::Function;
4use alloy_primitives::Bytes;
5use alloy_sol_types::SolError;
6use std::{fmt, path::Path};
7
8/// Test filter.
9pub trait TestFilter: Send + Sync {
10    /// Returns whether the test should be included.
11    fn matches_test(&self, test_signature: &str) -> bool;
12
13    /// Returns whether the contract should be included.
14    fn matches_contract(&self, contract_name: &str) -> bool;
15
16    /// Returns a contract with the given path should be included.
17    fn matches_path(&self, path: &Path) -> bool;
18
19    /// Returns whether the test should be included for the given contract.
20    ///
21    /// `contract_id` is the full artifact identifier (`path:Contract`).
22    fn matches_test_function_in_contract(&self, _contract_id: &str, func: &Function) -> bool {
23        func.is_any_test() && self.matches_test(&func.signature())
24    }
25}
26
27impl<'a> dyn TestFilter + 'a {
28    /// Returns `true` if the function is a test function that matches the given filter.
29    pub fn matches_test_function(&self, func: &Function) -> bool {
30        func.is_any_test() && self.matches_test(&func.signature())
31    }
32}
33
34/// A test filter that filters out nothing.
35#[derive(Clone, Debug, Default)]
36pub struct EmptyTestFilter(());
37impl TestFilter for EmptyTestFilter {
38    fn matches_test(&self, _test_signature: &str) -> bool {
39        true
40    }
41
42    fn matches_contract(&self, _contract_name: &str) -> bool {
43        true
44    }
45
46    fn matches_path(&self, _path: &Path) -> bool {
47        true
48    }
49}
50
51/// Extension trait for `Function`.
52pub trait TestFunctionExt {
53    /// Returns the kind of test function.
54    fn test_function_kind(&self) -> TestFunctionKind {
55        TestFunctionKind::classify(self.tfe_as_str(), self.tfe_has_inputs())
56    }
57
58    /// Returns `true` if this function is a `setUp` function.
59    fn is_setup(&self) -> bool {
60        self.test_function_kind().is_setup()
61    }
62
63    /// Returns `true` if this function is a unit, fuzz, or invariant test.
64    fn is_any_test(&self) -> bool {
65        self.test_function_kind().is_any_test()
66    }
67
68    /// Returns `true` if this function is a test that should fail.
69    fn is_any_test_fail(&self) -> bool {
70        self.test_function_kind().is_any_test_fail()
71    }
72
73    /// Returns `true` if this function is a unit test.
74    fn is_unit_test(&self) -> bool {
75        matches!(self.test_function_kind(), TestFunctionKind::UnitTest { .. })
76    }
77
78    /// Returns `true` if this function is a `beforeTestSetup` function.
79    fn is_before_test_setup(&self) -> bool {
80        self.tfe_as_str().eq_ignore_ascii_case("beforetestsetup")
81    }
82
83    /// Returns `true` if this function is a fuzz test.
84    fn is_fuzz_test(&self) -> bool {
85        self.test_function_kind().is_fuzz_test()
86    }
87
88    /// Returns `true` if this function is an invariant test.
89    fn is_invariant_test(&self) -> bool {
90        self.test_function_kind().is_invariant_test()
91    }
92
93    /// Returns `true` if this function is a symbolic test.
94    fn is_symbolic_test(&self) -> bool {
95        self.test_function_kind().is_symbolic_test()
96    }
97
98    /// Returns `true` if this function is an `afterInvariant` function.
99    fn is_after_invariant(&self) -> bool {
100        self.test_function_kind().is_after_invariant()
101    }
102
103    /// Returns `true` if this function is a `fixture` function.
104    fn is_fixture(&self) -> bool {
105        self.test_function_kind().is_fixture()
106    }
107
108    /// Returns `true` if this function is test reserved function.
109    fn is_reserved(&self) -> bool {
110        self.is_any_test()
111            || self.is_setup()
112            || self.is_before_test_setup()
113            || self.is_after_invariant()
114            || self.is_fixture()
115    }
116
117    #[doc(hidden)]
118    fn tfe_as_str(&self) -> &str;
119    #[doc(hidden)]
120    fn tfe_has_inputs(&self) -> bool;
121}
122
123impl TestFunctionExt for Function {
124    fn tfe_as_str(&self) -> &str {
125        self.name.as_str()
126    }
127
128    fn tfe_has_inputs(&self) -> bool {
129        !self.inputs.is_empty()
130    }
131}
132
133impl TestFunctionExt for String {
134    fn tfe_as_str(&self) -> &str {
135        self
136    }
137
138    fn tfe_has_inputs(&self) -> bool {
139        false
140    }
141}
142
143impl TestFunctionExt for str {
144    fn tfe_as_str(&self) -> &str {
145        self
146    }
147
148    fn tfe_has_inputs(&self) -> bool {
149        false
150    }
151}
152
153/// Test function kind.
154#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
155pub enum TestFunctionKind {
156    /// `setUp`.
157    Setup,
158    /// `test*`. `should_fail` is `true` for `testFail*`.
159    UnitTest { should_fail: bool },
160    /// `test*`, with arguments. `should_fail` is `true` for `testFail*`.
161    FuzzTest { should_fail: bool },
162    /// `invariant*` or `statefulFuzz*`.
163    InvariantTest,
164    /// `table*`, with arguments.
165    TableTest,
166    /// `check*` or `prove*`, when selected by symbolic test mode.
167    SymbolicTest,
168    /// `afterInvariant`.
169    AfterInvariant,
170    /// `fixture*`.
171    Fixture,
172    /// Unknown kind.
173    Unknown,
174}
175
176impl TestFunctionKind {
177    /// Classify a function.
178    pub fn classify(name: &str, has_inputs: bool) -> Self {
179        match () {
180            _ if name.starts_with("test") => {
181                let should_fail = name.starts_with("testFail");
182                if has_inputs {
183                    Self::FuzzTest { should_fail }
184                } else {
185                    Self::UnitTest { should_fail }
186                }
187            }
188            _ if name.starts_with("invariant") || name.starts_with("statefulFuzz") => {
189                Self::InvariantTest
190            }
191            _ if name.starts_with("table") => Self::TableTest,
192            _ if name.eq_ignore_ascii_case("setup") && !has_inputs => Self::Setup,
193            _ if name.eq_ignore_ascii_case("afterinvariant") => Self::AfterInvariant,
194            _ if name.starts_with("fixture") => Self::Fixture,
195            _ => Self::Unknown,
196        }
197    }
198
199    /// Returns the name of the function kind.
200    pub const fn name(&self) -> &'static str {
201        match self {
202            Self::Setup => "setUp",
203            Self::UnitTest { should_fail: false } => "test",
204            Self::UnitTest { should_fail: true } => "testFail",
205            Self::FuzzTest { should_fail: false } => "fuzz",
206            Self::FuzzTest { should_fail: true } => "fuzz fail",
207            Self::InvariantTest => "invariant",
208            Self::TableTest => "table",
209            Self::SymbolicTest => "symbolic",
210            Self::AfterInvariant => "afterInvariant",
211            Self::Fixture => "fixture",
212            Self::Unknown => "unknown",
213        }
214    }
215
216    /// Returns `true` if this function is a `setUp` function.
217    #[inline]
218    pub const fn is_setup(&self) -> bool {
219        matches!(self, Self::Setup)
220    }
221
222    /// Returns `true` if this function is a unit, fuzz, or invariant test.
223    #[inline]
224    pub const fn is_any_test(&self) -> bool {
225        matches!(
226            self,
227            Self::UnitTest { .. }
228                | Self::FuzzTest { .. }
229                | Self::TableTest
230                | Self::InvariantTest
231                | Self::SymbolicTest
232        )
233    }
234
235    /// Returns `true` if this function is a test that should fail.
236    #[inline]
237    pub const fn is_any_test_fail(&self) -> bool {
238        matches!(self, Self::UnitTest { should_fail: true } | Self::FuzzTest { should_fail: true })
239    }
240
241    /// Returns `true` if this function is a unit test.
242    #[inline]
243    pub const fn is_unit_test(&self) -> bool {
244        matches!(self, Self::UnitTest { .. })
245    }
246
247    /// Returns `true` if this function is a fuzz test.
248    #[inline]
249    pub const fn is_fuzz_test(&self) -> bool {
250        matches!(self, Self::FuzzTest { .. })
251    }
252
253    /// Returns `true` if this function is an invariant test.
254    #[inline]
255    pub const fn is_invariant_test(&self) -> bool {
256        matches!(self, Self::InvariantTest)
257    }
258
259    /// Returns `true` if this function is a table test.
260    #[inline]
261    pub const fn is_table_test(&self) -> bool {
262        matches!(self, Self::TableTest)
263    }
264
265    /// Returns `true` if this function is a symbolic test.
266    #[inline]
267    pub const fn is_symbolic_test(&self) -> bool {
268        matches!(self, Self::SymbolicTest)
269    }
270
271    /// Returns `true` if this function is an `afterInvariant` function.
272    #[inline]
273    pub const fn is_after_invariant(&self) -> bool {
274        matches!(self, Self::AfterInvariant)
275    }
276
277    /// Returns `true` if this function is a `fixture` function.
278    #[inline]
279    pub const fn is_fixture(&self) -> bool {
280        matches!(self, Self::Fixture)
281    }
282
283    /// Returns `true` if this function kind is known.
284    #[inline]
285    pub const fn is_known(&self) -> bool {
286        !matches!(self, Self::Unknown)
287    }
288
289    /// Returns `true` if this function kind is unknown.
290    #[inline]
291    pub const fn is_unknown(&self) -> bool {
292        matches!(self, Self::Unknown)
293    }
294}
295
296impl fmt::Display for TestFunctionKind {
297    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298        self.name().fmt(f)
299    }
300}
301
302/// An extension trait for `std::error::Error` for ABI encoding.
303pub trait ErrorExt: std::error::Error {
304    /// ABI-encodes the error using `Revert(string)`.
305    fn abi_encode_revert(&self) -> Bytes;
306}
307
308impl<T: std::error::Error> ErrorExt for T {
309    fn abi_encode_revert(&self) -> Bytes {
310        alloy_sol_types::Revert::from(self.to_string()).abi_encode().into()
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn test_setup_classification() {
320        // setUp() with no params should be classified as Setup
321        assert_eq!(TestFunctionKind::classify("setUp", false), TestFunctionKind::Setup);
322
323        // setUp(bytes memory) with params should NOT be classified as Setup
324        // This is common in Gnosis Safe/Zodiac modules
325        assert_eq!(TestFunctionKind::classify("setUp", true), TestFunctionKind::Unknown);
326    }
327}