1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{hex, I256, U256};
3use foundry_evm_core::{
4 abi::console::{format_units_int, format_units_uint},
5 backend::GLOBAL_FAIL_SLOT,
6 constants::CHEATCODE_ADDRESS,
7};
8use itertools::Itertools;
9use revm::context::JournalTr;
10use std::fmt::{Debug, Display};
11
12const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
13
14#[derive(Debug, thiserror::Error)]
15#[error("assertion failed")]
16struct SimpleAssertionError;
17
18#[derive(thiserror::Error, Debug)]
19enum ComparisonAssertionError<'a, T> {
20 Ne { left: &'a T, right: &'a T },
21 Eq { left: &'a T, right: &'a T },
22 Ge { left: &'a T, right: &'a T },
23 Gt { left: &'a T, right: &'a T },
24 Le { left: &'a T, right: &'a T },
25 Lt { left: &'a T, right: &'a T },
26}
27
28macro_rules! format_values {
29 ($self:expr, $format_fn:expr) => {
30 match $self {
31 Self::Ne { left, right } => format!("{} == {}", $format_fn(left), $format_fn(right)),
32 Self::Eq { left, right } => format!("{} != {}", $format_fn(left), $format_fn(right)),
33 Self::Ge { left, right } => format!("{} < {}", $format_fn(left), $format_fn(right)),
34 Self::Gt { left, right } => format!("{} <= {}", $format_fn(left), $format_fn(right)),
35 Self::Le { left, right } => format!("{} > {}", $format_fn(left), $format_fn(right)),
36 Self::Lt { left, right } => format!("{} >= {}", $format_fn(left), $format_fn(right)),
37 }
38 };
39}
40
41impl<T: Display> ComparisonAssertionError<'_, T> {
42 fn format_for_values(&self) -> String {
43 format_values!(self, T::to_string)
44 }
45}
46
47impl<T: Display> ComparisonAssertionError<'_, Vec<T>> {
48 fn format_for_arrays(&self) -> String {
49 let formatter = |v: &Vec<T>| format!("[{}]", v.iter().format(", "));
50 format_values!(self, formatter)
51 }
52}
53
54impl ComparisonAssertionError<'_, U256> {
55 fn format_with_decimals(&self, decimals: &U256) -> String {
56 let formatter = |v: &U256| format_units_uint(v, decimals);
57 format_values!(self, formatter)
58 }
59}
60
61impl ComparisonAssertionError<'_, I256> {
62 fn format_with_decimals(&self, decimals: &U256) -> String {
63 let formatter = |v: &I256| format_units_int(v, decimals);
64 format_values!(self, formatter)
65 }
66}
67
68#[derive(thiserror::Error, Debug)]
69#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
70struct EqAbsAssertionError<T, D> {
71 left: T,
72 right: T,
73 max_delta: D,
74 real_delta: D,
75}
76
77impl EqAbsAssertionError<U256, U256> {
78 fn format_with_decimals(&self, decimals: &U256) -> String {
79 format!(
80 "{} !~= {} (max delta: {}, real delta: {})",
81 format_units_uint(&self.left, decimals),
82 format_units_uint(&self.right, decimals),
83 format_units_uint(&self.max_delta, decimals),
84 format_units_uint(&self.real_delta, decimals),
85 )
86 }
87}
88
89impl EqAbsAssertionError<I256, U256> {
90 fn format_with_decimals(&self, decimals: &U256) -> String {
91 format!(
92 "{} !~= {} (max delta: {}, real delta: {})",
93 format_units_int(&self.left, decimals),
94 format_units_int(&self.right, decimals),
95 format_units_uint(&self.max_delta, decimals),
96 format_units_uint(&self.real_delta, decimals),
97 )
98 }
99}
100
101fn format_delta_percent(delta: &U256) -> String {
102 format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
103}
104
105#[derive(Debug)]
106enum EqRelDelta {
107 Defined(U256),
108 Undefined,
109}
110
111impl Display for EqRelDelta {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 match self {
114 Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
115 Self::Undefined => write!(f, "undefined"),
116 }
117 }
118}
119
120#[derive(thiserror::Error, Debug)]
121#[error(
122 "{left} !~= {right} (max delta: {}, real delta: {})",
123 format_delta_percent(max_delta),
124 real_delta
125)]
126struct EqRelAssertionFailure<T> {
127 left: T,
128 right: T,
129 max_delta: U256,
130 real_delta: EqRelDelta,
131}
132
133#[derive(thiserror::Error, Debug)]
134enum EqRelAssertionError<T> {
135 #[error(transparent)]
136 Failure(Box<EqRelAssertionFailure<T>>),
137 #[error("overflow in delta calculation")]
138 Overflow,
139}
140
141impl EqRelAssertionError<U256> {
142 fn format_with_decimals(&self, decimals: &U256) -> String {
143 match self {
144 Self::Failure(f) => format!(
145 "{} !~= {} (max delta: {}, real delta: {})",
146 format_units_uint(&f.left, decimals),
147 format_units_uint(&f.right, decimals),
148 format_delta_percent(&f.max_delta),
149 &f.real_delta,
150 ),
151 Self::Overflow => self.to_string(),
152 }
153 }
154}
155
156impl EqRelAssertionError<I256> {
157 fn format_with_decimals(&self, decimals: &U256) -> String {
158 match self {
159 Self::Failure(f) => format!(
160 "{} !~= {} (max delta: {}, real delta: {})",
161 format_units_int(&f.left, decimals),
162 format_units_int(&f.right, decimals),
163 format_delta_percent(&f.max_delta),
164 &f.real_delta,
165 ),
166 Self::Overflow => self.to_string(),
167 }
168 }
169}
170
171type ComparisonResult<'a, T> = Result<Vec<u8>, ComparisonAssertionError<'a, T>>;
172
173fn handle_assertion_result<ERR>(
174 result: core::result::Result<Vec<u8>, ERR>,
175 ccx: &mut CheatsCtxt,
176 executor: &mut dyn CheatcodesExecutor,
177 error_formatter: impl Fn(&ERR) -> String,
178 error_msg: Option<&str>,
179 format_error: bool,
180) -> Result {
181 match result {
182 Ok(_) => Ok(Default::default()),
183 Err(err) => {
184 let error_msg = error_msg.unwrap_or("assertion failed");
185 let msg = if format_error {
186 format!("{error_msg}: {}", error_formatter(&err))
187 } else {
188 error_msg.to_string()
189 };
190 if ccx.state.config.assertions_revert {
191 Err(msg.into())
192 } else {
193 executor.console_log(ccx, &msg);
194 ccx.ecx.journaled_state.sstore(
195 CHEATCODE_ADDRESS,
196 GLOBAL_FAIL_SLOT,
197 U256::from(1),
198 )?;
199 Ok(Default::default())
200 }
201 }
202 }
203}
204
205macro_rules! impl_assertions {
214 (|$($arg:ident),*| $body:expr, $format_error:literal, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
215 impl_assertions!(@args_tt |($($arg),*)| $body, |e| e.to_string(), $format_error, $(($no_error, $with_error),)*);
216 };
217 (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
218 impl_assertions!(@args_tt |($($arg),*)| $body, |e| e.to_string(), true, $(($no_error, $with_error),)*);
219 };
220 (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
221 impl_assertions!(@args_tt |($($arg),*)| $body, $error_formatter, true, $(($no_error, $with_error)),*);
222 };
223 (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $format_error:literal, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
226 $(
227 impl_assertions!(@impl $no_error, $with_error, $args, $body, $error_formatter, $format_error);
228 )*
229 };
230 (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr, $format_error:literal) => {
231 impl crate::Cheatcode for $no_error {
232 fn apply_full(
233 &self,
234 ccx: &mut CheatsCtxt,
235 executor: &mut dyn CheatcodesExecutor,
236 ) -> Result {
237 let Self { $($arg),* } = self;
238 handle_assertion_result($body, ccx, executor, $error_formatter, None, $format_error)
239 }
240 }
241
242 impl crate::Cheatcode for $with_error {
243 fn apply_full(
244 &self,
245 ccx: &mut CheatsCtxt,
246 executor: &mut dyn CheatcodesExecutor,
247 ) -> Result {
248 let Self { $($arg),*, error} = self;
249 handle_assertion_result($body, ccx, executor, $error_formatter, Some(error), $format_error)
250 }
251 }
252 };
253}
254
255impl_assertions! {
256 |condition| assert_true(*condition),
257 false,
258 (assertTrue_0Call, assertTrue_1Call),
259}
260
261impl_assertions! {
262 |condition| assert_false(*condition),
263 false,
264 (assertFalse_0Call, assertFalse_1Call),
265}
266
267impl_assertions! {
268 |left, right| assert_eq(left, right),
269 |e| e.format_for_values(),
270 (assertEq_0Call, assertEq_1Call),
271 (assertEq_2Call, assertEq_3Call),
272 (assertEq_4Call, assertEq_5Call),
273 (assertEq_6Call, assertEq_7Call),
274 (assertEq_8Call, assertEq_9Call),
275 (assertEq_10Call, assertEq_11Call),
276}
277
278impl_assertions! {
279 |left, right| assert_eq(&hex::encode_prefixed(left), &hex::encode_prefixed(right)),
280 |e| e.format_for_values(),
281 (assertEq_12Call, assertEq_13Call),
282}
283
284impl_assertions! {
285 |left, right| assert_eq(left, right),
286 |e| e.format_for_arrays(),
287 (assertEq_14Call, assertEq_15Call),
288 (assertEq_16Call, assertEq_17Call),
289 (assertEq_18Call, assertEq_19Call),
290 (assertEq_20Call, assertEq_21Call),
291 (assertEq_22Call, assertEq_23Call),
292 (assertEq_24Call, assertEq_25Call),
293}
294
295impl_assertions! {
296 |left, right| assert_eq(
297 &left.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
298 &right.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
299 ),
300 |e| e.format_for_arrays(),
301 (assertEq_26Call, assertEq_27Call),
302}
303
304impl_assertions! {
305 |left, right, decimals| assert_eq(left, right),
306 |e| e.format_with_decimals(decimals),
307 (assertEqDecimal_0Call, assertEqDecimal_1Call),
308 (assertEqDecimal_2Call, assertEqDecimal_3Call),
309}
310
311impl_assertions! {
312 |left, right| assert_not_eq(left, right),
313 |e| e.format_for_values(),
314 (assertNotEq_0Call, assertNotEq_1Call),
315 (assertNotEq_2Call, assertNotEq_3Call),
316 (assertNotEq_4Call, assertNotEq_5Call),
317 (assertNotEq_6Call, assertNotEq_7Call),
318 (assertNotEq_8Call, assertNotEq_9Call),
319 (assertNotEq_10Call, assertNotEq_11Call),
320}
321
322impl_assertions! {
323 |left, right| assert_not_eq(&hex::encode_prefixed(left), &hex::encode_prefixed(right)),
324 |e| e.format_for_values(),
325 (assertNotEq_12Call, assertNotEq_13Call),
326}
327
328impl_assertions! {
329 |left, right| assert_not_eq(left, right),
330 |e| e.format_for_arrays(),
331 (assertNotEq_14Call, assertNotEq_15Call),
332 (assertNotEq_16Call, assertNotEq_17Call),
333 (assertNotEq_18Call, assertNotEq_19Call),
334 (assertNotEq_20Call, assertNotEq_21Call),
335 (assertNotEq_22Call, assertNotEq_23Call),
336 (assertNotEq_24Call, assertNotEq_25Call),
337}
338
339impl_assertions! {
340 |left, right| assert_not_eq(
341 &left.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
342 &right.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
343 ),
344 |e| e.format_for_arrays(),
345 (assertNotEq_26Call, assertNotEq_27Call),
346}
347
348impl_assertions! {
349 |left, right, decimals| assert_not_eq(left, right),
350 |e| e.format_with_decimals(decimals),
351 (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
352 (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
353}
354
355impl_assertions! {
356 |left, right| assert_gt(left, right),
357 |e| e.format_for_values(),
358 (assertGt_0Call, assertGt_1Call),
359 (assertGt_2Call, assertGt_3Call),
360}
361
362impl_assertions! {
363 |left, right, decimals| assert_gt(left, right),
364 |e| e.format_with_decimals(decimals),
365 (assertGtDecimal_0Call, assertGtDecimal_1Call),
366 (assertGtDecimal_2Call, assertGtDecimal_3Call),
367}
368
369impl_assertions! {
370 |left, right| assert_ge(left, right),
371 |e| e.format_for_values(),
372 (assertGe_0Call, assertGe_1Call),
373 (assertGe_2Call, assertGe_3Call),
374}
375
376impl_assertions! {
377 |left, right, decimals| assert_ge(left, right),
378 |e| e.format_with_decimals(decimals),
379 (assertGeDecimal_0Call, assertGeDecimal_1Call),
380 (assertGeDecimal_2Call, assertGeDecimal_3Call),
381}
382
383impl_assertions! {
384 |left, right| assert_lt(left, right),
385 |e| e.format_for_values(),
386 (assertLt_0Call, assertLt_1Call),
387 (assertLt_2Call, assertLt_3Call),
388}
389
390impl_assertions! {
391 |left, right, decimals| assert_lt(left, right),
392 |e| e.format_with_decimals(decimals),
393 (assertLtDecimal_0Call, assertLtDecimal_1Call),
394 (assertLtDecimal_2Call, assertLtDecimal_3Call),
395}
396
397impl_assertions! {
398 |left, right| assert_le(left, right),
399 |e| e.format_for_values(),
400 (assertLe_0Call, assertLe_1Call),
401 (assertLe_2Call, assertLe_3Call),
402}
403
404impl_assertions! {
405 |left, right, decimals| assert_le(left, right),
406 |e| e.format_with_decimals(decimals),
407 (assertLeDecimal_0Call, assertLeDecimal_1Call),
408 (assertLeDecimal_2Call, assertLeDecimal_3Call),
409}
410
411impl_assertions! {
412 |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
413 (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
414}
415
416impl_assertions! {
417 |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
418 (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
419}
420
421impl_assertions! {
422 |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
423 |e| e.format_with_decimals(decimals),
424 (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
425}
426
427impl_assertions! {
428 |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
429 |e| e.format_with_decimals(decimals),
430 (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
431}
432
433impl_assertions! {
434 |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
435 (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
436}
437
438impl_assertions! {
439 |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
440 (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
441}
442
443impl_assertions! {
444 |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
445 |e| e.format_with_decimals(decimals),
446 (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
447}
448
449impl_assertions! {
450 |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
451 |e| e.format_with_decimals(decimals),
452 (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
453}
454
455fn assert_true(condition: bool) -> Result<Vec<u8>, SimpleAssertionError> {
456 if condition {
457 Ok(Default::default())
458 } else {
459 Err(SimpleAssertionError)
460 }
461}
462
463fn assert_false(condition: bool) -> Result<Vec<u8>, SimpleAssertionError> {
464 if !condition {
465 Ok(Default::default())
466 } else {
467 Err(SimpleAssertionError)
468 }
469}
470
471fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
472 if left == right {
473 Ok(Default::default())
474 } else {
475 Err(ComparisonAssertionError::Eq { left, right })
476 }
477}
478
479fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
480 if left != right {
481 Ok(Default::default())
482 } else {
483 Err(ComparisonAssertionError::Ne { left, right })
484 }
485}
486
487fn get_delta_uint(left: U256, right: U256) -> U256 {
488 if left > right {
489 left - right
490 } else {
491 right - left
492 }
493}
494
495fn get_delta_int(left: I256, right: I256) -> U256 {
496 let (left_sign, left_abs) = left.into_sign_and_abs();
497 let (right_sign, right_abs) = right.into_sign_and_abs();
498
499 if left_sign == right_sign {
500 if left_abs > right_abs {
501 left_abs - right_abs
502 } else {
503 right_abs - left_abs
504 }
505 } else {
506 left_abs + right_abs
507 }
508}
509
510fn uint_assert_approx_eq_abs(
511 left: U256,
512 right: U256,
513 max_delta: U256,
514) -> Result<Vec<u8>, Box<EqAbsAssertionError<U256, U256>>> {
515 let delta = get_delta_uint(left, right);
516
517 if delta <= max_delta {
518 Ok(Default::default())
519 } else {
520 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
521 }
522}
523
524fn int_assert_approx_eq_abs(
525 left: I256,
526 right: I256,
527 max_delta: U256,
528) -> Result<Vec<u8>, Box<EqAbsAssertionError<I256, U256>>> {
529 let delta = get_delta_int(left, right);
530
531 if delta <= max_delta {
532 Ok(Default::default())
533 } else {
534 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
535 }
536}
537
538fn uint_assert_approx_eq_rel(
539 left: U256,
540 right: U256,
541 max_delta: U256,
542) -> Result<Vec<u8>, EqRelAssertionError<U256>> {
543 if right.is_zero() {
544 if left.is_zero() {
545 return Ok(Default::default())
546 } else {
547 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
548 left,
549 right,
550 max_delta,
551 real_delta: EqRelDelta::Undefined,
552 })))
553 };
554 }
555
556 let delta = get_delta_uint(left, right)
557 .checked_mul(U256::pow(U256::from(10), EQ_REL_DELTA_RESOLUTION))
558 .ok_or(EqRelAssertionError::Overflow)? /
559 right;
560
561 if delta <= max_delta {
562 Ok(Default::default())
563 } else {
564 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
565 left,
566 right,
567 max_delta,
568 real_delta: EqRelDelta::Defined(delta),
569 })))
570 }
571}
572
573fn int_assert_approx_eq_rel(
574 left: I256,
575 right: I256,
576 max_delta: U256,
577) -> Result<Vec<u8>, EqRelAssertionError<I256>> {
578 if right.is_zero() {
579 if left.is_zero() {
580 return Ok(Default::default())
581 } else {
582 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
583 left,
584 right,
585 max_delta,
586 real_delta: EqRelDelta::Undefined,
587 })))
588 }
589 }
590
591 let (_, abs_right) = right.into_sign_and_abs();
592 let delta = get_delta_int(left, right)
593 .checked_mul(U256::pow(U256::from(10), EQ_REL_DELTA_RESOLUTION))
594 .ok_or(EqRelAssertionError::Overflow)? /
595 abs_right;
596
597 if delta <= max_delta {
598 Ok(Default::default())
599 } else {
600 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
601 left,
602 right,
603 max_delta,
604 real_delta: EqRelDelta::Defined(delta),
605 })))
606 }
607}
608
609fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
610 if left > right {
611 Ok(Default::default())
612 } else {
613 Err(ComparisonAssertionError::Gt { left, right })
614 }
615}
616
617fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
618 if left >= right {
619 Ok(Default::default())
620 } else {
621 Err(ComparisonAssertionError::Ge { left, right })
622 }
623}
624
625fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
626 if left < right {
627 Ok(Default::default())
628 } else {
629 Err(ComparisonAssertionError::Lt { left, right })
630 }
631}
632
633fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
634 if left <= right {
635 Ok(Default::default())
636 } else {
637 Err(ComparisonAssertionError::Le { left, right })
638 }
639}