Skip to main content

foundry_cheatcodes/test/
revert_handlers.rs

1use crate::{Error, Result};
2use alloy_dyn_abi::{DynSolValue, ErrorExt};
3use alloy_primitives::{Address, Bytes, address, hex};
4use alloy_sol_types::{SolError, SolValue};
5use foundry_common::{ContractsByArtifact, abi::get_error};
6use foundry_evm_core::decode::RevertDecoder;
7use revm::interpreter::{InstructionResult, return_ok};
8use spec::Vm;
9
10use super::{
11    assume::{AcceptableRevertParameters, AssumeNoRevert},
12    expect::ExpectedRevert,
13};
14
15/// For some cheatcodes we may internally change the status of the call, i.e. in `expectRevert`.
16/// Solidity will see a successful call and attempt to decode the return data. Therefore, we need
17/// to populate the return with dummy bytes so the decode doesn't fail.
18///
19/// 8192 bytes was arbitrarily chosen because it is long enough for return values up to 256 words in
20/// size.
21static DUMMY_CALL_OUTPUT: Bytes = Bytes::from_static(&[0u8; 8192]);
22
23/// Same reasoning as [DUMMY_CALL_OUTPUT], but for creates.
24const DUMMY_CREATE_ADDRESS: Address = address!("0x0000000000000000000000000000000000000001");
25
26fn stringify(data: &[u8]) -> String {
27    if let Ok(s) = String::abi_decode(data) {
28        return s;
29    }
30    if data.is_ascii() {
31        return std::str::from_utf8(data).unwrap().to_owned();
32    }
33    hex::encode_prefixed(data)
34}
35
36/// Common parameters for expected or assumed reverts. Allows for code reuse.
37pub(crate) trait RevertParameters {
38    fn reverter(&self) -> Option<Address>;
39    fn reason(&self) -> Option<&[u8]>;
40    fn partial_match(&self) -> bool;
41}
42
43impl RevertParameters for AcceptableRevertParameters {
44    fn reverter(&self) -> Option<Address> {
45        self.reverter
46    }
47
48    fn reason(&self) -> Option<&[u8]> {
49        Some(&self.reason)
50    }
51
52    fn partial_match(&self) -> bool {
53        self.partial_match
54    }
55}
56
57/// Core logic for handling reverts that may or may not be expected (or assumed).
58fn handle_revert(
59    is_cheatcode: bool,
60    revert_params: &impl RevertParameters,
61    status: InstructionResult,
62    retdata: &Bytes,
63    known_contracts: &Option<ContractsByArtifact>,
64    reverter: Option<&Address>,
65) -> Result<(), Error> {
66    // If expected reverter address is set then check it matches the actual reverter.
67    if let (Some(expected_reverter), Some(&actual_reverter)) = (revert_params.reverter(), reverter)
68        && expected_reverter != actual_reverter
69    {
70        return Err(fmt_err!(
71            "Reverter != expected reverter: {} != {}",
72            actual_reverter,
73            expected_reverter
74        ));
75    }
76
77    let expected_reason = revert_params.reason();
78    // If None, accept any revert.
79    let Some(expected_reason) = expected_reason else {
80        return Ok(());
81    };
82
83    if !expected_reason.is_empty() && retdata.is_empty() {
84        bail!("call reverted as expected, but without data");
85    }
86
87    let mut actual_revert: Vec<u8> = retdata.to_vec();
88
89    // Compare only the first 4 bytes if partial match.
90    if revert_params.partial_match() && actual_revert.get(..4) == expected_reason.get(..4) {
91        return Ok(());
92    }
93
94    // Try decoding as known errors.
95    actual_revert = decode_revert(actual_revert);
96
97    if actual_revert == expected_reason
98        || (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some())
99    {
100        return Ok(());
101    }
102
103    // If expected reason is `Error(string)` then decode and compare with actual revert.
104    // See <https://github.com/foundry-rs/foundry/issues/12511>
105    if expected_reason.len() >= 4
106        && let Ok(e) = get_error("Error(string)")
107        && let Ok(dec) = e.decode_error(expected_reason)
108        && let Some(DynSolValue::String(revert_str)) = dec.body.first()
109        && revert_str.as_str() == String::from_utf8_lossy(&actual_revert)
110    {
111        return Ok(());
112    }
113
114    let (actual, expected) = if let Some(contracts) = known_contracts {
115        let decoder = RevertDecoder::new().with_abis(contracts.values().map(|c| &c.abi));
116        (
117            &decoder.decode(actual_revert.as_slice(), Some(status)),
118            &decoder.decode(expected_reason, Some(status)),
119        )
120    } else {
121        (&stringify(&actual_revert), &stringify(expected_reason))
122    };
123
124    if expected == actual {
125        return Ok(());
126    }
127
128    Err(fmt_err!("Error != expected error: {} != {}", actual, expected))
129}
130
131pub(crate) fn handle_assume_no_revert(
132    assume_no_revert: &AssumeNoRevert,
133    status: InstructionResult,
134    retdata: &Bytes,
135    known_contracts: &Option<ContractsByArtifact>,
136) -> Result<()> {
137    // if a generic AssumeNoRevert, return Ok(). Otherwise, iterate over acceptable reasons and try
138    // to match against any, otherwise, return an Error with the revert data
139    if assume_no_revert.reasons.is_empty() {
140        Ok(())
141    } else {
142        assume_no_revert
143            .reasons
144            .iter()
145            .find_map(|reason| {
146                handle_revert(
147                    false,
148                    reason,
149                    status,
150                    retdata,
151                    known_contracts,
152                    assume_no_revert.reverted_by.as_ref(),
153                )
154                .ok()
155            })
156            .ok_or_else(|| retdata.clone().into())
157    }
158}
159
160pub(crate) fn handle_expect_revert(
161    is_cheatcode: bool,
162    is_create: bool,
163    internal_expect_revert: bool,
164    expected_revert: &ExpectedRevert,
165    status: InstructionResult,
166    retdata: Bytes,
167    known_contracts: &Option<ContractsByArtifact>,
168) -> Result<(Option<Address>, Bytes)> {
169    let success_return = || {
170        if is_create {
171            (Some(DUMMY_CREATE_ADDRESS), Bytes::new())
172        } else {
173            (None, DUMMY_CALL_OUTPUT.clone())
174        }
175    };
176
177    // Check depths if it's not an expect cheatcode call and if internal expect reverts not enabled.
178    if !is_cheatcode && !internal_expect_revert {
179        ensure!(
180            expected_revert.max_depth > expected_revert.depth,
181            "call didn't revert at a lower depth than cheatcode call depth"
182        );
183    }
184
185    if expected_revert.count == 0 {
186        // If no specific reason or reverter is expected, we just check if it reverted
187        if expected_revert.reverter.is_none() && expected_revert.reason.is_none() {
188            ensure!(
189                matches!(status, return_ok!()),
190                "call reverted when it was expected not to revert"
191            );
192            return Ok(success_return());
193        }
194
195        // Flags to track if the reason and reverter match.
196        let mut reason_match = expected_revert.reason.as_ref().map(|_| false);
197        let mut reverter_match = expected_revert.reverter.as_ref().map(|_| false);
198
199        // If we expect no reverts with a specific reason/reverter, but got a revert,
200        // we need to check if it matches our criteria
201        if !matches!(status, return_ok!()) {
202            // We got a revert, but we expected 0 reverts
203            // We need to check if this revert matches our expected criteria
204
205            // Reverter check
206            if let (Some(expected_reverter), Some(actual_reverter)) =
207                (expected_revert.reverter, expected_revert.reverted_by)
208                && expected_reverter == actual_reverter
209            {
210                reverter_match = Some(true);
211            }
212
213            // Reason check
214            let expected_reason = expected_revert.reason();
215            if let Some(expected_reason) = expected_reason {
216                let mut actual_revert: Vec<u8> = retdata.to_vec();
217                actual_revert = decode_revert(actual_revert);
218
219                if actual_revert == expected_reason {
220                    reason_match = Some(true);
221                }
222            }
223
224            match (reason_match, reverter_match) {
225                (Some(true), Some(true)) => Err(fmt_err!(
226                    "expected 0 reverts with reason: {}, from address: {}, but got one",
227                    stringify(expected_reason.unwrap_or_default()),
228                    expected_revert.reverter.unwrap()
229                )),
230                (Some(true), None) => Err(fmt_err!(
231                    "expected 0 reverts with reason: {}, but got one",
232                    stringify(expected_reason.unwrap_or_default())
233                )),
234                (None, Some(true)) => Err(fmt_err!(
235                    "expected 0 reverts from address: {}, but got one",
236                    expected_revert.reverter.unwrap()
237                )),
238                _ => {
239                    // The revert doesn't match our criteria, which means it's a different revert
240                    // For expectRevert with count=0, any revert should fail the test
241                    let decoded_revert = decode_revert(retdata.to_vec());
242
243                    // Provide more specific error messages based on what was expected
244                    if let Some(reverter) = expected_revert.reverter {
245                        if expected_revert.reason.is_some() {
246                            Err(fmt_err!(
247                                "call reverted with '{}' from {}, but expected 0 reverts with reason '{}' from {}",
248                                stringify(&decoded_revert),
249                                expected_revert.reverted_by.unwrap_or_default(),
250                                stringify(expected_reason.unwrap_or_default()),
251                                reverter
252                            ))
253                        } else {
254                            Err(fmt_err!(
255                                "call reverted with '{}' from {}, but expected 0 reverts from {}",
256                                stringify(&decoded_revert),
257                                expected_revert.reverted_by.unwrap_or_default(),
258                                reverter
259                            ))
260                        }
261                    } else {
262                        Err(fmt_err!(
263                            "call reverted with '{}' when it was expected not to revert",
264                            stringify(&decoded_revert)
265                        ))
266                    }
267                }
268            }
269        } else {
270            // No revert occurred, which is what we expected
271            Ok(success_return())
272        }
273    } else {
274        ensure!(!matches!(status, return_ok!()), "next call did not revert as expected");
275
276        handle_revert(
277            is_cheatcode,
278            expected_revert,
279            status,
280            &retdata,
281            known_contracts,
282            expected_revert.reverted_by.as_ref(),
283        )?;
284        Ok(success_return())
285    }
286}
287
288fn decode_revert(revert: Vec<u8>) -> Vec<u8> {
289    if matches!(
290        revert.get(..4).map(|s| s.try_into().unwrap()),
291        Some(Vm::CheatcodeError::SELECTOR | alloy_sol_types::Revert::SELECTOR)
292    ) && let Ok(decoded) = Vec::<u8>::abi_decode(&revert[4..])
293    {
294        return decoded;
295    }
296    revert
297}