1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{I256, U256, U512};
3use foundry_evm_core::{
4 abi::console::{format_units_int, format_units_uint},
5 backend::GLOBAL_FAIL_SLOT,
6 constants::CHEATCODE_ADDRESS,
7};
8use itertools::Itertools;
9use revm::context::JournalTr;
10use std::{borrow::Cow, fmt};
11
12const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
13
14struct ComparisonAssertionError<'a, T> {
15 kind: AssertionKind,
16 left: &'a T,
17 right: &'a T,
18}
19
20#[derive(Clone, Copy)]
21enum AssertionKind {
22 Eq,
23 Ne,
24 Gt,
25 Ge,
26 Lt,
27 Le,
28}
29
30impl AssertionKind {
31 fn inverse(self) -> Self {
32 match self {
33 Self::Eq => Self::Ne,
34 Self::Ne => Self::Eq,
35 Self::Gt => Self::Le,
36 Self::Ge => Self::Lt,
37 Self::Lt => Self::Ge,
38 Self::Le => Self::Gt,
39 }
40 }
41
42 fn to_str(self) -> &'static str {
43 match self {
44 Self::Eq => "==",
45 Self::Ne => "!=",
46 Self::Gt => ">",
47 Self::Ge => ">=",
48 Self::Lt => "<",
49 Self::Le => "<=",
50 }
51 }
52}
53
54impl<T> ComparisonAssertionError<'_, T> {
55 fn format_values<D: fmt::Display>(&self, f: impl Fn(&T) -> D) -> String {
56 format!("{} {} {}", f(self.left), self.kind.inverse().to_str(), f(self.right))
57 }
58}
59
60impl<T: fmt::Display> ComparisonAssertionError<'_, T> {
61 fn format_for_values(&self) -> String {
62 self.format_values(T::to_string)
63 }
64}
65
66impl<T: fmt::Display> ComparisonAssertionError<'_, Vec<T>> {
67 fn format_for_arrays(&self) -> String {
68 self.format_values(|v| format!("[{}]", v.iter().format(", ")))
69 }
70}
71
72impl ComparisonAssertionError<'_, U256> {
73 fn format_with_decimals(&self, decimals: &U256) -> String {
74 self.format_values(|v| format_units_uint(v, decimals))
75 }
76}
77
78impl ComparisonAssertionError<'_, I256> {
79 fn format_with_decimals(&self, decimals: &U256) -> String {
80 self.format_values(|v| format_units_int(v, decimals))
81 }
82}
83
84#[derive(thiserror::Error, Debug)]
85#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
86struct EqAbsAssertionError<T, D> {
87 left: T,
88 right: T,
89 max_delta: D,
90 real_delta: D,
91}
92
93impl EqAbsAssertionError<U256, U256> {
94 fn format_with_decimals(&self, decimals: &U256) -> String {
95 format!(
96 "{} !~= {} (max delta: {}, real delta: {})",
97 format_units_uint(&self.left, decimals),
98 format_units_uint(&self.right, decimals),
99 format_units_uint(&self.max_delta, decimals),
100 format_units_uint(&self.real_delta, decimals),
101 )
102 }
103}
104
105impl EqAbsAssertionError<I256, U256> {
106 fn format_with_decimals(&self, decimals: &U256) -> String {
107 format!(
108 "{} !~= {} (max delta: {}, real delta: {})",
109 format_units_int(&self.left, decimals),
110 format_units_int(&self.right, decimals),
111 format_units_uint(&self.max_delta, decimals),
112 format_units_uint(&self.real_delta, decimals),
113 )
114 }
115}
116
117fn format_delta_percent(delta: &U256) -> String {
118 format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
119}
120
121#[derive(Debug)]
122enum EqRelDelta {
123 Defined(U256),
124 Undefined,
125}
126
127impl fmt::Display for EqRelDelta {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 match self {
130 Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
131 Self::Undefined => write!(f, "undefined"),
132 }
133 }
134}
135
136#[derive(thiserror::Error, Debug)]
137#[error(
138 "{left} !~= {right} (max delta: {}, real delta: {})",
139 format_delta_percent(max_delta),
140 real_delta
141)]
142struct EqRelAssertionFailure<T> {
143 left: T,
144 right: T,
145 max_delta: U256,
146 real_delta: EqRelDelta,
147}
148
149#[derive(thiserror::Error, Debug)]
150enum EqRelAssertionError<T> {
151 #[error(transparent)]
152 Failure(Box<EqRelAssertionFailure<T>>),
153 #[error("overflow in delta calculation")]
154 Overflow,
155}
156
157impl EqRelAssertionError<U256> {
158 fn format_with_decimals(&self, decimals: &U256) -> String {
159 match self {
160 Self::Failure(f) => format!(
161 "{} !~= {} (max delta: {}, real delta: {})",
162 format_units_uint(&f.left, decimals),
163 format_units_uint(&f.right, decimals),
164 format_delta_percent(&f.max_delta),
165 &f.real_delta,
166 ),
167 Self::Overflow => self.to_string(),
168 }
169 }
170}
171
172impl EqRelAssertionError<I256> {
173 fn format_with_decimals(&self, decimals: &U256) -> String {
174 match self {
175 Self::Failure(f) => format!(
176 "{} !~= {} (max delta: {}, real delta: {})",
177 format_units_int(&f.left, decimals),
178 format_units_int(&f.right, decimals),
179 format_delta_percent(&f.max_delta),
180 &f.real_delta,
181 ),
182 Self::Overflow => self.to_string(),
183 }
184 }
185}
186
187type ComparisonResult<'a, T> = Result<(), ComparisonAssertionError<'a, T>>;
188
189#[cold]
190fn handle_assertion_result<E>(
191 ccx: &mut CheatsCtxt,
192 executor: &mut dyn CheatcodesExecutor,
193 err: E,
194 error_formatter: Option<&dyn Fn(&E) -> String>,
195 error_msg: Option<&str>,
196) -> Result {
197 let error_msg = error_msg.unwrap_or("assertion failed");
198 let msg = if let Some(error_formatter) = error_formatter {
199 Cow::Owned(format!("{error_msg}: {}", error_formatter(&err)))
200 } else {
201 Cow::Borrowed(error_msg)
202 };
203 handle_assertion_result_mono(ccx, executor, msg)
204}
205
206fn handle_assertion_result_mono(
207 ccx: &mut CheatsCtxt,
208 executor: &mut dyn CheatcodesExecutor,
209 msg: Cow<'_, str>,
210) -> Result {
211 if ccx.state.config.assertions_revert {
212 Err(msg.into_owned().into())
213 } else {
214 executor.console_log(ccx, &msg);
215 ccx.ecx.journaled_state.sstore(CHEATCODE_ADDRESS, GLOBAL_FAIL_SLOT, U256::from(1))?;
216 Ok(Default::default())
217 }
218}
219
220macro_rules! impl_assertions {
229 (|$($arg:ident),*| $body:expr, false, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
230 impl_assertions! { @args_tt |($($arg),*)| $body, None, $(($no_error, $with_error)),* }
231 };
232 (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
233 impl_assertions! { @args_tt |($($arg),*)| $body, Some(&ToString::to_string), $(($no_error, $with_error)),* }
234 };
235 (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
236 impl_assertions! { @args_tt |($($arg),*)| $body, Some(&$error_formatter), $(($no_error, $with_error)),* }
237 };
238
239 (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
242 $(
243 impl_assertions! { @impl $no_error, $with_error, $args, $body, $error_formatter }
244 )*
245 };
246
247 (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr) => {
248 impl crate::Cheatcode for $no_error {
249 fn apply_full(
250 &self,
251 ccx: &mut CheatsCtxt,
252 executor: &mut dyn CheatcodesExecutor,
253 ) -> Result {
254 let Self { $($arg),* } = self;
255 match $body {
256 Ok(()) => Ok(Default::default()),
257 Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, None)
258 }
259 }
260 }
261
262 impl crate::Cheatcode for $with_error {
263 fn apply_full(
264 &self,
265 ccx: &mut CheatsCtxt,
266 executor: &mut dyn CheatcodesExecutor,
267 ) -> Result {
268 let Self { $($arg,)* error } = self;
269 match $body {
270 Ok(()) => Ok(Default::default()),
271 Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, Some(error))
272 }
273 }
274 }
275 };
276}
277
278impl_assertions! {
279 |condition| assert_true(*condition),
280 false,
281 (assertTrue_0Call, assertTrue_1Call),
282}
283
284impl_assertions! {
285 |condition| assert_false(*condition),
286 false,
287 (assertFalse_0Call, assertFalse_1Call),
288}
289
290impl_assertions! {
291 |left, right| assert_eq(left, right),
292 ComparisonAssertionError::format_for_values,
293 (assertEq_0Call, assertEq_1Call),
294 (assertEq_2Call, assertEq_3Call),
295 (assertEq_4Call, assertEq_5Call),
296 (assertEq_6Call, assertEq_7Call),
297 (assertEq_8Call, assertEq_9Call),
298 (assertEq_10Call, assertEq_11Call),
299 (assertEq_12Call, assertEq_13Call),
300}
301
302impl_assertions! {
303 |left, right| assert_eq(left, right),
304 ComparisonAssertionError::format_for_arrays,
305 (assertEq_14Call, assertEq_15Call),
306 (assertEq_16Call, assertEq_17Call),
307 (assertEq_18Call, assertEq_19Call),
308 (assertEq_20Call, assertEq_21Call),
309 (assertEq_22Call, assertEq_23Call),
310 (assertEq_24Call, assertEq_25Call),
311 (assertEq_26Call, assertEq_27Call),
312}
313
314impl_assertions! {
315 |left, right, decimals| assert_eq(left, right),
316 |e| e.format_with_decimals(decimals),
317 (assertEqDecimal_0Call, assertEqDecimal_1Call),
318 (assertEqDecimal_2Call, assertEqDecimal_3Call),
319}
320
321impl_assertions! {
322 |left, right| assert_not_eq(left, right),
323 ComparisonAssertionError::format_for_values,
324 (assertNotEq_0Call, assertNotEq_1Call),
325 (assertNotEq_2Call, assertNotEq_3Call),
326 (assertNotEq_4Call, assertNotEq_5Call),
327 (assertNotEq_6Call, assertNotEq_7Call),
328 (assertNotEq_8Call, assertNotEq_9Call),
329 (assertNotEq_10Call, assertNotEq_11Call),
330 (assertNotEq_12Call, assertNotEq_13Call),
331}
332
333impl_assertions! {
334 |left, right| assert_not_eq(left, right),
335 ComparisonAssertionError::format_for_arrays,
336 (assertNotEq_14Call, assertNotEq_15Call),
337 (assertNotEq_16Call, assertNotEq_17Call),
338 (assertNotEq_18Call, assertNotEq_19Call),
339 (assertNotEq_20Call, assertNotEq_21Call),
340 (assertNotEq_22Call, assertNotEq_23Call),
341 (assertNotEq_24Call, assertNotEq_25Call),
342 (assertNotEq_26Call, assertNotEq_27Call),
343}
344
345impl_assertions! {
346 |left, right, decimals| assert_not_eq(left, right),
347 |e| e.format_with_decimals(decimals),
348 (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
349 (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
350}
351
352impl_assertions! {
353 |left, right| assert_gt(left, right),
354 ComparisonAssertionError::format_for_values,
355 (assertGt_0Call, assertGt_1Call),
356 (assertGt_2Call, assertGt_3Call),
357}
358
359impl_assertions! {
360 |left, right, decimals| assert_gt(left, right),
361 |e| e.format_with_decimals(decimals),
362 (assertGtDecimal_0Call, assertGtDecimal_1Call),
363 (assertGtDecimal_2Call, assertGtDecimal_3Call),
364}
365
366impl_assertions! {
367 |left, right| assert_ge(left, right),
368 ComparisonAssertionError::format_for_values,
369 (assertGe_0Call, assertGe_1Call),
370 (assertGe_2Call, assertGe_3Call),
371}
372
373impl_assertions! {
374 |left, right, decimals| assert_ge(left, right),
375 |e| e.format_with_decimals(decimals),
376 (assertGeDecimal_0Call, assertGeDecimal_1Call),
377 (assertGeDecimal_2Call, assertGeDecimal_3Call),
378}
379
380impl_assertions! {
381 |left, right| assert_lt(left, right),
382 ComparisonAssertionError::format_for_values,
383 (assertLt_0Call, assertLt_1Call),
384 (assertLt_2Call, assertLt_3Call),
385}
386
387impl_assertions! {
388 |left, right, decimals| assert_lt(left, right),
389 |e| e.format_with_decimals(decimals),
390 (assertLtDecimal_0Call, assertLtDecimal_1Call),
391 (assertLtDecimal_2Call, assertLtDecimal_3Call),
392}
393
394impl_assertions! {
395 |left, right| assert_le(left, right),
396 ComparisonAssertionError::format_for_values,
397 (assertLe_0Call, assertLe_1Call),
398 (assertLe_2Call, assertLe_3Call),
399}
400
401impl_assertions! {
402 |left, right, decimals| assert_le(left, right),
403 |e| e.format_with_decimals(decimals),
404 (assertLeDecimal_0Call, assertLeDecimal_1Call),
405 (assertLeDecimal_2Call, assertLeDecimal_3Call),
406}
407
408impl_assertions! {
409 |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
410 (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
411}
412
413impl_assertions! {
414 |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
415 (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
416}
417
418impl_assertions! {
419 |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
420 |e| e.format_with_decimals(decimals),
421 (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
422}
423
424impl_assertions! {
425 |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
426 |e| e.format_with_decimals(decimals),
427 (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
428}
429
430impl_assertions! {
431 |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
432 (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
433}
434
435impl_assertions! {
436 |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
437 (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
438}
439
440impl_assertions! {
441 |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
442 |e| e.format_with_decimals(decimals),
443 (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
444}
445
446impl_assertions! {
447 |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
448 |e| e.format_with_decimals(decimals),
449 (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
450}
451
452fn assert_true(condition: bool) -> Result<(), ()> {
453 if condition { Ok(()) } else { Err(()) }
454}
455
456fn assert_false(condition: bool) -> Result<(), ()> {
457 assert_true(!condition)
458}
459
460fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
461 if left == right {
462 Ok(())
463 } else {
464 Err(ComparisonAssertionError { kind: AssertionKind::Eq, left, right })
465 }
466}
467
468fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
469 if left != right {
470 Ok(())
471 } else {
472 Err(ComparisonAssertionError { kind: AssertionKind::Ne, left, right })
473 }
474}
475
476fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
477 if left > right {
478 Ok(())
479 } else {
480 Err(ComparisonAssertionError { kind: AssertionKind::Gt, left, right })
481 }
482}
483
484fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
485 if left >= right {
486 Ok(())
487 } else {
488 Err(ComparisonAssertionError { kind: AssertionKind::Ge, left, right })
489 }
490}
491
492fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
493 if left < right {
494 Ok(())
495 } else {
496 Err(ComparisonAssertionError { kind: AssertionKind::Lt, left, right })
497 }
498}
499
500fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
501 if left <= right {
502 Ok(())
503 } else {
504 Err(ComparisonAssertionError { kind: AssertionKind::Le, left, right })
505 }
506}
507
508fn get_delta_int(left: I256, right: I256) -> U256 {
509 let (left_sign, left_abs) = left.into_sign_and_abs();
510 let (right_sign, right_abs) = right.into_sign_and_abs();
511
512 if left_sign == right_sign {
513 if left_abs > right_abs { left_abs - right_abs } else { right_abs - left_abs }
514 } else {
515 left_abs.wrapping_add(right_abs)
516 }
517}
518
519fn calc_delta_full<T>(abs_diff: U256, right: U256) -> Result<U256, EqRelAssertionError<T>> {
523 let delta = U512::from(abs_diff) * U512::from(10).pow(U512::from(EQ_REL_DELTA_RESOLUTION))
524 / U512::from(right);
525 U256::checked_from_limbs_slice(delta.as_limbs()).ok_or(EqRelAssertionError::Overflow)
526}
527
528fn uint_assert_approx_eq_abs(
529 left: U256,
530 right: U256,
531 max_delta: U256,
532) -> Result<(), Box<EqAbsAssertionError<U256, U256>>> {
533 let delta = left.abs_diff(right);
534
535 if delta <= max_delta {
536 Ok(())
537 } else {
538 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
539 }
540}
541
542fn int_assert_approx_eq_abs(
543 left: I256,
544 right: I256,
545 max_delta: U256,
546) -> Result<(), Box<EqAbsAssertionError<I256, U256>>> {
547 let delta = get_delta_int(left, right);
548
549 if delta <= max_delta {
550 Ok(())
551 } else {
552 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
553 }
554}
555
556fn uint_assert_approx_eq_rel(
557 left: U256,
558 right: U256,
559 max_delta: U256,
560) -> Result<(), EqRelAssertionError<U256>> {
561 if right.is_zero() {
562 if left.is_zero() {
563 return Ok(());
564 } else {
565 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
566 left,
567 right,
568 max_delta,
569 real_delta: EqRelDelta::Undefined,
570 })));
571 };
572 }
573
574 let delta = calc_delta_full::<U256>(left.abs_diff(right), right)?;
575
576 if delta <= max_delta {
577 Ok(())
578 } else {
579 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
580 left,
581 right,
582 max_delta,
583 real_delta: EqRelDelta::Defined(delta),
584 })))
585 }
586}
587
588fn int_assert_approx_eq_rel(
589 left: I256,
590 right: I256,
591 max_delta: U256,
592) -> Result<(), EqRelAssertionError<I256>> {
593 if right.is_zero() {
594 if left.is_zero() {
595 return Ok(());
596 } else {
597 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
598 left,
599 right,
600 max_delta,
601 real_delta: EqRelDelta::Undefined,
602 })));
603 }
604 }
605
606 let delta = calc_delta_full::<I256>(get_delta_int(left, right), right.unsigned_abs())?;
607
608 if delta <= max_delta {
609 Ok(())
610 } else {
611 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
612 left,
613 right,
614 max_delta,
615 real_delta: EqRelDelta::Defined(delta),
616 })))
617 }
618}