Skip to main content

forge/cmd/
snapshot.rs

1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use comfy_table::{
6    Cell, Color, Row, Table, modifiers::UTF8_ROUND_CORNERS, presets::ASCII_MARKDOWN,
7};
8use eyre::{Context, Result};
9use foundry_cli::utils::STATIC_FUZZ_SEED;
10use foundry_common::shell;
11use regex::Regex;
12use std::{
13    cmp::Ordering,
14    fs,
15    io::{self, BufRead},
16    path::{Path, PathBuf},
17    str::FromStr,
18    sync::LazyLock,
19};
20use yansi::Paint;
21
22/// A regex that matches a basic snapshot entry like
23/// `Test:testDeposit() (gas: 58804)`
24pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
26});
27
28/// CLI arguments for `forge snapshot`.
29#[derive(Clone, Debug, Parser)]
30pub struct GasSnapshotArgs {
31    /// Output a diff against a pre-existing gas snapshot.
32    ///
33    /// By default, the comparison is done with .gas-snapshot.
34    #[arg(
35        conflicts_with = "snap",
36        long,
37        value_hint = ValueHint::FilePath,
38        value_name = "SNAPSHOT_FILE",
39    )]
40    diff: Option<Option<PathBuf>>,
41
42    /// Compare against a pre-existing gas snapshot, exiting with code 1 if they do not match.
43    ///
44    /// Outputs a diff if the gas snapshots do not match.
45    ///
46    /// By default, the comparison is done with .gas-snapshot.
47    #[arg(
48        conflicts_with = "diff",
49        long,
50        value_hint = ValueHint::FilePath,
51        value_name = "SNAPSHOT_FILE",
52    )]
53    check: Option<Option<PathBuf>>,
54
55    // Hidden because there is only one option
56    /// How to format the output.
57    #[arg(long, hide(true))]
58    format: Option<Format>,
59
60    /// Output file for the gas snapshot.
61    #[arg(
62        long,
63        default_value = ".gas-snapshot",
64        value_hint = ValueHint::FilePath,
65        value_name = "FILE",
66    )]
67    snap: PathBuf,
68
69    /// Tolerates gas deviations up to the specified percentage.
70    #[arg(
71        long,
72        value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
73        value_name = "SNAPSHOT_THRESHOLD"
74    )]
75    tolerance: Option<u32>,
76
77    /// How to sort diff results.
78    #[arg(long, value_name = "ORDER")]
79    diff_sort: Option<DiffSortOrder>,
80
81    /// All test arguments are supported
82    #[command(flatten)]
83    pub(crate) test: test::TestArgs,
84
85    /// Additional configs for test results
86    #[command(flatten)]
87    config: GasSnapshotConfig,
88}
89
90impl GasSnapshotArgs {
91    /// Returns whether `GasSnapshotArgs` was configured with `--watch`
92    pub const fn is_watch(&self) -> bool {
93        self.test.is_watch()
94    }
95
96    /// Returns the [`watchexec::Config`] necessary to bootstrap a new watch loop.
97    pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
98        self.test.watchexec_config()
99    }
100
101    pub async fn run(mut self) -> Result<()> {
102        // Default to a static fuzz seed so gas snapshots are deterministic,
103        // but allow the user to override it via `--fuzz-seed`.
104        if self.test.fuzz_seed.is_none() {
105            self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
106        }
107
108        let outcome = self.test.compile_and_run().await?;
109        outcome.ensure_ok(false)?;
110        let tests = self.config.apply(outcome);
111
112        if let Some(path) = self.diff {
113            let snap = path.as_ref().unwrap_or(&self.snap);
114            let snaps = read_gas_snapshot(snap)?;
115            diff(tests, snaps, self.diff_sort.unwrap_or_default())?;
116        } else if let Some(path) = self.check {
117            let snap = path.as_ref().unwrap_or(&self.snap);
118            let snaps = read_gas_snapshot(snap)?;
119            if check(tests, snaps, self.tolerance) {
120                std::process::exit(0)
121            } else {
122                std::process::exit(1)
123            }
124        } else {
125            if matches!(self.format, Some(Format::Table)) {
126                let table = build_gas_snapshot_table(&tests);
127                sh_println!("\n{}", table)?;
128            }
129            write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
130        }
131        Ok(())
132    }
133}
134
135// Gas report format on stdout.
136#[derive(Clone, Debug)]
137pub enum Format {
138    Table,
139}
140
141impl FromStr for Format {
142    type Err = String;
143
144    fn from_str(s: &str) -> Result<Self, Self::Err> {
145        match s {
146            "t" | "table" => Ok(Self::Table),
147            _ => Err(format!("Unrecognized format `{s}`")),
148        }
149    }
150}
151
152/// Additional filters that can be applied on the test results
153#[derive(Clone, Debug, Default, Parser)]
154struct GasSnapshotConfig {
155    /// Sort results by gas used (ascending).
156    #[arg(long)]
157    asc: bool,
158
159    /// Sort results by gas used (descending).
160    #[arg(conflicts_with = "asc", long)]
161    desc: bool,
162
163    /// Only include tests that used more gas that the given amount.
164    #[arg(long, value_name = "MIN_GAS")]
165    min: Option<u64>,
166
167    /// Only include tests that used less gas that the given amount.
168    #[arg(long, value_name = "MAX_GAS")]
169    max: Option<u64>,
170}
171
172/// Sort order for diff output
173#[derive(Clone, Debug, Default, clap::ValueEnum)]
174enum DiffSortOrder {
175    /// Sort by percentage change (smallest to largest) - default behavior
176    #[default]
177    Percentage,
178    /// Sort by percentage change (largest to smallest)
179    PercentageDesc,
180    /// Sort by absolute gas change (smallest to largest)
181    Absolute,
182    /// Sort by absolute gas change (largest to smallest)
183    AbsoluteDesc,
184}
185
186impl GasSnapshotConfig {
187    const fn is_in_gas_range(&self, gas_used: u64) -> bool {
188        if let Some(min) = self.min
189            && gas_used < min
190        {
191            return false;
192        }
193        if let Some(max) = self.max
194            && gas_used > max
195        {
196            return false;
197        }
198        true
199    }
200
201    fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
202        let mut tests = outcome
203            .into_tests()
204            .filter(|test| self.is_in_gas_range(test.gas_used()))
205            .collect::<Vec<_>>();
206
207        if self.asc {
208            tests.sort_by_key(|a| a.gas_used());
209        } else if self.desc {
210            tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
211        }
212
213        tests
214    }
215}
216
217/// A general entry in a gas snapshot file
218///
219/// Has the form:
220///   `<signature>(gas:? 40181)` for normal tests
221///   `<signature>(runs: 256, μ: 40181, ~: 40181)` for fuzz tests
222///   `<signature>(runs: 256, calls: 40181, reverts: 40181)` for invariant tests
223#[derive(Clone, Debug, PartialEq, Eq)]
224pub struct GasSnapshotEntry {
225    pub contract_name: String,
226    pub signature: String,
227    pub gas_used: TestKindReport,
228}
229
230impl FromStr for GasSnapshotEntry {
231    type Err = String;
232
233    fn from_str(s: &str) -> Result<Self, Self::Err> {
234        RE_BASIC_SNAPSHOT_ENTRY
235            .captures(s)
236            .and_then(|cap| {
237                cap.name("file").and_then(|file| {
238                    cap.name("sig").and_then(|sig| {
239                        if let Some(gas) = cap.name("gas") {
240                            Some(Self {
241                                contract_name: file.as_str().to_string(),
242                                signature: sig.as_str().to_string(),
243                                gas_used: TestKindReport::Unit {
244                                    gas: gas.as_str().parse().unwrap(),
245                                },
246                            })
247                        } else if let Some(runs) = cap.name("runs") {
248                            cap.name("avg")
249                                .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
250                                .map(|(runs, avg, med)| Self {
251                                    contract_name: file.as_str().to_string(),
252                                    signature: sig.as_str().to_string(),
253                                    gas_used: TestKindReport::Fuzz {
254                                        runs: runs.as_str().parse().unwrap(),
255                                        median_gas: med.as_str().parse().unwrap(),
256                                        mean_gas: avg.as_str().parse().unwrap(),
257                                        failed_corpus_replays: 0,
258                                    },
259                                })
260                        } else {
261                            cap.name("invruns")
262                                .and_then(|runs| {
263                                    cap.name("calls").and_then(|avg| {
264                                        cap.name("reverts").map(|med| (runs, avg, med))
265                                    })
266                                })
267                                .map(|(runs, calls, reverts)| Self {
268                                    contract_name: file.as_str().to_string(),
269                                    signature: sig.as_str().to_string(),
270                                    gas_used: TestKindReport::Invariant {
271                                        runs: runs.as_str().parse().unwrap(),
272                                        calls: calls.as_str().parse().unwrap(),
273                                        reverts: reverts.as_str().parse().unwrap(),
274                                        metrics: HashMap::default(),
275                                        failed_corpus_replays: 0,
276                                        optimization_best_value: None,
277                                    },
278                                })
279                        }
280                    })
281                })
282            })
283            .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
284    }
285}
286
287/// Reads a list of gas snapshot entries from a gas snapshot file.
288fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
289    let path = path.as_ref();
290    let mut entries = Vec::new();
291    for line in io::BufReader::new(
292        fs::File::open(path)
293            .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
294    )
295    .lines()
296    {
297        entries
298            .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
299    }
300    Ok(entries)
301}
302
303/// Writes a series of tests to a gas snapshot file after sorting them.
304fn write_to_gas_snapshot_file(
305    tests: &[SuiteTestResult],
306    path: impl AsRef<Path>,
307    _format: Option<Format>,
308) -> Result<()> {
309    let mut reports = tests
310        .iter()
311        .map(|test| {
312            format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
313        })
314        .collect::<Vec<_>>();
315
316    // sort all reports
317    reports.sort();
318
319    let content = reports.join("\n");
320    Ok(fs::write(path, content)?)
321}
322
323fn build_gas_snapshot_table(tests: &[SuiteTestResult]) -> Table {
324    let mut table = Table::new();
325    if shell::is_markdown() {
326        table.load_preset(ASCII_MARKDOWN);
327    } else {
328        table.apply_modifier(UTF8_ROUND_CORNERS);
329    }
330
331    table.set_header(vec![
332        Cell::new("Contract").fg(Color::Cyan),
333        Cell::new("Signature").fg(Color::Cyan),
334        Cell::new("Report").fg(Color::Cyan),
335    ]);
336
337    for test in tests {
338        let mut row = Row::new();
339        row.add_cell(Cell::new(test.contract_name()));
340        row.add_cell(Cell::new(&test.signature));
341        row.add_cell(Cell::new(test.result.kind.report()));
342        table.add_row(row);
343    }
344
345    table
346}
347
348/// A Gas snapshot entry diff.
349#[derive(Clone, Debug, PartialEq, Eq)]
350pub struct GasSnapshotDiff {
351    pub signature: String,
352    pub source_gas_used: TestKindReport,
353    pub target_gas_used: TestKindReport,
354}
355
356impl GasSnapshotDiff {
357    /// Returns the gas diff
358    ///
359    /// `> 0` if the source used more gas
360    /// `< 0` if the target used more gas
361    const fn gas_change(&self) -> i128 {
362        self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
363    }
364
365    /// Determines the percentage change
366    fn gas_diff(&self) -> f64 {
367        self.gas_change() as f64 / self.target_gas_used.gas() as f64
368    }
369}
370
371/// Compares the set of tests with an existing gas snapshot.
372///
373/// Returns true all tests match
374fn check(
375    tests: Vec<SuiteTestResult>,
376    snaps: Vec<GasSnapshotEntry>,
377    tolerance: Option<u32>,
378) -> bool {
379    let snaps = snaps
380        .into_iter()
381        .map(|s| ((s.contract_name, s.signature), s.gas_used))
382        .collect::<HashMap<_, _>>();
383    let mut has_diff = false;
384    for test in tests {
385        if let Some(target_gas) =
386            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
387        {
388            let source_gas = test.result.kind.report();
389            if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
390                let _ = sh_println!(
391                    "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
392                    test.contract_name(),
393                    test.signature,
394                    source_gas,
395                    target_gas
396                );
397                has_diff = true;
398            }
399        } else {
400            let _ = sh_println!(
401                "No matching snapshot entry found for \"{}::{}\" in snapshot file",
402                test.contract_name(),
403                test.signature
404            );
405            has_diff = true;
406        }
407    }
408    !has_diff
409}
410
411/// Compare the set of tests with an existing gas snapshot.
412fn diff(
413    tests: Vec<SuiteTestResult>,
414    snaps: Vec<GasSnapshotEntry>,
415    sort_order: DiffSortOrder,
416) -> Result<()> {
417    let snaps = snaps
418        .into_iter()
419        .map(|s| ((s.contract_name, s.signature), s.gas_used))
420        .collect::<HashMap<_, _>>();
421    let mut diffs = Vec::with_capacity(tests.len());
422    let mut new_tests = Vec::new();
423
424    for test in tests {
425        if let Some(target_gas_used) =
426            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
427        {
428            diffs.push(GasSnapshotDiff {
429                source_gas_used: test.result.kind.report(),
430                signature: format!("{}::{}", test.contract_name(), test.signature),
431                target_gas_used,
432            });
433        } else {
434            // Track new tests
435            new_tests.push(format!("{}::{}", test.contract_name(), test.signature));
436        }
437    }
438
439    let mut increased = 0;
440    let mut decreased = 0;
441    let mut unchanged = 0;
442    let mut overall_gas_change = 0i128;
443    let mut overall_gas_used = 0i128;
444
445    // Sort based on user preference
446    match sort_order {
447        DiffSortOrder::Percentage => {
448            // Default: sort by percentage change (smallest to largest)
449            diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
450        }
451        DiffSortOrder::PercentageDesc => {
452            // Sort by percentage change (largest to smallest)
453            diffs.sort_by(|a, b| b.gas_diff().abs().total_cmp(&a.gas_diff().abs()));
454        }
455        DiffSortOrder::Absolute => {
456            // Sort by absolute gas change (smallest to largest)
457            diffs.sort_by_key(|d| d.gas_change().abs());
458        }
459        DiffSortOrder::AbsoluteDesc => {
460            // Sort by absolute gas change (largest to smallest)
461            diffs.sort_by_key(|d| std::cmp::Reverse(d.gas_change().abs()));
462        }
463    }
464
465    for diff in &diffs {
466        let gas_change = diff.gas_change();
467        overall_gas_change += gas_change;
468        overall_gas_used += diff.target_gas_used.gas() as i128;
469        let gas_diff = diff.gas_diff();
470
471        // Classify changes
472        if gas_change > 0 {
473            increased += 1;
474        } else if gas_change < 0 {
475            decreased += 1;
476        } else {
477            unchanged += 1;
478        }
479
480        // Display with icon and before/after values
481        let icon = if gas_change > 0 {
482            "↑".red().to_string()
483        } else if gas_change < 0 {
484            "↓".green().to_string()
485        } else {
486            "━".to_string()
487        };
488
489        sh_println!(
490            "{} {} (gas: {} → {} | {} {})",
491            icon,
492            diff.signature,
493            diff.target_gas_used.gas(),
494            diff.source_gas_used.gas(),
495            fmt_change(gas_change),
496            fmt_pct_change(gas_diff)
497        )?;
498    }
499
500    // Display new tests if any
501    if !new_tests.is_empty() {
502        sh_println!("\n{}", "New tests:".yellow())?;
503        for test in new_tests {
504            sh_println!("  {} {}", "+".green(), test)?;
505        }
506    }
507
508    // Summary separator
509    sh_println!("\n{}", "-".repeat(80))?;
510
511    let overall_gas_diff = if overall_gas_used > 0 {
512        overall_gas_change as f64 / overall_gas_used as f64
513    } else {
514        0.0
515    };
516
517    sh_println!(
518        "Total tests: {}, {} {}, {} {}, {} {}",
519        diffs.len(),
520        "↑".red().to_string(),
521        increased,
522        "↓".green().to_string(),
523        decreased,
524        "━",
525        unchanged
526    )?;
527    sh_println!(
528        "Overall gas change: {} ({})",
529        fmt_change(overall_gas_change),
530        fmt_pct_change(overall_gas_diff)
531    )?;
532    Ok(())
533}
534
535fn fmt_pct_change(change: f64) -> String {
536    let change_pct = change * 100.0;
537    match change.total_cmp(&0.0) {
538        Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
539        Ordering::Equal => {
540            format!("{change_pct:.3}%")
541        }
542        Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
543    }
544}
545
546fn fmt_change(change: i128) -> String {
547    match change.cmp(&0) {
548        Ordering::Less => format!("{change}").green().to_string(),
549        Ordering::Equal => change.to_string(),
550        Ordering::Greater => format!("{change}").red().to_string(),
551    }
552}
553
554/// Returns true of the difference between the gas values exceeds the tolerance
555///
556/// If `tolerance` is `None`, then this returns `true` if both gas values are equal
557fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
558    if let Some(tolerance) = tolerance_pct {
559        let (hi, lo) = if source_gas > target_gas {
560            (source_gas, target_gas)
561        } else {
562            (target_gas, source_gas)
563        };
564        let diff = (1. - (lo as f64 / hi as f64)) * 100.;
565        diff < tolerance as f64
566    } else {
567        source_gas == target_gas
568    }
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    #[test]
576    fn test_tolerance() {
577        assert!(within_tolerance(100, 105, Some(5)));
578        assert!(within_tolerance(105, 100, Some(5)));
579        assert!(!within_tolerance(100, 106, Some(5)));
580        assert!(!within_tolerance(106, 100, Some(5)));
581        assert!(within_tolerance(100, 100, None));
582    }
583
584    #[test]
585    fn can_parse_basic_gas_snapshot_entry() {
586        let s = "Test:deposit() (gas: 7222)";
587        let entry = GasSnapshotEntry::from_str(s).unwrap();
588        assert_eq!(
589            entry,
590            GasSnapshotEntry {
591                contract_name: "Test".to_string(),
592                signature: "deposit()".to_string(),
593                gas_used: TestKindReport::Unit { gas: 7222 }
594            }
595        );
596    }
597
598    #[test]
599    fn can_parse_fuzz_gas_snapshot_entry() {
600        let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
601        let entry = GasSnapshotEntry::from_str(s).unwrap();
602        assert_eq!(
603            entry,
604            GasSnapshotEntry {
605                contract_name: "Test".to_string(),
606                signature: "deposit()".to_string(),
607                gas_used: TestKindReport::Fuzz {
608                    runs: 256,
609                    median_gas: 200,
610                    mean_gas: 100,
611                    failed_corpus_replays: 0
612                }
613            }
614        );
615    }
616
617    #[test]
618    fn can_parse_invariant_gas_snapshot_entry() {
619        let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
620        let entry = GasSnapshotEntry::from_str(s).unwrap();
621        assert_eq!(
622            entry,
623            GasSnapshotEntry {
624                contract_name: "Test".to_string(),
625                signature: "deposit()".to_string(),
626                gas_used: TestKindReport::Invariant {
627                    runs: 256,
628                    calls: 100,
629                    reverts: 200,
630                    metrics: HashMap::default(),
631                    failed_corpus_replays: 0,
632                    optimization_best_value: None,
633                }
634            }
635        );
636    }
637
638    #[test]
639    fn can_parse_invariant_gas_snapshot_entry2() {
640        let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
641        let entry = GasSnapshotEntry::from_str(s).unwrap();
642        assert_eq!(
643            entry,
644            GasSnapshotEntry {
645                contract_name: "ERC20Invariants".to_string(),
646                signature: "invariantBalanceSum()".to_string(),
647                gas_used: TestKindReport::Invariant {
648                    runs: 256,
649                    calls: 3840,
650                    reverts: 2388,
651                    metrics: HashMap::default(),
652                    failed_corpus_replays: 0,
653                    optimization_best_value: None,
654                }
655            }
656        );
657    }
658}