1use crate::{Error, Result};
2use alloy_dyn_abi::{DynSolValue, ErrorExt};
3use alloy_primitives::{Address, Bytes, address, hex};
4use alloy_sol_types::{SolError, SolValue};
5use foundry_common::{ContractsByArtifact, abi::get_error};
6use foundry_evm_core::decode::RevertDecoder;
7use revm::interpreter::{InstructionResult, return_ok};
8use spec::Vm;
9
10use super::{
11 assume::{AcceptableRevertParameters, AssumeNoRevert},
12 expect::ExpectedRevert,
13};
14
15static DUMMY_CALL_OUTPUT: Bytes = Bytes::from_static(&[0u8; 8192]);
22
23const DUMMY_CREATE_ADDRESS: Address = address!("0x0000000000000000000000000000000000000001");
25
26fn stringify(data: &[u8]) -> String {
27 if let Ok(s) = String::abi_decode(data) {
28 return s;
29 }
30 if data.is_ascii() {
31 return std::str::from_utf8(data).unwrap().to_owned();
32 }
33 hex::encode_prefixed(data)
34}
35
36pub(crate) trait RevertParameters {
38 fn reverter(&self) -> Option<Address>;
39 fn reason(&self) -> Option<&[u8]>;
40 fn partial_match(&self) -> bool;
41}
42
43impl RevertParameters for AcceptableRevertParameters {
44 fn reverter(&self) -> Option<Address> {
45 self.reverter
46 }
47
48 fn reason(&self) -> Option<&[u8]> {
49 Some(&self.reason)
50 }
51
52 fn partial_match(&self) -> bool {
53 self.partial_match
54 }
55}
56
57fn handle_revert(
59 is_cheatcode: bool,
60 revert_params: &impl RevertParameters,
61 status: InstructionResult,
62 retdata: &Bytes,
63 known_contracts: &Option<ContractsByArtifact>,
64 reverter: Option<&Address>,
65) -> Result<(), Error> {
66 if let (Some(expected_reverter), Some(&actual_reverter)) = (revert_params.reverter(), reverter)
68 && expected_reverter != actual_reverter
69 {
70 return Err(fmt_err!(
71 "Reverter != expected reverter: {} != {}",
72 actual_reverter,
73 expected_reverter
74 ));
75 }
76
77 let expected_reason = revert_params.reason();
78 let Some(expected_reason) = expected_reason else {
80 return Ok(());
81 };
82
83 if !expected_reason.is_empty() && retdata.is_empty() {
84 bail!("call reverted as expected, but without data");
85 }
86
87 let mut actual_revert: Vec<u8> = retdata.to_vec();
88
89 if revert_params.partial_match() && actual_revert.get(..4) == expected_reason.get(..4) {
91 return Ok(());
92 }
93
94 actual_revert = decode_revert(actual_revert);
96
97 if actual_revert == expected_reason
98 || (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some())
99 {
100 return Ok(());
101 }
102
103 if let Ok(e) = get_error("Error(string)")
106 && let Ok(dec) = e.decode_error(expected_reason)
107 && let Some(DynSolValue::String(revert_str)) = dec.body.first()
108 && revert_str.as_str() == String::from_utf8_lossy(&actual_revert)
109 {
110 return Ok(());
111 }
112
113 let (actual, expected) = if let Some(contracts) = known_contracts {
114 let decoder = RevertDecoder::new().with_abis(contracts.values().map(|c| &c.abi));
115 (
116 &decoder.decode(actual_revert.as_slice(), Some(status)),
117 &decoder.decode(expected_reason, Some(status)),
118 )
119 } else {
120 (&stringify(&actual_revert), &stringify(expected_reason))
121 };
122
123 if expected == actual {
124 return Ok(());
125 }
126
127 Err(fmt_err!("Error != expected error: {} != {}", actual, expected))
128}
129
130pub(crate) fn handle_assume_no_revert(
131 assume_no_revert: &AssumeNoRevert,
132 status: InstructionResult,
133 retdata: &Bytes,
134 known_contracts: &Option<ContractsByArtifact>,
135) -> Result<()> {
136 if assume_no_revert.reasons.is_empty() {
139 Ok(())
140 } else {
141 assume_no_revert
142 .reasons
143 .iter()
144 .find_map(|reason| {
145 handle_revert(
146 false,
147 reason,
148 status,
149 retdata,
150 known_contracts,
151 assume_no_revert.reverted_by.as_ref(),
152 )
153 .ok()
154 })
155 .ok_or_else(|| retdata.clone().into())
156 }
157}
158
159pub(crate) fn handle_expect_revert(
160 is_cheatcode: bool,
161 is_create: bool,
162 internal_expect_revert: bool,
163 expected_revert: &ExpectedRevert,
164 status: InstructionResult,
165 retdata: Bytes,
166 known_contracts: &Option<ContractsByArtifact>,
167) -> Result<(Option<Address>, Bytes)> {
168 let success_return = || {
169 if is_create {
170 (Some(DUMMY_CREATE_ADDRESS), Bytes::new())
171 } else {
172 (None, DUMMY_CALL_OUTPUT.clone())
173 }
174 };
175
176 if !is_cheatcode && !internal_expect_revert {
178 ensure!(
179 expected_revert.max_depth > expected_revert.depth,
180 "call didn't revert at a lower depth than cheatcode call depth"
181 );
182 }
183
184 if expected_revert.count == 0 {
185 if expected_revert.reverter.is_none() && expected_revert.reason.is_none() {
187 ensure!(
188 matches!(status, return_ok!()),
189 "call reverted when it was expected not to revert"
190 );
191 return Ok(success_return());
192 }
193
194 let mut reason_match = expected_revert.reason.as_ref().map(|_| false);
196 let mut reverter_match = expected_revert.reverter.as_ref().map(|_| false);
197
198 if !matches!(status, return_ok!()) {
201 if let (Some(expected_reverter), Some(actual_reverter)) =
206 (expected_revert.reverter, expected_revert.reverted_by)
207 && expected_reverter == actual_reverter
208 {
209 reverter_match = Some(true);
210 }
211
212 let expected_reason = expected_revert.reason();
214 if let Some(expected_reason) = expected_reason {
215 let mut actual_revert: Vec<u8> = retdata.to_vec();
216 actual_revert = decode_revert(actual_revert);
217
218 if actual_revert == expected_reason {
219 reason_match = Some(true);
220 }
221 }
222
223 match (reason_match, reverter_match) {
224 (Some(true), Some(true)) => Err(fmt_err!(
225 "expected 0 reverts with reason: {}, from address: {}, but got one",
226 stringify(expected_reason.unwrap_or_default()),
227 expected_revert.reverter.unwrap()
228 )),
229 (Some(true), None) => Err(fmt_err!(
230 "expected 0 reverts with reason: {}, but got one",
231 stringify(expected_reason.unwrap_or_default())
232 )),
233 (None, Some(true)) => Err(fmt_err!(
234 "expected 0 reverts from address: {}, but got one",
235 expected_revert.reverter.unwrap()
236 )),
237 _ => {
238 let decoded_revert = decode_revert(retdata.to_vec());
241
242 if let Some(reverter) = expected_revert.reverter {
244 if expected_revert.reason.is_some() {
245 Err(fmt_err!(
246 "call reverted with '{}' from {}, but expected 0 reverts with reason '{}' from {}",
247 stringify(&decoded_revert),
248 expected_revert.reverted_by.unwrap_or_default(),
249 stringify(expected_reason.unwrap_or_default()),
250 reverter
251 ))
252 } else {
253 Err(fmt_err!(
254 "call reverted with '{}' from {}, but expected 0 reverts from {}",
255 stringify(&decoded_revert),
256 expected_revert.reverted_by.unwrap_or_default(),
257 reverter
258 ))
259 }
260 } else {
261 Err(fmt_err!(
262 "call reverted with '{}' when it was expected not to revert",
263 stringify(&decoded_revert)
264 ))
265 }
266 }
267 }
268 } else {
269 Ok(success_return())
271 }
272 } else {
273 ensure!(!matches!(status, return_ok!()), "next call did not revert as expected");
274
275 handle_revert(
276 is_cheatcode,
277 expected_revert,
278 status,
279 &retdata,
280 known_contracts,
281 expected_revert.reverted_by.as_ref(),
282 )?;
283 Ok(success_return())
284 }
285}
286
287fn decode_revert(revert: Vec<u8>) -> Vec<u8> {
288 if matches!(
289 revert.get(..4).map(|s| s.try_into().unwrap()),
290 Some(Vm::CheatcodeError::SELECTOR | alloy_sol_types::Revert::SELECTOR)
291 ) && let Ok(decoded) = Vec::<u8>::abi_decode(&revert[4..])
292 {
293 return decoded;
294 }
295 revert
296}