use alloy_json_abi::Function;
use alloy_primitives::Bytes;
use alloy_sol_types::SolError;
use std::{fmt, path::Path};
pub trait TestFilter: Send + Sync {
fn matches_test(&self, test_name: &str) -> bool;
fn matches_contract(&self, contract_name: &str) -> bool;
fn matches_path(&self, path: &Path) -> bool;
}
pub trait TestFunctionExt {
fn test_function_kind(&self) -> TestFunctionKind {
TestFunctionKind::classify(self.tfe_as_str(), self.tfe_has_inputs())
}
fn is_setup(&self) -> bool {
self.test_function_kind().is_setup()
}
fn is_any_test(&self) -> bool {
self.test_function_kind().is_any_test()
}
fn is_any_test_fail(&self) -> bool {
self.test_function_kind().is_any_test_fail()
}
fn is_unit_test(&self) -> bool {
matches!(self.test_function_kind(), TestFunctionKind::UnitTest { .. })
}
fn is_before_test_setup(&self) -> bool {
self.tfe_as_str().eq_ignore_ascii_case("beforetestsetup")
}
fn is_fuzz_test(&self) -> bool {
self.test_function_kind().is_fuzz_test()
}
fn is_invariant_test(&self) -> bool {
self.test_function_kind().is_invariant_test()
}
fn is_after_invariant(&self) -> bool {
self.test_function_kind().is_after_invariant()
}
fn is_fixture(&self) -> bool {
self.test_function_kind().is_fixture()
}
#[doc(hidden)]
fn tfe_as_str(&self) -> &str;
#[doc(hidden)]
fn tfe_has_inputs(&self) -> bool;
}
impl TestFunctionExt for Function {
fn tfe_as_str(&self) -> &str {
self.name.as_str()
}
fn tfe_has_inputs(&self) -> bool {
!self.inputs.is_empty()
}
}
impl TestFunctionExt for String {
fn tfe_as_str(&self) -> &str {
self
}
fn tfe_has_inputs(&self) -> bool {
false
}
}
impl TestFunctionExt for str {
fn tfe_as_str(&self) -> &str {
self
}
fn tfe_has_inputs(&self) -> bool {
false
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum TestFunctionKind {
Setup,
UnitTest { should_fail: bool },
FuzzTest { should_fail: bool },
InvariantTest,
AfterInvariant,
Fixture,
Unknown,
}
impl TestFunctionKind {
#[inline]
pub fn classify(name: &str, has_inputs: bool) -> Self {
match () {
_ if name.starts_with("test") => {
let should_fail = name.starts_with("testFail");
if has_inputs {
Self::FuzzTest { should_fail }
} else {
Self::UnitTest { should_fail }
}
}
_ if name.starts_with("invariant") || name.starts_with("statefulFuzz") => {
Self::InvariantTest
}
_ if name.eq_ignore_ascii_case("setup") => Self::Setup,
_ if name.eq_ignore_ascii_case("afterinvariant") => Self::AfterInvariant,
_ if name.starts_with("fixture") => Self::Fixture,
_ => Self::Unknown,
}
}
pub const fn name(&self) -> &'static str {
match self {
Self::Setup => "setUp",
Self::UnitTest { should_fail: false } => "test",
Self::UnitTest { should_fail: true } => "testFail",
Self::FuzzTest { should_fail: false } => "fuzz",
Self::FuzzTest { should_fail: true } => "fuzz fail",
Self::InvariantTest => "invariant",
Self::AfterInvariant => "afterInvariant",
Self::Fixture => "fixture",
Self::Unknown => "unknown",
}
}
#[inline]
pub const fn is_setup(&self) -> bool {
matches!(self, Self::Setup)
}
#[inline]
pub const fn is_any_test(&self) -> bool {
matches!(self, Self::UnitTest { .. } | Self::FuzzTest { .. } | Self::InvariantTest)
}
#[inline]
pub const fn is_any_test_fail(&self) -> bool {
matches!(self, Self::UnitTest { should_fail: true } | Self::FuzzTest { should_fail: true })
}
#[inline]
pub fn is_unit_test(&self) -> bool {
matches!(self, Self::UnitTest { .. })
}
#[inline]
pub const fn is_fuzz_test(&self) -> bool {
matches!(self, Self::FuzzTest { .. })
}
#[inline]
pub const fn is_invariant_test(&self) -> bool {
matches!(self, Self::InvariantTest)
}
#[inline]
pub const fn is_after_invariant(&self) -> bool {
matches!(self, Self::AfterInvariant)
}
#[inline]
pub const fn is_fixture(&self) -> bool {
matches!(self, Self::Fixture)
}
#[inline]
pub const fn is_known(&self) -> bool {
!matches!(self, Self::Unknown)
}
#[inline]
pub const fn is_unknown(&self) -> bool {
matches!(self, Self::Unknown)
}
}
impl fmt::Display for TestFunctionKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.name().fmt(f)
}
}
pub trait ErrorExt: std::error::Error {
fn abi_encode_revert(&self) -> Bytes;
}
impl<T: std::error::Error> ErrorExt for T {
fn abi_encode_revert(&self) -> Bytes {
alloy_sol_types::Revert::from(self.to_string()).abi_encode().into()
}
}