1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{I256, U256, U512};
3use foundry_evm_core::{
4 abi::console::{format_units_int, format_units_uint},
5 backend::GLOBAL_FAIL_SLOT,
6 constants::CHEATCODE_ADDRESS,
7 decode::ASSERTION_FAILED_PREFIX,
8 evm::FoundryEvmNetwork,
9};
10use itertools::Itertools;
11use revm::context::{ContextTr, JournalTr};
12use std::{borrow::Cow, fmt};
13
14const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
15
16struct ComparisonAssertionError<'a, T> {
17 kind: AssertionKind,
18 left: &'a T,
19 right: &'a T,
20}
21
22#[derive(Clone, Copy)]
23enum AssertionKind {
24 Eq,
25 Ne,
26 Gt,
27 Ge,
28 Lt,
29 Le,
30}
31
32impl AssertionKind {
33 const fn inverse(self) -> Self {
34 match self {
35 Self::Eq => Self::Ne,
36 Self::Ne => Self::Eq,
37 Self::Gt => Self::Le,
38 Self::Ge => Self::Lt,
39 Self::Lt => Self::Ge,
40 Self::Le => Self::Gt,
41 }
42 }
43
44 const fn to_str(self) -> &'static str {
45 match self {
46 Self::Eq => "==",
47 Self::Ne => "!=",
48 Self::Gt => ">",
49 Self::Ge => ">=",
50 Self::Lt => "<",
51 Self::Le => "<=",
52 }
53 }
54}
55
56impl<T> ComparisonAssertionError<'_, T> {
57 fn format_values<D: fmt::Display>(&self, f: impl Fn(&T) -> D) -> String {
58 format!("{} {} {}", f(self.left), self.kind.inverse().to_str(), f(self.right))
59 }
60}
61
62impl<T: fmt::Display> ComparisonAssertionError<'_, T> {
63 fn format_for_values(&self) -> String {
64 self.format_values(T::to_string)
65 }
66}
67
68impl<T: fmt::Display> ComparisonAssertionError<'_, Vec<T>> {
69 fn format_for_arrays(&self) -> String {
70 self.format_values(|v| format!("[{}]", v.iter().format(", ")))
71 }
72}
73
74impl ComparisonAssertionError<'_, U256> {
75 fn format_with_decimals(&self, decimals: &U256) -> String {
76 self.format_values(|v| format_units_uint(v, decimals))
77 }
78}
79
80impl ComparisonAssertionError<'_, I256> {
81 fn format_with_decimals(&self, decimals: &U256) -> String {
82 self.format_values(|v| format_units_int(v, decimals))
83 }
84}
85
86#[derive(thiserror::Error, Debug)]
87#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
88struct EqAbsAssertionError<T, D> {
89 left: T,
90 right: T,
91 max_delta: D,
92 real_delta: D,
93}
94
95impl EqAbsAssertionError<U256, U256> {
96 fn format_with_decimals(&self, decimals: &U256) -> String {
97 format!(
98 "{} !~= {} (max delta: {}, real delta: {})",
99 format_units_uint(&self.left, decimals),
100 format_units_uint(&self.right, decimals),
101 format_units_uint(&self.max_delta, decimals),
102 format_units_uint(&self.real_delta, decimals),
103 )
104 }
105}
106
107impl EqAbsAssertionError<I256, U256> {
108 fn format_with_decimals(&self, decimals: &U256) -> String {
109 format!(
110 "{} !~= {} (max delta: {}, real delta: {})",
111 format_units_int(&self.left, decimals),
112 format_units_int(&self.right, decimals),
113 format_units_uint(&self.max_delta, decimals),
114 format_units_uint(&self.real_delta, decimals),
115 )
116 }
117}
118
119fn format_delta_percent(delta: &U256) -> String {
120 format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
121}
122
123#[derive(Debug)]
124enum EqRelDelta {
125 Defined(U256),
126 Undefined,
127}
128
129impl fmt::Display for EqRelDelta {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 match self {
132 Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
133 Self::Undefined => write!(f, "undefined"),
134 }
135 }
136}
137
138#[derive(thiserror::Error, Debug)]
139#[error(
140 "{left} !~= {right} (max delta: {}, real delta: {})",
141 format_delta_percent(max_delta),
142 real_delta
143)]
144struct EqRelAssertionFailure<T> {
145 left: T,
146 right: T,
147 max_delta: U256,
148 real_delta: EqRelDelta,
149}
150
151#[derive(thiserror::Error, Debug)]
152enum EqRelAssertionError<T> {
153 #[error(transparent)]
154 Failure(Box<EqRelAssertionFailure<T>>),
155 #[error("overflow in delta calculation")]
156 Overflow,
157}
158
159impl EqRelAssertionError<U256> {
160 fn format_with_decimals(&self, decimals: &U256) -> String {
161 match self {
162 Self::Failure(f) => format!(
163 "{} !~= {} (max delta: {}, real delta: {})",
164 format_units_uint(&f.left, decimals),
165 format_units_uint(&f.right, decimals),
166 format_delta_percent(&f.max_delta),
167 f.real_delta,
168 ),
169 Self::Overflow => self.to_string(),
170 }
171 }
172}
173
174impl EqRelAssertionError<I256> {
175 fn format_with_decimals(&self, decimals: &U256) -> String {
176 match self {
177 Self::Failure(f) => format!(
178 "{} !~= {} (max delta: {}, real delta: {})",
179 format_units_int(&f.left, decimals),
180 format_units_int(&f.right, decimals),
181 format_delta_percent(&f.max_delta),
182 f.real_delta,
183 ),
184 Self::Overflow => self.to_string(),
185 }
186 }
187}
188
189type ComparisonResult<'a, T> = Result<(), ComparisonAssertionError<'a, T>>;
190
191#[cold]
192fn handle_assertion_result<FEN: FoundryEvmNetwork, E>(
193 ccx: &mut CheatsCtxt<'_, '_, FEN>,
194 executor: &mut dyn CheatcodesExecutor<FEN>,
195 err: E,
196 error_formatter: Option<&dyn Fn(&E) -> String>,
197 error_msg: Option<&str>,
198) -> Result {
199 let error_msg = error_msg.unwrap_or(ASSERTION_FAILED_PREFIX);
200 let msg = if let Some(error_formatter) = error_formatter {
201 Cow::Owned(format!("{error_msg}: {}", error_formatter(&err)))
202 } else {
203 Cow::Borrowed(error_msg)
204 };
205 handle_assertion_result_mono(ccx, executor, msg)
206}
207
208fn handle_assertion_result_mono<FEN: FoundryEvmNetwork>(
209 ccx: &mut CheatsCtxt<'_, '_, FEN>,
210 executor: &mut dyn CheatcodesExecutor<FEN>,
211 msg: Cow<'_, str>,
212) -> Result {
213 if ccx.state.config.assertions_revert {
214 Err(msg.into_owned().into())
215 } else {
216 executor.console_log(&msg);
217 ccx.ecx.journal_mut().sstore(CHEATCODE_ADDRESS, GLOBAL_FAIL_SLOT, U256::from(1))?;
218 Ok(Default::default())
219 }
220}
221
222macro_rules! impl_assertions {
231 (|$($arg:ident),*| $body:expr, false, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
232 impl_assertions! { @args_tt |($($arg),*)| $body, None, $(($no_error, $with_error)),* }
233 };
234 (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
235 impl_assertions! { @args_tt |($($arg),*)| $body, Some(&ToString::to_string), $(($no_error, $with_error)),* }
236 };
237 (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
238 impl_assertions! { @args_tt |($($arg),*)| $body, Some(&$error_formatter), $(($no_error, $with_error)),* }
239 };
240
241 (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
244 $(
245 impl_assertions! { @impl $no_error, $with_error, $args, $body, $error_formatter }
246 )*
247 };
248
249 (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr) => {
250 impl crate::Cheatcode for $no_error {
251 fn apply_full<FEN: FoundryEvmNetwork>(
252 &self,
253 ccx: &mut CheatsCtxt<'_, '_, FEN>,
254 executor: &mut dyn CheatcodesExecutor<FEN>,
255 ) -> Result {
256 let Self { $($arg),* } = self;
257 match $body {
258 Ok(()) => Ok(Default::default()),
259 Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, None)
260 }
261 }
262 }
263
264 impl crate::Cheatcode for $with_error {
265 fn apply_full<FEN: FoundryEvmNetwork>(
266 &self,
267 ccx: &mut CheatsCtxt<'_, '_, FEN>,
268 executor: &mut dyn CheatcodesExecutor<FEN>,
269 ) -> Result {
270 let Self { $($arg,)* err } = self;
271 match $body {
272 Ok(()) => Ok(Default::default()),
273 Err(assertion_err) => {
274 handle_assertion_result(ccx, executor, assertion_err, $error_formatter, Some(err))
275 }
276 }
277 }
278 }
279 };
280}
281
282impl_assertions! {
283 |condition| assert_true(*condition),
284 false,
285 (assertTrue_0Call, assertTrue_1Call),
286}
287
288impl_assertions! {
289 |condition| assert_false(*condition),
290 false,
291 (assertFalse_0Call, assertFalse_1Call),
292}
293
294impl_assertions! {
295 |left, right| assert_eq(left, right),
296 ComparisonAssertionError::format_for_values,
297 (assertEq_0Call, assertEq_1Call),
298 (assertEq_2Call, assertEq_3Call),
299 (assertEq_4Call, assertEq_5Call),
300 (assertEq_6Call, assertEq_7Call),
301 (assertEq_8Call, assertEq_9Call),
302 (assertEq_10Call, assertEq_11Call),
303 (assertEq_12Call, assertEq_13Call),
304}
305
306impl_assertions! {
307 |left, right| assert_eq(left, right),
308 ComparisonAssertionError::format_for_arrays,
309 (assertEq_14Call, assertEq_15Call),
310 (assertEq_16Call, assertEq_17Call),
311 (assertEq_18Call, assertEq_19Call),
312 (assertEq_20Call, assertEq_21Call),
313 (assertEq_22Call, assertEq_23Call),
314 (assertEq_24Call, assertEq_25Call),
315 (assertEq_26Call, assertEq_27Call),
316}
317
318impl_assertions! {
319 |left, right, decimals| assert_eq(left, right),
320 |e| e.format_with_decimals(decimals),
321 (assertEqDecimal_0Call, assertEqDecimal_1Call),
322 (assertEqDecimal_2Call, assertEqDecimal_3Call),
323}
324
325impl_assertions! {
326 |left, right| assert_not_eq(left, right),
327 ComparisonAssertionError::format_for_values,
328 (assertNotEq_0Call, assertNotEq_1Call),
329 (assertNotEq_2Call, assertNotEq_3Call),
330 (assertNotEq_4Call, assertNotEq_5Call),
331 (assertNotEq_6Call, assertNotEq_7Call),
332 (assertNotEq_8Call, assertNotEq_9Call),
333 (assertNotEq_10Call, assertNotEq_11Call),
334 (assertNotEq_12Call, assertNotEq_13Call),
335}
336
337impl_assertions! {
338 |left, right| assert_not_eq(left, right),
339 ComparisonAssertionError::format_for_arrays,
340 (assertNotEq_14Call, assertNotEq_15Call),
341 (assertNotEq_16Call, assertNotEq_17Call),
342 (assertNotEq_18Call, assertNotEq_19Call),
343 (assertNotEq_20Call, assertNotEq_21Call),
344 (assertNotEq_22Call, assertNotEq_23Call),
345 (assertNotEq_24Call, assertNotEq_25Call),
346 (assertNotEq_26Call, assertNotEq_27Call),
347}
348
349impl_assertions! {
350 |left, right, decimals| assert_not_eq(left, right),
351 |e| e.format_with_decimals(decimals),
352 (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
353 (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
354}
355
356impl_assertions! {
357 |left, right| assert_gt(left, right),
358 ComparisonAssertionError::format_for_values,
359 (assertGt_0Call, assertGt_1Call),
360 (assertGt_2Call, assertGt_3Call),
361}
362
363impl_assertions! {
364 |left, right, decimals| assert_gt(left, right),
365 |e| e.format_with_decimals(decimals),
366 (assertGtDecimal_0Call, assertGtDecimal_1Call),
367 (assertGtDecimal_2Call, assertGtDecimal_3Call),
368}
369
370impl_assertions! {
371 |left, right| assert_ge(left, right),
372 ComparisonAssertionError::format_for_values,
373 (assertGe_0Call, assertGe_1Call),
374 (assertGe_2Call, assertGe_3Call),
375}
376
377impl_assertions! {
378 |left, right, decimals| assert_ge(left, right),
379 |e| e.format_with_decimals(decimals),
380 (assertGeDecimal_0Call, assertGeDecimal_1Call),
381 (assertGeDecimal_2Call, assertGeDecimal_3Call),
382}
383
384impl_assertions! {
385 |left, right| assert_lt(left, right),
386 ComparisonAssertionError::format_for_values,
387 (assertLt_0Call, assertLt_1Call),
388 (assertLt_2Call, assertLt_3Call),
389}
390
391impl_assertions! {
392 |left, right, decimals| assert_lt(left, right),
393 |e| e.format_with_decimals(decimals),
394 (assertLtDecimal_0Call, assertLtDecimal_1Call),
395 (assertLtDecimal_2Call, assertLtDecimal_3Call),
396}
397
398impl_assertions! {
399 |left, right| assert_le(left, right),
400 ComparisonAssertionError::format_for_values,
401 (assertLe_0Call, assertLe_1Call),
402 (assertLe_2Call, assertLe_3Call),
403}
404
405impl_assertions! {
406 |left, right, decimals| assert_le(left, right),
407 |e| e.format_with_decimals(decimals),
408 (assertLeDecimal_0Call, assertLeDecimal_1Call),
409 (assertLeDecimal_2Call, assertLeDecimal_3Call),
410}
411
412impl_assertions! {
413 |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
414 (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
415}
416
417impl_assertions! {
418 |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
419 (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
420}
421
422impl_assertions! {
423 |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
424 |e| e.format_with_decimals(decimals),
425 (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
426}
427
428impl_assertions! {
429 |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
430 |e| e.format_with_decimals(decimals),
431 (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
432}
433
434impl_assertions! {
435 |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
436 (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
437}
438
439impl_assertions! {
440 |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
441 (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
442}
443
444impl_assertions! {
445 |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
446 |e| e.format_with_decimals(decimals),
447 (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
448}
449
450impl_assertions! {
451 |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
452 |e| e.format_with_decimals(decimals),
453 (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
454}
455
456const fn assert_true(condition: bool) -> Result<(), ()> {
457 if condition { Ok(()) } else { Err(()) }
458}
459
460const fn assert_false(condition: bool) -> Result<(), ()> {
461 assert_true(!condition)
462}
463
464fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
465 if left == right {
466 Ok(())
467 } else {
468 Err(ComparisonAssertionError { kind: AssertionKind::Eq, left, right })
469 }
470}
471
472fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
473 if left == right {
474 Err(ComparisonAssertionError { kind: AssertionKind::Ne, left, right })
475 } else {
476 Ok(())
477 }
478}
479
480fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
481 if left > right {
482 Ok(())
483 } else {
484 Err(ComparisonAssertionError { kind: AssertionKind::Gt, left, right })
485 }
486}
487
488fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
489 if left >= right {
490 Ok(())
491 } else {
492 Err(ComparisonAssertionError { kind: AssertionKind::Ge, left, right })
493 }
494}
495
496fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
497 if left < right {
498 Ok(())
499 } else {
500 Err(ComparisonAssertionError { kind: AssertionKind::Lt, left, right })
501 }
502}
503
504fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
505 if left <= right {
506 Ok(())
507 } else {
508 Err(ComparisonAssertionError { kind: AssertionKind::Le, left, right })
509 }
510}
511
512fn get_delta_int(left: I256, right: I256) -> U256 {
513 let (left_sign, left_abs) = left.into_sign_and_abs();
514 let (right_sign, right_abs) = right.into_sign_and_abs();
515
516 if left_sign == right_sign {
517 if left_abs > right_abs { left_abs - right_abs } else { right_abs - left_abs }
518 } else {
519 left_abs.wrapping_add(right_abs)
520 }
521}
522
523fn calc_delta_full<T>(abs_diff: U256, right: U256) -> Result<U256, EqRelAssertionError<T>> {
527 let delta = U512::from(abs_diff) * U512::from(10).pow(U512::from(EQ_REL_DELTA_RESOLUTION))
528 / U512::from(right);
529 U256::checked_from_limbs_slice(delta.as_limbs()).ok_or(EqRelAssertionError::Overflow)
530}
531
532fn uint_assert_approx_eq_abs(
533 left: U256,
534 right: U256,
535 max_delta: U256,
536) -> Result<(), Box<EqAbsAssertionError<U256, U256>>> {
537 let delta = left.abs_diff(right);
538
539 if delta <= max_delta {
540 Ok(())
541 } else {
542 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
543 }
544}
545
546fn int_assert_approx_eq_abs(
547 left: I256,
548 right: I256,
549 max_delta: U256,
550) -> Result<(), Box<EqAbsAssertionError<I256, U256>>> {
551 let delta = get_delta_int(left, right);
552
553 if delta <= max_delta {
554 Ok(())
555 } else {
556 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
557 }
558}
559
560fn uint_assert_approx_eq_rel(
561 left: U256,
562 right: U256,
563 max_delta: U256,
564) -> Result<(), EqRelAssertionError<U256>> {
565 if right.is_zero() {
566 if left.is_zero() {
567 return Ok(());
568 }
569 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
570 left,
571 right,
572 max_delta,
573 real_delta: EqRelDelta::Undefined,
574 })));
575 }
576
577 let delta = calc_delta_full::<U256>(left.abs_diff(right), right)?;
578
579 if delta <= max_delta {
580 Ok(())
581 } else {
582 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
583 left,
584 right,
585 max_delta,
586 real_delta: EqRelDelta::Defined(delta),
587 })))
588 }
589}
590
591fn int_assert_approx_eq_rel(
592 left: I256,
593 right: I256,
594 max_delta: U256,
595) -> Result<(), EqRelAssertionError<I256>> {
596 if right.is_zero() {
597 if left.is_zero() {
598 return Ok(());
599 }
600 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
601 left,
602 right,
603 max_delta,
604 real_delta: EqRelDelta::Undefined,
605 })));
606 }
607
608 let delta = calc_delta_full::<I256>(get_delta_int(left, right), right.unsigned_abs())?;
609
610 if delta <= max_delta {
611 Ok(())
612 } else {
613 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
614 left,
615 right,
616 max_delta,
617 real_delta: EqRelDelta::Defined(delta),
618 })))
619 }
620}