1use alloy_json_abi::Function;
4use alloy_primitives::Bytes;
5use alloy_sol_types::SolError;
6use std::{fmt, path::Path};
7
8pub trait TestFilter: Send + Sync {
10 fn matches_test(&self, test_signature: &str) -> bool;
12
13 fn matches_contract(&self, contract_name: &str) -> bool;
15
16 fn matches_path(&self, path: &Path) -> bool;
18
19 fn matches_test_function_in_contract(&self, _contract_id: &str, func: &Function) -> bool {
23 func.is_any_test() && self.matches_test(&func.signature())
24 }
25}
26
27impl<'a> dyn TestFilter + 'a {
28 pub fn matches_test_function(&self, func: &Function) -> bool {
30 func.is_any_test() && self.matches_test(&func.signature())
31 }
32}
33
34#[derive(Clone, Debug, Default)]
36pub struct EmptyTestFilter(());
37impl TestFilter for EmptyTestFilter {
38 fn matches_test(&self, _test_signature: &str) -> bool {
39 true
40 }
41
42 fn matches_contract(&self, _contract_name: &str) -> bool {
43 true
44 }
45
46 fn matches_path(&self, _path: &Path) -> bool {
47 true
48 }
49}
50
51pub trait TestFunctionExt {
53 fn test_function_kind(&self) -> TestFunctionKind {
55 TestFunctionKind::classify(self.tfe_as_str(), self.tfe_has_inputs())
56 }
57
58 fn is_setup(&self) -> bool {
60 self.test_function_kind().is_setup()
61 }
62
63 fn is_any_test(&self) -> bool {
65 self.test_function_kind().is_any_test()
66 }
67
68 fn is_any_test_fail(&self) -> bool {
70 self.test_function_kind().is_any_test_fail()
71 }
72
73 fn is_unit_test(&self) -> bool {
75 matches!(self.test_function_kind(), TestFunctionKind::UnitTest { .. })
76 }
77
78 fn is_before_test_setup(&self) -> bool {
80 self.tfe_as_str().eq_ignore_ascii_case("beforetestsetup")
81 }
82
83 fn is_fuzz_test(&self) -> bool {
85 self.test_function_kind().is_fuzz_test()
86 }
87
88 fn is_invariant_test(&self) -> bool {
90 self.test_function_kind().is_invariant_test()
91 }
92
93 fn is_symbolic_test(&self) -> bool {
95 self.test_function_kind().is_symbolic_test()
96 }
97
98 fn is_after_invariant(&self) -> bool {
100 self.test_function_kind().is_after_invariant()
101 }
102
103 fn is_fixture(&self) -> bool {
105 self.test_function_kind().is_fixture()
106 }
107
108 fn is_reserved(&self) -> bool {
110 self.is_any_test()
111 || self.is_setup()
112 || self.is_before_test_setup()
113 || self.is_after_invariant()
114 || self.is_fixture()
115 }
116
117 #[doc(hidden)]
118 fn tfe_as_str(&self) -> &str;
119 #[doc(hidden)]
120 fn tfe_has_inputs(&self) -> bool;
121}
122
123impl TestFunctionExt for Function {
124 fn tfe_as_str(&self) -> &str {
125 self.name.as_str()
126 }
127
128 fn tfe_has_inputs(&self) -> bool {
129 !self.inputs.is_empty()
130 }
131}
132
133impl TestFunctionExt for String {
134 fn tfe_as_str(&self) -> &str {
135 self
136 }
137
138 fn tfe_has_inputs(&self) -> bool {
139 false
140 }
141}
142
143impl TestFunctionExt for str {
144 fn tfe_as_str(&self) -> &str {
145 self
146 }
147
148 fn tfe_has_inputs(&self) -> bool {
149 false
150 }
151}
152
153#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
155pub enum TestFunctionKind {
156 Setup,
158 UnitTest { should_fail: bool },
160 FuzzTest { should_fail: bool },
162 InvariantTest,
164 TableTest,
166 SymbolicTest,
168 AfterInvariant,
170 Fixture,
172 Unknown,
174}
175
176impl TestFunctionKind {
177 pub fn classify(name: &str, has_inputs: bool) -> Self {
179 match () {
180 _ if name.starts_with("test") => {
181 let should_fail = name.starts_with("testFail");
182 if has_inputs {
183 Self::FuzzTest { should_fail }
184 } else {
185 Self::UnitTest { should_fail }
186 }
187 }
188 _ if name.starts_with("invariant") || name.starts_with("statefulFuzz") => {
189 Self::InvariantTest
190 }
191 _ if name.starts_with("table") => Self::TableTest,
192 _ if name.eq_ignore_ascii_case("setup") && !has_inputs => Self::Setup,
193 _ if name.eq_ignore_ascii_case("afterinvariant") => Self::AfterInvariant,
194 _ if name.starts_with("fixture") => Self::Fixture,
195 _ => Self::Unknown,
196 }
197 }
198
199 pub const fn name(&self) -> &'static str {
201 match self {
202 Self::Setup => "setUp",
203 Self::UnitTest { should_fail: false } => "test",
204 Self::UnitTest { should_fail: true } => "testFail",
205 Self::FuzzTest { should_fail: false } => "fuzz",
206 Self::FuzzTest { should_fail: true } => "fuzz fail",
207 Self::InvariantTest => "invariant",
208 Self::TableTest => "table",
209 Self::SymbolicTest => "symbolic",
210 Self::AfterInvariant => "afterInvariant",
211 Self::Fixture => "fixture",
212 Self::Unknown => "unknown",
213 }
214 }
215
216 #[inline]
218 pub const fn is_setup(&self) -> bool {
219 matches!(self, Self::Setup)
220 }
221
222 #[inline]
224 pub const fn is_any_test(&self) -> bool {
225 matches!(
226 self,
227 Self::UnitTest { .. }
228 | Self::FuzzTest { .. }
229 | Self::TableTest
230 | Self::InvariantTest
231 | Self::SymbolicTest
232 )
233 }
234
235 #[inline]
237 pub const fn is_any_test_fail(&self) -> bool {
238 matches!(self, Self::UnitTest { should_fail: true } | Self::FuzzTest { should_fail: true })
239 }
240
241 #[inline]
243 pub const fn is_unit_test(&self) -> bool {
244 matches!(self, Self::UnitTest { .. })
245 }
246
247 #[inline]
249 pub const fn is_fuzz_test(&self) -> bool {
250 matches!(self, Self::FuzzTest { .. })
251 }
252
253 #[inline]
255 pub const fn is_invariant_test(&self) -> bool {
256 matches!(self, Self::InvariantTest)
257 }
258
259 #[inline]
261 pub const fn is_table_test(&self) -> bool {
262 matches!(self, Self::TableTest)
263 }
264
265 #[inline]
267 pub const fn is_symbolic_test(&self) -> bool {
268 matches!(self, Self::SymbolicTest)
269 }
270
271 #[inline]
273 pub const fn is_after_invariant(&self) -> bool {
274 matches!(self, Self::AfterInvariant)
275 }
276
277 #[inline]
279 pub const fn is_fixture(&self) -> bool {
280 matches!(self, Self::Fixture)
281 }
282
283 #[inline]
285 pub const fn is_known(&self) -> bool {
286 !matches!(self, Self::Unknown)
287 }
288
289 #[inline]
291 pub const fn is_unknown(&self) -> bool {
292 matches!(self, Self::Unknown)
293 }
294}
295
296impl fmt::Display for TestFunctionKind {
297 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
298 self.name().fmt(f)
299 }
300}
301
302pub trait ErrorExt: std::error::Error {
304 fn abi_encode_revert(&self) -> Bytes;
306}
307
308impl<T: std::error::Error> ErrorExt for T {
309 fn abi_encode_revert(&self) -> Bytes {
310 alloy_sol_types::Revert::from(self.to_string()).abi_encode().into()
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn test_setup_classification() {
320 assert_eq!(TestFunctionKind::classify("setUp", false), TestFunctionKind::Setup);
322
323 assert_eq!(TestFunctionKind::classify("setUp", true), TestFunctionKind::Unknown);
326 }
327}