Skip to main content

forge/cmd/
snapshot.rs

1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use comfy_table::{
6    Cell, Color, Row, Table, modifiers::UTF8_ROUND_CORNERS, presets::ASCII_MARKDOWN,
7};
8use eyre::{Context, Result};
9use foundry_cli::utils::STATIC_FUZZ_SEED;
10use foundry_common::shell;
11use regex::Regex;
12use std::{
13    cmp::Ordering,
14    fs,
15    io::{self, BufRead},
16    path::{Path, PathBuf},
17    str::FromStr,
18    sync::LazyLock,
19};
20use yansi::Paint;
21
22/// A regex that matches a basic snapshot entry like
23/// `Test:testDeposit() (gas: 58804)`
24pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
26});
27
28/// CLI arguments for `forge snapshot`.
29#[derive(Clone, Debug, Parser)]
30pub struct GasSnapshotArgs {
31    /// Output a diff against a pre-existing gas snapshot.
32    ///
33    /// By default, the comparison is done with .gas-snapshot.
34    #[arg(
35        conflicts_with = "snap",
36        long,
37        value_hint = ValueHint::FilePath,
38        value_name = "SNAPSHOT_FILE",
39    )]
40    diff: Option<Option<PathBuf>>,
41
42    /// Compare against a pre-existing gas snapshot, exiting with code 1 if they do not match.
43    ///
44    /// Outputs a diff if the gas snapshots do not match.
45    ///
46    /// By default, the comparison is done with .gas-snapshot.
47    #[arg(
48        conflicts_with = "diff",
49        long,
50        value_hint = ValueHint::FilePath,
51        value_name = "SNAPSHOT_FILE",
52    )]
53    check: Option<Option<PathBuf>>,
54
55    // Hidden because there is only one option
56    /// How to format the output.
57    #[arg(long, hide(true))]
58    format: Option<Format>,
59
60    /// Output file for the gas snapshot.
61    #[arg(
62        long,
63        default_value = ".gas-snapshot",
64        value_hint = ValueHint::FilePath,
65        value_name = "FILE",
66    )]
67    snap: PathBuf,
68
69    /// Tolerates gas deviations up to the specified percentage.
70    #[arg(
71        long,
72        value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
73        value_name = "SNAPSHOT_THRESHOLD"
74    )]
75    tolerance: Option<u32>,
76
77    /// How to sort diff results.
78    #[arg(long, value_name = "ORDER")]
79    diff_sort: Option<DiffSortOrder>,
80
81    /// All test arguments are supported
82    #[command(flatten)]
83    pub(crate) test: test::TestArgs,
84
85    /// Additional configs for test results
86    #[command(flatten)]
87    config: GasSnapshotConfig,
88}
89
90impl GasSnapshotArgs {
91    /// Returns whether `GasSnapshotArgs` was configured with `--watch`
92    pub const fn is_watch(&self) -> bool {
93        self.test.is_watch()
94    }
95
96    /// Returns the [`watchexec::Config`] necessary to bootstrap a new watch loop.
97    pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
98        self.test.watchexec_config()
99    }
100
101    pub async fn run(mut self) -> Result<()> {
102        // Default to a static fuzz seed so gas snapshots are deterministic,
103        // but allow the user to override it via `--fuzz-seed`.
104        if self.test.fuzz_seed.is_none() {
105            self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
106        }
107
108        let outcome = self.test.compile_and_run().await?;
109        if !shell::is_quiet()
110            && !outcome.allow_failure
111            && self.diff.is_none()
112            && self.check.is_none()
113            && outcome.failed() > 0
114        {
115            sh_eprintln!(
116                "Error: gas snapshot file \"{}\" was not written because the test run failed",
117                self.snap.display()
118            )?;
119        }
120        outcome.ensure_ok(false)?;
121        let tests = self.config.apply(outcome);
122
123        if let Some(path) = self.diff {
124            let snap = path.as_ref().unwrap_or(&self.snap);
125            let snaps = read_gas_snapshot(snap)?;
126            diff(tests, snaps, self.diff_sort.unwrap_or_default())?;
127        } else if let Some(path) = self.check {
128            let snap = path.as_ref().unwrap_or(&self.snap);
129            let snaps = read_gas_snapshot(snap)?;
130            if check(tests, snaps, self.tolerance) {
131                std::process::exit(0)
132            } else {
133                std::process::exit(1)
134            }
135        } else {
136            if matches!(self.format, Some(Format::Table)) {
137                let table = build_gas_snapshot_table(&tests);
138                sh_println!("\n{}", table)?;
139            }
140            write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
141        }
142        Ok(())
143    }
144}
145
146// Gas report format on stdout.
147#[derive(Clone, Debug)]
148pub enum Format {
149    Table,
150}
151
152impl FromStr for Format {
153    type Err = String;
154
155    fn from_str(s: &str) -> Result<Self, Self::Err> {
156        match s {
157            "t" | "table" => Ok(Self::Table),
158            _ => Err(format!("Unrecognized format `{s}`")),
159        }
160    }
161}
162
163/// Additional filters that can be applied on the test results
164#[derive(Clone, Debug, Default, Parser)]
165struct GasSnapshotConfig {
166    /// Sort results by gas used (ascending).
167    #[arg(long)]
168    asc: bool,
169
170    /// Sort results by gas used (descending).
171    #[arg(conflicts_with = "asc", long)]
172    desc: bool,
173
174    /// Only include tests that used more gas that the given amount.
175    #[arg(long, value_name = "MIN_GAS")]
176    min: Option<u64>,
177
178    /// Only include tests that used less gas that the given amount.
179    #[arg(long, value_name = "MAX_GAS")]
180    max: Option<u64>,
181}
182
183/// Sort order for diff output
184#[derive(Clone, Debug, Default, clap::ValueEnum)]
185enum DiffSortOrder {
186    /// Sort by percentage change (smallest to largest) - default behavior
187    #[default]
188    Percentage,
189    /// Sort by percentage change (largest to smallest)
190    PercentageDesc,
191    /// Sort by absolute gas change (smallest to largest)
192    Absolute,
193    /// Sort by absolute gas change (largest to smallest)
194    AbsoluteDesc,
195}
196
197impl GasSnapshotConfig {
198    const fn is_in_gas_range(&self, gas_used: u64) -> bool {
199        if let Some(min) = self.min
200            && gas_used < min
201        {
202            return false;
203        }
204        if let Some(max) = self.max
205            && gas_used > max
206        {
207            return false;
208        }
209        true
210    }
211
212    fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
213        let mut tests = outcome
214            .into_tests()
215            .filter(|test| self.is_in_gas_range(test.gas_used()))
216            .flat_map(expand_invariant_snapshot_entries)
217            .collect::<Vec<_>>();
218
219        if self.asc {
220            tests.sort_by_key(|a| a.gas_used());
221        } else if self.desc {
222            tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
223        }
224
225        tests
226    }
227}
228
229/// Expands merged invariant campaigns into per-predicate gas snapshot rows.
230fn expand_invariant_snapshot_entries(test: SuiteTestResult) -> Vec<SuiteTestResult> {
231    if !test.result.kind.is_invariant() || test.result.invariant_predicate_results.len() <= 1 {
232        return vec![test];
233    }
234
235    test.result
236        .invariant_predicate_results
237        .iter()
238        .map(|predicate| {
239            let mut expanded = test.clone();
240            expanded.signature = format!("{}()", predicate.name);
241            expanded
242        })
243        .collect()
244}
245
246/// A general entry in a gas snapshot file
247///
248/// Has the form:
249///   `<signature>(gas:? 40181)` for normal tests
250///   `<signature>(runs: 256, μ: 40181, ~: 40181)` for fuzz tests
251///   `<signature>(runs: 256, calls: 40181, reverts: 40181)` for invariant tests
252#[derive(Clone, Debug, PartialEq, Eq)]
253pub struct GasSnapshotEntry {
254    pub contract_name: String,
255    pub signature: String,
256    pub gas_used: TestKindReport,
257}
258
259impl FromStr for GasSnapshotEntry {
260    type Err = String;
261
262    fn from_str(s: &str) -> Result<Self, Self::Err> {
263        RE_BASIC_SNAPSHOT_ENTRY
264            .captures(s)
265            .and_then(|cap| {
266                cap.name("file").and_then(|file| {
267                    cap.name("sig").and_then(|sig| {
268                        if let Some(gas) = cap.name("gas") {
269                            Some(Self {
270                                contract_name: file.as_str().to_string(),
271                                signature: sig.as_str().to_string(),
272                                gas_used: TestKindReport::Unit {
273                                    gas: gas.as_str().parse().unwrap(),
274                                },
275                            })
276                        } else if let Some(runs) = cap.name("runs") {
277                            cap.name("avg")
278                                .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
279                                .map(|(runs, avg, med)| Self {
280                                    contract_name: file.as_str().to_string(),
281                                    signature: sig.as_str().to_string(),
282                                    gas_used: TestKindReport::Fuzz {
283                                        runs: runs.as_str().parse().unwrap(),
284                                        median_gas: med.as_str().parse().unwrap(),
285                                        mean_gas: avg.as_str().parse().unwrap(),
286                                        failed_corpus_replays: 0,
287                                    },
288                                })
289                        } else {
290                            cap.name("invruns")
291                                .and_then(|runs| {
292                                    cap.name("calls").and_then(|avg| {
293                                        cap.name("reverts").map(|med| (runs, avg, med))
294                                    })
295                                })
296                                .map(|(runs, calls, reverts)| Self {
297                                    contract_name: file.as_str().to_string(),
298                                    signature: sig.as_str().to_string(),
299                                    gas_used: TestKindReport::Invariant {
300                                        runs: runs.as_str().parse().unwrap(),
301                                        calls: calls.as_str().parse().unwrap(),
302                                        reverts: reverts.as_str().parse().unwrap(),
303                                        metrics: HashMap::default(),
304                                        failed_corpus_replays: 0,
305                                        optimization_best_value: None,
306                                    },
307                                })
308                        }
309                    })
310                })
311            })
312            .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
313    }
314}
315
316/// Reads a list of gas snapshot entries from a gas snapshot file.
317fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
318    let path = path.as_ref();
319    let mut entries = Vec::new();
320    for line in io::BufReader::new(
321        fs::File::open(path)
322            .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
323    )
324    .lines()
325    {
326        entries
327            .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
328    }
329    Ok(entries)
330}
331
332/// Writes a series of tests to a gas snapshot file after sorting them.
333fn write_to_gas_snapshot_file(
334    tests: &[SuiteTestResult],
335    path: impl AsRef<Path>,
336    _format: Option<Format>,
337) -> Result<()> {
338    let mut reports = tests
339        .iter()
340        .map(|test| {
341            format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
342        })
343        .collect::<Vec<_>>();
344
345    // sort all reports
346    reports.sort();
347
348    let content = reports.join("\n");
349    Ok(fs::write(path, content)?)
350}
351
352fn build_gas_snapshot_table(tests: &[SuiteTestResult]) -> Table {
353    let mut table = Table::new();
354    if shell::is_markdown() {
355        table.load_preset(ASCII_MARKDOWN);
356    } else {
357        table.apply_modifier(UTF8_ROUND_CORNERS);
358    }
359
360    table.set_header(vec![
361        Cell::new("Contract").fg(Color::Cyan),
362        Cell::new("Signature").fg(Color::Cyan),
363        Cell::new("Report").fg(Color::Cyan),
364    ]);
365
366    for test in tests {
367        let mut row = Row::new();
368        row.add_cell(Cell::new(test.contract_name()));
369        row.add_cell(Cell::new(&test.signature));
370        row.add_cell(Cell::new(test.result.kind.report()));
371        table.add_row(row);
372    }
373
374    table
375}
376
377/// A Gas snapshot entry diff.
378#[derive(Clone, Debug, PartialEq, Eq)]
379pub struct GasSnapshotDiff {
380    pub signature: String,
381    pub source_gas_used: TestKindReport,
382    pub target_gas_used: TestKindReport,
383}
384
385impl GasSnapshotDiff {
386    /// Returns the gas diff
387    ///
388    /// `> 0` if the source used more gas
389    /// `< 0` if the target used more gas
390    const fn gas_change(&self) -> i128 {
391        self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
392    }
393
394    /// Determines the percentage change
395    fn gas_diff(&self) -> f64 {
396        self.gas_change() as f64 / self.target_gas_used.gas() as f64
397    }
398}
399
400/// Compares the set of tests with an existing gas snapshot.
401///
402/// Returns true all tests match
403fn check(
404    tests: Vec<SuiteTestResult>,
405    snaps: Vec<GasSnapshotEntry>,
406    tolerance: Option<u32>,
407) -> bool {
408    let snaps = snaps
409        .into_iter()
410        .map(|s| ((s.contract_name, s.signature), s.gas_used))
411        .collect::<HashMap<_, _>>();
412    let mut has_diff = false;
413    for test in tests {
414        if let Some(target_gas) =
415            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
416        {
417            let source_gas = test.result.kind.report();
418            if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
419                let _ = sh_eprintln!(
420                    "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
421                    test.contract_name(),
422                    test.signature,
423                    source_gas,
424                    target_gas
425                );
426                has_diff = true;
427            }
428        } else {
429            let _ = sh_eprintln!(
430                "No matching snapshot entry found for \"{}::{}\" in snapshot file",
431                test.contract_name(),
432                test.signature
433            );
434            has_diff = true;
435        }
436    }
437    !has_diff
438}
439
440/// Compare the set of tests with an existing gas snapshot.
441fn diff(
442    tests: Vec<SuiteTestResult>,
443    snaps: Vec<GasSnapshotEntry>,
444    sort_order: DiffSortOrder,
445) -> Result<()> {
446    let snaps = snaps
447        .into_iter()
448        .map(|s| ((s.contract_name, s.signature), s.gas_used))
449        .collect::<HashMap<_, _>>();
450    let mut diffs = Vec::with_capacity(tests.len());
451    let mut new_tests = Vec::new();
452
453    for test in tests {
454        if let Some(target_gas_used) =
455            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
456        {
457            diffs.push(GasSnapshotDiff {
458                source_gas_used: test.result.kind.report(),
459                signature: format!("{}::{}", test.contract_name(), test.signature),
460                target_gas_used,
461            });
462        } else {
463            // Track new tests
464            new_tests.push(format!("{}::{}", test.contract_name(), test.signature));
465        }
466    }
467
468    let mut increased = 0;
469    let mut decreased = 0;
470    let mut unchanged = 0;
471    let mut overall_gas_change = 0i128;
472    let mut overall_gas_used = 0i128;
473
474    // Sort based on user preference
475    match sort_order {
476        DiffSortOrder::Percentage => {
477            // Default: sort by percentage change (smallest to largest)
478            diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
479        }
480        DiffSortOrder::PercentageDesc => {
481            // Sort by percentage change (largest to smallest)
482            diffs.sort_by(|a, b| b.gas_diff().abs().total_cmp(&a.gas_diff().abs()));
483        }
484        DiffSortOrder::Absolute => {
485            // Sort by absolute gas change (smallest to largest)
486            diffs.sort_by_key(|d| d.gas_change().abs());
487        }
488        DiffSortOrder::AbsoluteDesc => {
489            // Sort by absolute gas change (largest to smallest)
490            diffs.sort_by_key(|d| std::cmp::Reverse(d.gas_change().abs()));
491        }
492    }
493
494    for diff in &diffs {
495        let gas_change = diff.gas_change();
496        overall_gas_change += gas_change;
497        overall_gas_used += diff.target_gas_used.gas() as i128;
498        let gas_diff = diff.gas_diff();
499
500        // Classify changes
501        if gas_change > 0 {
502            increased += 1;
503        } else if gas_change < 0 {
504            decreased += 1;
505        } else {
506            unchanged += 1;
507        }
508
509        // Display with icon and before/after values
510        let icon = if gas_change > 0 {
511            "↑".red().to_string()
512        } else if gas_change < 0 {
513            "↓".green().to_string()
514        } else {
515            "━".to_string()
516        };
517
518        sh_println!(
519            "{} {} (gas: {} → {} | {} {})",
520            icon,
521            diff.signature,
522            diff.target_gas_used.gas(),
523            diff.source_gas_used.gas(),
524            fmt_change(gas_change),
525            fmt_pct_change(gas_diff)
526        )?;
527    }
528
529    // Display new tests if any
530    if !new_tests.is_empty() {
531        sh_eprintln!("\n{}", "New tests:".yellow())?;
532        for test in new_tests {
533            sh_eprintln!("  {} {}", "+".green(), test)?;
534        }
535    }
536
537    // Summary separator
538    sh_eprintln!("\n{}", "-".repeat(80))?;
539
540    let overall_gas_diff = if overall_gas_used > 0 {
541        overall_gas_change as f64 / overall_gas_used as f64
542    } else {
543        0.0
544    };
545
546    sh_eprintln!(
547        "Total tests: {}, {} {}, {} {}, {} {}",
548        diffs.len(),
549        "↑".red().to_string(),
550        increased,
551        "↓".green().to_string(),
552        decreased,
553        "━",
554        unchanged
555    )?;
556    sh_eprintln!(
557        "Overall gas change: {} ({})",
558        fmt_change(overall_gas_change),
559        fmt_pct_change(overall_gas_diff)
560    )?;
561    Ok(())
562}
563
564fn fmt_pct_change(change: f64) -> String {
565    let change_pct = change * 100.0;
566    match change.total_cmp(&0.0) {
567        Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
568        Ordering::Equal => {
569            format!("{change_pct:.3}%")
570        }
571        Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
572    }
573}
574
575fn fmt_change(change: i128) -> String {
576    match change.cmp(&0) {
577        Ordering::Less => format!("{change}").green().to_string(),
578        Ordering::Equal => change.to_string(),
579        Ordering::Greater => format!("{change}").red().to_string(),
580    }
581}
582
583/// Returns true of the difference between the gas values exceeds the tolerance
584///
585/// If `tolerance` is `None`, then this returns `true` if both gas values are equal
586fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
587    if let Some(tolerance) = tolerance_pct {
588        let (hi, lo) = if source_gas > target_gas {
589            (source_gas, target_gas)
590        } else {
591            (target_gas, source_gas)
592        };
593        let diff = (1. - (lo as f64 / hi as f64)) * 100.;
594        diff < tolerance as f64
595    } else {
596        source_gas == target_gas
597    }
598}
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    #[test]
605    fn test_tolerance() {
606        assert!(within_tolerance(100, 105, Some(5)));
607        assert!(within_tolerance(105, 100, Some(5)));
608        assert!(!within_tolerance(100, 106, Some(5)));
609        assert!(!within_tolerance(106, 100, Some(5)));
610        assert!(within_tolerance(100, 100, None));
611    }
612
613    #[test]
614    fn can_parse_basic_gas_snapshot_entry() {
615        let s = "Test:deposit() (gas: 7222)";
616        let entry = GasSnapshotEntry::from_str(s).unwrap();
617        assert_eq!(
618            entry,
619            GasSnapshotEntry {
620                contract_name: "Test".to_string(),
621                signature: "deposit()".to_string(),
622                gas_used: TestKindReport::Unit { gas: 7222 }
623            }
624        );
625    }
626
627    #[test]
628    fn can_parse_fuzz_gas_snapshot_entry() {
629        let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
630        let entry = GasSnapshotEntry::from_str(s).unwrap();
631        assert_eq!(
632            entry,
633            GasSnapshotEntry {
634                contract_name: "Test".to_string(),
635                signature: "deposit()".to_string(),
636                gas_used: TestKindReport::Fuzz {
637                    runs: 256,
638                    median_gas: 200,
639                    mean_gas: 100,
640                    failed_corpus_replays: 0
641                }
642            }
643        );
644    }
645
646    #[test]
647    fn can_parse_invariant_gas_snapshot_entry() {
648        let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
649        let entry = GasSnapshotEntry::from_str(s).unwrap();
650        assert_eq!(
651            entry,
652            GasSnapshotEntry {
653                contract_name: "Test".to_string(),
654                signature: "deposit()".to_string(),
655                gas_used: TestKindReport::Invariant {
656                    runs: 256,
657                    calls: 100,
658                    reverts: 200,
659                    metrics: HashMap::default(),
660                    failed_corpus_replays: 0,
661                    optimization_best_value: None,
662                }
663            }
664        );
665    }
666
667    #[test]
668    fn can_parse_invariant_gas_snapshot_entry2() {
669        let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
670        let entry = GasSnapshotEntry::from_str(s).unwrap();
671        assert_eq!(
672            entry,
673            GasSnapshotEntry {
674                contract_name: "ERC20Invariants".to_string(),
675                signature: "invariantBalanceSum()".to_string(),
676                gas_used: TestKindReport::Invariant {
677                    runs: 256,
678                    calls: 3840,
679                    reverts: 2388,
680                    metrics: HashMap::default(),
681                    failed_corpus_replays: 0,
682                    optimization_best_value: None,
683                }
684            }
685        );
686    }
687}