1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use comfy_table::{
6 Cell, Color, Row, Table, modifiers::UTF8_ROUND_CORNERS, presets::ASCII_MARKDOWN,
7};
8use eyre::{Context, Result};
9use foundry_cli::utils::STATIC_FUZZ_SEED;
10use foundry_common::shell;
11use regex::Regex;
12use std::{
13 cmp::Ordering,
14 fs,
15 io::{self, BufRead},
16 path::{Path, PathBuf},
17 str::FromStr,
18 sync::LazyLock,
19};
20use yansi::Paint;
21
22pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
26});
27
28#[derive(Clone, Debug, Parser)]
30pub struct GasSnapshotArgs {
31 #[arg(
35 conflicts_with = "snap",
36 long,
37 value_hint = ValueHint::FilePath,
38 value_name = "SNAPSHOT_FILE",
39 )]
40 diff: Option<Option<PathBuf>>,
41
42 #[arg(
48 conflicts_with = "diff",
49 long,
50 value_hint = ValueHint::FilePath,
51 value_name = "SNAPSHOT_FILE",
52 )]
53 check: Option<Option<PathBuf>>,
54
55 #[arg(long, hide(true))]
58 format: Option<Format>,
59
60 #[arg(
62 long,
63 default_value = ".gas-snapshot",
64 value_hint = ValueHint::FilePath,
65 value_name = "FILE",
66 )]
67 snap: PathBuf,
68
69 #[arg(
71 long,
72 value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
73 value_name = "SNAPSHOT_THRESHOLD"
74 )]
75 tolerance: Option<u32>,
76
77 #[arg(long, value_name = "ORDER")]
79 diff_sort: Option<DiffSortOrder>,
80
81 #[command(flatten)]
83 pub(crate) test: test::TestArgs,
84
85 #[command(flatten)]
87 config: GasSnapshotConfig,
88}
89
90impl GasSnapshotArgs {
91 pub const fn is_watch(&self) -> bool {
93 self.test.is_watch()
94 }
95
96 pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
98 self.test.watchexec_config()
99 }
100
101 pub async fn run(mut self) -> Result<()> {
102 if self.test.fuzz_seed.is_none() {
105 self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
106 }
107
108 let outcome = self.test.compile_and_run().await?;
109 outcome.ensure_ok(false)?;
110 let tests = self.config.apply(outcome);
111
112 if let Some(path) = self.diff {
113 let snap = path.as_ref().unwrap_or(&self.snap);
114 let snaps = read_gas_snapshot(snap)?;
115 diff(tests, snaps, self.diff_sort.unwrap_or_default())?;
116 } else if let Some(path) = self.check {
117 let snap = path.as_ref().unwrap_or(&self.snap);
118 let snaps = read_gas_snapshot(snap)?;
119 if check(tests, snaps, self.tolerance) {
120 std::process::exit(0)
121 } else {
122 std::process::exit(1)
123 }
124 } else {
125 if matches!(self.format, Some(Format::Table)) {
126 let table = build_gas_snapshot_table(&tests);
127 sh_println!("\n{}", table)?;
128 }
129 write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
130 }
131 Ok(())
132 }
133}
134
135#[derive(Clone, Debug)]
137pub enum Format {
138 Table,
139}
140
141impl FromStr for Format {
142 type Err = String;
143
144 fn from_str(s: &str) -> Result<Self, Self::Err> {
145 match s {
146 "t" | "table" => Ok(Self::Table),
147 _ => Err(format!("Unrecognized format `{s}`")),
148 }
149 }
150}
151
152#[derive(Clone, Debug, Default, Parser)]
154struct GasSnapshotConfig {
155 #[arg(long)]
157 asc: bool,
158
159 #[arg(conflicts_with = "asc", long)]
161 desc: bool,
162
163 #[arg(long, value_name = "MIN_GAS")]
165 min: Option<u64>,
166
167 #[arg(long, value_name = "MAX_GAS")]
169 max: Option<u64>,
170}
171
172#[derive(Clone, Debug, Default, clap::ValueEnum)]
174enum DiffSortOrder {
175 #[default]
177 Percentage,
178 PercentageDesc,
180 Absolute,
182 AbsoluteDesc,
184}
185
186impl GasSnapshotConfig {
187 const fn is_in_gas_range(&self, gas_used: u64) -> bool {
188 if let Some(min) = self.min
189 && gas_used < min
190 {
191 return false;
192 }
193 if let Some(max) = self.max
194 && gas_used > max
195 {
196 return false;
197 }
198 true
199 }
200
201 fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
202 let mut tests = outcome
203 .into_tests()
204 .filter(|test| self.is_in_gas_range(test.gas_used()))
205 .collect::<Vec<_>>();
206
207 if self.asc {
208 tests.sort_by_key(|a| a.gas_used());
209 } else if self.desc {
210 tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
211 }
212
213 tests
214 }
215}
216
217#[derive(Clone, Debug, PartialEq, Eq)]
224pub struct GasSnapshotEntry {
225 pub contract_name: String,
226 pub signature: String,
227 pub gas_used: TestKindReport,
228}
229
230impl FromStr for GasSnapshotEntry {
231 type Err = String;
232
233 fn from_str(s: &str) -> Result<Self, Self::Err> {
234 RE_BASIC_SNAPSHOT_ENTRY
235 .captures(s)
236 .and_then(|cap| {
237 cap.name("file").and_then(|file| {
238 cap.name("sig").and_then(|sig| {
239 if let Some(gas) = cap.name("gas") {
240 Some(Self {
241 contract_name: file.as_str().to_string(),
242 signature: sig.as_str().to_string(),
243 gas_used: TestKindReport::Unit {
244 gas: gas.as_str().parse().unwrap(),
245 },
246 })
247 } else if let Some(runs) = cap.name("runs") {
248 cap.name("avg")
249 .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
250 .map(|(runs, avg, med)| Self {
251 contract_name: file.as_str().to_string(),
252 signature: sig.as_str().to_string(),
253 gas_used: TestKindReport::Fuzz {
254 runs: runs.as_str().parse().unwrap(),
255 median_gas: med.as_str().parse().unwrap(),
256 mean_gas: avg.as_str().parse().unwrap(),
257 failed_corpus_replays: 0,
258 },
259 })
260 } else {
261 cap.name("invruns")
262 .and_then(|runs| {
263 cap.name("calls").and_then(|avg| {
264 cap.name("reverts").map(|med| (runs, avg, med))
265 })
266 })
267 .map(|(runs, calls, reverts)| Self {
268 contract_name: file.as_str().to_string(),
269 signature: sig.as_str().to_string(),
270 gas_used: TestKindReport::Invariant {
271 runs: runs.as_str().parse().unwrap(),
272 calls: calls.as_str().parse().unwrap(),
273 reverts: reverts.as_str().parse().unwrap(),
274 metrics: HashMap::default(),
275 failed_corpus_replays: 0,
276 optimization_best_value: None,
277 },
278 })
279 }
280 })
281 })
282 })
283 .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
284 }
285}
286
287fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
289 let path = path.as_ref();
290 let mut entries = Vec::new();
291 for line in io::BufReader::new(
292 fs::File::open(path)
293 .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
294 )
295 .lines()
296 {
297 entries
298 .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
299 }
300 Ok(entries)
301}
302
303fn write_to_gas_snapshot_file(
305 tests: &[SuiteTestResult],
306 path: impl AsRef<Path>,
307 _format: Option<Format>,
308) -> Result<()> {
309 let mut reports = tests
310 .iter()
311 .map(|test| {
312 format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
313 })
314 .collect::<Vec<_>>();
315
316 reports.sort();
318
319 let content = reports.join("\n");
320 Ok(fs::write(path, content)?)
321}
322
323fn build_gas_snapshot_table(tests: &[SuiteTestResult]) -> Table {
324 let mut table = Table::new();
325 if shell::is_markdown() {
326 table.load_preset(ASCII_MARKDOWN);
327 } else {
328 table.apply_modifier(UTF8_ROUND_CORNERS);
329 }
330
331 table.set_header(vec![
332 Cell::new("Contract").fg(Color::Cyan),
333 Cell::new("Signature").fg(Color::Cyan),
334 Cell::new("Report").fg(Color::Cyan),
335 ]);
336
337 for test in tests {
338 let mut row = Row::new();
339 row.add_cell(Cell::new(test.contract_name()));
340 row.add_cell(Cell::new(&test.signature));
341 row.add_cell(Cell::new(test.result.kind.report()));
342 table.add_row(row);
343 }
344
345 table
346}
347
348#[derive(Clone, Debug, PartialEq, Eq)]
350pub struct GasSnapshotDiff {
351 pub signature: String,
352 pub source_gas_used: TestKindReport,
353 pub target_gas_used: TestKindReport,
354}
355
356impl GasSnapshotDiff {
357 const fn gas_change(&self) -> i128 {
362 self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
363 }
364
365 fn gas_diff(&self) -> f64 {
367 self.gas_change() as f64 / self.target_gas_used.gas() as f64
368 }
369}
370
371fn check(
375 tests: Vec<SuiteTestResult>,
376 snaps: Vec<GasSnapshotEntry>,
377 tolerance: Option<u32>,
378) -> bool {
379 let snaps = snaps
380 .into_iter()
381 .map(|s| ((s.contract_name, s.signature), s.gas_used))
382 .collect::<HashMap<_, _>>();
383 let mut has_diff = false;
384 for test in tests {
385 if let Some(target_gas) =
386 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
387 {
388 let source_gas = test.result.kind.report();
389 if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
390 let _ = sh_println!(
391 "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
392 test.contract_name(),
393 test.signature,
394 source_gas,
395 target_gas
396 );
397 has_diff = true;
398 }
399 } else {
400 let _ = sh_println!(
401 "No matching snapshot entry found for \"{}::{}\" in snapshot file",
402 test.contract_name(),
403 test.signature
404 );
405 has_diff = true;
406 }
407 }
408 !has_diff
409}
410
411fn diff(
413 tests: Vec<SuiteTestResult>,
414 snaps: Vec<GasSnapshotEntry>,
415 sort_order: DiffSortOrder,
416) -> Result<()> {
417 let snaps = snaps
418 .into_iter()
419 .map(|s| ((s.contract_name, s.signature), s.gas_used))
420 .collect::<HashMap<_, _>>();
421 let mut diffs = Vec::with_capacity(tests.len());
422 let mut new_tests = Vec::new();
423
424 for test in tests {
425 if let Some(target_gas_used) =
426 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
427 {
428 diffs.push(GasSnapshotDiff {
429 source_gas_used: test.result.kind.report(),
430 signature: format!("{}::{}", test.contract_name(), test.signature),
431 target_gas_used,
432 });
433 } else {
434 new_tests.push(format!("{}::{}", test.contract_name(), test.signature));
436 }
437 }
438
439 let mut increased = 0;
440 let mut decreased = 0;
441 let mut unchanged = 0;
442 let mut overall_gas_change = 0i128;
443 let mut overall_gas_used = 0i128;
444
445 match sort_order {
447 DiffSortOrder::Percentage => {
448 diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
450 }
451 DiffSortOrder::PercentageDesc => {
452 diffs.sort_by(|a, b| b.gas_diff().abs().total_cmp(&a.gas_diff().abs()));
454 }
455 DiffSortOrder::Absolute => {
456 diffs.sort_by_key(|d| d.gas_change().abs());
458 }
459 DiffSortOrder::AbsoluteDesc => {
460 diffs.sort_by_key(|d| std::cmp::Reverse(d.gas_change().abs()));
462 }
463 }
464
465 for diff in &diffs {
466 let gas_change = diff.gas_change();
467 overall_gas_change += gas_change;
468 overall_gas_used += diff.target_gas_used.gas() as i128;
469 let gas_diff = diff.gas_diff();
470
471 if gas_change > 0 {
473 increased += 1;
474 } else if gas_change < 0 {
475 decreased += 1;
476 } else {
477 unchanged += 1;
478 }
479
480 let icon = if gas_change > 0 {
482 "↑".red().to_string()
483 } else if gas_change < 0 {
484 "↓".green().to_string()
485 } else {
486 "━".to_string()
487 };
488
489 sh_println!(
490 "{} {} (gas: {} → {} | {} {})",
491 icon,
492 diff.signature,
493 diff.target_gas_used.gas(),
494 diff.source_gas_used.gas(),
495 fmt_change(gas_change),
496 fmt_pct_change(gas_diff)
497 )?;
498 }
499
500 if !new_tests.is_empty() {
502 sh_println!("\n{}", "New tests:".yellow())?;
503 for test in new_tests {
504 sh_println!(" {} {}", "+".green(), test)?;
505 }
506 }
507
508 sh_println!("\n{}", "-".repeat(80))?;
510
511 let overall_gas_diff = if overall_gas_used > 0 {
512 overall_gas_change as f64 / overall_gas_used as f64
513 } else {
514 0.0
515 };
516
517 sh_println!(
518 "Total tests: {}, {} {}, {} {}, {} {}",
519 diffs.len(),
520 "↑".red().to_string(),
521 increased,
522 "↓".green().to_string(),
523 decreased,
524 "━",
525 unchanged
526 )?;
527 sh_println!(
528 "Overall gas change: {} ({})",
529 fmt_change(overall_gas_change),
530 fmt_pct_change(overall_gas_diff)
531 )?;
532 Ok(())
533}
534
535fn fmt_pct_change(change: f64) -> String {
536 let change_pct = change * 100.0;
537 match change.total_cmp(&0.0) {
538 Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
539 Ordering::Equal => {
540 format!("{change_pct:.3}%")
541 }
542 Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
543 }
544}
545
546fn fmt_change(change: i128) -> String {
547 match change.cmp(&0) {
548 Ordering::Less => format!("{change}").green().to_string(),
549 Ordering::Equal => change.to_string(),
550 Ordering::Greater => format!("{change}").red().to_string(),
551 }
552}
553
554fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
558 if let Some(tolerance) = tolerance_pct {
559 let (hi, lo) = if source_gas > target_gas {
560 (source_gas, target_gas)
561 } else {
562 (target_gas, source_gas)
563 };
564 let diff = (1. - (lo as f64 / hi as f64)) * 100.;
565 diff < tolerance as f64
566 } else {
567 source_gas == target_gas
568 }
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn test_tolerance() {
577 assert!(within_tolerance(100, 105, Some(5)));
578 assert!(within_tolerance(105, 100, Some(5)));
579 assert!(!within_tolerance(100, 106, Some(5)));
580 assert!(!within_tolerance(106, 100, Some(5)));
581 assert!(within_tolerance(100, 100, None));
582 }
583
584 #[test]
585 fn can_parse_basic_gas_snapshot_entry() {
586 let s = "Test:deposit() (gas: 7222)";
587 let entry = GasSnapshotEntry::from_str(s).unwrap();
588 assert_eq!(
589 entry,
590 GasSnapshotEntry {
591 contract_name: "Test".to_string(),
592 signature: "deposit()".to_string(),
593 gas_used: TestKindReport::Unit { gas: 7222 }
594 }
595 );
596 }
597
598 #[test]
599 fn can_parse_fuzz_gas_snapshot_entry() {
600 let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
601 let entry = GasSnapshotEntry::from_str(s).unwrap();
602 assert_eq!(
603 entry,
604 GasSnapshotEntry {
605 contract_name: "Test".to_string(),
606 signature: "deposit()".to_string(),
607 gas_used: TestKindReport::Fuzz {
608 runs: 256,
609 median_gas: 200,
610 mean_gas: 100,
611 failed_corpus_replays: 0
612 }
613 }
614 );
615 }
616
617 #[test]
618 fn can_parse_invariant_gas_snapshot_entry() {
619 let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
620 let entry = GasSnapshotEntry::from_str(s).unwrap();
621 assert_eq!(
622 entry,
623 GasSnapshotEntry {
624 contract_name: "Test".to_string(),
625 signature: "deposit()".to_string(),
626 gas_used: TestKindReport::Invariant {
627 runs: 256,
628 calls: 100,
629 reverts: 200,
630 metrics: HashMap::default(),
631 failed_corpus_replays: 0,
632 optimization_best_value: None,
633 }
634 }
635 );
636 }
637
638 #[test]
639 fn can_parse_invariant_gas_snapshot_entry2() {
640 let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
641 let entry = GasSnapshotEntry::from_str(s).unwrap();
642 assert_eq!(
643 entry,
644 GasSnapshotEntry {
645 contract_name: "ERC20Invariants".to_string(),
646 signature: "invariantBalanceSum()".to_string(),
647 gas_used: TestKindReport::Invariant {
648 runs: 256,
649 calls: 3840,
650 reverts: 2388,
651 metrics: HashMap::default(),
652 failed_corpus_replays: 0,
653 optimization_best_value: None,
654 }
655 }
656 );
657 }
658}