1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use comfy_table::{
6 Cell, Color, Row, Table, modifiers::UTF8_ROUND_CORNERS, presets::ASCII_MARKDOWN,
7};
8use eyre::{Context, Result};
9use foundry_cli::utils::STATIC_FUZZ_SEED;
10use foundry_common::shell;
11use regex::Regex;
12use std::{
13 cmp::Ordering,
14 fs,
15 io::{self, BufRead},
16 path::{Path, PathBuf},
17 str::FromStr,
18 sync::LazyLock,
19};
20use yansi::Paint;
21
22pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
26});
27
28#[derive(Clone, Debug, Parser)]
30pub struct GasSnapshotArgs {
31 #[arg(
35 conflicts_with = "snap",
36 long,
37 value_hint = ValueHint::FilePath,
38 value_name = "SNAPSHOT_FILE",
39 )]
40 diff: Option<Option<PathBuf>>,
41
42 #[arg(
48 conflicts_with = "diff",
49 long,
50 value_hint = ValueHint::FilePath,
51 value_name = "SNAPSHOT_FILE",
52 )]
53 check: Option<Option<PathBuf>>,
54
55 #[arg(long, hide(true))]
58 format: Option<Format>,
59
60 #[arg(
62 long,
63 default_value = ".gas-snapshot",
64 value_hint = ValueHint::FilePath,
65 value_name = "FILE",
66 )]
67 snap: PathBuf,
68
69 #[arg(
71 long,
72 value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
73 value_name = "SNAPSHOT_THRESHOLD"
74 )]
75 tolerance: Option<u32>,
76
77 #[arg(long, value_name = "ORDER")]
79 diff_sort: Option<DiffSortOrder>,
80
81 #[command(flatten)]
83 pub(crate) test: test::TestArgs,
84
85 #[command(flatten)]
87 config: GasSnapshotConfig,
88}
89
90impl GasSnapshotArgs {
91 pub fn is_watch(&self) -> bool {
93 self.test.is_watch()
94 }
95
96 pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
98 self.test.watchexec_config()
99 }
100
101 pub async fn run(mut self) -> Result<()> {
102 self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
104
105 let outcome = self.test.compile_and_run().await?;
106 outcome.ensure_ok(false)?;
107 let tests = self.config.apply(outcome);
108
109 if let Some(path) = self.diff {
110 let snap = path.as_ref().unwrap_or(&self.snap);
111 let snaps = read_gas_snapshot(snap)?;
112 diff(tests, snaps, self.diff_sort.unwrap_or_default())?;
113 } else if let Some(path) = self.check {
114 let snap = path.as_ref().unwrap_or(&self.snap);
115 let snaps = read_gas_snapshot(snap)?;
116 if check(tests, snaps, self.tolerance) {
117 std::process::exit(0)
118 } else {
119 std::process::exit(1)
120 }
121 } else {
122 if matches!(self.format, Some(Format::Table)) {
123 let table = build_gas_snapshot_table(&tests);
124 sh_println!("\n{}", table)?;
125 }
126 write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
127 }
128 Ok(())
129 }
130}
131
132#[derive(Clone, Debug)]
134pub enum Format {
135 Table,
136}
137
138impl FromStr for Format {
139 type Err = String;
140
141 fn from_str(s: &str) -> Result<Self, Self::Err> {
142 match s {
143 "t" | "table" => Ok(Self::Table),
144 _ => Err(format!("Unrecognized format `{s}`")),
145 }
146 }
147}
148
149#[derive(Clone, Debug, Default, Parser)]
151struct GasSnapshotConfig {
152 #[arg(long)]
154 asc: bool,
155
156 #[arg(conflicts_with = "asc", long)]
158 desc: bool,
159
160 #[arg(long, value_name = "MIN_GAS")]
162 min: Option<u64>,
163
164 #[arg(long, value_name = "MAX_GAS")]
166 max: Option<u64>,
167}
168
169#[derive(Clone, Debug, Default, clap::ValueEnum)]
171enum DiffSortOrder {
172 #[default]
174 Percentage,
175 PercentageDesc,
177 Absolute,
179 AbsoluteDesc,
181}
182
183impl GasSnapshotConfig {
184 fn is_in_gas_range(&self, gas_used: u64) -> bool {
185 if let Some(min) = self.min
186 && gas_used < min
187 {
188 return false;
189 }
190 if let Some(max) = self.max
191 && gas_used > max
192 {
193 return false;
194 }
195 true
196 }
197
198 fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
199 let mut tests = outcome
200 .into_tests()
201 .filter(|test| self.is_in_gas_range(test.gas_used()))
202 .collect::<Vec<_>>();
203
204 if self.asc {
205 tests.sort_by_key(|a| a.gas_used());
206 } else if self.desc {
207 tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
208 }
209
210 tests
211 }
212}
213
214#[derive(Clone, Debug, PartialEq, Eq)]
221pub struct GasSnapshotEntry {
222 pub contract_name: String,
223 pub signature: String,
224 pub gas_used: TestKindReport,
225}
226
227impl FromStr for GasSnapshotEntry {
228 type Err = String;
229
230 fn from_str(s: &str) -> Result<Self, Self::Err> {
231 RE_BASIC_SNAPSHOT_ENTRY
232 .captures(s)
233 .and_then(|cap| {
234 cap.name("file").and_then(|file| {
235 cap.name("sig").and_then(|sig| {
236 if let Some(gas) = cap.name("gas") {
237 Some(Self {
238 contract_name: file.as_str().to_string(),
239 signature: sig.as_str().to_string(),
240 gas_used: TestKindReport::Unit {
241 gas: gas.as_str().parse().unwrap(),
242 },
243 })
244 } else if let Some(runs) = cap.name("runs") {
245 cap.name("avg")
246 .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
247 .map(|(runs, avg, med)| Self {
248 contract_name: file.as_str().to_string(),
249 signature: sig.as_str().to_string(),
250 gas_used: TestKindReport::Fuzz {
251 runs: runs.as_str().parse().unwrap(),
252 median_gas: med.as_str().parse().unwrap(),
253 mean_gas: avg.as_str().parse().unwrap(),
254 failed_corpus_replays: 0,
255 },
256 })
257 } else {
258 cap.name("invruns")
259 .and_then(|runs| {
260 cap.name("calls").and_then(|avg| {
261 cap.name("reverts").map(|med| (runs, avg, med))
262 })
263 })
264 .map(|(runs, calls, reverts)| Self {
265 contract_name: file.as_str().to_string(),
266 signature: sig.as_str().to_string(),
267 gas_used: TestKindReport::Invariant {
268 runs: runs.as_str().parse().unwrap(),
269 calls: calls.as_str().parse().unwrap(),
270 reverts: reverts.as_str().parse().unwrap(),
271 metrics: HashMap::default(),
272 failed_corpus_replays: 0,
273 },
274 })
275 }
276 })
277 })
278 })
279 .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
280 }
281}
282
283fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
285 let path = path.as_ref();
286 let mut entries = Vec::new();
287 for line in io::BufReader::new(
288 fs::File::open(path)
289 .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
290 )
291 .lines()
292 {
293 entries
294 .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
295 }
296 Ok(entries)
297}
298
299fn write_to_gas_snapshot_file(
301 tests: &[SuiteTestResult],
302 path: impl AsRef<Path>,
303 _format: Option<Format>,
304) -> Result<()> {
305 let mut reports = tests
306 .iter()
307 .map(|test| {
308 format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
309 })
310 .collect::<Vec<_>>();
311
312 reports.sort();
314
315 let content = reports.join("\n");
316 Ok(fs::write(path, content)?)
317}
318
319fn build_gas_snapshot_table(tests: &[SuiteTestResult]) -> Table {
320 let mut table = Table::new();
321 if shell::is_markdown() {
322 table.load_preset(ASCII_MARKDOWN);
323 } else {
324 table.apply_modifier(UTF8_ROUND_CORNERS);
325 }
326
327 table.set_header(vec![
328 Cell::new("Contract").fg(Color::Cyan),
329 Cell::new("Signature").fg(Color::Cyan),
330 Cell::new("Report").fg(Color::Cyan),
331 ]);
332
333 for test in tests {
334 let mut row = Row::new();
335 row.add_cell(Cell::new(test.contract_name()));
336 row.add_cell(Cell::new(&test.signature));
337 row.add_cell(Cell::new(test.result.kind.report().to_string()));
338 table.add_row(row);
339 }
340
341 table
342}
343
344#[derive(Clone, Debug, PartialEq, Eq)]
346pub struct GasSnapshotDiff {
347 pub signature: String,
348 pub source_gas_used: TestKindReport,
349 pub target_gas_used: TestKindReport,
350}
351
352impl GasSnapshotDiff {
353 fn gas_change(&self) -> i128 {
358 self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
359 }
360
361 fn gas_diff(&self) -> f64 {
363 self.gas_change() as f64 / self.target_gas_used.gas() as f64
364 }
365}
366
367fn check(
371 tests: Vec<SuiteTestResult>,
372 snaps: Vec<GasSnapshotEntry>,
373 tolerance: Option<u32>,
374) -> bool {
375 let snaps = snaps
376 .into_iter()
377 .map(|s| ((s.contract_name, s.signature), s.gas_used))
378 .collect::<HashMap<_, _>>();
379 let mut has_diff = false;
380 for test in tests {
381 if let Some(target_gas) =
382 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
383 {
384 let source_gas = test.result.kind.report();
385 if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
386 let _ = sh_println!(
387 "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
388 test.contract_name(),
389 test.signature,
390 source_gas,
391 target_gas
392 );
393 has_diff = true;
394 }
395 } else {
396 let _ = sh_println!(
397 "No matching snapshot entry found for \"{}::{}\" in snapshot file",
398 test.contract_name(),
399 test.signature
400 );
401 has_diff = true;
402 }
403 }
404 !has_diff
405}
406
407fn diff(
409 tests: Vec<SuiteTestResult>,
410 snaps: Vec<GasSnapshotEntry>,
411 sort_order: DiffSortOrder,
412) -> Result<()> {
413 let snaps = snaps
414 .into_iter()
415 .map(|s| ((s.contract_name, s.signature), s.gas_used))
416 .collect::<HashMap<_, _>>();
417 let mut diffs = Vec::with_capacity(tests.len());
418 let mut new_tests = Vec::new();
419
420 for test in tests.into_iter() {
421 if let Some(target_gas_used) =
422 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
423 {
424 diffs.push(GasSnapshotDiff {
425 source_gas_used: test.result.kind.report(),
426 signature: format!("{}::{}", test.contract_name(), test.signature),
427 target_gas_used,
428 });
429 } else {
430 new_tests.push(format!("{}::{}", test.contract_name(), test.signature));
432 }
433 }
434
435 let mut increased = 0;
436 let mut decreased = 0;
437 let mut unchanged = 0;
438 let mut overall_gas_change = 0i128;
439 let mut overall_gas_used = 0i128;
440
441 match sort_order {
443 DiffSortOrder::Percentage => {
444 diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
446 }
447 DiffSortOrder::PercentageDesc => {
448 diffs.sort_by(|a, b| b.gas_diff().abs().total_cmp(&a.gas_diff().abs()));
450 }
451 DiffSortOrder::Absolute => {
452 diffs.sort_by_key(|d| d.gas_change().abs());
454 }
455 DiffSortOrder::AbsoluteDesc => {
456 diffs.sort_by_key(|d| std::cmp::Reverse(d.gas_change().abs()));
458 }
459 }
460
461 for diff in &diffs {
462 let gas_change = diff.gas_change();
463 overall_gas_change += gas_change;
464 overall_gas_used += diff.target_gas_used.gas() as i128;
465 let gas_diff = diff.gas_diff();
466
467 if gas_change > 0 {
469 increased += 1;
470 } else if gas_change < 0 {
471 decreased += 1;
472 } else {
473 unchanged += 1;
474 }
475
476 let icon = if gas_change > 0 {
478 "↑".red().to_string()
479 } else if gas_change < 0 {
480 "↓".green().to_string()
481 } else {
482 "━".to_string()
483 };
484
485 sh_println!(
486 "{} {} (gas: {} → {} | {} {})",
487 icon,
488 diff.signature,
489 diff.target_gas_used.gas(),
490 diff.source_gas_used.gas(),
491 fmt_change(gas_change),
492 fmt_pct_change(gas_diff)
493 )?;
494 }
495
496 if !new_tests.is_empty() {
498 sh_println!("\n{}", "New tests:".yellow())?;
499 for test in new_tests {
500 sh_println!(" {} {}", "+".green(), test)?;
501 }
502 }
503
504 sh_println!("\n{}", "-".repeat(80))?;
506
507 let overall_gas_diff = if overall_gas_used > 0 {
508 overall_gas_change as f64 / overall_gas_used as f64
509 } else {
510 0.0
511 };
512
513 sh_println!(
514 "Total tests: {}, {} {}, {} {}, {} {}",
515 diffs.len(),
516 "↑".red().to_string(),
517 increased,
518 "↓".green().to_string(),
519 decreased,
520 "━",
521 unchanged
522 )?;
523 sh_println!(
524 "Overall gas change: {} ({})",
525 fmt_change(overall_gas_change),
526 fmt_pct_change(overall_gas_diff)
527 )?;
528 Ok(())
529}
530
531fn fmt_pct_change(change: f64) -> String {
532 let change_pct = change * 100.0;
533 match change.total_cmp(&0.0) {
534 Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
535 Ordering::Equal => {
536 format!("{change_pct:.3}%")
537 }
538 Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
539 }
540}
541
542fn fmt_change(change: i128) -> String {
543 match change.cmp(&0) {
544 Ordering::Less => format!("{change}").green().to_string(),
545 Ordering::Equal => {
546 format!("{change}")
547 }
548 Ordering::Greater => format!("{change}").red().to_string(),
549 }
550}
551
552fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
556 if let Some(tolerance) = tolerance_pct {
557 let (hi, lo) = if source_gas > target_gas {
558 (source_gas, target_gas)
559 } else {
560 (target_gas, source_gas)
561 };
562 let diff = (1. - (lo as f64 / hi as f64)) * 100.;
563 diff < tolerance as f64
564 } else {
565 source_gas == target_gas
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
574 fn test_tolerance() {
575 assert!(within_tolerance(100, 105, Some(5)));
576 assert!(within_tolerance(105, 100, Some(5)));
577 assert!(!within_tolerance(100, 106, Some(5)));
578 assert!(!within_tolerance(106, 100, Some(5)));
579 assert!(within_tolerance(100, 100, None));
580 }
581
582 #[test]
583 fn can_parse_basic_gas_snapshot_entry() {
584 let s = "Test:deposit() (gas: 7222)";
585 let entry = GasSnapshotEntry::from_str(s).unwrap();
586 assert_eq!(
587 entry,
588 GasSnapshotEntry {
589 contract_name: "Test".to_string(),
590 signature: "deposit()".to_string(),
591 gas_used: TestKindReport::Unit { gas: 7222 }
592 }
593 );
594 }
595
596 #[test]
597 fn can_parse_fuzz_gas_snapshot_entry() {
598 let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
599 let entry = GasSnapshotEntry::from_str(s).unwrap();
600 assert_eq!(
601 entry,
602 GasSnapshotEntry {
603 contract_name: "Test".to_string(),
604 signature: "deposit()".to_string(),
605 gas_used: TestKindReport::Fuzz {
606 runs: 256,
607 median_gas: 200,
608 mean_gas: 100,
609 failed_corpus_replays: 0
610 }
611 }
612 );
613 }
614
615 #[test]
616 fn can_parse_invariant_gas_snapshot_entry() {
617 let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
618 let entry = GasSnapshotEntry::from_str(s).unwrap();
619 assert_eq!(
620 entry,
621 GasSnapshotEntry {
622 contract_name: "Test".to_string(),
623 signature: "deposit()".to_string(),
624 gas_used: TestKindReport::Invariant {
625 runs: 256,
626 calls: 100,
627 reverts: 200,
628 metrics: HashMap::default(),
629 failed_corpus_replays: 0,
630 }
631 }
632 );
633 }
634
635 #[test]
636 fn can_parse_invariant_gas_snapshot_entry2() {
637 let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
638 let entry = GasSnapshotEntry::from_str(s).unwrap();
639 assert_eq!(
640 entry,
641 GasSnapshotEntry {
642 contract_name: "ERC20Invariants".to_string(),
643 signature: "invariantBalanceSum()".to_string(),
644 gas_used: TestKindReport::Invariant {
645 runs: 256,
646 calls: 3840,
647 reverts: 2388,
648 metrics: HashMap::default(),
649 failed_corpus_replays: 0,
650 }
651 }
652 );
653 }
654}