foundry_cheatcodes/test/
revert_handlers.rs

1use crate::{Error, Result};
2use alloy_primitives::{address, hex, Address, Bytes};
3use alloy_sol_types::{SolError, SolValue};
4use foundry_common::ContractsByArtifact;
5use foundry_evm_core::decode::RevertDecoder;
6use revm::interpreter::{return_ok, InstructionResult};
7use spec::Vm;
8
9use super::{
10    assume::{AcceptableRevertParameters, AssumeNoRevert},
11    expect::ExpectedRevert,
12};
13
14/// For some cheatcodes we may internally change the status of the call, i.e. in `expectRevert`.
15/// Solidity will see a successful call and attempt to decode the return data. Therefore, we need
16/// to populate the return with dummy bytes so the decode doesn't fail.
17///
18/// 8192 bytes was arbitrarily chosen because it is long enough for return values up to 256 words in
19/// size.
20static DUMMY_CALL_OUTPUT: Bytes = Bytes::from_static(&[0u8; 8192]);
21
22/// Same reasoning as [DUMMY_CALL_OUTPUT], but for creates.
23const DUMMY_CREATE_ADDRESS: Address = address!("0x0000000000000000000000000000000000000001");
24
25fn stringify(data: &[u8]) -> String {
26    if let Ok(s) = String::abi_decode(data, true) {
27        return s;
28    }
29    if data.is_ascii() {
30        return std::str::from_utf8(data).unwrap().to_owned();
31    }
32    hex::encode_prefixed(data)
33}
34
35/// Common parameters for expected or assumed reverts. Allows for code reuse.
36pub(crate) trait RevertParameters {
37    fn reverter(&self) -> Option<Address>;
38    fn reason(&self) -> Option<&[u8]>;
39    fn partial_match(&self) -> bool;
40}
41
42impl RevertParameters for AcceptableRevertParameters {
43    fn reverter(&self) -> Option<Address> {
44        self.reverter
45    }
46
47    fn reason(&self) -> Option<&[u8]> {
48        Some(&self.reason)
49    }
50
51    fn partial_match(&self) -> bool {
52        self.partial_match
53    }
54}
55
56/// Core logic for handling reverts that may or may not be expected (or assumed).
57fn handle_revert(
58    is_cheatcode: bool,
59    revert_params: &impl RevertParameters,
60    status: InstructionResult,
61    retdata: &Bytes,
62    known_contracts: &Option<ContractsByArtifact>,
63    reverter: Option<&Address>,
64) -> Result<(), Error> {
65    // If expected reverter address is set then check it matches the actual reverter.
66    if let (Some(expected_reverter), Some(&actual_reverter)) = (revert_params.reverter(), reverter)
67    {
68        if expected_reverter != actual_reverter {
69            return Err(fmt_err!(
70                "Reverter != expected reverter: {} != {}",
71                actual_reverter,
72                expected_reverter
73            ));
74        }
75    }
76
77    let expected_reason = revert_params.reason();
78    // If None, accept any revert.
79    let Some(expected_reason) = expected_reason else {
80        return Ok(());
81    };
82
83    if !expected_reason.is_empty() && retdata.is_empty() {
84        bail!("call reverted as expected, but without data");
85    }
86
87    let mut actual_revert: Vec<u8> = retdata.to_vec();
88
89    // Compare only the first 4 bytes if partial match.
90    if revert_params.partial_match() && actual_revert.get(..4) == expected_reason.get(..4) {
91        return Ok(());
92    }
93
94    // Try decoding as known errors.
95    actual_revert = decode_revert(actual_revert);
96
97    if actual_revert == expected_reason ||
98        (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some())
99    {
100        Ok(())
101    } else {
102        let (actual, expected) = if let Some(contracts) = known_contracts {
103            let decoder = RevertDecoder::new().with_abis(contracts.iter().map(|(_, c)| &c.abi));
104            (
105                &decoder.decode(actual_revert.as_slice(), Some(status)),
106                &decoder.decode(expected_reason, Some(status)),
107            )
108        } else {
109            (&stringify(&actual_revert), &stringify(expected_reason))
110        };
111
112        if expected == actual {
113            return Ok(());
114        }
115
116        Err(fmt_err!("Error != expected error: {} != {}", actual, expected))
117    }
118}
119
120pub(crate) fn handle_assume_no_revert(
121    assume_no_revert: &AssumeNoRevert,
122    status: InstructionResult,
123    retdata: &Bytes,
124    known_contracts: &Option<ContractsByArtifact>,
125) -> Result<()> {
126    // if a generic AssumeNoRevert, return Ok(). Otherwise, iterate over acceptable reasons and try
127    // to match against any, otherwise, return an Error with the revert data
128    if assume_no_revert.reasons.is_empty() {
129        Ok(())
130    } else {
131        assume_no_revert
132            .reasons
133            .iter()
134            .find_map(|reason| {
135                handle_revert(
136                    false,
137                    reason,
138                    status,
139                    retdata,
140                    known_contracts,
141                    assume_no_revert.reverted_by.as_ref(),
142                )
143                .ok()
144            })
145            .ok_or_else(|| retdata.clone().into())
146    }
147}
148
149pub(crate) fn handle_expect_revert(
150    is_cheatcode: bool,
151    is_create: bool,
152    internal_expect_revert: bool,
153    expected_revert: &ExpectedRevert,
154    status: InstructionResult,
155    retdata: Bytes,
156    known_contracts: &Option<ContractsByArtifact>,
157) -> Result<(Option<Address>, Bytes)> {
158    let success_return = || {
159        if is_create {
160            (Some(DUMMY_CREATE_ADDRESS), Bytes::new())
161        } else {
162            (None, DUMMY_CALL_OUTPUT.clone())
163        }
164    };
165
166    // Check depths if it's not an expect cheatcode call and if internal expect reverts not enabled.
167    if !is_cheatcode && !internal_expect_revert {
168        ensure!(
169            expected_revert.max_depth > expected_revert.depth,
170            "call didn't revert at a lower depth than cheatcode call depth"
171        );
172    }
173
174    if expected_revert.count == 0 {
175        if expected_revert.reverter.is_none() && expected_revert.reason.is_none() {
176            ensure!(
177                matches!(status, return_ok!()),
178                "call reverted when it was expected not to revert"
179            );
180            return Ok(success_return());
181        }
182
183        // Flags to track if the reason and reverter match.
184        let mut reason_match = expected_revert.reason.as_ref().map(|_| false);
185        let mut reverter_match = expected_revert.reverter.as_ref().map(|_| false);
186
187        // Reverter check
188        if let (Some(expected_reverter), Some(actual_reverter)) =
189            (expected_revert.reverter, expected_revert.reverted_by)
190        {
191            if expected_reverter == actual_reverter {
192                reverter_match = Some(true);
193            }
194        }
195
196        // Reason check
197        let expected_reason = expected_revert.reason.as_deref();
198        if let Some(expected_reason) = expected_reason {
199            let mut actual_revert: Vec<u8> = retdata.into();
200            actual_revert = decode_revert(actual_revert);
201
202            if actual_revert == expected_reason {
203                reason_match = Some(true);
204            }
205        };
206
207        match (reason_match, reverter_match) {
208            (Some(true), Some(true)) => Err(fmt_err!(
209                "expected 0 reverts with reason: {}, from address: {}, but got one",
210                &stringify(expected_reason.unwrap_or_default()),
211                expected_revert.reverter.unwrap()
212            )),
213            (Some(true), None) => Err(fmt_err!(
214                "expected 0 reverts with reason: {}, but got one",
215                &stringify(expected_reason.unwrap_or_default())
216            )),
217            (None, Some(true)) => Err(fmt_err!(
218                "expected 0 reverts from address: {}, but got one",
219                expected_revert.reverter.unwrap()
220            )),
221            _ => Ok(success_return()),
222        }
223    } else {
224        ensure!(!matches!(status, return_ok!()), "next call did not revert as expected");
225
226        handle_revert(
227            is_cheatcode,
228            expected_revert,
229            status,
230            &retdata,
231            known_contracts,
232            expected_revert.reverted_by.as_ref(),
233        )?;
234        Ok(success_return())
235    }
236}
237
238fn decode_revert(revert: Vec<u8>) -> Vec<u8> {
239    if matches!(
240        revert.get(..4).map(|s| s.try_into().unwrap()),
241        Some(Vm::CheatcodeError::SELECTOR | alloy_sol_types::Revert::SELECTOR)
242    ) {
243        if let Ok(decoded) = Vec::<u8>::abi_decode(&revert[4..], false) {
244            return decoded;
245        }
246    }
247    revert
248}