Skip to main content

foundry_cheatcodes/test/
assert.rs

1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{I256, U256, U512};
3use foundry_evm_core::{
4    abi::console::{format_units_int, format_units_uint},
5    backend::GLOBAL_FAIL_SLOT,
6    constants::CHEATCODE_ADDRESS,
7    decode::ASSERTION_FAILED_PREFIX,
8    evm::FoundryEvmNetwork,
9};
10use itertools::Itertools;
11use revm::context::{ContextTr, JournalTr};
12use std::{borrow::Cow, fmt};
13
14const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
15
16struct ComparisonAssertionError<'a, T> {
17    kind: AssertionKind,
18    left: &'a T,
19    right: &'a T,
20}
21
22#[derive(Clone, Copy)]
23enum AssertionKind {
24    Eq,
25    Ne,
26    Gt,
27    Ge,
28    Lt,
29    Le,
30}
31
32impl AssertionKind {
33    const fn inverse(self) -> Self {
34        match self {
35            Self::Eq => Self::Ne,
36            Self::Ne => Self::Eq,
37            Self::Gt => Self::Le,
38            Self::Ge => Self::Lt,
39            Self::Lt => Self::Ge,
40            Self::Le => Self::Gt,
41        }
42    }
43
44    const fn to_str(self) -> &'static str {
45        match self {
46            Self::Eq => "==",
47            Self::Ne => "!=",
48            Self::Gt => ">",
49            Self::Ge => ">=",
50            Self::Lt => "<",
51            Self::Le => "<=",
52        }
53    }
54}
55
56impl<T> ComparisonAssertionError<'_, T> {
57    fn format_values<D: fmt::Display>(&self, f: impl Fn(&T) -> D) -> String {
58        format!("{} {} {}", f(self.left), self.kind.inverse().to_str(), f(self.right))
59    }
60}
61
62impl<T: fmt::Display> ComparisonAssertionError<'_, T> {
63    fn format_for_values(&self) -> String {
64        self.format_values(T::to_string)
65    }
66}
67
68impl<T: fmt::Display> ComparisonAssertionError<'_, Vec<T>> {
69    fn format_for_arrays(&self) -> String {
70        self.format_values(|v| format!("[{}]", v.iter().format(", ")))
71    }
72}
73
74impl ComparisonAssertionError<'_, U256> {
75    fn format_with_decimals(&self, decimals: &U256) -> String {
76        self.format_values(|v| format_units_uint(v, decimals))
77    }
78}
79
80impl ComparisonAssertionError<'_, I256> {
81    fn format_with_decimals(&self, decimals: &U256) -> String {
82        self.format_values(|v| format_units_int(v, decimals))
83    }
84}
85
86#[derive(thiserror::Error, Debug)]
87#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
88struct EqAbsAssertionError<T, D> {
89    left: T,
90    right: T,
91    max_delta: D,
92    real_delta: D,
93}
94
95impl EqAbsAssertionError<U256, U256> {
96    fn format_with_decimals(&self, decimals: &U256) -> String {
97        format!(
98            "{} !~= {} (max delta: {}, real delta: {})",
99            format_units_uint(&self.left, decimals),
100            format_units_uint(&self.right, decimals),
101            format_units_uint(&self.max_delta, decimals),
102            format_units_uint(&self.real_delta, decimals),
103        )
104    }
105}
106
107impl EqAbsAssertionError<I256, U256> {
108    fn format_with_decimals(&self, decimals: &U256) -> String {
109        format!(
110            "{} !~= {} (max delta: {}, real delta: {})",
111            format_units_int(&self.left, decimals),
112            format_units_int(&self.right, decimals),
113            format_units_uint(&self.max_delta, decimals),
114            format_units_uint(&self.real_delta, decimals),
115        )
116    }
117}
118
119fn format_delta_percent(delta: &U256) -> String {
120    format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
121}
122
123#[derive(Debug)]
124enum EqRelDelta {
125    Defined(U256),
126    Undefined,
127}
128
129impl fmt::Display for EqRelDelta {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        match self {
132            Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
133            Self::Undefined => write!(f, "undefined"),
134        }
135    }
136}
137
138#[derive(thiserror::Error, Debug)]
139#[error(
140    "{left} !~= {right} (max delta: {}, real delta: {})",
141    format_delta_percent(max_delta),
142    real_delta
143)]
144struct EqRelAssertionFailure<T> {
145    left: T,
146    right: T,
147    max_delta: U256,
148    real_delta: EqRelDelta,
149}
150
151#[derive(thiserror::Error, Debug)]
152enum EqRelAssertionError<T> {
153    #[error(transparent)]
154    Failure(Box<EqRelAssertionFailure<T>>),
155    #[error("overflow in delta calculation")]
156    Overflow,
157}
158
159impl EqRelAssertionError<U256> {
160    fn format_with_decimals(&self, decimals: &U256) -> String {
161        match self {
162            Self::Failure(f) => format!(
163                "{} !~= {} (max delta: {}, real delta: {})",
164                format_units_uint(&f.left, decimals),
165                format_units_uint(&f.right, decimals),
166                format_delta_percent(&f.max_delta),
167                f.real_delta,
168            ),
169            Self::Overflow => self.to_string(),
170        }
171    }
172}
173
174impl EqRelAssertionError<I256> {
175    fn format_with_decimals(&self, decimals: &U256) -> String {
176        match self {
177            Self::Failure(f) => format!(
178                "{} !~= {} (max delta: {}, real delta: {})",
179                format_units_int(&f.left, decimals),
180                format_units_int(&f.right, decimals),
181                format_delta_percent(&f.max_delta),
182                f.real_delta,
183            ),
184            Self::Overflow => self.to_string(),
185        }
186    }
187}
188
189type ComparisonResult<'a, T> = Result<(), ComparisonAssertionError<'a, T>>;
190
191#[cold]
192fn handle_assertion_result<FEN: FoundryEvmNetwork, E>(
193    ccx: &mut CheatsCtxt<'_, '_, FEN>,
194    executor: &mut dyn CheatcodesExecutor<FEN>,
195    err: E,
196    error_formatter: Option<&dyn Fn(&E) -> String>,
197    error_msg: Option<&str>,
198) -> Result {
199    let error_msg = error_msg.unwrap_or(ASSERTION_FAILED_PREFIX);
200    let msg = if let Some(error_formatter) = error_formatter {
201        Cow::Owned(format!("{error_msg}: {}", error_formatter(&err)))
202    } else {
203        Cow::Borrowed(error_msg)
204    };
205    handle_assertion_result_mono(ccx, executor, msg)
206}
207
208fn handle_assertion_result_mono<FEN: FoundryEvmNetwork>(
209    ccx: &mut CheatsCtxt<'_, '_, FEN>,
210    executor: &mut dyn CheatcodesExecutor<FEN>,
211    msg: Cow<'_, str>,
212) -> Result {
213    if ccx.state.config.assertions_revert {
214        Err(msg.into_owned().into())
215    } else {
216        executor.console_log(&msg);
217        ccx.ecx.journal_mut().sstore(CHEATCODE_ADDRESS, GLOBAL_FAIL_SLOT, U256::from(1))?;
218        Ok(Default::default())
219    }
220}
221
222/// Implements [crate::Cheatcode] for pairs of cheatcodes.
223///
224/// Accepts a list of pairs of cheatcodes, where the first cheatcode is the one that doesn't contain
225/// a custom error message, and the second one contains it at `err` field.
226///
227/// Passed `args` are the common arguments for both cheatcode structs (excluding `err` field).
228///
229/// Macro also accepts an optional closure that formats the error returned by the assertion.
230macro_rules! impl_assertions {
231    (|$($arg:ident),*| $body:expr, false, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
232        impl_assertions! { @args_tt |($($arg),*)| $body, None, $(($no_error, $with_error)),* }
233    };
234    (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
235        impl_assertions! { @args_tt |($($arg),*)| $body, Some(&ToString::to_string), $(($no_error, $with_error)),* }
236    };
237    (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
238        impl_assertions! { @args_tt |($($arg),*)| $body, Some(&$error_formatter), $(($no_error, $with_error)),* }
239    };
240
241    // We convert args to `tt` and later expand them back into tuple to allow usage of expanded args inside of
242    // each assertion type context.
243    (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
244        $(
245            impl_assertions! { @impl $no_error, $with_error, $args, $body, $error_formatter }
246        )*
247    };
248
249    (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr) => {
250        impl crate::Cheatcode for $no_error {
251            fn apply_full<FEN: FoundryEvmNetwork>(
252                &self,
253                ccx: &mut CheatsCtxt<'_, '_, FEN>,
254                executor: &mut dyn CheatcodesExecutor<FEN>,
255            ) -> Result {
256                let Self { $($arg),* } = self;
257                match $body {
258                    Ok(()) => Ok(Default::default()),
259                    Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, None)
260                }
261            }
262        }
263
264        impl crate::Cheatcode for $with_error {
265            fn apply_full<FEN: FoundryEvmNetwork>(
266                &self,
267                ccx: &mut CheatsCtxt<'_, '_, FEN>,
268                executor: &mut dyn CheatcodesExecutor<FEN>,
269            ) -> Result {
270                let Self { $($arg,)* err } = self;
271                match $body {
272                    Ok(()) => Ok(Default::default()),
273                    Err(assertion_err) => {
274                        handle_assertion_result(ccx, executor, assertion_err, $error_formatter, Some(err))
275                    }
276                }
277            }
278        }
279    };
280}
281
282impl_assertions! {
283    |condition| assert_true(*condition),
284    false,
285    (assertTrue_0Call, assertTrue_1Call),
286}
287
288impl_assertions! {
289    |condition| assert_false(*condition),
290    false,
291    (assertFalse_0Call, assertFalse_1Call),
292}
293
294impl_assertions! {
295    |left, right| assert_eq(left, right),
296    ComparisonAssertionError::format_for_values,
297    (assertEq_0Call, assertEq_1Call),
298    (assertEq_2Call, assertEq_3Call),
299    (assertEq_4Call, assertEq_5Call),
300    (assertEq_6Call, assertEq_7Call),
301    (assertEq_8Call, assertEq_9Call),
302    (assertEq_10Call, assertEq_11Call),
303    (assertEq_12Call, assertEq_13Call),
304}
305
306impl_assertions! {
307    |left, right| assert_eq(left, right),
308    ComparisonAssertionError::format_for_arrays,
309    (assertEq_14Call, assertEq_15Call),
310    (assertEq_16Call, assertEq_17Call),
311    (assertEq_18Call, assertEq_19Call),
312    (assertEq_20Call, assertEq_21Call),
313    (assertEq_22Call, assertEq_23Call),
314    (assertEq_24Call, assertEq_25Call),
315    (assertEq_26Call, assertEq_27Call),
316}
317
318impl_assertions! {
319    |left, right, decimals| assert_eq(left, right),
320    |e| e.format_with_decimals(decimals),
321    (assertEqDecimal_0Call, assertEqDecimal_1Call),
322    (assertEqDecimal_2Call, assertEqDecimal_3Call),
323}
324
325impl_assertions! {
326    |left, right| assert_not_eq(left, right),
327    ComparisonAssertionError::format_for_values,
328    (assertNotEq_0Call, assertNotEq_1Call),
329    (assertNotEq_2Call, assertNotEq_3Call),
330    (assertNotEq_4Call, assertNotEq_5Call),
331    (assertNotEq_6Call, assertNotEq_7Call),
332    (assertNotEq_8Call, assertNotEq_9Call),
333    (assertNotEq_10Call, assertNotEq_11Call),
334    (assertNotEq_12Call, assertNotEq_13Call),
335}
336
337impl_assertions! {
338    |left, right| assert_not_eq(left, right),
339    ComparisonAssertionError::format_for_arrays,
340    (assertNotEq_14Call, assertNotEq_15Call),
341    (assertNotEq_16Call, assertNotEq_17Call),
342    (assertNotEq_18Call, assertNotEq_19Call),
343    (assertNotEq_20Call, assertNotEq_21Call),
344    (assertNotEq_22Call, assertNotEq_23Call),
345    (assertNotEq_24Call, assertNotEq_25Call),
346    (assertNotEq_26Call, assertNotEq_27Call),
347}
348
349impl_assertions! {
350    |left, right, decimals| assert_not_eq(left, right),
351    |e| e.format_with_decimals(decimals),
352    (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
353    (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
354}
355
356impl_assertions! {
357    |left, right| assert_gt(left, right),
358    ComparisonAssertionError::format_for_values,
359    (assertGt_0Call, assertGt_1Call),
360    (assertGt_2Call, assertGt_3Call),
361}
362
363impl_assertions! {
364    |left, right, decimals| assert_gt(left, right),
365    |e| e.format_with_decimals(decimals),
366    (assertGtDecimal_0Call, assertGtDecimal_1Call),
367    (assertGtDecimal_2Call, assertGtDecimal_3Call),
368}
369
370impl_assertions! {
371    |left, right| assert_ge(left, right),
372    ComparisonAssertionError::format_for_values,
373    (assertGe_0Call, assertGe_1Call),
374    (assertGe_2Call, assertGe_3Call),
375}
376
377impl_assertions! {
378    |left, right, decimals| assert_ge(left, right),
379    |e| e.format_with_decimals(decimals),
380    (assertGeDecimal_0Call, assertGeDecimal_1Call),
381    (assertGeDecimal_2Call, assertGeDecimal_3Call),
382}
383
384impl_assertions! {
385    |left, right| assert_lt(left, right),
386    ComparisonAssertionError::format_for_values,
387    (assertLt_0Call, assertLt_1Call),
388    (assertLt_2Call, assertLt_3Call),
389}
390
391impl_assertions! {
392    |left, right, decimals| assert_lt(left, right),
393    |e| e.format_with_decimals(decimals),
394    (assertLtDecimal_0Call, assertLtDecimal_1Call),
395    (assertLtDecimal_2Call, assertLtDecimal_3Call),
396}
397
398impl_assertions! {
399    |left, right| assert_le(left, right),
400    ComparisonAssertionError::format_for_values,
401    (assertLe_0Call, assertLe_1Call),
402    (assertLe_2Call, assertLe_3Call),
403}
404
405impl_assertions! {
406    |left, right, decimals| assert_le(left, right),
407    |e| e.format_with_decimals(decimals),
408    (assertLeDecimal_0Call, assertLeDecimal_1Call),
409    (assertLeDecimal_2Call, assertLeDecimal_3Call),
410}
411
412impl_assertions! {
413    |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
414    (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
415}
416
417impl_assertions! {
418    |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
419    (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
420}
421
422impl_assertions! {
423    |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
424    |e| e.format_with_decimals(decimals),
425    (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
426}
427
428impl_assertions! {
429    |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
430    |e| e.format_with_decimals(decimals),
431    (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
432}
433
434impl_assertions! {
435    |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
436    (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
437}
438
439impl_assertions! {
440    |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
441    (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
442}
443
444impl_assertions! {
445    |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
446    |e| e.format_with_decimals(decimals),
447    (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
448}
449
450impl_assertions! {
451    |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
452    |e| e.format_with_decimals(decimals),
453    (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
454}
455
456const fn assert_true(condition: bool) -> Result<(), ()> {
457    if condition { Ok(()) } else { Err(()) }
458}
459
460const fn assert_false(condition: bool) -> Result<(), ()> {
461    assert_true(!condition)
462}
463
464fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
465    if left == right {
466        Ok(())
467    } else {
468        Err(ComparisonAssertionError { kind: AssertionKind::Eq, left, right })
469    }
470}
471
472fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
473    if left == right {
474        Err(ComparisonAssertionError { kind: AssertionKind::Ne, left, right })
475    } else {
476        Ok(())
477    }
478}
479
480fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
481    if left > right {
482        Ok(())
483    } else {
484        Err(ComparisonAssertionError { kind: AssertionKind::Gt, left, right })
485    }
486}
487
488fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
489    if left >= right {
490        Ok(())
491    } else {
492        Err(ComparisonAssertionError { kind: AssertionKind::Ge, left, right })
493    }
494}
495
496fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
497    if left < right {
498        Ok(())
499    } else {
500        Err(ComparisonAssertionError { kind: AssertionKind::Lt, left, right })
501    }
502}
503
504fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
505    if left <= right {
506        Ok(())
507    } else {
508        Err(ComparisonAssertionError { kind: AssertionKind::Le, left, right })
509    }
510}
511
512fn get_delta_int(left: I256, right: I256) -> U256 {
513    let (left_sign, left_abs) = left.into_sign_and_abs();
514    let (right_sign, right_abs) = right.into_sign_and_abs();
515
516    if left_sign == right_sign {
517        if left_abs > right_abs { left_abs - right_abs } else { right_abs - left_abs }
518    } else {
519        left_abs.wrapping_add(right_abs)
520    }
521}
522
523/// Calculates the relative delta for an absolute difference.
524///
525/// Avoids overflow in the multiplication by using [`U512`] to hold the intermediary result.
526fn calc_delta_full<T>(abs_diff: U256, right: U256) -> Result<U256, EqRelAssertionError<T>> {
527    let delta = U512::from(abs_diff) * U512::from(10).pow(U512::from(EQ_REL_DELTA_RESOLUTION))
528        / U512::from(right);
529    U256::checked_from_limbs_slice(delta.as_limbs()).ok_or(EqRelAssertionError::Overflow)
530}
531
532fn uint_assert_approx_eq_abs(
533    left: U256,
534    right: U256,
535    max_delta: U256,
536) -> Result<(), Box<EqAbsAssertionError<U256, U256>>> {
537    let delta = left.abs_diff(right);
538
539    if delta <= max_delta {
540        Ok(())
541    } else {
542        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
543    }
544}
545
546fn int_assert_approx_eq_abs(
547    left: I256,
548    right: I256,
549    max_delta: U256,
550) -> Result<(), Box<EqAbsAssertionError<I256, U256>>> {
551    let delta = get_delta_int(left, right);
552
553    if delta <= max_delta {
554        Ok(())
555    } else {
556        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
557    }
558}
559
560fn uint_assert_approx_eq_rel(
561    left: U256,
562    right: U256,
563    max_delta: U256,
564) -> Result<(), EqRelAssertionError<U256>> {
565    if right.is_zero() {
566        if left.is_zero() {
567            return Ok(());
568        }
569        return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
570            left,
571            right,
572            max_delta,
573            real_delta: EqRelDelta::Undefined,
574        })));
575    }
576
577    let delta = calc_delta_full::<U256>(left.abs_diff(right), right)?;
578
579    if delta <= max_delta {
580        Ok(())
581    } else {
582        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
583            left,
584            right,
585            max_delta,
586            real_delta: EqRelDelta::Defined(delta),
587        })))
588    }
589}
590
591fn int_assert_approx_eq_rel(
592    left: I256,
593    right: I256,
594    max_delta: U256,
595) -> Result<(), EqRelAssertionError<I256>> {
596    if right.is_zero() {
597        if left.is_zero() {
598            return Ok(());
599        }
600        return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
601            left,
602            right,
603            max_delta,
604            real_delta: EqRelDelta::Undefined,
605        })));
606    }
607
608    let delta = calc_delta_full::<I256>(get_delta_int(left, right), right.unsigned_abs())?;
609
610    if delta <= max_delta {
611        Ok(())
612    } else {
613        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
614            left,
615            right,
616            max_delta,
617            real_delta: EqRelDelta::Defined(delta),
618        })))
619    }
620}