foundry_cheatcodes/test/
assert.rs

1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{I256, U256, U512};
3use foundry_evm_core::{
4    abi::console::{format_units_int, format_units_uint},
5    backend::GLOBAL_FAIL_SLOT,
6    constants::CHEATCODE_ADDRESS,
7};
8use itertools::Itertools;
9use revm::context::JournalTr;
10use std::{borrow::Cow, fmt};
11
12const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
13
14struct ComparisonAssertionError<'a, T> {
15    kind: AssertionKind,
16    left: &'a T,
17    right: &'a T,
18}
19
20#[derive(Clone, Copy)]
21enum AssertionKind {
22    Eq,
23    Ne,
24    Gt,
25    Ge,
26    Lt,
27    Le,
28}
29
30impl AssertionKind {
31    fn inverse(self) -> Self {
32        match self {
33            Self::Eq => Self::Ne,
34            Self::Ne => Self::Eq,
35            Self::Gt => Self::Le,
36            Self::Ge => Self::Lt,
37            Self::Lt => Self::Ge,
38            Self::Le => Self::Gt,
39        }
40    }
41
42    fn to_str(self) -> &'static str {
43        match self {
44            Self::Eq => "==",
45            Self::Ne => "!=",
46            Self::Gt => ">",
47            Self::Ge => ">=",
48            Self::Lt => "<",
49            Self::Le => "<=",
50        }
51    }
52}
53
54impl<T> ComparisonAssertionError<'_, T> {
55    fn format_values<D: fmt::Display>(&self, f: impl Fn(&T) -> D) -> String {
56        format!("{} {} {}", f(self.left), self.kind.inverse().to_str(), f(self.right))
57    }
58}
59
60impl<T: fmt::Display> ComparisonAssertionError<'_, T> {
61    fn format_for_values(&self) -> String {
62        self.format_values(T::to_string)
63    }
64}
65
66impl<T: fmt::Display> ComparisonAssertionError<'_, Vec<T>> {
67    fn format_for_arrays(&self) -> String {
68        self.format_values(|v| format!("[{}]", v.iter().format(", ")))
69    }
70}
71
72impl ComparisonAssertionError<'_, U256> {
73    fn format_with_decimals(&self, decimals: &U256) -> String {
74        self.format_values(|v| format_units_uint(v, decimals))
75    }
76}
77
78impl ComparisonAssertionError<'_, I256> {
79    fn format_with_decimals(&self, decimals: &U256) -> String {
80        self.format_values(|v| format_units_int(v, decimals))
81    }
82}
83
84#[derive(thiserror::Error, Debug)]
85#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
86struct EqAbsAssertionError<T, D> {
87    left: T,
88    right: T,
89    max_delta: D,
90    real_delta: D,
91}
92
93impl EqAbsAssertionError<U256, U256> {
94    fn format_with_decimals(&self, decimals: &U256) -> String {
95        format!(
96            "{} !~= {} (max delta: {}, real delta: {})",
97            format_units_uint(&self.left, decimals),
98            format_units_uint(&self.right, decimals),
99            format_units_uint(&self.max_delta, decimals),
100            format_units_uint(&self.real_delta, decimals),
101        )
102    }
103}
104
105impl EqAbsAssertionError<I256, U256> {
106    fn format_with_decimals(&self, decimals: &U256) -> String {
107        format!(
108            "{} !~= {} (max delta: {}, real delta: {})",
109            format_units_int(&self.left, decimals),
110            format_units_int(&self.right, decimals),
111            format_units_uint(&self.max_delta, decimals),
112            format_units_uint(&self.real_delta, decimals),
113        )
114    }
115}
116
117fn format_delta_percent(delta: &U256) -> String {
118    format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
119}
120
121#[derive(Debug)]
122enum EqRelDelta {
123    Defined(U256),
124    Undefined,
125}
126
127impl fmt::Display for EqRelDelta {
128    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129        match self {
130            Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
131            Self::Undefined => write!(f, "undefined"),
132        }
133    }
134}
135
136#[derive(thiserror::Error, Debug)]
137#[error(
138    "{left} !~= {right} (max delta: {}, real delta: {})",
139    format_delta_percent(max_delta),
140    real_delta
141)]
142struct EqRelAssertionFailure<T> {
143    left: T,
144    right: T,
145    max_delta: U256,
146    real_delta: EqRelDelta,
147}
148
149#[derive(thiserror::Error, Debug)]
150enum EqRelAssertionError<T> {
151    #[error(transparent)]
152    Failure(Box<EqRelAssertionFailure<T>>),
153    #[error("overflow in delta calculation")]
154    Overflow,
155}
156
157impl EqRelAssertionError<U256> {
158    fn format_with_decimals(&self, decimals: &U256) -> String {
159        match self {
160            Self::Failure(f) => format!(
161                "{} !~= {} (max delta: {}, real delta: {})",
162                format_units_uint(&f.left, decimals),
163                format_units_uint(&f.right, decimals),
164                format_delta_percent(&f.max_delta),
165                &f.real_delta,
166            ),
167            Self::Overflow => self.to_string(),
168        }
169    }
170}
171
172impl EqRelAssertionError<I256> {
173    fn format_with_decimals(&self, decimals: &U256) -> String {
174        match self {
175            Self::Failure(f) => format!(
176                "{} !~= {} (max delta: {}, real delta: {})",
177                format_units_int(&f.left, decimals),
178                format_units_int(&f.right, decimals),
179                format_delta_percent(&f.max_delta),
180                &f.real_delta,
181            ),
182            Self::Overflow => self.to_string(),
183        }
184    }
185}
186
187type ComparisonResult<'a, T> = Result<(), ComparisonAssertionError<'a, T>>;
188
189#[cold]
190fn handle_assertion_result<E>(
191    ccx: &mut CheatsCtxt,
192    executor: &mut dyn CheatcodesExecutor,
193    err: E,
194    error_formatter: Option<&dyn Fn(&E) -> String>,
195    error_msg: Option<&str>,
196) -> Result {
197    let error_msg = error_msg.unwrap_or("assertion failed");
198    let msg = if let Some(error_formatter) = error_formatter {
199        Cow::Owned(format!("{error_msg}: {}", error_formatter(&err)))
200    } else {
201        Cow::Borrowed(error_msg)
202    };
203    handle_assertion_result_mono(ccx, executor, msg)
204}
205
206fn handle_assertion_result_mono(
207    ccx: &mut CheatsCtxt,
208    executor: &mut dyn CheatcodesExecutor,
209    msg: Cow<'_, str>,
210) -> Result {
211    if ccx.state.config.assertions_revert {
212        Err(msg.into_owned().into())
213    } else {
214        executor.console_log(ccx, &msg);
215        ccx.ecx.journaled_state.sstore(CHEATCODE_ADDRESS, GLOBAL_FAIL_SLOT, U256::from(1))?;
216        Ok(Default::default())
217    }
218}
219
220/// Implements [crate::Cheatcode] for pairs of cheatcodes.
221///
222/// Accepts a list of pairs of cheatcodes, where the first cheatcode is the one that doesn't contain
223/// a custom error message, and the second one contains it at `error` field.
224///
225/// Passed `args` are the common arguments for both cheatcode structs (excluding `error` field).
226///
227/// Macro also accepts an optional closure that formats the error returned by the assertion.
228macro_rules! impl_assertions {
229    (|$($arg:ident),*| $body:expr, false, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
230        impl_assertions! { @args_tt |($($arg),*)| $body, None, $(($no_error, $with_error)),* }
231    };
232    (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
233        impl_assertions! { @args_tt |($($arg),*)| $body, Some(&ToString::to_string), $(($no_error, $with_error)),* }
234    };
235    (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
236        impl_assertions! { @args_tt |($($arg),*)| $body, Some(&$error_formatter), $(($no_error, $with_error)),* }
237    };
238
239    // We convert args to `tt` and later expand them back into tuple to allow usage of expanded args inside of
240    // each assertion type context.
241    (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
242        $(
243            impl_assertions! { @impl $no_error, $with_error, $args, $body, $error_formatter }
244        )*
245    };
246
247    (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr) => {
248        impl crate::Cheatcode for $no_error {
249            fn apply_full(
250                &self,
251                ccx: &mut CheatsCtxt,
252                executor: &mut dyn CheatcodesExecutor,
253            ) -> Result {
254                let Self { $($arg),* } = self;
255                match $body {
256                    Ok(()) => Ok(Default::default()),
257                    Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, None)
258                }
259            }
260        }
261
262        impl crate::Cheatcode for $with_error {
263            fn apply_full(
264                &self,
265                ccx: &mut CheatsCtxt,
266                executor: &mut dyn CheatcodesExecutor,
267            ) -> Result {
268                let Self { $($arg,)* error } = self;
269                match $body {
270                    Ok(()) => Ok(Default::default()),
271                    Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, Some(error))
272                }
273            }
274        }
275    };
276}
277
278impl_assertions! {
279    |condition| assert_true(*condition),
280    false,
281    (assertTrue_0Call, assertTrue_1Call),
282}
283
284impl_assertions! {
285    |condition| assert_false(*condition),
286    false,
287    (assertFalse_0Call, assertFalse_1Call),
288}
289
290impl_assertions! {
291    |left, right| assert_eq(left, right),
292    ComparisonAssertionError::format_for_values,
293    (assertEq_0Call, assertEq_1Call),
294    (assertEq_2Call, assertEq_3Call),
295    (assertEq_4Call, assertEq_5Call),
296    (assertEq_6Call, assertEq_7Call),
297    (assertEq_8Call, assertEq_9Call),
298    (assertEq_10Call, assertEq_11Call),
299    (assertEq_12Call, assertEq_13Call),
300}
301
302impl_assertions! {
303    |left, right| assert_eq(left, right),
304    ComparisonAssertionError::format_for_arrays,
305    (assertEq_14Call, assertEq_15Call),
306    (assertEq_16Call, assertEq_17Call),
307    (assertEq_18Call, assertEq_19Call),
308    (assertEq_20Call, assertEq_21Call),
309    (assertEq_22Call, assertEq_23Call),
310    (assertEq_24Call, assertEq_25Call),
311    (assertEq_26Call, assertEq_27Call),
312}
313
314impl_assertions! {
315    |left, right, decimals| assert_eq(left, right),
316    |e| e.format_with_decimals(decimals),
317    (assertEqDecimal_0Call, assertEqDecimal_1Call),
318    (assertEqDecimal_2Call, assertEqDecimal_3Call),
319}
320
321impl_assertions! {
322    |left, right| assert_not_eq(left, right),
323    ComparisonAssertionError::format_for_values,
324    (assertNotEq_0Call, assertNotEq_1Call),
325    (assertNotEq_2Call, assertNotEq_3Call),
326    (assertNotEq_4Call, assertNotEq_5Call),
327    (assertNotEq_6Call, assertNotEq_7Call),
328    (assertNotEq_8Call, assertNotEq_9Call),
329    (assertNotEq_10Call, assertNotEq_11Call),
330    (assertNotEq_12Call, assertNotEq_13Call),
331}
332
333impl_assertions! {
334    |left, right| assert_not_eq(left, right),
335    ComparisonAssertionError::format_for_arrays,
336    (assertNotEq_14Call, assertNotEq_15Call),
337    (assertNotEq_16Call, assertNotEq_17Call),
338    (assertNotEq_18Call, assertNotEq_19Call),
339    (assertNotEq_20Call, assertNotEq_21Call),
340    (assertNotEq_22Call, assertNotEq_23Call),
341    (assertNotEq_24Call, assertNotEq_25Call),
342    (assertNotEq_26Call, assertNotEq_27Call),
343}
344
345impl_assertions! {
346    |left, right, decimals| assert_not_eq(left, right),
347    |e| e.format_with_decimals(decimals),
348    (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
349    (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
350}
351
352impl_assertions! {
353    |left, right| assert_gt(left, right),
354    ComparisonAssertionError::format_for_values,
355    (assertGt_0Call, assertGt_1Call),
356    (assertGt_2Call, assertGt_3Call),
357}
358
359impl_assertions! {
360    |left, right, decimals| assert_gt(left, right),
361    |e| e.format_with_decimals(decimals),
362    (assertGtDecimal_0Call, assertGtDecimal_1Call),
363    (assertGtDecimal_2Call, assertGtDecimal_3Call),
364}
365
366impl_assertions! {
367    |left, right| assert_ge(left, right),
368    ComparisonAssertionError::format_for_values,
369    (assertGe_0Call, assertGe_1Call),
370    (assertGe_2Call, assertGe_3Call),
371}
372
373impl_assertions! {
374    |left, right, decimals| assert_ge(left, right),
375    |e| e.format_with_decimals(decimals),
376    (assertGeDecimal_0Call, assertGeDecimal_1Call),
377    (assertGeDecimal_2Call, assertGeDecimal_3Call),
378}
379
380impl_assertions! {
381    |left, right| assert_lt(left, right),
382    ComparisonAssertionError::format_for_values,
383    (assertLt_0Call, assertLt_1Call),
384    (assertLt_2Call, assertLt_3Call),
385}
386
387impl_assertions! {
388    |left, right, decimals| assert_lt(left, right),
389    |e| e.format_with_decimals(decimals),
390    (assertLtDecimal_0Call, assertLtDecimal_1Call),
391    (assertLtDecimal_2Call, assertLtDecimal_3Call),
392}
393
394impl_assertions! {
395    |left, right| assert_le(left, right),
396    ComparisonAssertionError::format_for_values,
397    (assertLe_0Call, assertLe_1Call),
398    (assertLe_2Call, assertLe_3Call),
399}
400
401impl_assertions! {
402    |left, right, decimals| assert_le(left, right),
403    |e| e.format_with_decimals(decimals),
404    (assertLeDecimal_0Call, assertLeDecimal_1Call),
405    (assertLeDecimal_2Call, assertLeDecimal_3Call),
406}
407
408impl_assertions! {
409    |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
410    (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
411}
412
413impl_assertions! {
414    |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
415    (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
416}
417
418impl_assertions! {
419    |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
420    |e| e.format_with_decimals(decimals),
421    (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
422}
423
424impl_assertions! {
425    |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
426    |e| e.format_with_decimals(decimals),
427    (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
428}
429
430impl_assertions! {
431    |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
432    (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
433}
434
435impl_assertions! {
436    |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
437    (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
438}
439
440impl_assertions! {
441    |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
442    |e| e.format_with_decimals(decimals),
443    (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
444}
445
446impl_assertions! {
447    |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
448    |e| e.format_with_decimals(decimals),
449    (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
450}
451
452fn assert_true(condition: bool) -> Result<(), ()> {
453    if condition { Ok(()) } else { Err(()) }
454}
455
456fn assert_false(condition: bool) -> Result<(), ()> {
457    assert_true(!condition)
458}
459
460fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
461    if left == right {
462        Ok(())
463    } else {
464        Err(ComparisonAssertionError { kind: AssertionKind::Eq, left, right })
465    }
466}
467
468fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
469    if left != right {
470        Ok(())
471    } else {
472        Err(ComparisonAssertionError { kind: AssertionKind::Ne, left, right })
473    }
474}
475
476fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
477    if left > right {
478        Ok(())
479    } else {
480        Err(ComparisonAssertionError { kind: AssertionKind::Gt, left, right })
481    }
482}
483
484fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
485    if left >= right {
486        Ok(())
487    } else {
488        Err(ComparisonAssertionError { kind: AssertionKind::Ge, left, right })
489    }
490}
491
492fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
493    if left < right {
494        Ok(())
495    } else {
496        Err(ComparisonAssertionError { kind: AssertionKind::Lt, left, right })
497    }
498}
499
500fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
501    if left <= right {
502        Ok(())
503    } else {
504        Err(ComparisonAssertionError { kind: AssertionKind::Le, left, right })
505    }
506}
507
508fn get_delta_int(left: I256, right: I256) -> U256 {
509    let (left_sign, left_abs) = left.into_sign_and_abs();
510    let (right_sign, right_abs) = right.into_sign_and_abs();
511
512    if left_sign == right_sign {
513        if left_abs > right_abs { left_abs - right_abs } else { right_abs - left_abs }
514    } else {
515        left_abs.wrapping_add(right_abs)
516    }
517}
518
519/// Calculates the relative delta for an absolute difference.
520///
521/// Avoids overflow in the multiplication by using [`U512`] to hold the intermediary result.
522fn calc_delta_full<T>(abs_diff: U256, right: U256) -> Result<U256, EqRelAssertionError<T>> {
523    let delta = U512::from(abs_diff) * U512::from(10).pow(U512::from(EQ_REL_DELTA_RESOLUTION))
524        / U512::from(right);
525    U256::checked_from_limbs_slice(delta.as_limbs()).ok_or(EqRelAssertionError::Overflow)
526}
527
528fn uint_assert_approx_eq_abs(
529    left: U256,
530    right: U256,
531    max_delta: U256,
532) -> Result<(), Box<EqAbsAssertionError<U256, U256>>> {
533    let delta = left.abs_diff(right);
534
535    if delta <= max_delta {
536        Ok(())
537    } else {
538        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
539    }
540}
541
542fn int_assert_approx_eq_abs(
543    left: I256,
544    right: I256,
545    max_delta: U256,
546) -> Result<(), Box<EqAbsAssertionError<I256, U256>>> {
547    let delta = get_delta_int(left, right);
548
549    if delta <= max_delta {
550        Ok(())
551    } else {
552        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
553    }
554}
555
556fn uint_assert_approx_eq_rel(
557    left: U256,
558    right: U256,
559    max_delta: U256,
560) -> Result<(), EqRelAssertionError<U256>> {
561    if right.is_zero() {
562        if left.is_zero() {
563            return Ok(());
564        } else {
565            return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
566                left,
567                right,
568                max_delta,
569                real_delta: EqRelDelta::Undefined,
570            })));
571        };
572    }
573
574    let delta = calc_delta_full::<U256>(left.abs_diff(right), right)?;
575
576    if delta <= max_delta {
577        Ok(())
578    } else {
579        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
580            left,
581            right,
582            max_delta,
583            real_delta: EqRelDelta::Defined(delta),
584        })))
585    }
586}
587
588fn int_assert_approx_eq_rel(
589    left: I256,
590    right: I256,
591    max_delta: U256,
592) -> Result<(), EqRelAssertionError<I256>> {
593    if right.is_zero() {
594        if left.is_zero() {
595            return Ok(());
596        } else {
597            return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
598                left,
599                right,
600                max_delta,
601                real_delta: EqRelDelta::Undefined,
602            })));
603        }
604    }
605
606    let delta = calc_delta_full::<I256>(get_delta_int(left, right), right.unsigned_abs())?;
607
608    if delta <= max_delta {
609        Ok(())
610    } else {
611        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
612            left,
613            right,
614            max_delta,
615            real_delta: EqRelDelta::Defined(delta),
616        })))
617    }
618}