forge/cmd/
snapshot.rs

1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use comfy_table::{
6    Cell, Color, Row, Table, modifiers::UTF8_ROUND_CORNERS, presets::ASCII_MARKDOWN,
7};
8use eyre::{Context, Result};
9use foundry_cli::utils::STATIC_FUZZ_SEED;
10use foundry_common::shell;
11use regex::Regex;
12use std::{
13    cmp::Ordering,
14    fs,
15    io::{self, BufRead},
16    path::{Path, PathBuf},
17    str::FromStr,
18    sync::LazyLock,
19};
20use yansi::Paint;
21
22/// A regex that matches a basic snapshot entry like
23/// `Test:testDeposit() (gas: 58804)`
24pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
25    Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
26});
27
28/// CLI arguments for `forge snapshot`.
29#[derive(Clone, Debug, Parser)]
30pub struct GasSnapshotArgs {
31    /// Output a diff against a pre-existing gas snapshot.
32    ///
33    /// By default, the comparison is done with .gas-snapshot.
34    #[arg(
35        conflicts_with = "snap",
36        long,
37        value_hint = ValueHint::FilePath,
38        value_name = "SNAPSHOT_FILE",
39    )]
40    diff: Option<Option<PathBuf>>,
41
42    /// Compare against a pre-existing gas snapshot, exiting with code 1 if they do not match.
43    ///
44    /// Outputs a diff if the gas snapshots do not match.
45    ///
46    /// By default, the comparison is done with .gas-snapshot.
47    #[arg(
48        conflicts_with = "diff",
49        long,
50        value_hint = ValueHint::FilePath,
51        value_name = "SNAPSHOT_FILE",
52    )]
53    check: Option<Option<PathBuf>>,
54
55    // Hidden because there is only one option
56    /// How to format the output.
57    #[arg(long, hide(true))]
58    format: Option<Format>,
59
60    /// Output file for the gas snapshot.
61    #[arg(
62        long,
63        default_value = ".gas-snapshot",
64        value_hint = ValueHint::FilePath,
65        value_name = "FILE",
66    )]
67    snap: PathBuf,
68
69    /// Tolerates gas deviations up to the specified percentage.
70    #[arg(
71        long,
72        value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
73        value_name = "SNAPSHOT_THRESHOLD"
74    )]
75    tolerance: Option<u32>,
76
77    /// How to sort diff results.
78    #[arg(long, value_name = "ORDER")]
79    diff_sort: Option<DiffSortOrder>,
80
81    /// All test arguments are supported
82    #[command(flatten)]
83    pub(crate) test: test::TestArgs,
84
85    /// Additional configs for test results
86    #[command(flatten)]
87    config: GasSnapshotConfig,
88}
89
90impl GasSnapshotArgs {
91    /// Returns whether `GasSnapshotArgs` was configured with `--watch`
92    pub fn is_watch(&self) -> bool {
93        self.test.is_watch()
94    }
95
96    /// Returns the [`watchexec::Config`] necessary to bootstrap a new watch loop.
97    pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
98        self.test.watchexec_config()
99    }
100
101    pub async fn run(mut self) -> Result<()> {
102        // Set fuzz seed so gas snapshots are deterministic
103        self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
104
105        let outcome = self.test.compile_and_run().await?;
106        outcome.ensure_ok(false)?;
107        let tests = self.config.apply(outcome);
108
109        if let Some(path) = self.diff {
110            let snap = path.as_ref().unwrap_or(&self.snap);
111            let snaps = read_gas_snapshot(snap)?;
112            diff(tests, snaps, self.diff_sort.unwrap_or_default())?;
113        } else if let Some(path) = self.check {
114            let snap = path.as_ref().unwrap_or(&self.snap);
115            let snaps = read_gas_snapshot(snap)?;
116            if check(tests, snaps, self.tolerance) {
117                std::process::exit(0)
118            } else {
119                std::process::exit(1)
120            }
121        } else {
122            if matches!(self.format, Some(Format::Table)) {
123                let table = build_gas_snapshot_table(&tests);
124                sh_println!("\n{}", table)?;
125            }
126            write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
127        }
128        Ok(())
129    }
130}
131
132// Gas report format on stdout.
133#[derive(Clone, Debug)]
134pub enum Format {
135    Table,
136}
137
138impl FromStr for Format {
139    type Err = String;
140
141    fn from_str(s: &str) -> Result<Self, Self::Err> {
142        match s {
143            "t" | "table" => Ok(Self::Table),
144            _ => Err(format!("Unrecognized format `{s}`")),
145        }
146    }
147}
148
149/// Additional filters that can be applied on the test results
150#[derive(Clone, Debug, Default, Parser)]
151struct GasSnapshotConfig {
152    /// Sort results by gas used (ascending).
153    #[arg(long)]
154    asc: bool,
155
156    /// Sort results by gas used (descending).
157    #[arg(conflicts_with = "asc", long)]
158    desc: bool,
159
160    /// Only include tests that used more gas that the given amount.
161    #[arg(long, value_name = "MIN_GAS")]
162    min: Option<u64>,
163
164    /// Only include tests that used less gas that the given amount.
165    #[arg(long, value_name = "MAX_GAS")]
166    max: Option<u64>,
167}
168
169/// Sort order for diff output
170#[derive(Clone, Debug, Default, clap::ValueEnum)]
171enum DiffSortOrder {
172    /// Sort by percentage change (smallest to largest) - default behavior
173    #[default]
174    Percentage,
175    /// Sort by percentage change (largest to smallest)
176    PercentageDesc,
177    /// Sort by absolute gas change (smallest to largest)
178    Absolute,
179    /// Sort by absolute gas change (largest to smallest)
180    AbsoluteDesc,
181}
182
183impl GasSnapshotConfig {
184    fn is_in_gas_range(&self, gas_used: u64) -> bool {
185        if let Some(min) = self.min
186            && gas_used < min
187        {
188            return false;
189        }
190        if let Some(max) = self.max
191            && gas_used > max
192        {
193            return false;
194        }
195        true
196    }
197
198    fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
199        let mut tests = outcome
200            .into_tests()
201            .filter(|test| self.is_in_gas_range(test.gas_used()))
202            .collect::<Vec<_>>();
203
204        if self.asc {
205            tests.sort_by_key(|a| a.gas_used());
206        } else if self.desc {
207            tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
208        }
209
210        tests
211    }
212}
213
214/// A general entry in a gas snapshot file
215///
216/// Has the form:
217///   `<signature>(gas:? 40181)` for normal tests
218///   `<signature>(runs: 256, μ: 40181, ~: 40181)` for fuzz tests
219///   `<signature>(runs: 256, calls: 40181, reverts: 40181)` for invariant tests
220#[derive(Clone, Debug, PartialEq, Eq)]
221pub struct GasSnapshotEntry {
222    pub contract_name: String,
223    pub signature: String,
224    pub gas_used: TestKindReport,
225}
226
227impl FromStr for GasSnapshotEntry {
228    type Err = String;
229
230    fn from_str(s: &str) -> Result<Self, Self::Err> {
231        RE_BASIC_SNAPSHOT_ENTRY
232            .captures(s)
233            .and_then(|cap| {
234                cap.name("file").and_then(|file| {
235                    cap.name("sig").and_then(|sig| {
236                        if let Some(gas) = cap.name("gas") {
237                            Some(Self {
238                                contract_name: file.as_str().to_string(),
239                                signature: sig.as_str().to_string(),
240                                gas_used: TestKindReport::Unit {
241                                    gas: gas.as_str().parse().unwrap(),
242                                },
243                            })
244                        } else if let Some(runs) = cap.name("runs") {
245                            cap.name("avg")
246                                .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
247                                .map(|(runs, avg, med)| Self {
248                                    contract_name: file.as_str().to_string(),
249                                    signature: sig.as_str().to_string(),
250                                    gas_used: TestKindReport::Fuzz {
251                                        runs: runs.as_str().parse().unwrap(),
252                                        median_gas: med.as_str().parse().unwrap(),
253                                        mean_gas: avg.as_str().parse().unwrap(),
254                                        failed_corpus_replays: 0,
255                                    },
256                                })
257                        } else {
258                            cap.name("invruns")
259                                .and_then(|runs| {
260                                    cap.name("calls").and_then(|avg| {
261                                        cap.name("reverts").map(|med| (runs, avg, med))
262                                    })
263                                })
264                                .map(|(runs, calls, reverts)| Self {
265                                    contract_name: file.as_str().to_string(),
266                                    signature: sig.as_str().to_string(),
267                                    gas_used: TestKindReport::Invariant {
268                                        runs: runs.as_str().parse().unwrap(),
269                                        calls: calls.as_str().parse().unwrap(),
270                                        reverts: reverts.as_str().parse().unwrap(),
271                                        metrics: HashMap::default(),
272                                        failed_corpus_replays: 0,
273                                    },
274                                })
275                        }
276                    })
277                })
278            })
279            .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
280    }
281}
282
283/// Reads a list of gas snapshot entries from a gas snapshot file.
284fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
285    let path = path.as_ref();
286    let mut entries = Vec::new();
287    for line in io::BufReader::new(
288        fs::File::open(path)
289            .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
290    )
291    .lines()
292    {
293        entries
294            .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
295    }
296    Ok(entries)
297}
298
299/// Writes a series of tests to a gas snapshot file after sorting them.
300fn write_to_gas_snapshot_file(
301    tests: &[SuiteTestResult],
302    path: impl AsRef<Path>,
303    _format: Option<Format>,
304) -> Result<()> {
305    let mut reports = tests
306        .iter()
307        .map(|test| {
308            format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
309        })
310        .collect::<Vec<_>>();
311
312    // sort all reports
313    reports.sort();
314
315    let content = reports.join("\n");
316    Ok(fs::write(path, content)?)
317}
318
319fn build_gas_snapshot_table(tests: &[SuiteTestResult]) -> Table {
320    let mut table = Table::new();
321    if shell::is_markdown() {
322        table.load_preset(ASCII_MARKDOWN);
323    } else {
324        table.apply_modifier(UTF8_ROUND_CORNERS);
325    }
326
327    table.set_header(vec![
328        Cell::new("Contract").fg(Color::Cyan),
329        Cell::new("Signature").fg(Color::Cyan),
330        Cell::new("Report").fg(Color::Cyan),
331    ]);
332
333    for test in tests {
334        let mut row = Row::new();
335        row.add_cell(Cell::new(test.contract_name()));
336        row.add_cell(Cell::new(&test.signature));
337        row.add_cell(Cell::new(test.result.kind.report().to_string()));
338        table.add_row(row);
339    }
340
341    table
342}
343
344/// A Gas snapshot entry diff.
345#[derive(Clone, Debug, PartialEq, Eq)]
346pub struct GasSnapshotDiff {
347    pub signature: String,
348    pub source_gas_used: TestKindReport,
349    pub target_gas_used: TestKindReport,
350}
351
352impl GasSnapshotDiff {
353    /// Returns the gas diff
354    ///
355    /// `> 0` if the source used more gas
356    /// `< 0` if the target used more gas
357    fn gas_change(&self) -> i128 {
358        self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
359    }
360
361    /// Determines the percentage change
362    fn gas_diff(&self) -> f64 {
363        self.gas_change() as f64 / self.target_gas_used.gas() as f64
364    }
365}
366
367/// Compares the set of tests with an existing gas snapshot.
368///
369/// Returns true all tests match
370fn check(
371    tests: Vec<SuiteTestResult>,
372    snaps: Vec<GasSnapshotEntry>,
373    tolerance: Option<u32>,
374) -> bool {
375    let snaps = snaps
376        .into_iter()
377        .map(|s| ((s.contract_name, s.signature), s.gas_used))
378        .collect::<HashMap<_, _>>();
379    let mut has_diff = false;
380    for test in tests {
381        if let Some(target_gas) =
382            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
383        {
384            let source_gas = test.result.kind.report();
385            if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
386                let _ = sh_println!(
387                    "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
388                    test.contract_name(),
389                    test.signature,
390                    source_gas,
391                    target_gas
392                );
393                has_diff = true;
394            }
395        } else {
396            let _ = sh_println!(
397                "No matching snapshot entry found for \"{}::{}\" in snapshot file",
398                test.contract_name(),
399                test.signature
400            );
401            has_diff = true;
402        }
403    }
404    !has_diff
405}
406
407/// Compare the set of tests with an existing gas snapshot.
408fn diff(
409    tests: Vec<SuiteTestResult>,
410    snaps: Vec<GasSnapshotEntry>,
411    sort_order: DiffSortOrder,
412) -> Result<()> {
413    let snaps = snaps
414        .into_iter()
415        .map(|s| ((s.contract_name, s.signature), s.gas_used))
416        .collect::<HashMap<_, _>>();
417    let mut diffs = Vec::with_capacity(tests.len());
418    let mut new_tests = Vec::new();
419
420    for test in tests.into_iter() {
421        if let Some(target_gas_used) =
422            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
423        {
424            diffs.push(GasSnapshotDiff {
425                source_gas_used: test.result.kind.report(),
426                signature: format!("{}::{}", test.contract_name(), test.signature),
427                target_gas_used,
428            });
429        } else {
430            // Track new tests
431            new_tests.push(format!("{}::{}", test.contract_name(), test.signature));
432        }
433    }
434
435    let mut increased = 0;
436    let mut decreased = 0;
437    let mut unchanged = 0;
438    let mut overall_gas_change = 0i128;
439    let mut overall_gas_used = 0i128;
440
441    // Sort based on user preference
442    match sort_order {
443        DiffSortOrder::Percentage => {
444            // Default: sort by percentage change (smallest to largest)
445            diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
446        }
447        DiffSortOrder::PercentageDesc => {
448            // Sort by percentage change (largest to smallest)
449            diffs.sort_by(|a, b| b.gas_diff().abs().total_cmp(&a.gas_diff().abs()));
450        }
451        DiffSortOrder::Absolute => {
452            // Sort by absolute gas change (smallest to largest)
453            diffs.sort_by_key(|d| d.gas_change().abs());
454        }
455        DiffSortOrder::AbsoluteDesc => {
456            // Sort by absolute gas change (largest to smallest)
457            diffs.sort_by_key(|d| std::cmp::Reverse(d.gas_change().abs()));
458        }
459    }
460
461    for diff in &diffs {
462        let gas_change = diff.gas_change();
463        overall_gas_change += gas_change;
464        overall_gas_used += diff.target_gas_used.gas() as i128;
465        let gas_diff = diff.gas_diff();
466
467        // Classify changes
468        if gas_change > 0 {
469            increased += 1;
470        } else if gas_change < 0 {
471            decreased += 1;
472        } else {
473            unchanged += 1;
474        }
475
476        // Display with icon and before/after values
477        let icon = if gas_change > 0 {
478            "↑".red().to_string()
479        } else if gas_change < 0 {
480            "↓".green().to_string()
481        } else {
482            "━".to_string()
483        };
484
485        sh_println!(
486            "{} {} (gas: {} → {} | {} {})",
487            icon,
488            diff.signature,
489            diff.target_gas_used.gas(),
490            diff.source_gas_used.gas(),
491            fmt_change(gas_change),
492            fmt_pct_change(gas_diff)
493        )?;
494    }
495
496    // Display new tests if any
497    if !new_tests.is_empty() {
498        sh_println!("\n{}", "New tests:".yellow())?;
499        for test in new_tests {
500            sh_println!("  {} {}", "+".green(), test)?;
501        }
502    }
503
504    // Summary separator
505    sh_println!("\n{}", "-".repeat(80))?;
506
507    let overall_gas_diff = if overall_gas_used > 0 {
508        overall_gas_change as f64 / overall_gas_used as f64
509    } else {
510        0.0
511    };
512
513    sh_println!(
514        "Total tests: {}, {} {}, {} {}, {} {}",
515        diffs.len(),
516        "↑".red().to_string(),
517        increased,
518        "↓".green().to_string(),
519        decreased,
520        "━",
521        unchanged
522    )?;
523    sh_println!(
524        "Overall gas change: {} ({})",
525        fmt_change(overall_gas_change),
526        fmt_pct_change(overall_gas_diff)
527    )?;
528    Ok(())
529}
530
531fn fmt_pct_change(change: f64) -> String {
532    let change_pct = change * 100.0;
533    match change.total_cmp(&0.0) {
534        Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
535        Ordering::Equal => {
536            format!("{change_pct:.3}%")
537        }
538        Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
539    }
540}
541
542fn fmt_change(change: i128) -> String {
543    match change.cmp(&0) {
544        Ordering::Less => format!("{change}").green().to_string(),
545        Ordering::Equal => {
546            format!("{change}")
547        }
548        Ordering::Greater => format!("{change}").red().to_string(),
549    }
550}
551
552/// Returns true of the difference between the gas values exceeds the tolerance
553///
554/// If `tolerance` is `None`, then this returns `true` if both gas values are equal
555fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
556    if let Some(tolerance) = tolerance_pct {
557        let (hi, lo) = if source_gas > target_gas {
558            (source_gas, target_gas)
559        } else {
560            (target_gas, source_gas)
561        };
562        let diff = (1. - (lo as f64 / hi as f64)) * 100.;
563        diff < tolerance as f64
564    } else {
565        source_gas == target_gas
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    #[test]
574    fn test_tolerance() {
575        assert!(within_tolerance(100, 105, Some(5)));
576        assert!(within_tolerance(105, 100, Some(5)));
577        assert!(!within_tolerance(100, 106, Some(5)));
578        assert!(!within_tolerance(106, 100, Some(5)));
579        assert!(within_tolerance(100, 100, None));
580    }
581
582    #[test]
583    fn can_parse_basic_gas_snapshot_entry() {
584        let s = "Test:deposit() (gas: 7222)";
585        let entry = GasSnapshotEntry::from_str(s).unwrap();
586        assert_eq!(
587            entry,
588            GasSnapshotEntry {
589                contract_name: "Test".to_string(),
590                signature: "deposit()".to_string(),
591                gas_used: TestKindReport::Unit { gas: 7222 }
592            }
593        );
594    }
595
596    #[test]
597    fn can_parse_fuzz_gas_snapshot_entry() {
598        let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
599        let entry = GasSnapshotEntry::from_str(s).unwrap();
600        assert_eq!(
601            entry,
602            GasSnapshotEntry {
603                contract_name: "Test".to_string(),
604                signature: "deposit()".to_string(),
605                gas_used: TestKindReport::Fuzz {
606                    runs: 256,
607                    median_gas: 200,
608                    mean_gas: 100,
609                    failed_corpus_replays: 0
610                }
611            }
612        );
613    }
614
615    #[test]
616    fn can_parse_invariant_gas_snapshot_entry() {
617        let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
618        let entry = GasSnapshotEntry::from_str(s).unwrap();
619        assert_eq!(
620            entry,
621            GasSnapshotEntry {
622                contract_name: "Test".to_string(),
623                signature: "deposit()".to_string(),
624                gas_used: TestKindReport::Invariant {
625                    runs: 256,
626                    calls: 100,
627                    reverts: 200,
628                    metrics: HashMap::default(),
629                    failed_corpus_replays: 0,
630                }
631            }
632        );
633    }
634
635    #[test]
636    fn can_parse_invariant_gas_snapshot_entry2() {
637        let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
638        let entry = GasSnapshotEntry::from_str(s).unwrap();
639        assert_eq!(
640            entry,
641            GasSnapshotEntry {
642                contract_name: "ERC20Invariants".to_string(),
643                signature: "invariantBalanceSum()".to_string(),
644                gas_used: TestKindReport::Invariant {
645                    runs: 256,
646                    calls: 3840,
647                    reverts: 2388,
648                    metrics: HashMap::default(),
649                    failed_corpus_replays: 0,
650                }
651            }
652        );
653    }
654}