1use crate::{CheatcodesExecutor, CheatsCtxt, Result, Vm::*};
2use alloy_primitives::{I256, U256, U512};
3use foundry_evm_core::{
4 abi::console::{format_units_int, format_units_uint},
5 backend::GLOBAL_FAIL_SLOT,
6 constants::CHEATCODE_ADDRESS,
7 decode::ASSERTION_FAILED_PREFIX,
8 evm::FoundryEvmNetwork,
9};
10use itertools::Itertools;
11use revm::context::{ContextTr, JournalTr};
12use std::{borrow::Cow, fmt};
13
14const EQ_REL_DELTA_RESOLUTION: U256 = U256::from_limbs([18, 0, 0, 0]);
15
16struct ComparisonAssertionError<'a, T> {
17 kind: AssertionKind,
18 left: &'a T,
19 right: &'a T,
20}
21
22#[derive(Clone, Copy)]
23enum AssertionKind {
24 Eq,
25 Ne,
26 Gt,
27 Ge,
28 Lt,
29 Le,
30}
31
32impl AssertionKind {
33 const fn inverse(self) -> Self {
34 match self {
35 Self::Eq => Self::Ne,
36 Self::Ne => Self::Eq,
37 Self::Gt => Self::Le,
38 Self::Ge => Self::Lt,
39 Self::Lt => Self::Ge,
40 Self::Le => Self::Gt,
41 }
42 }
43
44 const fn to_str(self) -> &'static str {
45 match self {
46 Self::Eq => "==",
47 Self::Ne => "!=",
48 Self::Gt => ">",
49 Self::Ge => ">=",
50 Self::Lt => "<",
51 Self::Le => "<=",
52 }
53 }
54}
55
56impl<T> ComparisonAssertionError<'_, T> {
57 fn format_values<D: fmt::Display>(&self, f: impl Fn(&T) -> D) -> String {
58 format!("{} {} {}", f(self.left), self.kind.inverse().to_str(), f(self.right))
59 }
60}
61
62impl<T: fmt::Display> ComparisonAssertionError<'_, T> {
63 fn format_for_values(&self) -> String {
64 self.format_values(T::to_string)
65 }
66}
67
68impl<T: fmt::Display> ComparisonAssertionError<'_, Vec<T>> {
69 fn format_for_arrays(&self) -> String {
70 self.format_values(|v| format!("[{}]", v.iter().format(", ")))
71 }
72}
73
74impl ComparisonAssertionError<'_, U256> {
75 fn format_with_decimals(&self, decimals: &U256) -> String {
76 self.format_values(|v| format_units_uint(v, decimals))
77 }
78}
79
80impl ComparisonAssertionError<'_, I256> {
81 fn format_with_decimals(&self, decimals: &U256) -> String {
82 self.format_values(|v| format_units_int(v, decimals))
83 }
84}
85
86#[derive(thiserror::Error, Debug)]
87#[error("{left} !~= {right} (max delta: {max_delta}, real delta: {real_delta})")]
88struct EqAbsAssertionError<T, D> {
89 left: T,
90 right: T,
91 max_delta: D,
92 real_delta: D,
93}
94
95impl EqAbsAssertionError<U256, U256> {
96 fn format_with_decimals(&self, decimals: &U256) -> String {
97 format!(
98 "{} !~= {} (max delta: {}, real delta: {})",
99 format_units_uint(&self.left, decimals),
100 format_units_uint(&self.right, decimals),
101 format_units_uint(&self.max_delta, decimals),
102 format_units_uint(&self.real_delta, decimals),
103 )
104 }
105}
106
107impl EqAbsAssertionError<I256, U256> {
108 fn format_with_decimals(&self, decimals: &U256) -> String {
109 format!(
110 "{} !~= {} (max delta: {}, real delta: {})",
111 format_units_int(&self.left, decimals),
112 format_units_int(&self.right, decimals),
113 format_units_uint(&self.max_delta, decimals),
114 format_units_uint(&self.real_delta, decimals),
115 )
116 }
117}
118
119fn format_delta_percent(delta: &U256) -> String {
120 format!("{}%", format_units_uint(delta, &(EQ_REL_DELTA_RESOLUTION - U256::from(2))))
121}
122
123#[derive(Debug)]
124enum EqRelDelta {
125 Defined(U256),
126 Undefined,
127}
128
129impl fmt::Display for EqRelDelta {
130 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
131 match self {
132 Self::Defined(delta) => write!(f, "{}", format_delta_percent(delta)),
133 Self::Undefined => write!(f, "undefined"),
134 }
135 }
136}
137
138#[derive(thiserror::Error, Debug)]
139#[error(
140 "{left} !~= {right} (max delta: {}, real delta: {})",
141 format_delta_percent(max_delta),
142 real_delta
143)]
144struct EqRelAssertionFailure<T> {
145 left: T,
146 right: T,
147 max_delta: U256,
148 real_delta: EqRelDelta,
149}
150
151#[derive(thiserror::Error, Debug)]
152enum EqRelAssertionError<T> {
153 #[error(transparent)]
154 Failure(Box<EqRelAssertionFailure<T>>),
155 #[error("overflow in delta calculation")]
156 Overflow,
157}
158
159impl EqRelAssertionError<U256> {
160 fn format_with_decimals(&self, decimals: &U256) -> String {
161 match self {
162 Self::Failure(f) => format!(
163 "{} !~= {} (max delta: {}, real delta: {})",
164 format_units_uint(&f.left, decimals),
165 format_units_uint(&f.right, decimals),
166 format_delta_percent(&f.max_delta),
167 &f.real_delta,
168 ),
169 Self::Overflow => self.to_string(),
170 }
171 }
172}
173
174impl EqRelAssertionError<I256> {
175 fn format_with_decimals(&self, decimals: &U256) -> String {
176 match self {
177 Self::Failure(f) => format!(
178 "{} !~= {} (max delta: {}, real delta: {})",
179 format_units_int(&f.left, decimals),
180 format_units_int(&f.right, decimals),
181 format_delta_percent(&f.max_delta),
182 &f.real_delta,
183 ),
184 Self::Overflow => self.to_string(),
185 }
186 }
187}
188
189type ComparisonResult<'a, T> = Result<(), ComparisonAssertionError<'a, T>>;
190
191#[cold]
192fn handle_assertion_result<FEN: FoundryEvmNetwork, E>(
193 ccx: &mut CheatsCtxt<'_, '_, FEN>,
194 executor: &mut dyn CheatcodesExecutor<FEN>,
195 err: E,
196 error_formatter: Option<&dyn Fn(&E) -> String>,
197 error_msg: Option<&str>,
198) -> Result {
199 let error_msg = error_msg.unwrap_or(ASSERTION_FAILED_PREFIX);
200 let msg = if let Some(error_formatter) = error_formatter {
201 Cow::Owned(format!("{error_msg}: {}", error_formatter(&err)))
202 } else {
203 Cow::Borrowed(error_msg)
204 };
205 handle_assertion_result_mono(ccx, executor, msg)
206}
207
208fn handle_assertion_result_mono<FEN: FoundryEvmNetwork>(
209 ccx: &mut CheatsCtxt<'_, '_, FEN>,
210 executor: &mut dyn CheatcodesExecutor<FEN>,
211 msg: Cow<'_, str>,
212) -> Result {
213 if ccx.state.config.assertions_revert {
214 Err(msg.into_owned().into())
215 } else {
216 executor.console_log(&msg);
217 ccx.ecx.journal_mut().sstore(CHEATCODE_ADDRESS, GLOBAL_FAIL_SLOT, U256::from(1))?;
218 Ok(Default::default())
219 }
220}
221
222macro_rules! impl_assertions {
231 (|$($arg:ident),*| $body:expr, false, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
232 impl_assertions! { @args_tt |($($arg),*)| $body, None, $(($no_error, $with_error)),* }
233 };
234 (|$($arg:ident),*| $body:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
235 impl_assertions! { @args_tt |($($arg),*)| $body, Some(&ToString::to_string), $(($no_error, $with_error)),* }
236 };
237 (|$($arg:ident),*| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
238 impl_assertions! { @args_tt |($($arg),*)| $body, Some(&$error_formatter), $(($no_error, $with_error)),* }
239 };
240
241 (@args_tt |$args:tt| $body:expr, $error_formatter:expr, $(($no_error:ident, $with_error:ident)),* $(,)?) => {
244 $(
245 impl_assertions! { @impl $no_error, $with_error, $args, $body, $error_formatter }
246 )*
247 };
248
249 (@impl $no_error:ident, $with_error:ident, ($($arg:ident),*), $body:expr, $error_formatter:expr) => {
250 impl crate::Cheatcode for $no_error {
251 fn apply_full<FEN: FoundryEvmNetwork>(
252 &self,
253 ccx: &mut CheatsCtxt<'_, '_, FEN>,
254 executor: &mut dyn CheatcodesExecutor<FEN>,
255 ) -> Result {
256 let Self { $($arg),* } = self;
257 match $body {
258 Ok(()) => Ok(Default::default()),
259 Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, None)
260 }
261 }
262 }
263
264 impl crate::Cheatcode for $with_error {
265 fn apply_full<FEN: FoundryEvmNetwork>(
266 &self,
267 ccx: &mut CheatsCtxt<'_, '_, FEN>,
268 executor: &mut dyn CheatcodesExecutor<FEN>,
269 ) -> Result {
270 let Self { $($arg,)* error } = self;
271 match $body {
272 Ok(()) => Ok(Default::default()),
273 Err(err) => handle_assertion_result(ccx, executor, err, $error_formatter, Some(error))
274 }
275 }
276 }
277 };
278}
279
280impl_assertions! {
281 |condition| assert_true(*condition),
282 false,
283 (assertTrue_0Call, assertTrue_1Call),
284}
285
286impl_assertions! {
287 |condition| assert_false(*condition),
288 false,
289 (assertFalse_0Call, assertFalse_1Call),
290}
291
292impl_assertions! {
293 |left, right| assert_eq(left, right),
294 ComparisonAssertionError::format_for_values,
295 (assertEq_0Call, assertEq_1Call),
296 (assertEq_2Call, assertEq_3Call),
297 (assertEq_4Call, assertEq_5Call),
298 (assertEq_6Call, assertEq_7Call),
299 (assertEq_8Call, assertEq_9Call),
300 (assertEq_10Call, assertEq_11Call),
301 (assertEq_12Call, assertEq_13Call),
302}
303
304impl_assertions! {
305 |left, right| assert_eq(left, right),
306 ComparisonAssertionError::format_for_arrays,
307 (assertEq_14Call, assertEq_15Call),
308 (assertEq_16Call, assertEq_17Call),
309 (assertEq_18Call, assertEq_19Call),
310 (assertEq_20Call, assertEq_21Call),
311 (assertEq_22Call, assertEq_23Call),
312 (assertEq_24Call, assertEq_25Call),
313 (assertEq_26Call, assertEq_27Call),
314}
315
316impl_assertions! {
317 |left, right, decimals| assert_eq(left, right),
318 |e| e.format_with_decimals(decimals),
319 (assertEqDecimal_0Call, assertEqDecimal_1Call),
320 (assertEqDecimal_2Call, assertEqDecimal_3Call),
321}
322
323impl_assertions! {
324 |left, right| assert_not_eq(left, right),
325 ComparisonAssertionError::format_for_values,
326 (assertNotEq_0Call, assertNotEq_1Call),
327 (assertNotEq_2Call, assertNotEq_3Call),
328 (assertNotEq_4Call, assertNotEq_5Call),
329 (assertNotEq_6Call, assertNotEq_7Call),
330 (assertNotEq_8Call, assertNotEq_9Call),
331 (assertNotEq_10Call, assertNotEq_11Call),
332 (assertNotEq_12Call, assertNotEq_13Call),
333}
334
335impl_assertions! {
336 |left, right| assert_not_eq(left, right),
337 ComparisonAssertionError::format_for_arrays,
338 (assertNotEq_14Call, assertNotEq_15Call),
339 (assertNotEq_16Call, assertNotEq_17Call),
340 (assertNotEq_18Call, assertNotEq_19Call),
341 (assertNotEq_20Call, assertNotEq_21Call),
342 (assertNotEq_22Call, assertNotEq_23Call),
343 (assertNotEq_24Call, assertNotEq_25Call),
344 (assertNotEq_26Call, assertNotEq_27Call),
345}
346
347impl_assertions! {
348 |left, right, decimals| assert_not_eq(left, right),
349 |e| e.format_with_decimals(decimals),
350 (assertNotEqDecimal_0Call, assertNotEqDecimal_1Call),
351 (assertNotEqDecimal_2Call, assertNotEqDecimal_3Call),
352}
353
354impl_assertions! {
355 |left, right| assert_gt(left, right),
356 ComparisonAssertionError::format_for_values,
357 (assertGt_0Call, assertGt_1Call),
358 (assertGt_2Call, assertGt_3Call),
359}
360
361impl_assertions! {
362 |left, right, decimals| assert_gt(left, right),
363 |e| e.format_with_decimals(decimals),
364 (assertGtDecimal_0Call, assertGtDecimal_1Call),
365 (assertGtDecimal_2Call, assertGtDecimal_3Call),
366}
367
368impl_assertions! {
369 |left, right| assert_ge(left, right),
370 ComparisonAssertionError::format_for_values,
371 (assertGe_0Call, assertGe_1Call),
372 (assertGe_2Call, assertGe_3Call),
373}
374
375impl_assertions! {
376 |left, right, decimals| assert_ge(left, right),
377 |e| e.format_with_decimals(decimals),
378 (assertGeDecimal_0Call, assertGeDecimal_1Call),
379 (assertGeDecimal_2Call, assertGeDecimal_3Call),
380}
381
382impl_assertions! {
383 |left, right| assert_lt(left, right),
384 ComparisonAssertionError::format_for_values,
385 (assertLt_0Call, assertLt_1Call),
386 (assertLt_2Call, assertLt_3Call),
387}
388
389impl_assertions! {
390 |left, right, decimals| assert_lt(left, right),
391 |e| e.format_with_decimals(decimals),
392 (assertLtDecimal_0Call, assertLtDecimal_1Call),
393 (assertLtDecimal_2Call, assertLtDecimal_3Call),
394}
395
396impl_assertions! {
397 |left, right| assert_le(left, right),
398 ComparisonAssertionError::format_for_values,
399 (assertLe_0Call, assertLe_1Call),
400 (assertLe_2Call, assertLe_3Call),
401}
402
403impl_assertions! {
404 |left, right, decimals| assert_le(left, right),
405 |e| e.format_with_decimals(decimals),
406 (assertLeDecimal_0Call, assertLeDecimal_1Call),
407 (assertLeDecimal_2Call, assertLeDecimal_3Call),
408}
409
410impl_assertions! {
411 |left, right, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
412 (assertApproxEqAbs_0Call, assertApproxEqAbs_1Call),
413}
414
415impl_assertions! {
416 |left, right, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
417 (assertApproxEqAbs_2Call, assertApproxEqAbs_3Call),
418}
419
420impl_assertions! {
421 |left, right, decimals, maxDelta| uint_assert_approx_eq_abs(*left, *right, *maxDelta),
422 |e| e.format_with_decimals(decimals),
423 (assertApproxEqAbsDecimal_0Call, assertApproxEqAbsDecimal_1Call),
424}
425
426impl_assertions! {
427 |left, right, decimals, maxDelta| int_assert_approx_eq_abs(*left, *right, *maxDelta),
428 |e| e.format_with_decimals(decimals),
429 (assertApproxEqAbsDecimal_2Call, assertApproxEqAbsDecimal_3Call),
430}
431
432impl_assertions! {
433 |left, right, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
434 (assertApproxEqRel_0Call, assertApproxEqRel_1Call),
435}
436
437impl_assertions! {
438 |left, right, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
439 (assertApproxEqRel_2Call, assertApproxEqRel_3Call),
440}
441
442impl_assertions! {
443 |left, right, decimals, maxPercentDelta| uint_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
444 |e| e.format_with_decimals(decimals),
445 (assertApproxEqRelDecimal_0Call, assertApproxEqRelDecimal_1Call),
446}
447
448impl_assertions! {
449 |left, right, decimals, maxPercentDelta| int_assert_approx_eq_rel(*left, *right, *maxPercentDelta),
450 |e| e.format_with_decimals(decimals),
451 (assertApproxEqRelDecimal_2Call, assertApproxEqRelDecimal_3Call),
452}
453
454const fn assert_true(condition: bool) -> Result<(), ()> {
455 if condition { Ok(()) } else { Err(()) }
456}
457
458const fn assert_false(condition: bool) -> Result<(), ()> {
459 assert_true(!condition)
460}
461
462fn assert_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
463 if left == right {
464 Ok(())
465 } else {
466 Err(ComparisonAssertionError { kind: AssertionKind::Eq, left, right })
467 }
468}
469
470fn assert_not_eq<'a, T: PartialEq>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
471 if left == right {
472 Err(ComparisonAssertionError { kind: AssertionKind::Ne, left, right })
473 } else {
474 Ok(())
475 }
476}
477
478fn assert_gt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
479 if left > right {
480 Ok(())
481 } else {
482 Err(ComparisonAssertionError { kind: AssertionKind::Gt, left, right })
483 }
484}
485
486fn assert_ge<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
487 if left >= right {
488 Ok(())
489 } else {
490 Err(ComparisonAssertionError { kind: AssertionKind::Ge, left, right })
491 }
492}
493
494fn assert_lt<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
495 if left < right {
496 Ok(())
497 } else {
498 Err(ComparisonAssertionError { kind: AssertionKind::Lt, left, right })
499 }
500}
501
502fn assert_le<'a, T: PartialOrd>(left: &'a T, right: &'a T) -> ComparisonResult<'a, T> {
503 if left <= right {
504 Ok(())
505 } else {
506 Err(ComparisonAssertionError { kind: AssertionKind::Le, left, right })
507 }
508}
509
510fn get_delta_int(left: I256, right: I256) -> U256 {
511 let (left_sign, left_abs) = left.into_sign_and_abs();
512 let (right_sign, right_abs) = right.into_sign_and_abs();
513
514 if left_sign == right_sign {
515 if left_abs > right_abs { left_abs - right_abs } else { right_abs - left_abs }
516 } else {
517 left_abs.wrapping_add(right_abs)
518 }
519}
520
521fn calc_delta_full<T>(abs_diff: U256, right: U256) -> Result<U256, EqRelAssertionError<T>> {
525 let delta = U512::from(abs_diff) * U512::from(10).pow(U512::from(EQ_REL_DELTA_RESOLUTION))
526 / U512::from(right);
527 U256::checked_from_limbs_slice(delta.as_limbs()).ok_or(EqRelAssertionError::Overflow)
528}
529
530fn uint_assert_approx_eq_abs(
531 left: U256,
532 right: U256,
533 max_delta: U256,
534) -> Result<(), Box<EqAbsAssertionError<U256, U256>>> {
535 let delta = left.abs_diff(right);
536
537 if delta <= max_delta {
538 Ok(())
539 } else {
540 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
541 }
542}
543
544fn int_assert_approx_eq_abs(
545 left: I256,
546 right: I256,
547 max_delta: U256,
548) -> Result<(), Box<EqAbsAssertionError<I256, U256>>> {
549 let delta = get_delta_int(left, right);
550
551 if delta <= max_delta {
552 Ok(())
553 } else {
554 Err(Box::new(EqAbsAssertionError { left, right, max_delta, real_delta: delta }))
555 }
556}
557
558fn uint_assert_approx_eq_rel(
559 left: U256,
560 right: U256,
561 max_delta: U256,
562) -> Result<(), EqRelAssertionError<U256>> {
563 if right.is_zero() {
564 if left.is_zero() {
565 return Ok(());
566 }
567 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
568 left,
569 right,
570 max_delta,
571 real_delta: EqRelDelta::Undefined,
572 })));
573 }
574
575 let delta = calc_delta_full::<U256>(left.abs_diff(right), right)?;
576
577 if delta <= max_delta {
578 Ok(())
579 } else {
580 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
581 left,
582 right,
583 max_delta,
584 real_delta: EqRelDelta::Defined(delta),
585 })))
586 }
587}
588
589fn int_assert_approx_eq_rel(
590 left: I256,
591 right: I256,
592 max_delta: U256,
593) -> Result<(), EqRelAssertionError<I256>> {
594 if right.is_zero() {
595 if left.is_zero() {
596 return Ok(());
597 }
598 return Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
599 left,
600 right,
601 max_delta,
602 real_delta: EqRelDelta::Undefined,
603 })));
604 }
605
606 let delta = calc_delta_full::<I256>(get_delta_int(left, right), right.unsigned_abs())?;
607
608 if delta <= max_delta {
609 Ok(())
610 } else {
611 Err(EqRelAssertionError::Failure(Box::new(EqRelAssertionFailure {
612 left,
613 right,
614 max_delta,
615 real_delta: EqRelDelta::Defined(delta),
616 })))
617 }
618}