foundry_cheatcodes/test/
assert.rs

1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{hex, I256, U256};
3use foundry_evm_core::{
4    abi::console::{format_units_int, format_units_uint},
5    backend::GLOBAL_FAIL_SLOT,
6    constants::CHEATCODE_ADDRESS,
7};
8use itertools::Itertools;
9use std::fmt::{Debug, Display};
10
11const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
12
13#[derive(Debug, thiserror::Error)]
14#[error("assertion failed")]
15struct SimpleAssertionError;
16
17#[derive(thiserror::Error, Debug)]
18enum ComparisonAssertionError<'a, T> {
19    Ne { left: &'a T, right: &'a T },
20    Eq { left: &'a T, right: &'a T },
21    Ge { left: &'a T, right: &'a T },
22    Gt { left: &'a T, right: &'a T },
23    Le { left: &'a T, right: &'a T },
24    Lt { left: &'a T, right: &'a T },
25}
26
27macro_rules! format_values {
28    ($self:expr, $format_fn:expr) => {
29        match $self {
30            Self::Ne { left, right } => format!("{} == {}", $format_fn(left), $format_fn(right)),
31            Self::Eq { left, right } => format!("{} != {}", $format_fn(left), $format_fn(right)),
32            Self::Ge { left, right } => format!("{} < {}", $format_fn(left), $format_fn(right)),
33            Self::Gt { left, right } => format!("{} <= {}", $format_fn(left), $format_fn(right)),
34            Self::Le { left, right } => format!("{} > {}", $format_fn(left), $format_fn(right)),
35            Self::Lt { left, right } => format!("{} >= {}", $format_fn(left), $format_fn(right)),
36        }
37    };
38}
39
40impl<T: Display> ComparisonAssertionError<'_, T> {
41    fn format_for_values(&self) -> String {
42        format_values!(self, T::to_string)
43    }
44}
45
46impl<T: Display> ComparisonAssertionError<'_, Vec<T>> {
47    fn format_for_arrays(&self) -> String {
48        let formatter = |v: &Vec<T>| format!("[{}]", v.iter().format(", "));
49        format_values!(self, formatter)
50    }
51}
52
53impl ComparisonAssertionError<'_, U256> {
54    fn format_with_decimals(&self, decimals: &U256) -> String {
55        let formatter = |v: &U256| format_units_uint(v, decimals);
56        format_values!(self, formatter)
57    }
58}
59
60impl ComparisonAssertionError<'_, I256> {
61    fn format_with_decimals(&self, decimals: &U256) -> String {
62        let formatter = |v: &I256| format_units_int(v, decimals);
63        format_values!(self, formatter)
64    }
65}
66
67#[derive(thiserror::Error, Debug)]
68#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
69struct EqAbsAssertionError<T, D> {
70    left: T,
71    right: T,
72    max_delta: D,
73    real_delta: D,
74}
75
76impl EqAbsAssertionError<U256, U256> {
77    fn format_with_decimals(&self, decimals: &U256) -> String {
78        format!(
79            "{} !~= {} (max delta: {}, real delta: {})",
80            format_units_uint(&self.left, decimals),
81            format_units_uint(&self.right, decimals),
82            format_units_uint(&self.max_delta, decimals),
83            format_units_uint(&self.real_delta, decimals),
84        )
85    }
86}
87
88impl EqAbsAssertionError<I256, U256> {
89    fn format_with_decimals(&self, decimals: &U256) -> String {
90        format!(
91            "{} !~= {} (max delta: {}, real delta: {})",
92            format_units_int(&self.left, decimals),
93            format_units_int(&self.right, decimals),
94            format_units_uint(&self.max_delta, decimals),
95            format_units_uint(&self.real_delta, decimals),
96        )
97    }
98}
99
100fn format_delta_percent(delta: &U256) -> String {
101    format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
102}
103
104#[derive(Debug)]
105enum EqRelDelta {
106    Defined(U256),
107    Undefined,
108}
109
110impl Display for EqRelDelta {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        match self {
113            Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
114            Self::Undefined => write!(f, "undefined"),
115        }
116    }
117}
118
119#[derive(thiserror::Error, Debug)]
120#[error(
121    "{left} !~= {right} (max delta: {}, real delta: {})",
122    format_delta_percent(max_delta),
123    real_delta
124)]
125struct EqRelAssertionFailure<T> {
126    left: T,
127    right: T,
128    max_delta: U256,
129    real_delta: EqRelDelta,
130}
131
132#[derive(thiserror::Error, Debug)]
133enum EqRelAssertionError<T> {
134    #[error(transparent)]
135    Failure(Box<EqRelAssertionFailure<T>>),
136    #[error("overflow in delta calculation")]
137    Overflow,
138}
139
140impl EqRelAssertionError<U256> {
141    fn format_with_decimals(&self, decimals: &U256) -> String {
142        match self {
143            Self::Failure(f) => format!(
144                "{} !~= {} (max delta: {}, real delta: {})",
145                format_units_uint(&f.left, decimals),
146                format_units_uint(&f.right, decimals),
147                format_delta_percent(&f.max_delta),
148                &f.real_delta,
149            ),
150            Self::Overflow => self.to_string(),
151        }
152    }
153}
154
155impl EqRelAssertionError<I256> {
156    fn format_with_decimals(&self, decimals: &U256) -> String {
157        match self {
158            Self::Failure(f) => format!(
159                "{} !~= {} (max delta: {}, real delta: {})",
160                format_units_int(&f.left, decimals),
161                format_units_int(&f.right, decimals),
162                format_delta_percent(&f.max_delta),
163                &f.real_delta,
164            ),
165            Self::Overflow => self.to_string(),
166        }
167    }
168}
169
170type ComparisonResult<'a, T> = Result<Vec<u8>, ComparisonAssertionError<'a, T>>;
171
172fn handle_assertion_result<ERR>(
173    result: core::result::Result<Vec<u8>, ERR>,
174    ccx: &mut CheatsCtxt,
175    executor: &mut dyn CheatcodesExecutor,
176    error_formatter: impl Fn(&ERR) -> String,
177    error_msg: Option<&str>,
178    format_error: bool,
179) -> Result {
180    match result {
181        Ok(_) => Ok(Default::default()),
182        Err(err) => {
183            let error_msg = error_msg.unwrap_or("assertion failed");
184            let msg = if format_error {
185                format!("{error_msg}: {}", error_formatter(&err))
186            } else {
187                error_msg.to_string()
188            };
189            if ccx.state.config.assertions_revert {
190                Err(msg.into())
191            } else {
192                executor.console_log(ccx, &msg);
193                ccx.ecx.sstore(CHEATCODE_ADDRESS, GLOBAL_FAIL_SLOT, U256::from(1))?;
194                Ok(Default::default())
195            }
196        }
197    }
198}
199
200/// Implements [crate::Cheatcode] for pairs of cheatcodes.
201///
202/// Accepts a list of pairs of cheatcodes, where the first cheatcode is the one that doesn't contain
203/// a custom error message, and the second one contains it at `error` field.
204///
205/// Passed `args` are the common arguments for both cheatcode structs (excluding `error` field).
206///
207/// Macro also accepts an optional closure that formats the error returned by the assertion.
208macro_rules! impl_assertions {
209    (|$($arg:ident),*| $body:expr, $format_error:literal, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
210        impl_assertions!(@args_tt |($($arg),*)| $body, |e| e.to_string(), $format_error, $(($no_error, $with_error),)*);
211    };
212    (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
213        impl_assertions!(@args_tt |($($arg),*)| $body, |e| e.to_string(), true, $(($no_error, $with_error),)*);
214    };
215    (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
216        impl_assertions!(@args_tt |($($arg),*)| $body, $error_formatter, true, $(($no_error, $with_error)),*);
217    };
218    // We convert args to `tt` and later expand them back into tuple to allow usage of expanded args inside of
219    // each assertion type context.
220    (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $format_error:literal, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
221        $(
222            impl_assertions!(@impl $no_error, $with_error, $args, $body, $error_formatter, $format_error);
223        )*
224    };
225    (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr, $format_error:literal) => {
226        impl crate::Cheatcode for $no_error {
227            fn apply_full(
228                &self,
229                ccx: &mut CheatsCtxt,
230                executor: &mut dyn CheatcodesExecutor,
231            ) -> Result {
232                let Self { $($arg),* } = self;
233                handle_assertion_result($body, ccx, executor, $error_formatter, None, $format_error)
234            }
235        }
236
237        impl crate::Cheatcode for $with_error {
238            fn apply_full(
239                &self,
240                ccx: &mut CheatsCtxt,
241                executor: &mut dyn CheatcodesExecutor,
242            ) -> Result {
243                let Self { $($arg),*, error} = self;
244                handle_assertion_result($body, ccx, executor, $error_formatter, Some(error), $format_error)
245            }
246        }
247    };
248}
249
250impl_assertions! {
251    |condition| assert_true(*condition),
252    false,
253    (assertTrue_0Call, assertTrue_1Call),
254}
255
256impl_assertions! {
257    |condition| assert_false(*condition),
258    false,
259    (assertFalse_0Call, assertFalse_1Call),
260}
261
262impl_assertions! {
263    |left, right| assert_eq(left, right),
264    |e| e.format_for_values(),
265    (assertEq_0Call, assertEq_1Call),
266    (assertEq_2Call, assertEq_3Call),
267    (assertEq_4Call, assertEq_5Call),
268    (assertEq_6Call, assertEq_7Call),
269    (assertEq_8Call, assertEq_9Call),
270    (assertEq_10Call, assertEq_11Call),
271}
272
273impl_assertions! {
274    |left, right| assert_eq(&hex::encode_prefixed(left), &hex::encode_prefixed(right)),
275    |e| e.format_for_values(),
276    (assertEq_12Call, assertEq_13Call),
277}
278
279impl_assertions! {
280    |left, right| assert_eq(left, right),
281    |e| e.format_for_arrays(),
282    (assertEq_14Call, assertEq_15Call),
283    (assertEq_16Call, assertEq_17Call),
284    (assertEq_18Call, assertEq_19Call),
285    (assertEq_20Call, assertEq_21Call),
286    (assertEq_22Call, assertEq_23Call),
287    (assertEq_24Call, assertEq_25Call),
288}
289
290impl_assertions! {
291    |left, right| assert_eq(
292        &left.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
293        &right.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
294    ),
295    |e| e.format_for_arrays(),
296    (assertEq_26Call, assertEq_27Call),
297}
298
299impl_assertions! {
300    |left, right, decimals| assert_eq(left, right),
301    |e| e.format_with_decimals(decimals),
302    (assertEqDecimal_0Call, assertEqDecimal_1Call),
303    (assertEqDecimal_2Call, assertEqDecimal_3Call),
304}
305
306impl_assertions! {
307    |left, right| assert_not_eq(left, right),
308    |e| e.format_for_values(),
309    (assertNotEq_0Call, assertNotEq_1Call),
310    (assertNotEq_2Call, assertNotEq_3Call),
311    (assertNotEq_4Call, assertNotEq_5Call),
312    (assertNotEq_6Call, assertNotEq_7Call),
313    (assertNotEq_8Call, assertNotEq_9Call),
314    (assertNotEq_10Call, assertNotEq_11Call),
315}
316
317impl_assertions! {
318    |left, right| assert_not_eq(&hex::encode_prefixed(left), &hex::encode_prefixed(right)),
319    |e| e.format_for_values(),
320    (assertNotEq_12Call, assertNotEq_13Call),
321}
322
323impl_assertions! {
324    |left, right| assert_not_eq(left, right),
325    |e| e.format_for_arrays(),
326    (assertNotEq_14Call, assertNotEq_15Call),
327    (assertNotEq_16Call, assertNotEq_17Call),
328    (assertNotEq_18Call, assertNotEq_19Call),
329    (assertNotEq_20Call, assertNotEq_21Call),
330    (assertNotEq_22Call, assertNotEq_23Call),
331    (assertNotEq_24Call, assertNotEq_25Call),
332}
333
334impl_assertions! {
335    |left, right| assert_not_eq(
336        &left.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
337        &right.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
338    ),
339    |e| e.format_for_arrays(),
340    (assertNotEq_26Call, assertNotEq_27Call),
341}
342
343impl_assertions! {
344    |left, right, decimals| assert_not_eq(left, right),
345    |e| e.format_with_decimals(decimals),
346    (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
347    (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
348}
349
350impl_assertions! {
351    |left, right| assert_gt(left, right),
352    |e| e.format_for_values(),
353    (assertGt_0Call, assertGt_1Call),
354    (assertGt_2Call, assertGt_3Call),
355}
356
357impl_assertions! {
358    |left, right, decimals| assert_gt(left, right),
359    |e| e.format_with_decimals(decimals),
360    (assertGtDecimal_0Call, assertGtDecimal_1Call),
361    (assertGtDecimal_2Call, assertGtDecimal_3Call),
362}
363
364impl_assertions! {
365    |left, right| assert_ge(left, right),
366    |e| e.format_for_values(),
367    (assertGe_0Call, assertGe_1Call),
368    (assertGe_2Call, assertGe_3Call),
369}
370
371impl_assertions! {
372    |left, right, decimals| assert_ge(left, right),
373    |e| e.format_with_decimals(decimals),
374    (assertGeDecimal_0Call, assertGeDecimal_1Call),
375    (assertGeDecimal_2Call, assertGeDecimal_3Call),
376}
377
378impl_assertions! {
379    |left, right| assert_lt(left, right),
380    |e| e.format_for_values(),
381    (assertLt_0Call, assertLt_1Call),
382    (assertLt_2Call, assertLt_3Call),
383}
384
385impl_assertions! {
386    |left, right, decimals| assert_lt(left, right),
387    |e| e.format_with_decimals(decimals),
388    (assertLtDecimal_0Call, assertLtDecimal_1Call),
389    (assertLtDecimal_2Call, assertLtDecimal_3Call),
390}
391
392impl_assertions! {
393    |left, right| assert_le(left, right),
394    |e| e.format_for_values(),
395    (assertLe_0Call, assertLe_1Call),
396    (assertLe_2Call, assertLe_3Call),
397}
398
399impl_assertions! {
400    |left, right, decimals| assert_le(left, right),
401    |e| e.format_with_decimals(decimals),
402    (assertLeDecimal_0Call, assertLeDecimal_1Call),
403    (assertLeDecimal_2Call, assertLeDecimal_3Call),
404}
405
406impl_assertions! {
407    |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
408    (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
409}
410
411impl_assertions! {
412    |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
413    (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
414}
415
416impl_assertions! {
417    |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
418    |e| e.format_with_decimals(decimals),
419    (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
420}
421
422impl_assertions! {
423    |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
424    |e| e.format_with_decimals(decimals),
425    (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
426}
427
428impl_assertions! {
429    |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
430    (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
431}
432
433impl_assertions! {
434    |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
435    (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
436}
437
438impl_assertions! {
439    |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
440    |e| e.format_with_decimals(decimals),
441    (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
442}
443
444impl_assertions! {
445    |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
446    |e| e.format_with_decimals(decimals),
447    (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
448}
449
450fn assert_true(condition: bool) -> Result<Vec<u8>, SimpleAssertionError> {
451    if condition {
452        Ok(Default::default())
453    } else {
454        Err(SimpleAssertionError)
455    }
456}
457
458fn assert_false(condition: bool) -> Result<Vec<u8>, SimpleAssertionError> {
459    if !condition {
460        Ok(Default::default())
461    } else {
462        Err(SimpleAssertionError)
463    }
464}
465
466fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
467    if left == right {
468        Ok(Default::default())
469    } else {
470        Err(ComparisonAssertionError::Eq { left, right })
471    }
472}
473
474fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
475    if left != right {
476        Ok(Default::default())
477    } else {
478        Err(ComparisonAssertionError::Ne { left, right })
479    }
480}
481
482fn get_delta_uint(left: U256, right: U256) -> U256 {
483    if left > right {
484        left - right
485    } else {
486        right - left
487    }
488}
489
490fn get_delta_int(left: I256, right: I256) -> U256 {
491    let (left_sign, left_abs) = left.into_sign_and_abs();
492    let (right_sign, right_abs) = right.into_sign_and_abs();
493
494    if left_sign == right_sign {
495        if left_abs > right_abs {
496            left_abs - right_abs
497        } else {
498            right_abs - left_abs
499        }
500    } else {
501        left_abs + right_abs
502    }
503}
504
505fn uint_assert_approx_eq_abs(
506    left: U256,
507    right: U256,
508    max_delta: U256,
509) -> Result<Vec<u8>, Box<EqAbsAssertionError<U256, U256>>> {
510    let delta = get_delta_uint(left, right);
511
512    if delta <= max_delta {
513        Ok(Default::default())
514    } else {
515        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
516    }
517}
518
519fn int_assert_approx_eq_abs(
520    left: I256,
521    right: I256,
522    max_delta: U256,
523) -> Result<Vec<u8>, Box<EqAbsAssertionError<I256, U256>>> {
524    let delta = get_delta_int(left, right);
525
526    if delta <= max_delta {
527        Ok(Default::default())
528    } else {
529        Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
530    }
531}
532
533fn uint_assert_approx_eq_rel(
534    left: U256,
535    right: U256,
536    max_delta: U256,
537) -> Result<Vec<u8>, EqRelAssertionError<U256>> {
538    if right.is_zero() {
539        if left.is_zero() {
540            return Ok(Default::default())
541        } else {
542            return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
543                left,
544                right,
545                max_delta,
546                real_delta: EqRelDelta::Undefined,
547            })))
548        };
549    }
550
551    let delta = get_delta_uint(left, right)
552        .checked_mul(U256::pow(U256::from(10), EQ_REL_DELTA_RESOLUTION))
553        .ok_or(EqRelAssertionError::Overflow)? /
554        right;
555
556    if delta <= max_delta {
557        Ok(Default::default())
558    } else {
559        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
560            left,
561            right,
562            max_delta,
563            real_delta: EqRelDelta::Defined(delta),
564        })))
565    }
566}
567
568fn int_assert_approx_eq_rel(
569    left: I256,
570    right: I256,
571    max_delta: U256,
572) -> Result<Vec<u8>, EqRelAssertionError<I256>> {
573    if right.is_zero() {
574        if left.is_zero() {
575            return Ok(Default::default())
576        } else {
577            return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
578                left,
579                right,
580                max_delta,
581                real_delta: EqRelDelta::Undefined,
582            })))
583        }
584    }
585
586    let (_, abs_right) = right.into_sign_and_abs();
587    let delta = get_delta_int(left, right)
588        .checked_mul(U256::pow(U256::from(10), EQ_REL_DELTA_RESOLUTION))
589        .ok_or(EqRelAssertionError::Overflow)? /
590        abs_right;
591
592    if delta <= max_delta {
593        Ok(Default::default())
594    } else {
595        Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
596            left,
597            right,
598            max_delta,
599            real_delta: EqRelDelta::Defined(delta),
600        })))
601    }
602}
603
604fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
605    if left > right {
606        Ok(Default::default())
607    } else {
608        Err(ComparisonAssertionError::Gt { left, right })
609    }
610}
611
612fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
613    if left >= right {
614        Ok(Default::default())
615    } else {
616        Err(ComparisonAssertionError::Ge { left, right })
617    }
618}
619
620fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
621    if left < right {
622        Ok(Default::default())
623    } else {
624        Err(ComparisonAssertionError::Lt { left, right })
625    }
626}
627
628fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
629    if left <= right {
630        Ok(Default::default())
631    } else {
632        Err(ComparisonAssertionError::Le { left, right })
633    }
634}