1use crate::{Error, Result};
2use alloy_dyn_abi::{DynSolValue, ErrorExt};
3use alloy_primitives::{Address, Bytes, address, hex};
4use alloy_sol_types::{SolError, SolValue};
5use foundry_common::{ContractsByArtifact, abi::get_error};
6use foundry_evm_core::decode::RevertDecoder;
7use revm::interpreter::{InstructionResult, return_ok};
8use spec::Vm;
9
10use super::{
11 assume::{AcceptableRevertParameters, AssumeNoRevert},
12 expect::ExpectedRevert,
13};
14
15static DUMMY_CALL_OUTPUT: Bytes = Bytes::from_static(&[0u8; 8192]);
22
23const DUMMY_CREATE_ADDRESS: Address = address!("0x0000000000000000000000000000000000000001");
25
26fn stringify(data: &[u8]) -> String {
27 if let Ok(s) = String::abi_decode(data) {
28 return s;
29 }
30 if data.is_ascii() {
31 return std::str::from_utf8(data).unwrap().to_owned();
32 }
33 hex::encode_prefixed(data)
34}
35
36pub(crate) trait RevertParameters {
38 fn reverter(&self) -> Option<Address>;
39 fn reason(&self) -> Option<&[u8]>;
40 fn partial_match(&self) -> bool;
41}
42
43impl RevertParameters for AcceptableRevertParameters {
44 fn reverter(&self) -> Option<Address> {
45 self.reverter
46 }
47
48 fn reason(&self) -> Option<&[u8]> {
49 Some(&self.reason)
50 }
51
52 fn partial_match(&self) -> bool {
53 self.partial_match
54 }
55}
56
57fn handle_revert(
59 is_cheatcode: bool,
60 revert_params: &impl RevertParameters,
61 status: InstructionResult,
62 retdata: &Bytes,
63 known_contracts: &Option<ContractsByArtifact>,
64 reverter: Option<&Address>,
65) -> Result<(), Error> {
66 if let (Some(expected_reverter), Some(&actual_reverter)) = (revert_params.reverter(), reverter)
68 && expected_reverter != actual_reverter
69 {
70 return Err(fmt_err!(
71 "Reverter != expected reverter: {} != {}",
72 actual_reverter,
73 expected_reverter
74 ));
75 }
76
77 let expected_reason = revert_params.reason();
78 let Some(expected_reason) = expected_reason else {
80 return Ok(());
81 };
82
83 if !expected_reason.is_empty() && retdata.is_empty() {
84 bail!("call reverted as expected, but without data");
85 }
86
87 let mut actual_revert: Vec<u8> = retdata.to_vec();
88
89 if revert_params.partial_match() && actual_revert.get(..4) == expected_reason.get(..4) {
91 return Ok(());
92 }
93
94 actual_revert = decode_revert(actual_revert);
96
97 if actual_revert == expected_reason
98 || (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some())
99 {
100 return Ok(());
101 }
102
103 if expected_reason.len() >= 4
106 && let Ok(e) = get_error("Error(string)")
107 && let Ok(dec) = e.decode_error(expected_reason)
108 && let Some(DynSolValue::String(revert_str)) = dec.body.first()
109 && revert_str.as_str() == String::from_utf8_lossy(&actual_revert)
110 {
111 return Ok(());
112 }
113
114 let (actual, expected) = if let Some(contracts) = known_contracts {
115 let decoder = RevertDecoder::new().with_abis(contracts.values().map(|c| &c.abi));
116 (
117 &decoder.decode(actual_revert.as_slice(), Some(status)),
118 &decoder.decode(expected_reason, Some(status)),
119 )
120 } else {
121 (&stringify(&actual_revert), &stringify(expected_reason))
122 };
123
124 if expected == actual {
125 return Ok(());
126 }
127
128 Err(fmt_err!("Error != expected error: {} != {}", actual, expected))
129}
130
131pub(crate) fn handle_assume_no_revert(
132 assume_no_revert: &AssumeNoRevert,
133 status: InstructionResult,
134 retdata: &Bytes,
135 known_contracts: &Option<ContractsByArtifact>,
136) -> Result<()> {
137 if assume_no_revert.reasons.is_empty() {
140 Ok(())
141 } else {
142 assume_no_revert
143 .reasons
144 .iter()
145 .find_map(|reason| {
146 handle_revert(
147 false,
148 reason,
149 status,
150 retdata,
151 known_contracts,
152 assume_no_revert.reverted_by.as_ref(),
153 )
154 .ok()
155 })
156 .ok_or_else(|| retdata.clone().into())
157 }
158}
159
160pub(crate) fn handle_expect_revert(
161 is_cheatcode: bool,
162 is_create: bool,
163 internal_expect_revert: bool,
164 expected_revert: &ExpectedRevert,
165 status: InstructionResult,
166 retdata: Bytes,
167 known_contracts: &Option<ContractsByArtifact>,
168) -> Result<(Option<Address>, Bytes)> {
169 let success_return = || {
170 if is_create {
171 (Some(DUMMY_CREATE_ADDRESS), Bytes::new())
172 } else {
173 (None, DUMMY_CALL_OUTPUT.clone())
174 }
175 };
176
177 if !is_cheatcode && !internal_expect_revert {
179 ensure!(
180 expected_revert.max_depth > expected_revert.depth,
181 "call didn't revert at a lower depth than cheatcode call depth"
182 );
183 }
184
185 if expected_revert.count == 0 {
186 if expected_revert.reverter.is_none() && expected_revert.reason.is_none() {
188 ensure!(
189 matches!(status, return_ok!()),
190 "call reverted when it was expected not to revert"
191 );
192 return Ok(success_return());
193 }
194
195 let mut reason_match = expected_revert.reason.as_ref().map(|_| false);
197 let mut reverter_match = expected_revert.reverter.as_ref().map(|_| false);
198
199 if !matches!(status, return_ok!()) {
202 if let (Some(expected_reverter), Some(actual_reverter)) =
207 (expected_revert.reverter, expected_revert.reverted_by)
208 && expected_reverter == actual_reverter
209 {
210 reverter_match = Some(true);
211 }
212
213 let expected_reason = expected_revert.reason();
215 if let Some(expected_reason) = expected_reason {
216 let mut actual_revert: Vec<u8> = retdata.to_vec();
217 actual_revert = decode_revert(actual_revert);
218
219 if actual_revert == expected_reason {
220 reason_match = Some(true);
221 }
222 }
223
224 match (reason_match, reverter_match) {
225 (Some(true), Some(true)) => Err(fmt_err!(
226 "expected 0 reverts with reason: {}, from address: {}, but got one",
227 stringify(expected_reason.unwrap_or_default()),
228 expected_revert.reverter.unwrap()
229 )),
230 (Some(true), None) => Err(fmt_err!(
231 "expected 0 reverts with reason: {}, but got one",
232 stringify(expected_reason.unwrap_or_default())
233 )),
234 (None, Some(true)) => Err(fmt_err!(
235 "expected 0 reverts from address: {}, but got one",
236 expected_revert.reverter.unwrap()
237 )),
238 _ => {
239 let decoded_revert = decode_revert(retdata.to_vec());
242
243 if let Some(reverter) = expected_revert.reverter {
245 if expected_revert.reason.is_some() {
246 Err(fmt_err!(
247 "call reverted with '{}' from {}, but expected 0 reverts with reason '{}' from {}",
248 stringify(&decoded_revert),
249 expected_revert.reverted_by.unwrap_or_default(),
250 stringify(expected_reason.unwrap_or_default()),
251 reverter
252 ))
253 } else {
254 Err(fmt_err!(
255 "call reverted with '{}' from {}, but expected 0 reverts from {}",
256 stringify(&decoded_revert),
257 expected_revert.reverted_by.unwrap_or_default(),
258 reverter
259 ))
260 }
261 } else {
262 Err(fmt_err!(
263 "call reverted with '{}' when it was expected not to revert",
264 stringify(&decoded_revert)
265 ))
266 }
267 }
268 }
269 } else {
270 Ok(success_return())
272 }
273 } else {
274 ensure!(!matches!(status, return_ok!()), "next call did not revert as expected");
275
276 handle_revert(
277 is_cheatcode,
278 expected_revert,
279 status,
280 &retdata,
281 known_contracts,
282 expected_revert.reverted_by.as_ref(),
283 )?;
284 Ok(success_return())
285 }
286}
287
288fn decode_revert(revert: Vec<u8>) -> Vec<u8> {
289 if matches!(
290 revert.get(..4).map(|s| s.try_into().unwrap()),
291 Some(Vm::CheatcodeError::SELECTOR | alloy_sol_types::Revert::SELECTOR)
292 ) && let Ok(decoded) = Vec::<u8>::abi_decode(&revert[4..])
293 {
294 return decoded;
295 }
296 revert
297}