1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use comfy_table::{
6 Cell, Color, Row, Table, modifiers::UTF8_ROUND_CORNERS, presets::ASCII_MARKDOWN,
7};
8use eyre::{Context, Result};
9use foundry_cli::utils::STATIC_FUZZ_SEED;
10use foundry_common::shell;
11use regex::Regex;
12use std::{
13 cmp::Ordering,
14 fs,
15 io::{self, BufRead},
16 path::{Path, PathBuf},
17 str::FromStr,
18 sync::LazyLock,
19};
20use yansi::Paint;
21
22pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
26});
27
28#[derive(Clone, Debug, Parser)]
30pub struct GasSnapshotArgs {
31 #[arg(
35 conflicts_with = "snap",
36 long,
37 value_hint = ValueHint::FilePath,
38 value_name = "SNAPSHOT_FILE",
39 )]
40 diff: Option<Option<PathBuf>>,
41
42 #[arg(
48 conflicts_with = "diff",
49 long,
50 value_hint = ValueHint::FilePath,
51 value_name = "SNAPSHOT_FILE",
52 )]
53 check: Option<Option<PathBuf>>,
54
55 #[arg(long, hide(true))]
58 format: Option<Format>,
59
60 #[arg(
62 long,
63 default_value = ".gas-snapshot",
64 value_hint = ValueHint::FilePath,
65 value_name = "FILE",
66 )]
67 snap: PathBuf,
68
69 #[arg(
71 long,
72 value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
73 value_name = "SNAPSHOT_THRESHOLD"
74 )]
75 tolerance: Option<u32>,
76
77 #[arg(long, value_name = "ORDER")]
79 diff_sort: Option<DiffSortOrder>,
80
81 #[command(flatten)]
83 pub(crate) test: test::TestArgs,
84
85 #[command(flatten)]
87 config: GasSnapshotConfig,
88}
89
90impl GasSnapshotArgs {
91 pub const fn is_watch(&self) -> bool {
93 self.test.is_watch()
94 }
95
96 pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
98 self.test.watchexec_config()
99 }
100
101 pub async fn run(mut self) -> Result<()> {
102 if self.test.fuzz_seed.is_none() {
105 self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
106 }
107
108 let outcome = self.test.compile_and_run().await?;
109 if !shell::is_quiet()
110 && !outcome.allow_failure
111 && self.diff.is_none()
112 && self.check.is_none()
113 && outcome.failed() > 0
114 {
115 sh_eprintln!(
116 "Error: gas snapshot file \"{}\" was not written because the test run failed",
117 self.snap.display()
118 )?;
119 }
120 outcome.ensure_ok(false)?;
121 let tests = self.config.apply(outcome);
122
123 if let Some(path) = self.diff {
124 let snap = path.as_ref().unwrap_or(&self.snap);
125 let snaps = read_gas_snapshot(snap)?;
126 diff(tests, snaps, self.diff_sort.unwrap_or_default())?;
127 } else if let Some(path) = self.check {
128 let snap = path.as_ref().unwrap_or(&self.snap);
129 let snaps = read_gas_snapshot(snap)?;
130 if check(tests, snaps, self.tolerance) {
131 std::process::exit(0)
132 } else {
133 std::process::exit(1)
134 }
135 } else {
136 if matches!(self.format, Some(Format::Table)) {
137 let table = build_gas_snapshot_table(&tests);
138 sh_println!("\n{}", table)?;
139 }
140 write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
141 }
142 Ok(())
143 }
144}
145
146#[derive(Clone, Debug)]
148pub enum Format {
149 Table,
150}
151
152impl FromStr for Format {
153 type Err = String;
154
155 fn from_str(s: &str) -> Result<Self, Self::Err> {
156 match s {
157 "t" | "table" => Ok(Self::Table),
158 _ => Err(format!("Unrecognized format `{s}`")),
159 }
160 }
161}
162
163#[derive(Clone, Debug, Default, Parser)]
165struct GasSnapshotConfig {
166 #[arg(long)]
168 asc: bool,
169
170 #[arg(conflicts_with = "asc", long)]
172 desc: bool,
173
174 #[arg(long, value_name = "MIN_GAS")]
176 min: Option<u64>,
177
178 #[arg(long, value_name = "MAX_GAS")]
180 max: Option<u64>,
181}
182
183#[derive(Clone, Debug, Default, clap::ValueEnum)]
185enum DiffSortOrder {
186 #[default]
188 Percentage,
189 PercentageDesc,
191 Absolute,
193 AbsoluteDesc,
195}
196
197impl GasSnapshotConfig {
198 const fn is_in_gas_range(&self, gas_used: u64) -> bool {
199 if let Some(min) = self.min
200 && gas_used < min
201 {
202 return false;
203 }
204 if let Some(max) = self.max
205 && gas_used > max
206 {
207 return false;
208 }
209 true
210 }
211
212 fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
213 let mut tests = outcome
214 .into_tests()
215 .filter(|test| self.is_in_gas_range(test.gas_used()))
216 .flat_map(expand_invariant_snapshot_entries)
217 .collect::<Vec<_>>();
218
219 if self.asc {
220 tests.sort_by_key(|a| a.gas_used());
221 } else if self.desc {
222 tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
223 }
224
225 tests
226 }
227}
228
229fn expand_invariant_snapshot_entries(test: SuiteTestResult) -> Vec<SuiteTestResult> {
231 if !test.result.kind.is_invariant() || test.result.invariant_predicate_results.len() <= 1 {
232 return vec![test];
233 }
234
235 test.result
236 .invariant_predicate_results
237 .iter()
238 .map(|predicate| {
239 let mut expanded = test.clone();
240 expanded.signature = format!("{}()", predicate.name);
241 expanded
242 })
243 .collect()
244}
245
246#[derive(Clone, Debug, PartialEq, Eq)]
253pub struct GasSnapshotEntry {
254 pub contract_name: String,
255 pub signature: String,
256 pub gas_used: TestKindReport,
257}
258
259impl FromStr for GasSnapshotEntry {
260 type Err = String;
261
262 fn from_str(s: &str) -> Result<Self, Self::Err> {
263 RE_BASIC_SNAPSHOT_ENTRY
264 .captures(s)
265 .and_then(|cap| {
266 cap.name("file").and_then(|file| {
267 cap.name("sig").and_then(|sig| {
268 if let Some(gas) = cap.name("gas") {
269 Some(Self {
270 contract_name: file.as_str().to_string(),
271 signature: sig.as_str().to_string(),
272 gas_used: TestKindReport::Unit {
273 gas: gas.as_str().parse().unwrap(),
274 },
275 })
276 } else if let Some(runs) = cap.name("runs") {
277 cap.name("avg")
278 .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
279 .map(|(runs, avg, med)| Self {
280 contract_name: file.as_str().to_string(),
281 signature: sig.as_str().to_string(),
282 gas_used: TestKindReport::Fuzz {
283 runs: runs.as_str().parse().unwrap(),
284 median_gas: med.as_str().parse().unwrap(),
285 mean_gas: avg.as_str().parse().unwrap(),
286 failed_corpus_replays: 0,
287 },
288 })
289 } else {
290 cap.name("invruns")
291 .and_then(|runs| {
292 cap.name("calls").and_then(|avg| {
293 cap.name("reverts").map(|med| (runs, avg, med))
294 })
295 })
296 .map(|(runs, calls, reverts)| Self {
297 contract_name: file.as_str().to_string(),
298 signature: sig.as_str().to_string(),
299 gas_used: TestKindReport::Invariant {
300 runs: runs.as_str().parse().unwrap(),
301 calls: calls.as_str().parse().unwrap(),
302 reverts: reverts.as_str().parse().unwrap(),
303 metrics: HashMap::default(),
304 failed_corpus_replays: 0,
305 optimization_best_value: None,
306 },
307 })
308 }
309 })
310 })
311 })
312 .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
313 }
314}
315
316fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
318 let path = path.as_ref();
319 let mut entries = Vec::new();
320 for line in io::BufReader::new(
321 fs::File::open(path)
322 .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
323 )
324 .lines()
325 {
326 entries
327 .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
328 }
329 Ok(entries)
330}
331
332fn write_to_gas_snapshot_file(
334 tests: &[SuiteTestResult],
335 path: impl AsRef<Path>,
336 _format: Option<Format>,
337) -> Result<()> {
338 let mut reports = tests
339 .iter()
340 .map(|test| {
341 format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
342 })
343 .collect::<Vec<_>>();
344
345 reports.sort();
347
348 let content = reports.join("\n");
349 Ok(fs::write(path, content)?)
350}
351
352fn build_gas_snapshot_table(tests: &[SuiteTestResult]) -> Table {
353 let mut table = Table::new();
354 if shell::is_markdown() {
355 table.load_preset(ASCII_MARKDOWN);
356 } else {
357 table.apply_modifier(UTF8_ROUND_CORNERS);
358 }
359
360 table.set_header(vec![
361 Cell::new("Contract").fg(Color::Cyan),
362 Cell::new("Signature").fg(Color::Cyan),
363 Cell::new("Report").fg(Color::Cyan),
364 ]);
365
366 for test in tests {
367 let mut row = Row::new();
368 row.add_cell(Cell::new(test.contract_name()));
369 row.add_cell(Cell::new(&test.signature));
370 row.add_cell(Cell::new(test.result.kind.report()));
371 table.add_row(row);
372 }
373
374 table
375}
376
377#[derive(Clone, Debug, PartialEq, Eq)]
379pub struct GasSnapshotDiff {
380 pub signature: String,
381 pub source_gas_used: TestKindReport,
382 pub target_gas_used: TestKindReport,
383}
384
385impl GasSnapshotDiff {
386 const fn gas_change(&self) -> i128 {
391 self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
392 }
393
394 fn gas_diff(&self) -> f64 {
396 self.gas_change() as f64 / self.target_gas_used.gas() as f64
397 }
398}
399
400fn check(
404 tests: Vec<SuiteTestResult>,
405 snaps: Vec<GasSnapshotEntry>,
406 tolerance: Option<u32>,
407) -> bool {
408 let snaps = snaps
409 .into_iter()
410 .map(|s| ((s.contract_name, s.signature), s.gas_used))
411 .collect::<HashMap<_, _>>();
412 let mut has_diff = false;
413 for test in tests {
414 if let Some(target_gas) =
415 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
416 {
417 let source_gas = test.result.kind.report();
418 if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
419 let _ = sh_eprintln!(
420 "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
421 test.contract_name(),
422 test.signature,
423 source_gas,
424 target_gas
425 );
426 has_diff = true;
427 }
428 } else {
429 let _ = sh_eprintln!(
430 "No matching snapshot entry found for \"{}::{}\" in snapshot file",
431 test.contract_name(),
432 test.signature
433 );
434 has_diff = true;
435 }
436 }
437 !has_diff
438}
439
440fn diff(
442 tests: Vec<SuiteTestResult>,
443 snaps: Vec<GasSnapshotEntry>,
444 sort_order: DiffSortOrder,
445) -> Result<()> {
446 let snaps = snaps
447 .into_iter()
448 .map(|s| ((s.contract_name, s.signature), s.gas_used))
449 .collect::<HashMap<_, _>>();
450 let mut diffs = Vec::with_capacity(tests.len());
451 let mut new_tests = Vec::new();
452
453 for test in tests {
454 if let Some(target_gas_used) =
455 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
456 {
457 diffs.push(GasSnapshotDiff {
458 source_gas_used: test.result.kind.report(),
459 signature: format!("{}::{}", test.contract_name(), test.signature),
460 target_gas_used,
461 });
462 } else {
463 new_tests.push(format!("{}::{}", test.contract_name(), test.signature));
465 }
466 }
467
468 let mut increased = 0;
469 let mut decreased = 0;
470 let mut unchanged = 0;
471 let mut overall_gas_change = 0i128;
472 let mut overall_gas_used = 0i128;
473
474 match sort_order {
476 DiffSortOrder::Percentage => {
477 diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
479 }
480 DiffSortOrder::PercentageDesc => {
481 diffs.sort_by(|a, b| b.gas_diff().abs().total_cmp(&a.gas_diff().abs()));
483 }
484 DiffSortOrder::Absolute => {
485 diffs.sort_by_key(|d| d.gas_change().abs());
487 }
488 DiffSortOrder::AbsoluteDesc => {
489 diffs.sort_by_key(|d| std::cmp::Reverse(d.gas_change().abs()));
491 }
492 }
493
494 for diff in &diffs {
495 let gas_change = diff.gas_change();
496 overall_gas_change += gas_change;
497 overall_gas_used += diff.target_gas_used.gas() as i128;
498 let gas_diff = diff.gas_diff();
499
500 if gas_change > 0 {
502 increased += 1;
503 } else if gas_change < 0 {
504 decreased += 1;
505 } else {
506 unchanged += 1;
507 }
508
509 let icon = if gas_change > 0 {
511 "↑".red().to_string()
512 } else if gas_change < 0 {
513 "↓".green().to_string()
514 } else {
515 "━".to_string()
516 };
517
518 sh_println!(
519 "{} {} (gas: {} → {} | {} {})",
520 icon,
521 diff.signature,
522 diff.target_gas_used.gas(),
523 diff.source_gas_used.gas(),
524 fmt_change(gas_change),
525 fmt_pct_change(gas_diff)
526 )?;
527 }
528
529 if !new_tests.is_empty() {
531 sh_eprintln!("\n{}", "New tests:".yellow())?;
532 for test in new_tests {
533 sh_eprintln!(" {} {}", "+".green(), test)?;
534 }
535 }
536
537 sh_eprintln!("\n{}", "-".repeat(80))?;
539
540 let overall_gas_diff = if overall_gas_used > 0 {
541 overall_gas_change as f64 / overall_gas_used as f64
542 } else {
543 0.0
544 };
545
546 sh_eprintln!(
547 "Total tests: {}, {} {}, {} {}, {} {}",
548 diffs.len(),
549 "↑".red().to_string(),
550 increased,
551 "↓".green().to_string(),
552 decreased,
553 "━",
554 unchanged
555 )?;
556 sh_eprintln!(
557 "Overall gas change: {} ({})",
558 fmt_change(overall_gas_change),
559 fmt_pct_change(overall_gas_diff)
560 )?;
561 Ok(())
562}
563
564fn fmt_pct_change(change: f64) -> String {
565 let change_pct = change * 100.0;
566 match change.total_cmp(&0.0) {
567 Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
568 Ordering::Equal => {
569 format!("{change_pct:.3}%")
570 }
571 Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
572 }
573}
574
575fn fmt_change(change: i128) -> String {
576 match change.cmp(&0) {
577 Ordering::Less => format!("{change}").green().to_string(),
578 Ordering::Equal => change.to_string(),
579 Ordering::Greater => format!("{change}").red().to_string(),
580 }
581}
582
583fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
587 if let Some(tolerance) = tolerance_pct {
588 let (hi, lo) = if source_gas > target_gas {
589 (source_gas, target_gas)
590 } else {
591 (target_gas, source_gas)
592 };
593 let diff = (1. - (lo as f64 / hi as f64)) * 100.;
594 diff < tolerance as f64
595 } else {
596 source_gas == target_gas
597 }
598}
599
600#[cfg(test)]
601mod tests {
602 use super::*;
603
604 #[test]
605 fn test_tolerance() {
606 assert!(within_tolerance(100, 105, Some(5)));
607 assert!(within_tolerance(105, 100, Some(5)));
608 assert!(!within_tolerance(100, 106, Some(5)));
609 assert!(!within_tolerance(106, 100, Some(5)));
610 assert!(within_tolerance(100, 100, None));
611 }
612
613 #[test]
614 fn can_parse_basic_gas_snapshot_entry() {
615 let s = "Test:deposit() (gas: 7222)";
616 let entry = GasSnapshotEntry::from_str(s).unwrap();
617 assert_eq!(
618 entry,
619 GasSnapshotEntry {
620 contract_name: "Test".to_string(),
621 signature: "deposit()".to_string(),
622 gas_used: TestKindReport::Unit { gas: 7222 }
623 }
624 );
625 }
626
627 #[test]
628 fn can_parse_fuzz_gas_snapshot_entry() {
629 let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
630 let entry = GasSnapshotEntry::from_str(s).unwrap();
631 assert_eq!(
632 entry,
633 GasSnapshotEntry {
634 contract_name: "Test".to_string(),
635 signature: "deposit()".to_string(),
636 gas_used: TestKindReport::Fuzz {
637 runs: 256,
638 median_gas: 200,
639 mean_gas: 100,
640 failed_corpus_replays: 0
641 }
642 }
643 );
644 }
645
646 #[test]
647 fn can_parse_invariant_gas_snapshot_entry() {
648 let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
649 let entry = GasSnapshotEntry::from_str(s).unwrap();
650 assert_eq!(
651 entry,
652 GasSnapshotEntry {
653 contract_name: "Test".to_string(),
654 signature: "deposit()".to_string(),
655 gas_used: TestKindReport::Invariant {
656 runs: 256,
657 calls: 100,
658 reverts: 200,
659 metrics: HashMap::default(),
660 failed_corpus_replays: 0,
661 optimization_best_value: None,
662 }
663 }
664 );
665 }
666
667 #[test]
668 fn can_parse_invariant_gas_snapshot_entry2() {
669 let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
670 let entry = GasSnapshotEntry::from_str(s).unwrap();
671 assert_eq!(
672 entry,
673 GasSnapshotEntry {
674 contract_name: "ERC20Invariants".to_string(),
675 signature: "invariantBalanceSum()".to_string(),
676 gas_used: TestKindReport::Invariant {
677 runs: 256,
678 calls: 3840,
679 reverts: 2388,
680 metrics: HashMap::default(),
681 failed_corpus_replays: 0,
682 optimization_best_value: None,
683 }
684 }
685 );
686 }
687}