Skip to main content

forge/cmd/
snapshot.rs

1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use comfy_table::{
6    Cell, Color, Row, Table, modifiers::UTF8_ROUND_CORNERS, presets::ASCII_MARKDOWN,
7};
8use eyre::{Context, Result};
9use foundry_cli::utils::STATIC_FUZZ_SEED;
10use foundry_common::shell;
11use regex::Regex;
12use std::{
13    cmp::Ordering,
14    fs,
15    io::{self, BufRead},
16    path::{Path, PathBuf},
17    str::FromStr,
18    sync::LazyLock,
19};
20use yansi::Paint;
21
22/// A regex that matches a basic snapshot entry like
23/// `Test:testDeposit() (gas: 58804)`
24pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
26});
27
28/// CLI arguments for `forge snapshot`.
29#[derive(Clone, Debug, Parser)]
30pub struct GasSnapshotArgs {
31    /// Output a diff against a pre-existing gas snapshot.
32    ///
33    /// By default, the comparison is done with .gas-snapshot.
34    #[arg(
35        conflicts_with = "snap",
36        long,
37        value_hint = ValueHint::FilePath,
38        value_name = "SNAPSHOT_FILE",
39    )]
40    diff: Option<Option<PathBuf>>,
41
42    /// Compare against a pre-existing gas snapshot, exiting with code 1 if they do not match.
43    ///
44    /// Outputs a diff if the gas snapshots do not match.
45    ///
46    /// By default, the comparison is done with .gas-snapshot.
47    #[arg(
48        conflicts_with = "diff",
49        long,
50        value_hint = ValueHint::FilePath,
51        value_name = "SNAPSHOT_FILE",
52    )]
53    check: Option<Option<PathBuf>>,
54
55    // Hidden because there is only one option
56    /// How to format the output.
57    #[arg(long, hide(true))]
58    format: Option<Format>,
59
60    /// Output file for the gas snapshot.
61    #[arg(
62        long,
63        default_value = ".gas-snapshot",
64        value_hint = ValueHint::FilePath,
65        value_name = "FILE",
66    )]
67    snap: PathBuf,
68
69    /// Tolerates gas deviations up to the specified percentage.
70    #[arg(
71        long,
72        value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
73        value_name = "SNAPSHOT_THRESHOLD"
74    )]
75    tolerance: Option<u32>,
76
77    /// How to sort diff results.
78    #[arg(long, value_name = "ORDER")]
79    diff_sort: Option<DiffSortOrder>,
80
81    /// All test arguments are supported
82    #[command(flatten)]
83    pub(crate) test: test::TestArgs,
84
85    /// Additional configs for test results
86    #[command(flatten)]
87    config: GasSnapshotConfig,
88}
89
90impl GasSnapshotArgs {
91    /// Returns whether `GasSnapshotArgs` was configured with `--watch`
92    pub fn is_watch(&self) -> bool {
93        self.test.is_watch()
94    }
95
96    /// Returns the [`watchexec::Config`] necessary to bootstrap a new watch loop.
97    pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
98        self.test.watchexec_config()
99    }
100
101    pub async fn run(mut self) -> Result<()> {
102        // Set fuzz seed so gas snapshots are deterministic
103        self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
104
105        let outcome = self.test.compile_and_run().await?;
106        outcome.ensure_ok(false)?;
107        let tests = self.config.apply(outcome);
108
109        if let Some(path) = self.diff {
110            let snap = path.as_ref().unwrap_or(&self.snap);
111            let snaps = read_gas_snapshot(snap)?;
112            diff(tests, snaps, self.diff_sort.unwrap_or_default())?;
113        } else if let Some(path) = self.check {
114            let snap = path.as_ref().unwrap_or(&self.snap);
115            let snaps = read_gas_snapshot(snap)?;
116            if check(tests, snaps, self.tolerance) {
117                std::process::exit(0)
118            } else {
119                std::process::exit(1)
120            }
121        } else {
122            if matches!(self.format, Some(Format::Table)) {
123                let table = build_gas_snapshot_table(&tests);
124                sh_println!("\n{}", table)?;
125            }
126            write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
127        }
128        Ok(())
129    }
130}
131
132// Gas report format on stdout.
133#[derive(Clone, Debug)]
134pub enum Format {
135    Table,
136}
137
138impl FromStr for Format {
139    type Err = String;
140
141    fn from_str(s: &str) -> Result<Self, Self::Err> {
142        match s {
143            "t" | "table" => Ok(Self::Table),
144            _ => Err(format!("Unrecognized format `{s}`")),
145        }
146    }
147}
148
149/// Additional filters that can be applied on the test results
150#[derive(Clone, Debug, Default, Parser)]
151struct GasSnapshotConfig {
152    /// Sort results by gas used (ascending).
153    #[arg(long)]
154    asc: bool,
155
156    /// Sort results by gas used (descending).
157    #[arg(conflicts_with = "asc", long)]
158    desc: bool,
159
160    /// Only include tests that used more gas that the given amount.
161    #[arg(long, value_name = "MIN_GAS")]
162    min: Option<u64>,
163
164    /// Only include tests that used less gas that the given amount.
165    #[arg(long, value_name = "MAX_GAS")]
166    max: Option<u64>,
167}
168
169/// Sort order for diff output
170#[derive(Clone, Debug, Default, clap::ValueEnum)]
171enum DiffSortOrder {
172    /// Sort by percentage change (smallest to largest) - default behavior
173    #[default]
174    Percentage,
175    /// Sort by percentage change (largest to smallest)
176    PercentageDesc,
177    /// Sort by absolute gas change (smallest to largest)
178    Absolute,
179    /// Sort by absolute gas change (largest to smallest)
180    AbsoluteDesc,
181}
182
183impl GasSnapshotConfig {
184    fn is_in_gas_range(&self, gas_used: u64) -> bool {
185        if let Some(min) = self.min
186            && gas_used < min
187        {
188            return false;
189        }
190        if let Some(max) = self.max
191            && gas_used > max
192        {
193            return false;
194        }
195        true
196    }
197
198    fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
199        let mut tests = outcome
200            .into_tests()
201            .filter(|test| self.is_in_gas_range(test.gas_used()))
202            .collect::<Vec<_>>();
203
204        if self.asc {
205            tests.sort_by_key(|a| a.gas_used());
206        } else if self.desc {
207            tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
208        }
209
210        tests
211    }
212}
213
214/// A general entry in a gas snapshot file
215///
216/// Has the form:
217///   `<signature>(gas:? 40181)` for normal tests
218///   `<signature>(runs: 256, μ: 40181, ~: 40181)` for fuzz tests
219///   `<signature>(runs: 256, calls: 40181, reverts: 40181)` for invariant tests
220#[derive(Clone, Debug, PartialEq, Eq)]
221pub struct GasSnapshotEntry {
222    pub contract_name: String,
223    pub signature: String,
224    pub gas_used: TestKindReport,
225}
226
227impl FromStr for GasSnapshotEntry {
228    type Err = String;
229
230    fn from_str(s: &str) -> Result<Self, Self::Err> {
231        RE_BASIC_SNAPSHOT_ENTRY
232            .captures(s)
233            .and_then(|cap| {
234                cap.name("file").and_then(|file| {
235                    cap.name("sig").and_then(|sig| {
236                        if let Some(gas) = cap.name("gas") {
237                            Some(Self {
238                                contract_name: file.as_str().to_string(),
239                                signature: sig.as_str().to_string(),
240                                gas_used: TestKindReport::Unit {
241                                    gas: gas.as_str().parse().unwrap(),
242                                },
243                            })
244                        } else if let Some(runs) = cap.name("runs") {
245                            cap.name("avg")
246                                .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
247                                .map(|(runs, avg, med)| Self {
248                                    contract_name: file.as_str().to_string(),
249                                    signature: sig.as_str().to_string(),
250                                    gas_used: TestKindReport::Fuzz {
251                                        runs: runs.as_str().parse().unwrap(),
252                                        median_gas: med.as_str().parse().unwrap(),
253                                        mean_gas: avg.as_str().parse().unwrap(),
254                                        failed_corpus_replays: 0,
255                                    },
256                                })
257                        } else {
258                            cap.name("invruns")
259                                .and_then(|runs| {
260                                    cap.name("calls").and_then(|avg| {
261                                        cap.name("reverts").map(|med| (runs, avg, med))
262                                    })
263                                })
264                                .map(|(runs, calls, reverts)| Self {
265                                    contract_name: file.as_str().to_string(),
266                                    signature: sig.as_str().to_string(),
267                                    gas_used: TestKindReport::Invariant {
268                                        runs: runs.as_str().parse().unwrap(),
269                                        calls: calls.as_str().parse().unwrap(),
270                                        reverts: reverts.as_str().parse().unwrap(),
271                                        metrics: HashMap::default(),
272                                        failed_corpus_replays: 0,
273                                        optimization_best_value: None,
274                                    },
275                                })
276                        }
277                    })
278                })
279            })
280            .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
281    }
282}
283
284/// Reads a list of gas snapshot entries from a gas snapshot file.
285fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
286    let path = path.as_ref();
287    let mut entries = Vec::new();
288    for line in io::BufReader::new(
289        fs::File::open(path)
290            .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
291    )
292    .lines()
293    {
294        entries
295            .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
296    }
297    Ok(entries)
298}
299
300/// Writes a series of tests to a gas snapshot file after sorting them.
301fn write_to_gas_snapshot_file(
302    tests: &[SuiteTestResult],
303    path: impl AsRef<Path>,
304    _format: Option<Format>,
305) -> Result<()> {
306    let mut reports = tests
307        .iter()
308        .map(|test| {
309            format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
310        })
311        .collect::<Vec<_>>();
312
313    // sort all reports
314    reports.sort();
315
316    let content = reports.join("\n");
317    Ok(fs::write(path, content)?)
318}
319
320fn build_gas_snapshot_table(tests: &[SuiteTestResult]) -> Table {
321    let mut table = Table::new();
322    if shell::is_markdown() {
323        table.load_preset(ASCII_MARKDOWN);
324    } else {
325        table.apply_modifier(UTF8_ROUND_CORNERS);
326    }
327
328    table.set_header(vec![
329        Cell::new("Contract").fg(Color::Cyan),
330        Cell::new("Signature").fg(Color::Cyan),
331        Cell::new("Report").fg(Color::Cyan),
332    ]);
333
334    for test in tests {
335        let mut row = Row::new();
336        row.add_cell(Cell::new(test.contract_name()));
337        row.add_cell(Cell::new(&test.signature));
338        row.add_cell(Cell::new(test.result.kind.report()));
339        table.add_row(row);
340    }
341
342    table
343}
344
345/// A Gas snapshot entry diff.
346#[derive(Clone, Debug, PartialEq, Eq)]
347pub struct GasSnapshotDiff {
348    pub signature: String,
349    pub source_gas_used: TestKindReport,
350    pub target_gas_used: TestKindReport,
351}
352
353impl GasSnapshotDiff {
354    /// Returns the gas diff
355    ///
356    /// `> 0` if the source used more gas
357    /// `< 0` if the target used more gas
358    fn gas_change(&self) -> i128 {
359        self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
360    }
361
362    /// Determines the percentage change
363    fn gas_diff(&self) -> f64 {
364        self.gas_change() as f64 / self.target_gas_used.gas() as f64
365    }
366}
367
368/// Compares the set of tests with an existing gas snapshot.
369///
370/// Returns true all tests match
371fn check(
372    tests: Vec<SuiteTestResult>,
373    snaps: Vec<GasSnapshotEntry>,
374    tolerance: Option<u32>,
375) -> bool {
376    let snaps = snaps
377        .into_iter()
378        .map(|s| ((s.contract_name, s.signature), s.gas_used))
379        .collect::<HashMap<_, _>>();
380    let mut has_diff = false;
381    for test in tests {
382        if let Some(target_gas) =
383            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
384        {
385            let source_gas = test.result.kind.report();
386            if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
387                let _ = sh_println!(
388                    "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
389                    test.contract_name(),
390                    test.signature,
391                    source_gas,
392                    target_gas
393                );
394                has_diff = true;
395            }
396        } else {
397            let _ = sh_println!(
398                "No matching snapshot entry found for \"{}::{}\" in snapshot file",
399                test.contract_name(),
400                test.signature
401            );
402            has_diff = true;
403        }
404    }
405    !has_diff
406}
407
408/// Compare the set of tests with an existing gas snapshot.
409fn diff(
410    tests: Vec<SuiteTestResult>,
411    snaps: Vec<GasSnapshotEntry>,
412    sort_order: DiffSortOrder,
413) -> Result<()> {
414    let snaps = snaps
415        .into_iter()
416        .map(|s| ((s.contract_name, s.signature), s.gas_used))
417        .collect::<HashMap<_, _>>();
418    let mut diffs = Vec::with_capacity(tests.len());
419    let mut new_tests = Vec::new();
420
421    for test in tests.into_iter() {
422        if let Some(target_gas_used) =
423            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
424        {
425            diffs.push(GasSnapshotDiff {
426                source_gas_used: test.result.kind.report(),
427                signature: format!("{}::{}", test.contract_name(), test.signature),
428                target_gas_used,
429            });
430        } else {
431            // Track new tests
432            new_tests.push(format!("{}::{}", test.contract_name(), test.signature));
433        }
434    }
435
436    let mut increased = 0;
437    let mut decreased = 0;
438    let mut unchanged = 0;
439    let mut overall_gas_change = 0i128;
440    let mut overall_gas_used = 0i128;
441
442    // Sort based on user preference
443    match sort_order {
444        DiffSortOrder::Percentage => {
445            // Default: sort by percentage change (smallest to largest)
446            diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
447        }
448        DiffSortOrder::PercentageDesc => {
449            // Sort by percentage change (largest to smallest)
450            diffs.sort_by(|a, b| b.gas_diff().abs().total_cmp(&a.gas_diff().abs()));
451        }
452        DiffSortOrder::Absolute => {
453            // Sort by absolute gas change (smallest to largest)
454            diffs.sort_by_key(|d| d.gas_change().abs());
455        }
456        DiffSortOrder::AbsoluteDesc => {
457            // Sort by absolute gas change (largest to smallest)
458            diffs.sort_by_key(|d| std::cmp::Reverse(d.gas_change().abs()));
459        }
460    }
461
462    for diff in &diffs {
463        let gas_change = diff.gas_change();
464        overall_gas_change += gas_change;
465        overall_gas_used += diff.target_gas_used.gas() as i128;
466        let gas_diff = diff.gas_diff();
467
468        // Classify changes
469        if gas_change > 0 {
470            increased += 1;
471        } else if gas_change < 0 {
472            decreased += 1;
473        } else {
474            unchanged += 1;
475        }
476
477        // Display with icon and before/after values
478        let icon = if gas_change > 0 {
479            "↑".red().to_string()
480        } else if gas_change < 0 {
481            "↓".green().to_string()
482        } else {
483            "━".to_string()
484        };
485
486        sh_println!(
487            "{} {} (gas: {} → {} | {} {})",
488            icon,
489            diff.signature,
490            diff.target_gas_used.gas(),
491            diff.source_gas_used.gas(),
492            fmt_change(gas_change),
493            fmt_pct_change(gas_diff)
494        )?;
495    }
496
497    // Display new tests if any
498    if !new_tests.is_empty() {
499        sh_println!("\n{}", "New tests:".yellow())?;
500        for test in new_tests {
501            sh_println!("  {} {}", "+".green(), test)?;
502        }
503    }
504
505    // Summary separator
506    sh_println!("\n{}", "-".repeat(80))?;
507
508    let overall_gas_diff = if overall_gas_used > 0 {
509        overall_gas_change as f64 / overall_gas_used as f64
510    } else {
511        0.0
512    };
513
514    sh_println!(
515        "Total tests: {}, {} {}, {} {}, {} {}",
516        diffs.len(),
517        "↑".red().to_string(),
518        increased,
519        "↓".green().to_string(),
520        decreased,
521        "━",
522        unchanged
523    )?;
524    sh_println!(
525        "Overall gas change: {} ({})",
526        fmt_change(overall_gas_change),
527        fmt_pct_change(overall_gas_diff)
528    )?;
529    Ok(())
530}
531
532fn fmt_pct_change(change: f64) -> String {
533    let change_pct = change * 100.0;
534    match change.total_cmp(&0.0) {
535        Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
536        Ordering::Equal => {
537            format!("{change_pct:.3}%")
538        }
539        Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
540    }
541}
542
543fn fmt_change(change: i128) -> String {
544    match change.cmp(&0) {
545        Ordering::Less => format!("{change}").green().to_string(),
546        Ordering::Equal => change.to_string(),
547        Ordering::Greater => format!("{change}").red().to_string(),
548    }
549}
550
551/// Returns true of the difference between the gas values exceeds the tolerance
552///
553/// If `tolerance` is `None`, then this returns `true` if both gas values are equal
554fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
555    if let Some(tolerance) = tolerance_pct {
556        let (hi, lo) = if source_gas > target_gas {
557            (source_gas, target_gas)
558        } else {
559            (target_gas, source_gas)
560        };
561        let diff = (1. - (lo as f64 / hi as f64)) * 100.;
562        diff < tolerance as f64
563    } else {
564        source_gas == target_gas
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    #[test]
573    fn test_tolerance() {
574        assert!(within_tolerance(100, 105, Some(5)));
575        assert!(within_tolerance(105, 100, Some(5)));
576        assert!(!within_tolerance(100, 106, Some(5)));
577        assert!(!within_tolerance(106, 100, Some(5)));
578        assert!(within_tolerance(100, 100, None));
579    }
580
581    #[test]
582    fn can_parse_basic_gas_snapshot_entry() {
583        let s = "Test:deposit() (gas: 7222)";
584        let entry = GasSnapshotEntry::from_str(s).unwrap();
585        assert_eq!(
586            entry,
587            GasSnapshotEntry {
588                contract_name: "Test".to_string(),
589                signature: "deposit()".to_string(),
590                gas_used: TestKindReport::Unit { gas: 7222 }
591            }
592        );
593    }
594
595    #[test]
596    fn can_parse_fuzz_gas_snapshot_entry() {
597        let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
598        let entry = GasSnapshotEntry::from_str(s).unwrap();
599        assert_eq!(
600            entry,
601            GasSnapshotEntry {
602                contract_name: "Test".to_string(),
603                signature: "deposit()".to_string(),
604                gas_used: TestKindReport::Fuzz {
605                    runs: 256,
606                    median_gas: 200,
607                    mean_gas: 100,
608                    failed_corpus_replays: 0
609                }
610            }
611        );
612    }
613
614    #[test]
615    fn can_parse_invariant_gas_snapshot_entry() {
616        let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
617        let entry = GasSnapshotEntry::from_str(s).unwrap();
618        assert_eq!(
619            entry,
620            GasSnapshotEntry {
621                contract_name: "Test".to_string(),
622                signature: "deposit()".to_string(),
623                gas_used: TestKindReport::Invariant {
624                    runs: 256,
625                    calls: 100,
626                    reverts: 200,
627                    metrics: HashMap::default(),
628                    failed_corpus_replays: 0,
629                    optimization_best_value: None,
630                }
631            }
632        );
633    }
634
635    #[test]
636    fn can_parse_invariant_gas_snapshot_entry2() {
637        let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
638        let entry = GasSnapshotEntry::from_str(s).unwrap();
639        assert_eq!(
640            entry,
641            GasSnapshotEntry {
642                contract_name: "ERC20Invariants".to_string(),
643                signature: "invariantBalanceSum()".to_string(),
644                gas_used: TestKindReport::Invariant {
645                    runs: 256,
646                    calls: 3840,
647                    reverts: 2388,
648                    metrics: HashMap::default(),
649                    failed_corpus_replays: 0,
650                    optimization_best_value: None,
651                }
652            }
653        );
654    }
655}