1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use comfy_table::{
6 Cell, Color, Row, Table, modifiers::UTF8_ROUND_CORNERS, presets::ASCII_MARKDOWN,
7};
8use eyre::{Context, Result};
9use foundry_cli::utils::STATIC_FUZZ_SEED;
10use foundry_common::shell;
11use regex::Regex;
12use std::{
13 cmp::Ordering,
14 fs,
15 io::{self, BufRead},
16 path::{Path, PathBuf},
17 str::FromStr,
18 sync::LazyLock,
19};
20use yansi::Paint;
21
22pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
25 Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
26});
27
28#[derive(Clone, Debug, Parser)]
30pub struct GasSnapshotArgs {
31 #[arg(
35 conflicts_with = "snap",
36 long,
37 value_hint = ValueHint::FilePath,
38 value_name = "SNAPSHOT_FILE",
39 )]
40 diff: Option<Option<PathBuf>>,
41
42 #[arg(
48 conflicts_with = "diff",
49 long,
50 value_hint = ValueHint::FilePath,
51 value_name = "SNAPSHOT_FILE",
52 )]
53 check: Option<Option<PathBuf>>,
54
55 #[arg(long, hide(true))]
58 format: Option<Format>,
59
60 #[arg(
62 long,
63 default_value = ".gas-snapshot",
64 value_hint = ValueHint::FilePath,
65 value_name = "FILE",
66 )]
67 snap: PathBuf,
68
69 #[arg(
71 long,
72 value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
73 value_name = "SNAPSHOT_THRESHOLD"
74 )]
75 tolerance: Option<u32>,
76
77 #[arg(long, value_name = "ORDER")]
79 diff_sort: Option<DiffSortOrder>,
80
81 #[command(flatten)]
83 pub(crate) test: test::TestArgs,
84
85 #[command(flatten)]
87 config: GasSnapshotConfig,
88}
89
90impl GasSnapshotArgs {
91 pub fn is_watch(&self) -> bool {
93 self.test.is_watch()
94 }
95
96 pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
98 self.test.watchexec_config()
99 }
100
101 pub async fn run(mut self) -> Result<()> {
102 self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
104
105 let outcome = self.test.compile_and_run().await?;
106 outcome.ensure_ok(false)?;
107 let tests = self.config.apply(outcome);
108
109 if let Some(path) = self.diff {
110 let snap = path.as_ref().unwrap_or(&self.snap);
111 let snaps = read_gas_snapshot(snap)?;
112 diff(tests, snaps, self.diff_sort.unwrap_or_default())?;
113 } else if let Some(path) = self.check {
114 let snap = path.as_ref().unwrap_or(&self.snap);
115 let snaps = read_gas_snapshot(snap)?;
116 if check(tests, snaps, self.tolerance) {
117 std::process::exit(0)
118 } else {
119 std::process::exit(1)
120 }
121 } else {
122 if matches!(self.format, Some(Format::Table)) {
123 let table = build_gas_snapshot_table(&tests);
124 sh_println!("\n{}", table)?;
125 }
126 write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
127 }
128 Ok(())
129 }
130}
131
132#[derive(Clone, Debug)]
134pub enum Format {
135 Table,
136}
137
138impl FromStr for Format {
139 type Err = String;
140
141 fn from_str(s: &str) -> Result<Self, Self::Err> {
142 match s {
143 "t" | "table" => Ok(Self::Table),
144 _ => Err(format!("Unrecognized format `{s}`")),
145 }
146 }
147}
148
149#[derive(Clone, Debug, Default, Parser)]
151struct GasSnapshotConfig {
152 #[arg(long)]
154 asc: bool,
155
156 #[arg(conflicts_with = "asc", long)]
158 desc: bool,
159
160 #[arg(long, value_name = "MIN_GAS")]
162 min: Option<u64>,
163
164 #[arg(long, value_name = "MAX_GAS")]
166 max: Option<u64>,
167}
168
169#[derive(Clone, Debug, Default, clap::ValueEnum)]
171enum DiffSortOrder {
172 #[default]
174 Percentage,
175 PercentageDesc,
177 Absolute,
179 AbsoluteDesc,
181}
182
183impl GasSnapshotConfig {
184 fn is_in_gas_range(&self, gas_used: u64) -> bool {
185 if let Some(min) = self.min
186 && gas_used < min
187 {
188 return false;
189 }
190 if let Some(max) = self.max
191 && gas_used > max
192 {
193 return false;
194 }
195 true
196 }
197
198 fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
199 let mut tests = outcome
200 .into_tests()
201 .filter(|test| self.is_in_gas_range(test.gas_used()))
202 .collect::<Vec<_>>();
203
204 if self.asc {
205 tests.sort_by_key(|a| a.gas_used());
206 } else if self.desc {
207 tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
208 }
209
210 tests
211 }
212}
213
214#[derive(Clone, Debug, PartialEq, Eq)]
221pub struct GasSnapshotEntry {
222 pub contract_name: String,
223 pub signature: String,
224 pub gas_used: TestKindReport,
225}
226
227impl FromStr for GasSnapshotEntry {
228 type Err = String;
229
230 fn from_str(s: &str) -> Result<Self, Self::Err> {
231 RE_BASIC_SNAPSHOT_ENTRY
232 .captures(s)
233 .and_then(|cap| {
234 cap.name("file").and_then(|file| {
235 cap.name("sig").and_then(|sig| {
236 if let Some(gas) = cap.name("gas") {
237 Some(Self {
238 contract_name: file.as_str().to_string(),
239 signature: sig.as_str().to_string(),
240 gas_used: TestKindReport::Unit {
241 gas: gas.as_str().parse().unwrap(),
242 },
243 })
244 } else if let Some(runs) = cap.name("runs") {
245 cap.name("avg")
246 .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
247 .map(|(runs, avg, med)| Self {
248 contract_name: file.as_str().to_string(),
249 signature: sig.as_str().to_string(),
250 gas_used: TestKindReport::Fuzz {
251 runs: runs.as_str().parse().unwrap(),
252 median_gas: med.as_str().parse().unwrap(),
253 mean_gas: avg.as_str().parse().unwrap(),
254 failed_corpus_replays: 0,
255 },
256 })
257 } else {
258 cap.name("invruns")
259 .and_then(|runs| {
260 cap.name("calls").and_then(|avg| {
261 cap.name("reverts").map(|med| (runs, avg, med))
262 })
263 })
264 .map(|(runs, calls, reverts)| Self {
265 contract_name: file.as_str().to_string(),
266 signature: sig.as_str().to_string(),
267 gas_used: TestKindReport::Invariant {
268 runs: runs.as_str().parse().unwrap(),
269 calls: calls.as_str().parse().unwrap(),
270 reverts: reverts.as_str().parse().unwrap(),
271 metrics: HashMap::default(),
272 failed_corpus_replays: 0,
273 optimization_best_value: None,
274 },
275 })
276 }
277 })
278 })
279 })
280 .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
281 }
282}
283
284fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
286 let path = path.as_ref();
287 let mut entries = Vec::new();
288 for line in io::BufReader::new(
289 fs::File::open(path)
290 .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
291 )
292 .lines()
293 {
294 entries
295 .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
296 }
297 Ok(entries)
298}
299
300fn write_to_gas_snapshot_file(
302 tests: &[SuiteTestResult],
303 path: impl AsRef<Path>,
304 _format: Option<Format>,
305) -> Result<()> {
306 let mut reports = tests
307 .iter()
308 .map(|test| {
309 format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
310 })
311 .collect::<Vec<_>>();
312
313 reports.sort();
315
316 let content = reports.join("\n");
317 Ok(fs::write(path, content)?)
318}
319
320fn build_gas_snapshot_table(tests: &[SuiteTestResult]) -> Table {
321 let mut table = Table::new();
322 if shell::is_markdown() {
323 table.load_preset(ASCII_MARKDOWN);
324 } else {
325 table.apply_modifier(UTF8_ROUND_CORNERS);
326 }
327
328 table.set_header(vec![
329 Cell::new("Contract").fg(Color::Cyan),
330 Cell::new("Signature").fg(Color::Cyan),
331 Cell::new("Report").fg(Color::Cyan),
332 ]);
333
334 for test in tests {
335 let mut row = Row::new();
336 row.add_cell(Cell::new(test.contract_name()));
337 row.add_cell(Cell::new(&test.signature));
338 row.add_cell(Cell::new(test.result.kind.report()));
339 table.add_row(row);
340 }
341
342 table
343}
344
345#[derive(Clone, Debug, PartialEq, Eq)]
347pub struct GasSnapshotDiff {
348 pub signature: String,
349 pub source_gas_used: TestKindReport,
350 pub target_gas_used: TestKindReport,
351}
352
353impl GasSnapshotDiff {
354 fn gas_change(&self) -> i128 {
359 self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
360 }
361
362 fn gas_diff(&self) -> f64 {
364 self.gas_change() as f64 / self.target_gas_used.gas() as f64
365 }
366}
367
368fn check(
372 tests: Vec<SuiteTestResult>,
373 snaps: Vec<GasSnapshotEntry>,
374 tolerance: Option<u32>,
375) -> bool {
376 let snaps = snaps
377 .into_iter()
378 .map(|s| ((s.contract_name, s.signature), s.gas_used))
379 .collect::<HashMap<_, _>>();
380 let mut has_diff = false;
381 for test in tests {
382 if let Some(target_gas) =
383 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
384 {
385 let source_gas = test.result.kind.report();
386 if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
387 let _ = sh_println!(
388 "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
389 test.contract_name(),
390 test.signature,
391 source_gas,
392 target_gas
393 );
394 has_diff = true;
395 }
396 } else {
397 let _ = sh_println!(
398 "No matching snapshot entry found for \"{}::{}\" in snapshot file",
399 test.contract_name(),
400 test.signature
401 );
402 has_diff = true;
403 }
404 }
405 !has_diff
406}
407
408fn diff(
410 tests: Vec<SuiteTestResult>,
411 snaps: Vec<GasSnapshotEntry>,
412 sort_order: DiffSortOrder,
413) -> Result<()> {
414 let snaps = snaps
415 .into_iter()
416 .map(|s| ((s.contract_name, s.signature), s.gas_used))
417 .collect::<HashMap<_, _>>();
418 let mut diffs = Vec::with_capacity(tests.len());
419 let mut new_tests = Vec::new();
420
421 for test in tests.into_iter() {
422 if let Some(target_gas_used) =
423 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
424 {
425 diffs.push(GasSnapshotDiff {
426 source_gas_used: test.result.kind.report(),
427 signature: format!("{}::{}", test.contract_name(), test.signature),
428 target_gas_used,
429 });
430 } else {
431 new_tests.push(format!("{}::{}", test.contract_name(), test.signature));
433 }
434 }
435
436 let mut increased = 0;
437 let mut decreased = 0;
438 let mut unchanged = 0;
439 let mut overall_gas_change = 0i128;
440 let mut overall_gas_used = 0i128;
441
442 match sort_order {
444 DiffSortOrder::Percentage => {
445 diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
447 }
448 DiffSortOrder::PercentageDesc => {
449 diffs.sort_by(|a, b| b.gas_diff().abs().total_cmp(&a.gas_diff().abs()));
451 }
452 DiffSortOrder::Absolute => {
453 diffs.sort_by_key(|d| d.gas_change().abs());
455 }
456 DiffSortOrder::AbsoluteDesc => {
457 diffs.sort_by_key(|d| std::cmp::Reverse(d.gas_change().abs()));
459 }
460 }
461
462 for diff in &diffs {
463 let gas_change = diff.gas_change();
464 overall_gas_change += gas_change;
465 overall_gas_used += diff.target_gas_used.gas() as i128;
466 let gas_diff = diff.gas_diff();
467
468 if gas_change > 0 {
470 increased += 1;
471 } else if gas_change < 0 {
472 decreased += 1;
473 } else {
474 unchanged += 1;
475 }
476
477 let icon = if gas_change > 0 {
479 "↑".red().to_string()
480 } else if gas_change < 0 {
481 "↓".green().to_string()
482 } else {
483 "━".to_string()
484 };
485
486 sh_println!(
487 "{} {} (gas: {} → {} | {} {})",
488 icon,
489 diff.signature,
490 diff.target_gas_used.gas(),
491 diff.source_gas_used.gas(),
492 fmt_change(gas_change),
493 fmt_pct_change(gas_diff)
494 )?;
495 }
496
497 if !new_tests.is_empty() {
499 sh_println!("\n{}", "New tests:".yellow())?;
500 for test in new_tests {
501 sh_println!(" {} {}", "+".green(), test)?;
502 }
503 }
504
505 sh_println!("\n{}", "-".repeat(80))?;
507
508 let overall_gas_diff = if overall_gas_used > 0 {
509 overall_gas_change as f64 / overall_gas_used as f64
510 } else {
511 0.0
512 };
513
514 sh_println!(
515 "Total tests: {}, {} {}, {} {}, {} {}",
516 diffs.len(),
517 "↑".red().to_string(),
518 increased,
519 "↓".green().to_string(),
520 decreased,
521 "━",
522 unchanged
523 )?;
524 sh_println!(
525 "Overall gas change: {} ({})",
526 fmt_change(overall_gas_change),
527 fmt_pct_change(overall_gas_diff)
528 )?;
529 Ok(())
530}
531
532fn fmt_pct_change(change: f64) -> String {
533 let change_pct = change * 100.0;
534 match change.total_cmp(&0.0) {
535 Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
536 Ordering::Equal => {
537 format!("{change_pct:.3}%")
538 }
539 Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
540 }
541}
542
543fn fmt_change(change: i128) -> String {
544 match change.cmp(&0) {
545 Ordering::Less => format!("{change}").green().to_string(),
546 Ordering::Equal => change.to_string(),
547 Ordering::Greater => format!("{change}").red().to_string(),
548 }
549}
550
551fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
555 if let Some(tolerance) = tolerance_pct {
556 let (hi, lo) = if source_gas > target_gas {
557 (source_gas, target_gas)
558 } else {
559 (target_gas, source_gas)
560 };
561 let diff = (1. - (lo as f64 / hi as f64)) * 100.;
562 diff < tolerance as f64
563 } else {
564 source_gas == target_gas
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
573 fn test_tolerance() {
574 assert!(within_tolerance(100, 105, Some(5)));
575 assert!(within_tolerance(105, 100, Some(5)));
576 assert!(!within_tolerance(100, 106, Some(5)));
577 assert!(!within_tolerance(106, 100, Some(5)));
578 assert!(within_tolerance(100, 100, None));
579 }
580
581 #[test]
582 fn can_parse_basic_gas_snapshot_entry() {
583 let s = "Test:deposit() (gas: 7222)";
584 let entry = GasSnapshotEntry::from_str(s).unwrap();
585 assert_eq!(
586 entry,
587 GasSnapshotEntry {
588 contract_name: "Test".to_string(),
589 signature: "deposit()".to_string(),
590 gas_used: TestKindReport::Unit { gas: 7222 }
591 }
592 );
593 }
594
595 #[test]
596 fn can_parse_fuzz_gas_snapshot_entry() {
597 let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
598 let entry = GasSnapshotEntry::from_str(s).unwrap();
599 assert_eq!(
600 entry,
601 GasSnapshotEntry {
602 contract_name: "Test".to_string(),
603 signature: "deposit()".to_string(),
604 gas_used: TestKindReport::Fuzz {
605 runs: 256,
606 median_gas: 200,
607 mean_gas: 100,
608 failed_corpus_replays: 0
609 }
610 }
611 );
612 }
613
614 #[test]
615 fn can_parse_invariant_gas_snapshot_entry() {
616 let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
617 let entry = GasSnapshotEntry::from_str(s).unwrap();
618 assert_eq!(
619 entry,
620 GasSnapshotEntry {
621 contract_name: "Test".to_string(),
622 signature: "deposit()".to_string(),
623 gas_used: TestKindReport::Invariant {
624 runs: 256,
625 calls: 100,
626 reverts: 200,
627 metrics: HashMap::default(),
628 failed_corpus_replays: 0,
629 optimization_best_value: None,
630 }
631 }
632 );
633 }
634
635 #[test]
636 fn can_parse_invariant_gas_snapshot_entry2() {
637 let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
638 let entry = GasSnapshotEntry::from_str(s).unwrap();
639 assert_eq!(
640 entry,
641 GasSnapshotEntry {
642 contract_name: "ERC20Invariants".to_string(),
643 signature: "invariantBalanceSum()".to_string(),
644 gas_used: TestKindReport::Invariant {
645 runs: 256,
646 calls: 3840,
647 reverts: 2388,
648 metrics: HashMap::default(),
649 failed_corpus_replays: 0,
650 optimization_best_value: None,
651 }
652 }
653 );
654 }
655}