foundry_cheatcodes/test/
assert.rs

1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{hex, I256, U256};
3use foundry_evm_core::{
4    abi::console::{format_units_int, format_units_uint},
5    backend::GLOBAL_FAIL_SLOT,
6    constants::CHEATCODE_ADDRESS,
7};
8use itertools::Itertools;
9use revm::context::JournalTr;
10use std::fmt::{Debug, Display};
11
12const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
13
14#[derive(Debug, thiserror::Error)]
15#[error("assertion failed")]
16struct SimpleAssertionError;
17
18#[derive(thiserror::Error, Debug)]
19enum ComparisonAssertionError<'a, T> {
20    Ne { left: &'a T, right: &'a T },
21    Eq { left: &'a T, right: &'a T },
22    Ge { left: &'a T, right: &'a T },
23    Gt { left: &'a T, right: &'a T },
24    Le { left: &'a T, right: &'a T },
25    Lt { left: &'a T, right: &'a T },
26}
27
28macro_rules! format_values {
29    ($self:expr, $format_fn:expr) => {
30        match $self {
31            Self::Ne { left, right } => format!("{} == {}", $format_fn(left), $format_fn(right)),
32            Self::Eq { left, right } => format!("{} != {}", $format_fn(left), $format_fn(right)),
33            Self::Ge { left, right } => format!("{} < {}", $format_fn(left), $format_fn(right)),
34            Self::Gt { left, right } => format!("{} <= {}", $format_fn(left), $format_fn(right)),
35            Self::Le { left, right } => format!("{} > {}", $format_fn(left), $format_fn(right)),
36            Self::Lt { left, right } => format!("{} >= {}", $format_fn(left), $format_fn(right)),
37        }
38    };
39}
40
41impl<T: Display> ComparisonAssertionError<'_, T> {
42    fn format_for_values(&self) -> String {
43        format_values!(self, T::to_string)
44    }
45}
46
47impl<T: Display> ComparisonAssertionError<'_, Vec<T>> {
48    fn format_for_arrays(&self) -> String {
49        let formatter = |v: &Vec<T>| format!("[{}]", v.iter().format(", "));
50        format_values!(self, formatter)
51    }
52}
53
54impl ComparisonAssertionError<'_, U256> {
55    fn format_with_decimals(&self, decimals: &U256) -> String {
56        let formatter = |v: &U256| format_units_uint(v, decimals);
57        format_values!(self, formatter)
58    }
59}
60
61impl ComparisonAssertionError<'_, I256> {
62    fn format_with_decimals(&self, decimals: &U256) -> String {
63        let formatter = |v: &I256| format_units_int(v, decimals);
64        format_values!(self, formatter)
65    }
66}
67
68#[derive(thiserror::Error, Debug)]
69#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
70struct EqAbsAssertionError<T, D> {
71    left: T,
72    right: T,
73    max_delta: D,
74    real_delta: D,
75}
76
77impl EqAbsAssertionError<U256, U256> {
78    fn format_with_decimals(&self, decimals: &U256) -> String {
79        format!(
80            "{} !~= {} (max delta: {}, real delta: {})",
81            format_units_uint(&self.left, decimals),
82            format_units_uint(&self.right, decimals),
83            format_units_uint(&self.max_delta, decimals),
84            format_units_uint(&self.real_delta, decimals),
85        )
86    }
87}
88
89impl EqAbsAssertionError<I256, U256> {
90    fn format_with_decimals(&self, decimals: &U256) -> String {
91        format!(
92            "{} !~= {} (max delta: {}, real delta: {})",
93            format_units_int(&self.left, decimals),
94            format_units_int(&self.right, decimals),
95            format_units_uint(&self.max_delta, decimals),
96            format_units_uint(&self.real_delta, decimals),
97        )
98    }
99}
100
101fn format_delta_percent(delta: &U256) -> String {
102    format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
103}
104
105#[derive(Debug)]
106enum EqRelDelta {
107    Defined(U256),
108    Undefined,
109}
110
111impl Display for EqRelDelta {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        match self {
114            Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
115            Self::Undefined => write!(f, "undefined"),
116        }
117    }
118}
119
120#[derive(thiserror::Error, Debug)]
121#[error(
122    "{left} !~= {right} (max delta: {}, real delta: {})",
123    format_delta_percent(max_delta),
124    real_delta
125)]
126struct EqRelAssertionFailure<T> {
127    left: T,
128    right: T,
129    max_delta: U256,
130    real_delta: EqRelDelta,
131}
132
133#[derive(thiserror::Error, Debug)]
134enum EqRelAssertionError<T> {
135    #[error(transparent)]
136    Failure(Box<EqRelAssertionFailure<T>>),
137    #[error("overflow in delta calculation")]
138    Overflow,
139}
140
141impl EqRelAssertionError<U256> {
142    fn format_with_decimals(&self, decimals: &U256) -> String {
143        match self {
144            Self::Failure(f) => format!(
145                "{} !~= {} (max delta: {}, real delta: {})",
146                format_units_uint(&f.left, decimals),
147                format_units_uint(&f.right, decimals),
148                format_delta_percent(&f.max_delta),
149                &f.real_delta,
150            ),
151            Self::Overflow => self.to_string(),
152        }
153    }
154}
155
156impl EqRelAssertionError<I256> {
157    fn format_with_decimals(&self, decimals: &U256) -> String {
158        match self {
159            Self::Failure(f) => format!(
160                "{} !~= {} (max delta: {}, real delta: {})",
161                format_units_int(&f.left, decimals),
162                format_units_int(&f.right, decimals),
163                format_delta_percent(&f.max_delta),
164                &f.real_delta,
165            ),
166            Self::Overflow => self.to_string(),
167        }
168    }
169}
170
171type ComparisonResult<'a, T> = Result<Vec<u8>, ComparisonAssertionError<'a, T>>;
172
173fn handle_assertion_result<ERR>(
174    result: core::result::Result<Vec<u8>, ERR>,
175    ccx: &mut CheatsCtxt,
176    executor: &mut dyn CheatcodesExecutor,
177    error_formatter: impl Fn(&ERR) -> String,
178    error_msg: Option<&str>,
179    format_error: bool,
180) -> Result {
181    match result {
182        Ok(_) => Ok(Default::default()),
183        Err(err) => {
184            let error_msg = error_msg.unwrap_or("assertion failed");
185            let msg = if format_error {
186                format!("{error_msg}: {}", error_formatter(&err))
187            } else {
188                error_msg.to_string()
189            };
190            if ccx.state.config.assertions_revert {
191                Err(msg.into())
192            } else {
193                executor.console_log(ccx, &msg);
194                ccx.ecx.journaled_state.sstore(
195                    CHEATCODE_ADDRESS,
196                    GLOBAL_FAIL_SLOT,
197                    U256::from(1),
198                )?;
199                Ok(Default::default())
200            }
201        }
202    }
203}
204
205/// Implements [crate::Cheatcode] for pairs of cheatcodes.
206///
207/// Accepts a list of pairs of cheatcodes, where the first cheatcode is the one that doesn't contain
208/// a custom error message, and the second one contains it at `error` field.
209///
210/// Passed `args` are the common arguments for both cheatcode structs (excluding `error` field).
211///
212/// Macro also accepts an optional closure that formats the error returned by the assertion.
213macro_rules! impl_assertions {
214    (|$($arg:ident),*| $body:expr, $format_error:literal, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
215        impl_assertions!(@args_tt |($($arg),*)| $body, |e| e.to_string(), $format_error, $(($no_error, $with_error),)*);
216    };
217    (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
218        impl_assertions!(@args_tt |($($arg),*)| $body, |e| e.to_string(), true, $(($no_error, $with_error),)*);
219    };
220    (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
221        impl_assertions!(@args_tt |($($arg),*)| $body, $error_formatter, true, $(($no_error, $with_error)),*);
222    };
223    // We convert args to `tt` and later expand them back into tuple to allow usage of expanded args inside of
224    // each assertion type context.
225    (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $format_error:literal, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
226        $(
227            impl_assertions!(@impl $no_error, $with_error, $args, $body, $error_formatter, $format_error);
228        )*
229    };
230    (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr, $format_error:literal) => {
231        impl crate::Cheatcode for $no_error {
232            fn apply_full(
233                &self,
234                ccx: &mut CheatsCtxt,
235                executor: &mut dyn CheatcodesExecutor,
236            ) -> Result {
237                let Self { $($arg),* } = self;
238                handle_assertion_result($body, ccx, executor, $error_formatter, None, $format_error)
239            }
240        }
241
242        impl crate::Cheatcode for $with_error {
243            fn apply_full(
244                &self,
245                ccx: &mut CheatsCtxt,
246                executor: &mut dyn CheatcodesExecutor,
247            ) -> Result {
248                let Self { $($arg),*, error} = self;
249                handle_assertion_result($body, ccx, executor, $error_formatter, Some(error), $format_error)
250            }
251        }
252    };
253}
254
255impl_assertions! {
256    |condition| assert_true(*condition),
257    false,
258    (assertTrue_0Call, assertTrue_1Call),
259}
260
261impl_assertions! {
262    |condition| assert_false(*condition),
263    false,
264    (assertFalse_0Call, assertFalse_1Call),
265}
266
267impl_assertions! {
268    |left, right| assert_eq(left, right),
269    |e| e.format_for_values(),
270    (assertEq_0Call, assertEq_1Call),
271    (assertEq_2Call, assertEq_3Call),
272    (assertEq_4Call, assertEq_5Call),
273    (assertEq_6Call, assertEq_7Call),
274    (assertEq_8Call, assertEq_9Call),
275    (assertEq_10Call, assertEq_11Call),
276}
277
278impl_assertions! {
279    |left, right| assert_eq(&hex::encode_prefixed(left), &hex::encode_prefixed(right)),
280    |e| e.format_for_values(),
281    (assertEq_12Call, assertEq_13Call),
282}
283
284impl_assertions! {
285    |left, right| assert_eq(left, right),
286    |e| e.format_for_arrays(),
287    (assertEq_14Call, assertEq_15Call),
288    (assertEq_16Call, assertEq_17Call),
289    (assertEq_18Call, assertEq_19Call),
290    (assertEq_20Call, assertEq_21Call),
291    (assertEq_22Call, assertEq_23Call),
292    (assertEq_24Call, assertEq_25Call),
293}
294
295impl_assertions! {
296    |left, right| assert_eq(
297        &left.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
298        &right.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
299    ),
300    |e| e.format_for_arrays(),
301    (assertEq_26Call, assertEq_27Call),
302}
303
304impl_assertions! {
305    |left, right, decimals| assert_eq(left, right),
306    |e| e.format_with_decimals(decimals),
307    (assertEqDecimal_0Call, assertEqDecimal_1Call),
308    (assertEqDecimal_2Call, assertEqDecimal_3Call),
309}
310
311impl_assertions! {
312    |left, right| assert_not_eq(left, right),
313    |e| e.format_for_values(),
314    (assertNotEq_0Call, assertNotEq_1Call),
315    (assertNotEq_2Call, assertNotEq_3Call),
316    (assertNotEq_4Call, assertNotEq_5Call),
317    (assertNotEq_6Call, assertNotEq_7Call),
318    (assertNotEq_8Call, assertNotEq_9Call),
319    (assertNotEq_10Call, assertNotEq_11Call),
320}
321
322impl_assertions! {
323    |left, right| assert_not_eq(&hex::encode_prefixed(left), &hex::encode_prefixed(right)),
324    |e| e.format_for_values(),
325    (assertNotEq_12Call, assertNotEq_13Call),
326}
327
328impl_assertions! {
329    |left, right| assert_not_eq(left, right),
330    |e| e.format_for_arrays(),
331    (assertNotEq_14Call, assertNotEq_15Call),
332    (assertNotEq_16Call, assertNotEq_17Call),
333    (assertNotEq_18Call, assertNotEq_19Call),
334    (assertNotEq_20Call, assertNotEq_21Call),
335    (assertNotEq_22Call, assertNotEq_23Call),
336    (assertNotEq_24Call, assertNotEq_25Call),
337}
338
339impl_assertions! {
340    |left, right| assert_not_eq(
341        &left.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
342        &right.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
343    ),
344    |e| e.format_for_arrays(),
345    (assertNotEq_26Call, assertNotEq_27Call),
346}
347
348impl_assertions! {
349    |left, right, decimals| assert_not_eq(left, right),
350    |e| e.format_with_decimals(decimals),
351    (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
352    (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
353}
354
355impl_assertions! {
356    |left, right| assert_gt(left, right),
357    |e| e.format_for_values(),
358    (assertGt_0Call, assertGt_1Call),
359    (assertGt_2Call, assertGt_3Call),
360}
361
362impl_assertions! {
363    |left, right, decimals| assert_gt(left, right),
364    |e| e.format_with_decimals(decimals),
365    (assertGtDecimal_0Call, assertGtDecimal_1Call),
366    (assertGtDecimal_2Call, assertGtDecimal_3Call),
367}
368
369impl_assertions! {
370    |left, right| assert_ge(left, right),
371    |e| e.format_for_values(),
372    (assertGe_0Call, assertGe_1Call),
373    (assertGe_2Call, assertGe_3Call),
374}
375
376impl_assertions! {
377    |left, right, decimals| assert_ge(left, right),
378    |e| e.format_with_decimals(decimals),
379    (assertGeDecimal_0Call, assertGeDecimal_1Call),
380    (assertGeDecimal_2Call, assertGeDecimal_3Call),
381}
382
383impl_assertions! {
384    |left, right| assert_lt(left, right),
385    |e| e.format_for_values(),
386    (assertLt_0Call, assertLt_1Call),
387    (assertLt_2Call, assertLt_3Call),
388}
389
390impl_assertions! {
391    |left, right, decimals| assert_lt(left, right),
392    |e| e.format_with_decimals(decimals),
393    (assertLtDecimal_0Call, assertLtDecimal_1Call),
394    (assertLtDecimal_2Call, assertLtDecimal_3Call),
395}
396
397impl_assertions! {
398    |left, right| assert_le(left, right),
399    |e| e.format_for_values(),
400    (assertLe_0Call, assertLe_1Call),
401    (assertLe_2Call, assertLe_3Call),
402}
403
404impl_assertions! {
405    |left, right, decimals| assert_le(left, right),
406    |e| e.format_with_decimals(decimals),
407    (assertLeDecimal_0Call, assertLeDecimal_1Call),
408    (assertLeDecimal_2Call, assertLeDecimal_3Call),
409}
410
411impl_assertions! {
412    |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
413    (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
414}
415
416impl_assertions! {
417    |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
418    (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
419}
420
421impl_assertions! {
422    |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
423    |e| e.format_with_decimals(decimals),
424    (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
425}
426
427impl_assertions! {
428    |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
429    |e| e.format_with_decimals(decimals),
430    (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
431}
432
433impl_assertions! {
434    |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
435    (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
436}
437
438impl_assertions! {
439    |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
440    (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
441}
442
443impl_assertions! {
444    |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
445    |e| e.format_with_decimals(decimals),
446    (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
447}
448
449impl_assertions! {
450    |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
451    |e| e.format_with_decimals(decimals),
452    (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
453}
454
455fn assert_true(condition: bool) -> Result<Vec<u8>, SimpleAssertionError> {
456    if condition {
457        Ok(Default::default())
458    } else {
459        Err(SimpleAssertionError)
460    }
461}
462
463fn assert_false(condition: bool) -> Result<Vec<u8>, SimpleAssertionError> {
464    if !condition {
465        Ok(Default::default())
466    } else {
467        Err(SimpleAssertionError)
468    }
469}
470
471fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
472    if left == right {
473        Ok(Default::default())
474    } else {
475        Err(ComparisonAssertionError::Eq { left, right })
476    }
477}
478
479fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
480    if left != right {
481        Ok(Default::default())
482    } else {
483        Err(ComparisonAssertionError::Ne { left, right })
484    }
485}
486
487fn get_delta_uint(left: U256, right: U256) -> U256 {
488    if left > right {
489        left - right
490    } else {
491        right - left
492    }
493}
494
495fn get_delta_int(left: I256, right: I256) -> U256 {
496    let (left_sign, left_abs) = left.into_sign_and_abs();
497    let (right_sign, right_abs) = right.into_sign_and_abs();
498
499    if left_sign == right_sign {
500        if left_abs > right_abs {
501            left_abs - right_abs
502        } else {
503            right_abs - left_abs
504        }
505    } else {
506        left_abs + right_abs
507    }
508}
509
510fn uint_assert_approx_eq_abs(
511    left: U256,
512    right: U256,
513    max_delta: U256,
514) -> Result<Vec<u8>, Box<EqAbsAssertionError<U256, U256>>> {
515    let delta = get_delta_uint(left, right);
516
517    if delta <= max_delta {
518        Ok(Default::default())
519    } else {
520        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
521    }
522}
523
524fn int_assert_approx_eq_abs(
525    left: I256,
526    right: I256,
527    max_delta: U256,
528) -> Result<Vec<u8>, Box<EqAbsAssertionError<I256, U256>>> {
529    let delta = get_delta_int(left, right);
530
531    if delta <= max_delta {
532        Ok(Default::default())
533    } else {
534        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
535    }
536}
537
538fn uint_assert_approx_eq_rel(
539    left: U256,
540    right: U256,
541    max_delta: U256,
542) -> Result<Vec<u8>, EqRelAssertionError<U256>> {
543    if right.is_zero() {
544        if left.is_zero() {
545            return Ok(Default::default())
546        } else {
547            return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
548                left,
549                right,
550                max_delta,
551                real_delta: EqRelDelta::Undefined,
552            })))
553        };
554    }
555
556    let delta = get_delta_uint(left, right)
557        .checked_mul(U256::pow(U256::from(10), EQ_REL_DELTA_RESOLUTION))
558        .ok_or(EqRelAssertionError::Overflow)? /
559        right;
560
561    if delta <= max_delta {
562        Ok(Default::default())
563    } else {
564        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
565            left,
566            right,
567            max_delta,
568            real_delta: EqRelDelta::Defined(delta),
569        })))
570    }
571}
572
573fn int_assert_approx_eq_rel(
574    left: I256,
575    right: I256,
576    max_delta: U256,
577) -> Result<Vec<u8>, EqRelAssertionError<I256>> {
578    if right.is_zero() {
579        if left.is_zero() {
580            return Ok(Default::default())
581        } else {
582            return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
583                left,
584                right,
585                max_delta,
586                real_delta: EqRelDelta::Undefined,
587            })))
588        }
589    }
590
591    let (_, abs_right) = right.into_sign_and_abs();
592    let delta = get_delta_int(left, right)
593        .checked_mul(U256::pow(U256::from(10), EQ_REL_DELTA_RESOLUTION))
594        .ok_or(EqRelAssertionError::Overflow)? /
595        abs_right;
596
597    if delta <= max_delta {
598        Ok(Default::default())
599    } else {
600        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
601            left,
602            right,
603            max_delta,
604            real_delta: EqRelDelta::Defined(delta),
605        })))
606    }
607}
608
609fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
610    if left > right {
611        Ok(Default::default())
612    } else {
613        Err(ComparisonAssertionError::Gt { left, right })
614    }
615}
616
617fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
618    if left >= right {
619        Ok(Default::default())
620    } else {
621        Err(ComparisonAssertionError::Ge { left, right })
622    }
623}
624
625fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
626    if left < right {
627        Ok(Default::default())
628    } else {
629        Err(ComparisonAssertionError::Lt { left, right })
630    }
631}
632
633fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
634    if left <= right {
635        Ok(Default::default())
636    } else {
637        Err(ComparisonAssertionError::Le { left, right })
638    }
639}