1use crate::{Error, Result};
2use alloy_primitives::{Address, Bytes, address, hex};
3use alloy_sol_types::{SolError, SolValue};
4use foundry_common::ContractsByArtifact;
5use foundry_evm_core::decode::RevertDecoder;
6use revm::interpreter::{InstructionResult, return_ok};
7use spec::Vm;
8
9use super::{
10 assume::{AcceptableRevertParameters, AssumeNoRevert},
11 expect::ExpectedRevert,
12};
13
14static DUMMY_CALL_OUTPUT: Bytes = Bytes::from_static(&[0u8; 8192]);
21
22const DUMMY_CREATE_ADDRESS: Address = address!("0x0000000000000000000000000000000000000001");
24
25fn stringify(data: &[u8]) -> String {
26 if let Ok(s) = String::abi_decode(data) {
27 return s;
28 }
29 if data.is_ascii() {
30 return std::str::from_utf8(data).unwrap().to_owned();
31 }
32 hex::encode_prefixed(data)
33}
34
35pub(crate) trait RevertParameters {
37 fn reverter(&self) -> Option<Address>;
38 fn reason(&self) -> Option<&[u8]>;
39 fn partial_match(&self) -> bool;
40}
41
42impl RevertParameters for AcceptableRevertParameters {
43 fn reverter(&self) -> Option<Address> {
44 self.reverter
45 }
46
47 fn reason(&self) -> Option<&[u8]> {
48 Some(&self.reason)
49 }
50
51 fn partial_match(&self) -> bool {
52 self.partial_match
53 }
54}
55
56fn handle_revert(
58 is_cheatcode: bool,
59 revert_params: &impl RevertParameters,
60 status: InstructionResult,
61 retdata: &Bytes,
62 known_contracts: &Option<ContractsByArtifact>,
63 reverter: Option<&Address>,
64) -> Result<(), Error> {
65 if let (Some(expected_reverter), Some(&actual_reverter)) = (revert_params.reverter(), reverter)
67 && expected_reverter != actual_reverter
68 {
69 return Err(fmt_err!(
70 "Reverter != expected reverter: {} != {}",
71 actual_reverter,
72 expected_reverter
73 ));
74 }
75
76 let expected_reason = revert_params.reason();
77 let Some(expected_reason) = expected_reason else {
79 return Ok(());
80 };
81
82 if !expected_reason.is_empty() && retdata.is_empty() {
83 bail!("call reverted as expected, but without data");
84 }
85
86 let mut actual_revert: Vec<u8> = retdata.to_vec();
87
88 if revert_params.partial_match() && actual_revert.get(..4) == expected_reason.get(..4) {
90 return Ok(());
91 }
92
93 actual_revert = decode_revert(actual_revert);
95
96 if actual_revert == expected_reason
97 || (is_cheatcode && memchr::memmem::find(&actual_revert, expected_reason).is_some())
98 {
99 Ok(())
100 } else {
101 let (actual, expected) = if let Some(contracts) = known_contracts {
102 let decoder = RevertDecoder::new().with_abis(contracts.values().map(|c| &c.abi));
103 (
104 &decoder.decode(actual_revert.as_slice(), Some(status)),
105 &decoder.decode(expected_reason, Some(status)),
106 )
107 } else {
108 (&stringify(&actual_revert), &stringify(expected_reason))
109 };
110
111 if expected == actual {
112 return Ok(());
113 }
114
115 Err(fmt_err!("Error != expected error: {} != {}", actual, expected))
116 }
117}
118
119pub(crate) fn handle_assume_no_revert(
120 assume_no_revert: &AssumeNoRevert,
121 status: InstructionResult,
122 retdata: &Bytes,
123 known_contracts: &Option<ContractsByArtifact>,
124) -> Result<()> {
125 if assume_no_revert.reasons.is_empty() {
128 Ok(())
129 } else {
130 assume_no_revert
131 .reasons
132 .iter()
133 .find_map(|reason| {
134 handle_revert(
135 false,
136 reason,
137 status,
138 retdata,
139 known_contracts,
140 assume_no_revert.reverted_by.as_ref(),
141 )
142 .ok()
143 })
144 .ok_or_else(|| retdata.clone().into())
145 }
146}
147
148pub(crate) fn handle_expect_revert(
149 is_cheatcode: bool,
150 is_create: bool,
151 internal_expect_revert: bool,
152 expected_revert: &ExpectedRevert,
153 status: InstructionResult,
154 retdata: Bytes,
155 known_contracts: &Option<ContractsByArtifact>,
156) -> Result<(Option<Address>, Bytes)> {
157 let success_return = || {
158 if is_create {
159 (Some(DUMMY_CREATE_ADDRESS), Bytes::new())
160 } else {
161 (None, DUMMY_CALL_OUTPUT.clone())
162 }
163 };
164
165 if !is_cheatcode && !internal_expect_revert {
167 ensure!(
168 expected_revert.max_depth > expected_revert.depth,
169 "call didn't revert at a lower depth than cheatcode call depth"
170 );
171 }
172
173 if expected_revert.count == 0 {
174 if expected_revert.reverter.is_none() && expected_revert.reason.is_none() {
176 ensure!(
177 matches!(status, return_ok!()),
178 "call reverted when it was expected not to revert"
179 );
180 return Ok(success_return());
181 }
182
183 let mut reason_match = expected_revert.reason.as_ref().map(|_| false);
185 let mut reverter_match = expected_revert.reverter.as_ref().map(|_| false);
186
187 if !matches!(status, return_ok!()) {
190 if let (Some(expected_reverter), Some(actual_reverter)) =
195 (expected_revert.reverter, expected_revert.reverted_by)
196 && expected_reverter == actual_reverter
197 {
198 reverter_match = Some(true);
199 }
200
201 let expected_reason = expected_revert.reason.as_deref();
203 if let Some(expected_reason) = expected_reason {
204 let mut actual_revert: Vec<u8> = retdata.to_vec();
205 actual_revert = decode_revert(actual_revert);
206
207 if actual_revert == expected_reason {
208 reason_match = Some(true);
209 }
210 }
211
212 match (reason_match, reverter_match) {
213 (Some(true), Some(true)) => Err(fmt_err!(
214 "expected 0 reverts with reason: {}, from address: {}, but got one",
215 &stringify(expected_reason.unwrap_or_default()),
216 expected_revert.reverter.unwrap()
217 )),
218 (Some(true), None) => Err(fmt_err!(
219 "expected 0 reverts with reason: {}, but got one",
220 &stringify(expected_reason.unwrap_or_default())
221 )),
222 (None, Some(true)) => Err(fmt_err!(
223 "expected 0 reverts from address: {}, but got one",
224 expected_revert.reverter.unwrap()
225 )),
226 _ => {
227 let decoded_revert = decode_revert(retdata.to_vec());
230
231 if expected_revert.reverter.is_some() && expected_revert.reason.is_some() {
233 Err(fmt_err!(
234 "call reverted with '{}' from {}, but expected 0 reverts with reason '{}' from {}",
235 &stringify(&decoded_revert),
236 expected_revert.reverted_by.unwrap_or_default(),
237 &stringify(expected_reason.unwrap_or_default()),
238 expected_revert.reverter.unwrap()
239 ))
240 } else if expected_revert.reverter.is_some() {
241 Err(fmt_err!(
242 "call reverted with '{}' from {}, but expected 0 reverts from {}",
243 &stringify(&decoded_revert),
244 expected_revert.reverted_by.unwrap_or_default(),
245 expected_revert.reverter.unwrap()
246 ))
247 } else {
248 Err(fmt_err!(
249 "call reverted with '{}' when it was expected not to revert",
250 &stringify(&decoded_revert)
251 ))
252 }
253 }
254 }
255 } else {
256 Ok(success_return())
258 }
259 } else {
260 ensure!(!matches!(status, return_ok!()), "next call did not revert as expected");
261
262 handle_revert(
263 is_cheatcode,
264 expected_revert,
265 status,
266 &retdata,
267 known_contracts,
268 expected_revert.reverted_by.as_ref(),
269 )?;
270 Ok(success_return())
271 }
272}
273
274fn decode_revert(revert: Vec<u8>) -> Vec<u8> {
275 if matches!(
276 revert.get(..4).map(|s| s.try_into().unwrap()),
277 Some(Vm::CheatcodeError::SELECTOR | alloy_sol_types::Revert::SELECTOR)
278 ) && let Ok(decoded) = Vec::<u8>::abi_decode(&revert[4..])
279 {
280 return decoded;
281 }
282 revert
283}