1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{hex, I256, U256};
3use foundry_evm_core::{
4 abi::console::{format_units_int, format_units_uint},
5 backend::GLOBAL_FAIL_SLOT,
6 constants::CHEATCODE_ADDRESS,
7};
8use itertools::Itertools;
9use std::fmt::{Debug, Display};
10
11const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
12
13#[derive(Debug, thiserror::Error)]
14#[error("assertion failed")]
15struct SimpleAssertionError;
16
17#[derive(thiserror::Error, Debug)]
18enum ComparisonAssertionError<'a, T> {
19 Ne { left: &'a T, right: &'a T },
20 Eq { left: &'a T, right: &'a T },
21 Ge { left: &'a T, right: &'a T },
22 Gt { left: &'a T, right: &'a T },
23 Le { left: &'a T, right: &'a T },
24 Lt { left: &'a T, right: &'a T },
25}
26
27macro_rules! format_values {
28 ($self:expr, $format_fn:expr) => {
29 match $self {
30 Self::Ne { left, right } => format!("{} == {}", $format_fn(left), $format_fn(right)),
31 Self::Eq { left, right } => format!("{} != {}", $format_fn(left), $format_fn(right)),
32 Self::Ge { left, right } => format!("{} < {}", $format_fn(left), $format_fn(right)),
33 Self::Gt { left, right } => format!("{} <= {}", $format_fn(left), $format_fn(right)),
34 Self::Le { left, right } => format!("{} > {}", $format_fn(left), $format_fn(right)),
35 Self::Lt { left, right } => format!("{} >= {}", $format_fn(left), $format_fn(right)),
36 }
37 };
38}
39
40impl<T: Display> ComparisonAssertionError<'_, T> {
41 fn format_for_values(&self) -> String {
42 format_values!(self, T::to_string)
43 }
44}
45
46impl<T: Display> ComparisonAssertionError<'_, Vec<T>> {
47 fn format_for_arrays(&self) -> String {
48 let formatter = |v: &Vec<T>| format!("[{}]", v.iter().format(", "));
49 format_values!(self, formatter)
50 }
51}
52
53impl ComparisonAssertionError<'_, U256> {
54 fn format_with_decimals(&self, decimals: &U256) -> String {
55 let formatter = |v: &U256| format_units_uint(v, decimals);
56 format_values!(self, formatter)
57 }
58}
59
60impl ComparisonAssertionError<'_, I256> {
61 fn format_with_decimals(&self, decimals: &U256) -> String {
62 let formatter = |v: &I256| format_units_int(v, decimals);
63 format_values!(self, formatter)
64 }
65}
66
67#[derive(thiserror::Error, Debug)]
68#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
69struct EqAbsAssertionError<T, D> {
70 left: T,
71 right: T,
72 max_delta: D,
73 real_delta: D,
74}
75
76impl EqAbsAssertionError<U256, U256> {
77 fn format_with_decimals(&self, decimals: &U256) -> String {
78 format!(
79 "{} !~= {} (max delta: {}, real delta: {})",
80 format_units_uint(&self.left, decimals),
81 format_units_uint(&self.right, decimals),
82 format_units_uint(&self.max_delta, decimals),
83 format_units_uint(&self.real_delta, decimals),
84 )
85 }
86}
87
88impl EqAbsAssertionError<I256, U256> {
89 fn format_with_decimals(&self, decimals: &U256) -> String {
90 format!(
91 "{} !~= {} (max delta: {}, real delta: {})",
92 format_units_int(&self.left, decimals),
93 format_units_int(&self.right, decimals),
94 format_units_uint(&self.max_delta, decimals),
95 format_units_uint(&self.real_delta, decimals),
96 )
97 }
98}
99
100fn format_delta_percent(delta: &U256) -> String {
101 format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
102}
103
104#[derive(Debug)]
105enum EqRelDelta {
106 Defined(U256),
107 Undefined,
108}
109
110impl Display for EqRelDelta {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 match self {
113 Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
114 Self::Undefined => write!(f, "undefined"),
115 }
116 }
117}
118
119#[derive(thiserror::Error, Debug)]
120#[error(
121 "{left} !~= {right} (max delta: {}, real delta: {})",
122 format_delta_percent(max_delta),
123 real_delta
124)]
125struct EqRelAssertionFailure<T> {
126 left: T,
127 right: T,
128 max_delta: U256,
129 real_delta: EqRelDelta,
130}
131
132#[derive(thiserror::Error, Debug)]
133enum EqRelAssertionError<T> {
134 #[error(transparent)]
135 Failure(Box<EqRelAssertionFailure<T>>),
136 #[error("overflow in delta calculation")]
137 Overflow,
138}
139
140impl EqRelAssertionError<U256> {
141 fn format_with_decimals(&self, decimals: &U256) -> String {
142 match self {
143 Self::Failure(f) => format!(
144 "{} !~= {} (max delta: {}, real delta: {})",
145 format_units_uint(&f.left, decimals),
146 format_units_uint(&f.right, decimals),
147 format_delta_percent(&f.max_delta),
148 &f.real_delta,
149 ),
150 Self::Overflow => self.to_string(),
151 }
152 }
153}
154
155impl EqRelAssertionError<I256> {
156 fn format_with_decimals(&self, decimals: &U256) -> String {
157 match self {
158 Self::Failure(f) => format!(
159 "{} !~= {} (max delta: {}, real delta: {})",
160 format_units_int(&f.left, decimals),
161 format_units_int(&f.right, decimals),
162 format_delta_percent(&f.max_delta),
163 &f.real_delta,
164 ),
165 Self::Overflow => self.to_string(),
166 }
167 }
168}
169
170type ComparisonResult<'a, T> = Result<Vec<u8>, ComparisonAssertionError<'a, T>>;
171
172fn handle_assertion_result<ERR>(
173 result: core::result::Result<Vec<u8>, ERR>,
174 ccx: &mut CheatsCtxt,
175 executor: &mut dyn CheatcodesExecutor,
176 error_formatter: impl Fn(&ERR) -> String,
177 error_msg: Option<&str>,
178 format_error: bool,
179) -> Result {
180 match result {
181 Ok(_) => Ok(Default::default()),
182 Err(err) => {
183 let error_msg = error_msg.unwrap_or("assertion failed");
184 let msg = if format_error {
185 format!("{error_msg}: {}", error_formatter(&err))
186 } else {
187 error_msg.to_string()
188 };
189 if ccx.state.config.assertions_revert {
190 Err(msg.into())
191 } else {
192 executor.console_log(ccx, &msg);
193 ccx.ecx.sstore(CHEATCODE_ADDRESS, GLOBAL_FAIL_SLOT, U256::from(1))?;
194 Ok(Default::default())
195 }
196 }
197 }
198}
199
200macro_rules! impl_assertions {
209 (|$($arg:ident),*| $body:expr, $format_error:literal, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
210 impl_assertions!(@args_tt |($($arg),*)| $body, |e| e.to_string(), $format_error, $(($no_error, $with_error),)*);
211 };
212 (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
213 impl_assertions!(@args_tt |($($arg),*)| $body, |e| e.to_string(), true, $(($no_error, $with_error),)*);
214 };
215 (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
216 impl_assertions!(@args_tt |($($arg),*)| $body, $error_formatter, true, $(($no_error, $with_error)),*);
217 };
218 (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $format_error:literal, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
221 $(
222 impl_assertions!(@impl $no_error, $with_error, $args, $body, $error_formatter, $format_error);
223 )*
224 };
225 (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr, $format_error:literal) => {
226 impl crate::Cheatcode for $no_error {
227 fn apply_full(
228 &self,
229 ccx: &mut CheatsCtxt,
230 executor: &mut dyn CheatcodesExecutor,
231 ) -> Result {
232 let Self { $($arg),* } = self;
233 handle_assertion_result($body, ccx, executor, $error_formatter, None, $format_error)
234 }
235 }
236
237 impl crate::Cheatcode for $with_error {
238 fn apply_full(
239 &self,
240 ccx: &mut CheatsCtxt,
241 executor: &mut dyn CheatcodesExecutor,
242 ) -> Result {
243 let Self { $($arg),*, error} = self;
244 handle_assertion_result($body, ccx, executor, $error_formatter, Some(error), $format_error)
245 }
246 }
247 };
248}
249
250impl_assertions! {
251 |condition| assert_true(*condition),
252 false,
253 (assertTrue_0Call, assertTrue_1Call),
254}
255
256impl_assertions! {
257 |condition| assert_false(*condition),
258 false,
259 (assertFalse_0Call, assertFalse_1Call),
260}
261
262impl_assertions! {
263 |left, right| assert_eq(left, right),
264 |e| e.format_for_values(),
265 (assertEq_0Call, assertEq_1Call),
266 (assertEq_2Call, assertEq_3Call),
267 (assertEq_4Call, assertEq_5Call),
268 (assertEq_6Call, assertEq_7Call),
269 (assertEq_8Call, assertEq_9Call),
270 (assertEq_10Call, assertEq_11Call),
271}
272
273impl_assertions! {
274 |left, right| assert_eq(&hex::encode_prefixed(left), &hex::encode_prefixed(right)),
275 |e| e.format_for_values(),
276 (assertEq_12Call, assertEq_13Call),
277}
278
279impl_assertions! {
280 |left, right| assert_eq(left, right),
281 |e| e.format_for_arrays(),
282 (assertEq_14Call, assertEq_15Call),
283 (assertEq_16Call, assertEq_17Call),
284 (assertEq_18Call, assertEq_19Call),
285 (assertEq_20Call, assertEq_21Call),
286 (assertEq_22Call, assertEq_23Call),
287 (assertEq_24Call, assertEq_25Call),
288}
289
290impl_assertions! {
291 |left, right| assert_eq(
292 &left.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
293 &right.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
294 ),
295 |e| e.format_for_arrays(),
296 (assertEq_26Call, assertEq_27Call),
297}
298
299impl_assertions! {
300 |left, right, decimals| assert_eq(left, right),
301 |e| e.format_with_decimals(decimals),
302 (assertEqDecimal_0Call, assertEqDecimal_1Call),
303 (assertEqDecimal_2Call, assertEqDecimal_3Call),
304}
305
306impl_assertions! {
307 |left, right| assert_not_eq(left, right),
308 |e| e.format_for_values(),
309 (assertNotEq_0Call, assertNotEq_1Call),
310 (assertNotEq_2Call, assertNotEq_3Call),
311 (assertNotEq_4Call, assertNotEq_5Call),
312 (assertNotEq_6Call, assertNotEq_7Call),
313 (assertNotEq_8Call, assertNotEq_9Call),
314 (assertNotEq_10Call, assertNotEq_11Call),
315}
316
317impl_assertions! {
318 |left, right| assert_not_eq(&hex::encode_prefixed(left), &hex::encode_prefixed(right)),
319 |e| e.format_for_values(),
320 (assertNotEq_12Call, assertNotEq_13Call),
321}
322
323impl_assertions! {
324 |left, right| assert_not_eq(left, right),
325 |e| e.format_for_arrays(),
326 (assertNotEq_14Call, assertNotEq_15Call),
327 (assertNotEq_16Call, assertNotEq_17Call),
328 (assertNotEq_18Call, assertNotEq_19Call),
329 (assertNotEq_20Call, assertNotEq_21Call),
330 (assertNotEq_22Call, assertNotEq_23Call),
331 (assertNotEq_24Call, assertNotEq_25Call),
332}
333
334impl_assertions! {
335 |left, right| assert_not_eq(
336 &left.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
337 &right.iter().map(hex::encode_prefixed).collect::<Vec<_>>(),
338 ),
339 |e| e.format_for_arrays(),
340 (assertNotEq_26Call, assertNotEq_27Call),
341}
342
343impl_assertions! {
344 |left, right, decimals| assert_not_eq(left, right),
345 |e| e.format_with_decimals(decimals),
346 (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
347 (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
348}
349
350impl_assertions! {
351 |left, right| assert_gt(left, right),
352 |e| e.format_for_values(),
353 (assertGt_0Call, assertGt_1Call),
354 (assertGt_2Call, assertGt_3Call),
355}
356
357impl_assertions! {
358 |left, right, decimals| assert_gt(left, right),
359 |e| e.format_with_decimals(decimals),
360 (assertGtDecimal_0Call, assertGtDecimal_1Call),
361 (assertGtDecimal_2Call, assertGtDecimal_3Call),
362}
363
364impl_assertions! {
365 |left, right| assert_ge(left, right),
366 |e| e.format_for_values(),
367 (assertGe_0Call, assertGe_1Call),
368 (assertGe_2Call, assertGe_3Call),
369}
370
371impl_assertions! {
372 |left, right, decimals| assert_ge(left, right),
373 |e| e.format_with_decimals(decimals),
374 (assertGeDecimal_0Call, assertGeDecimal_1Call),
375 (assertGeDecimal_2Call, assertGeDecimal_3Call),
376}
377
378impl_assertions! {
379 |left, right| assert_lt(left, right),
380 |e| e.format_for_values(),
381 (assertLt_0Call, assertLt_1Call),
382 (assertLt_2Call, assertLt_3Call),
383}
384
385impl_assertions! {
386 |left, right, decimals| assert_lt(left, right),
387 |e| e.format_with_decimals(decimals),
388 (assertLtDecimal_0Call, assertLtDecimal_1Call),
389 (assertLtDecimal_2Call, assertLtDecimal_3Call),
390}
391
392impl_assertions! {
393 |left, right| assert_le(left, right),
394 |e| e.format_for_values(),
395 (assertLe_0Call, assertLe_1Call),
396 (assertLe_2Call, assertLe_3Call),
397}
398
399impl_assertions! {
400 |left, right, decimals| assert_le(left, right),
401 |e| e.format_with_decimals(decimals),
402 (assertLeDecimal_0Call, assertLeDecimal_1Call),
403 (assertLeDecimal_2Call, assertLeDecimal_3Call),
404}
405
406impl_assertions! {
407 |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
408 (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
409}
410
411impl_assertions! {
412 |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
413 (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
414}
415
416impl_assertions! {
417 |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
418 |e| e.format_with_decimals(decimals),
419 (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
420}
421
422impl_assertions! {
423 |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
424 |e| e.format_with_decimals(decimals),
425 (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
426}
427
428impl_assertions! {
429 |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
430 (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
431}
432
433impl_assertions! {
434 |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
435 (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
436}
437
438impl_assertions! {
439 |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
440 |e| e.format_with_decimals(decimals),
441 (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
442}
443
444impl_assertions! {
445 |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
446 |e| e.format_with_decimals(decimals),
447 (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
448}
449
450fn assert_true(condition: bool) -> Result<Vec<u8>, SimpleAssertionError> {
451 if condition {
452 Ok(Default::default())
453 } else {
454 Err(SimpleAssertionError)
455 }
456}
457
458fn assert_false(condition: bool) -> Result<Vec<u8>, SimpleAssertionError> {
459 if !condition {
460 Ok(Default::default())
461 } else {
462 Err(SimpleAssertionError)
463 }
464}
465
466fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
467 if left == right {
468 Ok(Default::default())
469 } else {
470 Err(ComparisonAssertionError::Eq { left, right })
471 }
472}
473
474fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
475 if left != right {
476 Ok(Default::default())
477 } else {
478 Err(ComparisonAssertionError::Ne { left, right })
479 }
480}
481
482fn get_delta_uint(left: U256, right: U256) -> U256 {
483 if left > right {
484 left - right
485 } else {
486 right - left
487 }
488}
489
490fn get_delta_int(left: I256, right: I256) -> U256 {
491 let (left_sign, left_abs) = left.into_sign_and_abs();
492 let (right_sign, right_abs) = right.into_sign_and_abs();
493
494 if left_sign == right_sign {
495 if left_abs > right_abs {
496 left_abs - right_abs
497 } else {
498 right_abs - left_abs
499 }
500 } else {
501 left_abs + right_abs
502 }
503}
504
505fn uint_assert_approx_eq_abs(
506 left: U256,
507 right: U256,
508 max_delta: U256,
509) -> Result<Vec<u8>, Box<EqAbsAssertionError<U256, U256>>> {
510 let delta = get_delta_uint(left, right);
511
512 if delta <= max_delta {
513 Ok(Default::default())
514 } else {
515 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
516 }
517}
518
519fn int_assert_approx_eq_abs(
520 left: I256,
521 right: I256,
522 max_delta: U256,
523) -> Result<Vec<u8>, Box<EqAbsAssertionError<I256, U256>>> {
524 let delta = get_delta_int(left, right);
525
526 if delta <= max_delta {
527 Ok(Default::default())
528 } else {
529 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
530 }
531}
532
533fn uint_assert_approx_eq_rel(
534 left: U256,
535 right: U256,
536 max_delta: U256,
537) -> Result<Vec<u8>, EqRelAssertionError<U256>> {
538 if right.is_zero() {
539 if left.is_zero() {
540 return Ok(Default::default())
541 } else {
542 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
543 left,
544 right,
545 max_delta,
546 real_delta: EqRelDelta::Undefined,
547 })))
548 };
549 }
550
551 let delta = get_delta_uint(left, right)
552 .checked_mul(U256::pow(U256::from(10), EQ_REL_DELTA_RESOLUTION))
553 .ok_or(EqRelAssertionError::Overflow)? /
554 right;
555
556 if delta <= max_delta {
557 Ok(Default::default())
558 } else {
559 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
560 left,
561 right,
562 max_delta,
563 real_delta: EqRelDelta::Defined(delta),
564 })))
565 }
566}
567
568fn int_assert_approx_eq_rel(
569 left: I256,
570 right: I256,
571 max_delta: U256,
572) -> Result<Vec<u8>, EqRelAssertionError<I256>> {
573 if right.is_zero() {
574 if left.is_zero() {
575 return Ok(Default::default())
576 } else {
577 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
578 left,
579 right,
580 max_delta,
581 real_delta: EqRelDelta::Undefined,
582 })))
583 }
584 }
585
586 let (_, abs_right) = right.into_sign_and_abs();
587 let delta = get_delta_int(left, right)
588 .checked_mul(U256::pow(U256::from(10), EQ_REL_DELTA_RESOLUTION))
589 .ok_or(EqRelAssertionError::Overflow)? /
590 abs_right;
591
592 if delta <= max_delta {
593 Ok(Default::default())
594 } else {
595 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
596 left,
597 right,
598 max_delta,
599 real_delta: EqRelDelta::Defined(delta),
600 })))
601 }
602}
603
604fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
605 if left > right {
606 Ok(Default::default())
607 } else {
608 Err(ComparisonAssertionError::Gt { left, right })
609 }
610}
611
612fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
613 if left >= right {
614 Ok(Default::default())
615 } else {
616 Err(ComparisonAssertionError::Ge { left, right })
617 }
618}
619
620fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
621 if left < right {
622 Ok(Default::default())
623 } else {
624 Err(ComparisonAssertionError::Lt { left, right })
625 }
626}
627
628fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
629 if left <= right {
630 Ok(Default::default())
631 } else {
632 Err(ComparisonAssertionError::Le { left, right })
633 }
634}