Skip to main content

foundry_cheatcodes/test/
assert.rs

1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{I256, U256, U512};
3use foundry_evm_core::{
4    abi::console::{format_units_int, format_units_uint},
5    backend::GLOBAL_FAIL_SLOT,
6    constants::CHEATCODE_ADDRESS,
7    decode::ASSERTION_FAILED_PREFIX,
8    evm::FoundryEvmNetwork,
9};
10use itertools::Itertools;
11use revm::context::{ContextTr, JournalTr};
12use std::{borrow::Cow, fmt};
13
14const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
15
16struct ComparisonAssertionError<'a, T> {
17    kind: AssertionKind,
18    left: &'a T,
19    right: &'a T,
20}
21
22#[derive(Clone, Copy)]
23enum AssertionKind {
24    Eq,
25    Ne,
26    Gt,
27    Ge,
28    Lt,
29    Le,
30}
31
32impl AssertionKind {
33    const fn inverse(self) -> Self {
34        match self {
35            Self::Eq => Self::Ne,
36            Self::Ne => Self::Eq,
37            Self::Gt => Self::Le,
38            Self::Ge => Self::Lt,
39            Self::Lt => Self::Ge,
40            Self::Le => Self::Gt,
41        }
42    }
43
44    const fn to_str(self) -> &'static str {
45        match self {
46            Self::Eq => "==",
47            Self::Ne => "!=",
48            Self::Gt => ">",
49            Self::Ge => ">=",
50            Self::Lt => "<",
51            Self::Le => "<=",
52        }
53    }
54}
55
56impl<T> ComparisonAssertionError<'_, T> {
57    fn format_values<D: fmt::Display>(&self, f: impl Fn(&T) -> D) -> String {
58        format!("{} {} {}", f(self.left), self.kind.inverse().to_str(), f(self.right))
59    }
60}
61
62impl<T: fmt::Display> ComparisonAssertionError<'_, T> {
63    fn format_for_values(&self) -> String {
64        self.format_values(T::to_string)
65    }
66}
67
68impl<T: fmt::Display> ComparisonAssertionError<'_, Vec<T>> {
69    fn format_for_arrays(&self) -> String {
70        self.format_values(|v| format!("[{}]", v.iter().format(", ")))
71    }
72}
73
74impl ComparisonAssertionError<'_, U256> {
75    fn format_with_decimals(&self, decimals: &U256) -> String {
76        self.format_values(|v| format_units_uint(v, decimals))
77    }
78}
79
80impl ComparisonAssertionError<'_, I256> {
81    fn format_with_decimals(&self, decimals: &U256) -> String {
82        self.format_values(|v| format_units_int(v, decimals))
83    }
84}
85
86#[derive(thiserror::Error, Debug)]
87#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
88struct EqAbsAssertionError<T, D> {
89    left: T,
90    right: T,
91    max_delta: D,
92    real_delta: D,
93}
94
95impl EqAbsAssertionError<U256, U256> {
96    fn format_with_decimals(&self, decimals: &U256) -> String {
97        format!(
98            "{} !~= {} (max delta: {}, real delta: {})",
99            format_units_uint(&self.left, decimals),
100            format_units_uint(&self.right, decimals),
101            format_units_uint(&self.max_delta, decimals),
102            format_units_uint(&self.real_delta, decimals),
103        )
104    }
105}
106
107impl EqAbsAssertionError<I256, U256> {
108    fn format_with_decimals(&self, decimals: &U256) -> String {
109        format!(
110            "{} !~= {} (max delta: {}, real delta: {})",
111            format_units_int(&self.left, decimals),
112            format_units_int(&self.right, decimals),
113            format_units_uint(&self.max_delta, decimals),
114            format_units_uint(&self.real_delta, decimals),
115        )
116    }
117}
118
119fn format_delta_percent(delta: &U256) -> String {
120    format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
121}
122
123#[derive(Debug)]
124enum EqRelDelta {
125    Defined(U256),
126    Undefined,
127}
128
129impl fmt::Display for EqRelDelta {
130    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131        match self {
132            Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
133            Self::Undefined => write!(f, "undefined"),
134        }
135    }
136}
137
138#[derive(thiserror::Error, Debug)]
139#[error(
140    "{left} !~= {right} (max delta: {}, real delta: {})",
141    format_delta_percent(max_delta),
142    real_delta
143)]
144struct EqRelAssertionFailure<T> {
145    left: T,
146    right: T,
147    max_delta: U256,
148    real_delta: EqRelDelta,
149}
150
151#[derive(thiserror::Error, Debug)]
152enum EqRelAssertionError<T> {
153    #[error(transparent)]
154    Failure(Box<EqRelAssertionFailure<T>>),
155    #[error("overflow in delta calculation")]
156    Overflow,
157}
158
159impl EqRelAssertionError<U256> {
160    fn format_with_decimals(&self, decimals: &U256) -> String {
161        match self {
162            Self::Failure(f) => format!(
163                "{} !~= {} (max delta: {}, real delta: {})",
164                format_units_uint(&f.left, decimals),
165                format_units_uint(&f.right, decimals),
166                format_delta_percent(&f.max_delta),
167                &f.real_delta,
168            ),
169            Self::Overflow => self.to_string(),
170        }
171    }
172}
173
174impl EqRelAssertionError<I256> {
175    fn format_with_decimals(&self, decimals: &U256) -> String {
176        match self {
177            Self::Failure(f) => format!(
178                "{} !~= {} (max delta: {}, real delta: {})",
179                format_units_int(&f.left, decimals),
180                format_units_int(&f.right, decimals),
181                format_delta_percent(&f.max_delta),
182                &f.real_delta,
183            ),
184            Self::Overflow => self.to_string(),
185        }
186    }
187}
188
189type ComparisonResult<'a, T> = Result<(), ComparisonAssertionError<'a, T>>;
190
191#[cold]
192fn handle_assertion_result<FEN: FoundryEvmNetwork, E>(
193    ccx: &mut CheatsCtxt<'_, '_, FEN>,
194    executor: &mut dyn CheatcodesExecutor<FEN>,
195    err: E,
196    error_formatter: Option<&dyn Fn(&E) -> String>,
197    error_msg: Option<&str>,
198) -> Result {
199    let error_msg = error_msg.unwrap_or(ASSERTION_FAILED_PREFIX);
200    let msg = if let Some(error_formatter) = error_formatter {
201        Cow::Owned(format!("{error_msg}: {}", error_formatter(&err)))
202    } else {
203        Cow::Borrowed(error_msg)
204    };
205    handle_assertion_result_mono(ccx, executor, msg)
206}
207
208fn handle_assertion_result_mono<FEN: FoundryEvmNetwork>(
209    ccx: &mut CheatsCtxt<'_, '_, FEN>,
210    executor: &mut dyn CheatcodesExecutor<FEN>,
211    msg: Cow<'_, str>,
212) -> Result {
213    if ccx.state.config.assertions_revert {
214        Err(msg.into_owned().into())
215    } else {
216        executor.console_log(&msg);
217        ccx.ecx.journal_mut().sstore(CHEATCODE_ADDRESS, GLOBAL_FAIL_SLOT, U256::from(1))?;
218        Ok(Default::default())
219    }
220}
221
222/// Implements [crate::Cheatcode] for pairs of cheatcodes.
223///
224/// Accepts a list of pairs of cheatcodes, where the first cheatcode is the one that doesn't contain
225/// a custom error message, and the second one contains it at `error` field.
226///
227/// Passed `args` are the common arguments for both cheatcode structs (excluding `error` field).
228///
229/// Macro also accepts an optional closure that formats the error returned by the assertion.
230macro_rules! impl_assertions {
231    (|$($arg:ident),*| $body:expr, false, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
232        impl_assertions! { @args_tt |($($arg),*)| $body, None, $(($no_error, $with_error)),* }
233    };
234    (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
235        impl_assertions! { @args_tt |($($arg),*)| $body, Some(&ToString::to_string), $(($no_error, $with_error)),* }
236    };
237    (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
238        impl_assertions! { @args_tt |($($arg),*)| $body, Some(&$error_formatter), $(($no_error, $with_error)),* }
239    };
240
241    // We convert args to `tt` and later expand them back into tuple to allow usage of expanded args inside of
242    // each assertion type context.
243    (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
244        $(
245            impl_assertions! { @impl $no_error, $with_error, $args, $body, $error_formatter }
246        )*
247    };
248
249    (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr) => {
250        impl crate::Cheatcode for $no_error {
251            fn apply_full<FEN: FoundryEvmNetwork>(
252                &self,
253                ccx: &mut CheatsCtxt<'_, '_, FEN>,
254                executor: &mut dyn CheatcodesExecutor<FEN>,
255            ) -> Result {
256                let Self { $($arg),* } = self;
257                match $body {
258                    Ok(()) => Ok(Default::default()),
259                    Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, None)
260                }
261            }
262        }
263
264        impl crate::Cheatcode for $with_error {
265            fn apply_full<FEN: FoundryEvmNetwork>(
266                &self,
267                ccx: &mut CheatsCtxt<'_, '_, FEN>,
268                executor: &mut dyn CheatcodesExecutor<FEN>,
269            ) -> Result {
270                let Self { $($arg,)* error } = self;
271                match $body {
272                    Ok(()) => Ok(Default::default()),
273                    Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, Some(error))
274                }
275            }
276        }
277    };
278}
279
280impl_assertions! {
281    |condition| assert_true(*condition),
282    false,
283    (assertTrue_0Call, assertTrue_1Call),
284}
285
286impl_assertions! {
287    |condition| assert_false(*condition),
288    false,
289    (assertFalse_0Call, assertFalse_1Call),
290}
291
292impl_assertions! {
293    |left, right| assert_eq(left, right),
294    ComparisonAssertionError::format_for_values,
295    (assertEq_0Call, assertEq_1Call),
296    (assertEq_2Call, assertEq_3Call),
297    (assertEq_4Call, assertEq_5Call),
298    (assertEq_6Call, assertEq_7Call),
299    (assertEq_8Call, assertEq_9Call),
300    (assertEq_10Call, assertEq_11Call),
301    (assertEq_12Call, assertEq_13Call),
302}
303
304impl_assertions! {
305    |left, right| assert_eq(left, right),
306    ComparisonAssertionError::format_for_arrays,
307    (assertEq_14Call, assertEq_15Call),
308    (assertEq_16Call, assertEq_17Call),
309    (assertEq_18Call, assertEq_19Call),
310    (assertEq_20Call, assertEq_21Call),
311    (assertEq_22Call, assertEq_23Call),
312    (assertEq_24Call, assertEq_25Call),
313    (assertEq_26Call, assertEq_27Call),
314}
315
316impl_assertions! {
317    |left, right, decimals| assert_eq(left, right),
318    |e| e.format_with_decimals(decimals),
319    (assertEqDecimal_0Call, assertEqDecimal_1Call),
320    (assertEqDecimal_2Call, assertEqDecimal_3Call),
321}
322
323impl_assertions! {
324    |left, right| assert_not_eq(left, right),
325    ComparisonAssertionError::format_for_values,
326    (assertNotEq_0Call, assertNotEq_1Call),
327    (assertNotEq_2Call, assertNotEq_3Call),
328    (assertNotEq_4Call, assertNotEq_5Call),
329    (assertNotEq_6Call, assertNotEq_7Call),
330    (assertNotEq_8Call, assertNotEq_9Call),
331    (assertNotEq_10Call, assertNotEq_11Call),
332    (assertNotEq_12Call, assertNotEq_13Call),
333}
334
335impl_assertions! {
336    |left, right| assert_not_eq(left, right),
337    ComparisonAssertionError::format_for_arrays,
338    (assertNotEq_14Call, assertNotEq_15Call),
339    (assertNotEq_16Call, assertNotEq_17Call),
340    (assertNotEq_18Call, assertNotEq_19Call),
341    (assertNotEq_20Call, assertNotEq_21Call),
342    (assertNotEq_22Call, assertNotEq_23Call),
343    (assertNotEq_24Call, assertNotEq_25Call),
344    (assertNotEq_26Call, assertNotEq_27Call),
345}
346
347impl_assertions! {
348    |left, right, decimals| assert_not_eq(left, right),
349    |e| e.format_with_decimals(decimals),
350    (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
351    (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
352}
353
354impl_assertions! {
355    |left, right| assert_gt(left, right),
356    ComparisonAssertionError::format_for_values,
357    (assertGt_0Call, assertGt_1Call),
358    (assertGt_2Call, assertGt_3Call),
359}
360
361impl_assertions! {
362    |left, right, decimals| assert_gt(left, right),
363    |e| e.format_with_decimals(decimals),
364    (assertGtDecimal_0Call, assertGtDecimal_1Call),
365    (assertGtDecimal_2Call, assertGtDecimal_3Call),
366}
367
368impl_assertions! {
369    |left, right| assert_ge(left, right),
370    ComparisonAssertionError::format_for_values,
371    (assertGe_0Call, assertGe_1Call),
372    (assertGe_2Call, assertGe_3Call),
373}
374
375impl_assertions! {
376    |left, right, decimals| assert_ge(left, right),
377    |e| e.format_with_decimals(decimals),
378    (assertGeDecimal_0Call, assertGeDecimal_1Call),
379    (assertGeDecimal_2Call, assertGeDecimal_3Call),
380}
381
382impl_assertions! {
383    |left, right| assert_lt(left, right),
384    ComparisonAssertionError::format_for_values,
385    (assertLt_0Call, assertLt_1Call),
386    (assertLt_2Call, assertLt_3Call),
387}
388
389impl_assertions! {
390    |left, right, decimals| assert_lt(left, right),
391    |e| e.format_with_decimals(decimals),
392    (assertLtDecimal_0Call, assertLtDecimal_1Call),
393    (assertLtDecimal_2Call, assertLtDecimal_3Call),
394}
395
396impl_assertions! {
397    |left, right| assert_le(left, right),
398    ComparisonAssertionError::format_for_values,
399    (assertLe_0Call, assertLe_1Call),
400    (assertLe_2Call, assertLe_3Call),
401}
402
403impl_assertions! {
404    |left, right, decimals| assert_le(left, right),
405    |e| e.format_with_decimals(decimals),
406    (assertLeDecimal_0Call, assertLeDecimal_1Call),
407    (assertLeDecimal_2Call, assertLeDecimal_3Call),
408}
409
410impl_assertions! {
411    |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
412    (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
413}
414
415impl_assertions! {
416    |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
417    (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
418}
419
420impl_assertions! {
421    |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
422    |e| e.format_with_decimals(decimals),
423    (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
424}
425
426impl_assertions! {
427    |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
428    |e| e.format_with_decimals(decimals),
429    (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
430}
431
432impl_assertions! {
433    |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
434    (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
435}
436
437impl_assertions! {
438    |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
439    (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
440}
441
442impl_assertions! {
443    |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
444    |e| e.format_with_decimals(decimals),
445    (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
446}
447
448impl_assertions! {
449    |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
450    |e| e.format_with_decimals(decimals),
451    (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
452}
453
454const fn assert_true(condition: bool) -> Result<(), ()> {
455    if condition { Ok(()) } else { Err(()) }
456}
457
458const fn assert_false(condition: bool) -> Result<(), ()> {
459    assert_true(!condition)
460}
461
462fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
463    if left == right {
464        Ok(())
465    } else {
466        Err(ComparisonAssertionError { kind: AssertionKind::Eq, left, right })
467    }
468}
469
470fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
471    if left == right {
472        Err(ComparisonAssertionError { kind: AssertionKind::Ne, left, right })
473    } else {
474        Ok(())
475    }
476}
477
478fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
479    if left > right {
480        Ok(())
481    } else {
482        Err(ComparisonAssertionError { kind: AssertionKind::Gt, left, right })
483    }
484}
485
486fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
487    if left >= right {
488        Ok(())
489    } else {
490        Err(ComparisonAssertionError { kind: AssertionKind::Ge, left, right })
491    }
492}
493
494fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
495    if left < right {
496        Ok(())
497    } else {
498        Err(ComparisonAssertionError { kind: AssertionKind::Lt, left, right })
499    }
500}
501
502fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
503    if left <= right {
504        Ok(())
505    } else {
506        Err(ComparisonAssertionError { kind: AssertionKind::Le, left, right })
507    }
508}
509
510fn get_delta_int(left: I256, right: I256) -> U256 {
511    let (left_sign, left_abs) = left.into_sign_and_abs();
512    let (right_sign, right_abs) = right.into_sign_and_abs();
513
514    if left_sign == right_sign {
515        if left_abs > right_abs { left_abs - right_abs } else { right_abs - left_abs }
516    } else {
517        left_abs.wrapping_add(right_abs)
518    }
519}
520
521/// Calculates the relative delta for an absolute difference.
522///
523/// Avoids overflow in the multiplication by using [`U512`] to hold the intermediary result.
524fn calc_delta_full<T>(abs_diff: U256, right: U256) -> Result<U256, EqRelAssertionError<T>> {
525    let delta = U512::from(abs_diff) * U512::from(10).pow(U512::from(EQ_REL_DELTA_RESOLUTION))
526        / U512::from(right);
527    U256::checked_from_limbs_slice(delta.as_limbs()).ok_or(EqRelAssertionError::Overflow)
528}
529
530fn uint_assert_approx_eq_abs(
531    left: U256,
532    right: U256,
533    max_delta: U256,
534) -> Result<(), Box<EqAbsAssertionError<U256, U256>>> {
535    let delta = left.abs_diff(right);
536
537    if delta <= max_delta {
538        Ok(())
539    } else {
540        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
541    }
542}
543
544fn int_assert_approx_eq_abs(
545    left: I256,
546    right: I256,
547    max_delta: U256,
548) -> Result<(), Box<EqAbsAssertionError<I256, U256>>> {
549    let delta = get_delta_int(left, right);
550
551    if delta <= max_delta {
552        Ok(())
553    } else {
554        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
555    }
556}
557
558fn uint_assert_approx_eq_rel(
559    left: U256,
560    right: U256,
561    max_delta: U256,
562) -> Result<(), EqRelAssertionError<U256>> {
563    if right.is_zero() {
564        if left.is_zero() {
565            return Ok(());
566        }
567        return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
568            left,
569            right,
570            max_delta,
571            real_delta: EqRelDelta::Undefined,
572        })));
573    }
574
575    let delta = calc_delta_full::<U256>(left.abs_diff(right), right)?;
576
577    if delta <= max_delta {
578        Ok(())
579    } else {
580        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
581            left,
582            right,
583            max_delta,
584            real_delta: EqRelDelta::Defined(delta),
585        })))
586    }
587}
588
589fn int_assert_approx_eq_rel(
590    left: I256,
591    right: I256,
592    max_delta: U256,
593) -> Result<(), EqRelAssertionError<I256>> {
594    if right.is_zero() {
595        if left.is_zero() {
596            return Ok(());
597        }
598        return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
599            left,
600            right,
601            max_delta,
602            real_delta: EqRelDelta::Undefined,
603        })));
604    }
605
606    let delta = calc_delta_full::<I256>(get_delta_int(left, right), right.unsigned_abs())?;
607
608    if delta <= max_delta {
609        Ok(())
610    } else {
611        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
612            left,
613            right,
614            max_delta,
615            real_delta: EqRelDelta::Defined(delta),
616        })))
617    }
618}