forge/cmd/
snapshot.rs

1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use eyre::{Context, Result};
6use foundry_cli::utils::STATIC_FUZZ_SEED;
7use regex::Regex;
8use std::{
9    cmp::Ordering,
10    fs,
11    io::{self, BufRead},
12    path::{Path, PathBuf},
13    str::FromStr,
14    sync::LazyLock,
15};
16use yansi::Paint;
17
18/// A regex that matches a basic snapshot entry like
19/// `Test:testDeposit() (gas: 58804)`
20pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
21    Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
22});
23
24/// CLI arguments for `forge snapshot`.
25#[derive(Clone, Debug, Parser)]
26pub struct GasSnapshotArgs {
27    /// Output a diff against a pre-existing gas snapshot.
28    ///
29    /// By default, the comparison is done with .gas-snapshot.
30    #[arg(
31        conflicts_with = "snap",
32        long,
33        value_hint = ValueHint::FilePath,
34        value_name = "SNAPSHOT_FILE",
35    )]
36    diff: Option<Option<PathBuf>>,
37
38    /// Compare against a pre-existing gas snapshot, exiting with code 1 if they do not match.
39    ///
40    /// Outputs a diff if the gas snapshots do not match.
41    ///
42    /// By default, the comparison is done with .gas-snapshot.
43    #[arg(
44        conflicts_with = "diff",
45        long,
46        value_hint = ValueHint::FilePath,
47        value_name = "SNAPSHOT_FILE",
48    )]
49    check: Option<Option<PathBuf>>,
50
51    // Hidden because there is only one option
52    /// How to format the output.
53    #[arg(long, hide(true))]
54    format: Option<Format>,
55
56    /// Output file for the gas snapshot.
57    #[arg(
58        long,
59        default_value = ".gas-snapshot",
60        value_hint = ValueHint::FilePath,
61        value_name = "FILE",
62    )]
63    snap: PathBuf,
64
65    /// Tolerates gas deviations up to the specified percentage.
66    #[arg(
67        long,
68        value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
69        value_name = "SNAPSHOT_THRESHOLD"
70    )]
71    tolerance: Option<u32>,
72
73    /// All test arguments are supported
74    #[command(flatten)]
75    pub(crate) test: test::TestArgs,
76
77    /// Additional configs for test results
78    #[command(flatten)]
79    config: GasSnapshotConfig,
80}
81
82impl GasSnapshotArgs {
83    /// Returns whether `GasSnapshotArgs` was configured with `--watch`
84    pub fn is_watch(&self) -> bool {
85        self.test.is_watch()
86    }
87
88    /// Returns the [`watchexec::Config`] necessary to bootstrap a new watch loop.
89    pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
90        self.test.watchexec_config()
91    }
92
93    pub async fn run(mut self) -> Result<()> {
94        // Set fuzz seed so gas snapshots are deterministic
95        self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
96
97        let outcome = self.test.execute_tests().await?;
98        outcome.ensure_ok(false)?;
99        let tests = self.config.apply(outcome);
100
101        if let Some(path) = self.diff {
102            let snap = path.as_ref().unwrap_or(&self.snap);
103            let snaps = read_gas_snapshot(snap)?;
104            diff(tests, snaps)?;
105        } else if let Some(path) = self.check {
106            let snap = path.as_ref().unwrap_or(&self.snap);
107            let snaps = read_gas_snapshot(snap)?;
108            if check(tests, snaps, self.tolerance) {
109                std::process::exit(0)
110            } else {
111                std::process::exit(1)
112            }
113        } else {
114            write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
115        }
116        Ok(())
117    }
118}
119
120// TODO implement pretty tables
121#[derive(Clone, Debug)]
122pub enum Format {
123    Table,
124}
125
126impl FromStr for Format {
127    type Err = String;
128
129    fn from_str(s: &str) -> Result<Self, Self::Err> {
130        match s {
131            "t" | "table" => Ok(Self::Table),
132            _ => Err(format!("Unrecognized format `{s}`")),
133        }
134    }
135}
136
137/// Additional filters that can be applied on the test results
138#[derive(Clone, Debug, Default, Parser)]
139struct GasSnapshotConfig {
140    /// Sort results by gas used (ascending).
141    #[arg(long)]
142    asc: bool,
143
144    /// Sort results by gas used (descending).
145    #[arg(conflicts_with = "asc", long)]
146    desc: bool,
147
148    /// Only include tests that used more gas that the given amount.
149    #[arg(long, value_name = "MIN_GAS")]
150    min: Option<u64>,
151
152    /// Only include tests that used less gas that the given amount.
153    #[arg(long, value_name = "MAX_GAS")]
154    max: Option<u64>,
155}
156
157impl GasSnapshotConfig {
158    fn is_in_gas_range(&self, gas_used: u64) -> bool {
159        if let Some(min) = self.min
160            && gas_used < min
161        {
162            return false;
163        }
164        if let Some(max) = self.max
165            && gas_used > max
166        {
167            return false;
168        }
169        true
170    }
171
172    fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
173        let mut tests = outcome
174            .into_tests()
175            .filter(|test| self.is_in_gas_range(test.gas_used()))
176            .collect::<Vec<_>>();
177
178        if self.asc {
179            tests.sort_by_key(|a| a.gas_used());
180        } else if self.desc {
181            tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
182        }
183
184        tests
185    }
186}
187
188/// A general entry in a gas snapshot file
189///
190/// Has the form:
191///   `<signature>(gas:? 40181)` for normal tests
192///   `<signature>(runs: 256, μ: 40181, ~: 40181)` for fuzz tests
193///   `<signature>(runs: 256, calls: 40181, reverts: 40181)` for invariant tests
194#[derive(Clone, Debug, PartialEq, Eq)]
195pub struct GasSnapshotEntry {
196    pub contract_name: String,
197    pub signature: String,
198    pub gas_used: TestKindReport,
199}
200
201impl FromStr for GasSnapshotEntry {
202    type Err = String;
203
204    fn from_str(s: &str) -> Result<Self, Self::Err> {
205        RE_BASIC_SNAPSHOT_ENTRY
206            .captures(s)
207            .and_then(|cap| {
208                cap.name("file").and_then(|file| {
209                    cap.name("sig").and_then(|sig| {
210                        if let Some(gas) = cap.name("gas") {
211                            Some(Self {
212                                contract_name: file.as_str().to_string(),
213                                signature: sig.as_str().to_string(),
214                                gas_used: TestKindReport::Unit {
215                                    gas: gas.as_str().parse().unwrap(),
216                                },
217                            })
218                        } else if let Some(runs) = cap.name("runs") {
219                            cap.name("avg")
220                                .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
221                                .map(|(runs, avg, med)| Self {
222                                    contract_name: file.as_str().to_string(),
223                                    signature: sig.as_str().to_string(),
224                                    gas_used: TestKindReport::Fuzz {
225                                        runs: runs.as_str().parse().unwrap(),
226                                        median_gas: med.as_str().parse().unwrap(),
227                                        mean_gas: avg.as_str().parse().unwrap(),
228                                        failed_corpus_replays: 0,
229                                    },
230                                })
231                        } else {
232                            cap.name("invruns")
233                                .and_then(|runs| {
234                                    cap.name("calls").and_then(|avg| {
235                                        cap.name("reverts").map(|med| (runs, avg, med))
236                                    })
237                                })
238                                .map(|(runs, calls, reverts)| Self {
239                                    contract_name: file.as_str().to_string(),
240                                    signature: sig.as_str().to_string(),
241                                    gas_used: TestKindReport::Invariant {
242                                        runs: runs.as_str().parse().unwrap(),
243                                        calls: calls.as_str().parse().unwrap(),
244                                        reverts: reverts.as_str().parse().unwrap(),
245                                        metrics: HashMap::default(),
246                                        failed_corpus_replays: 0,
247                                    },
248                                })
249                        }
250                    })
251                })
252            })
253            .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
254    }
255}
256
257/// Reads a list of gas snapshot entries from a gas snapshot file.
258fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
259    let path = path.as_ref();
260    let mut entries = Vec::new();
261    for line in io::BufReader::new(
262        fs::File::open(path)
263            .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
264    )
265    .lines()
266    {
267        entries
268            .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
269    }
270    Ok(entries)
271}
272
273/// Writes a series of tests to a gas snapshot file after sorting them.
274fn write_to_gas_snapshot_file(
275    tests: &[SuiteTestResult],
276    path: impl AsRef<Path>,
277    _format: Option<Format>,
278) -> Result<()> {
279    let mut reports = tests
280        .iter()
281        .map(|test| {
282            format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
283        })
284        .collect::<Vec<_>>();
285
286    // sort all reports
287    reports.sort();
288
289    let content = reports.join("\n");
290    Ok(fs::write(path, content)?)
291}
292
293/// A Gas snapshot entry diff.
294#[derive(Clone, Debug, PartialEq, Eq)]
295pub struct GasSnapshotDiff {
296    pub signature: String,
297    pub source_gas_used: TestKindReport,
298    pub target_gas_used: TestKindReport,
299}
300
301impl GasSnapshotDiff {
302    /// Returns the gas diff
303    ///
304    /// `> 0` if the source used more gas
305    /// `< 0` if the target used more gas
306    fn gas_change(&self) -> i128 {
307        self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
308    }
309
310    /// Determines the percentage change
311    fn gas_diff(&self) -> f64 {
312        self.gas_change() as f64 / self.target_gas_used.gas() as f64
313    }
314}
315
316/// Compares the set of tests with an existing gas snapshot.
317///
318/// Returns true all tests match
319fn check(
320    tests: Vec<SuiteTestResult>,
321    snaps: Vec<GasSnapshotEntry>,
322    tolerance: Option<u32>,
323) -> bool {
324    let snaps = snaps
325        .into_iter()
326        .map(|s| ((s.contract_name, s.signature), s.gas_used))
327        .collect::<HashMap<_, _>>();
328    let mut has_diff = false;
329    for test in tests {
330        if let Some(target_gas) =
331            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
332        {
333            let source_gas = test.result.kind.report();
334            if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
335                let _ = sh_println!(
336                    "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
337                    test.contract_name(),
338                    test.signature,
339                    source_gas,
340                    target_gas
341                );
342                has_diff = true;
343            }
344        } else {
345            let _ = sh_println!(
346                "No matching snapshot entry found for \"{}::{}\" in snapshot file",
347                test.contract_name(),
348                test.signature
349            );
350            has_diff = true;
351        }
352    }
353    !has_diff
354}
355
356/// Compare the set of tests with an existing gas snapshot.
357fn diff(tests: Vec<SuiteTestResult>, snaps: Vec<GasSnapshotEntry>) -> Result<()> {
358    let snaps = snaps
359        .into_iter()
360        .map(|s| ((s.contract_name, s.signature), s.gas_used))
361        .collect::<HashMap<_, _>>();
362    let mut diffs = Vec::with_capacity(tests.len());
363    for test in tests.into_iter() {
364        if let Some(target_gas_used) =
365            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
366        {
367            diffs.push(GasSnapshotDiff {
368                source_gas_used: test.result.kind.report(),
369                signature: test.signature,
370                target_gas_used,
371            });
372        }
373    }
374    let mut overall_gas_change = 0i128;
375    let mut overall_gas_used = 0i128;
376
377    diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
378
379    for diff in diffs {
380        let gas_change = diff.gas_change();
381        overall_gas_change += gas_change;
382        overall_gas_used += diff.target_gas_used.gas() as i128;
383        let gas_diff = diff.gas_diff();
384        sh_println!(
385            "{} (gas: {} ({})) ",
386            diff.signature,
387            fmt_change(gas_change),
388            fmt_pct_change(gas_diff)
389        )?;
390    }
391
392    let overall_gas_diff = overall_gas_change as f64 / overall_gas_used as f64;
393    sh_println!(
394        "Overall gas change: {} ({})",
395        fmt_change(overall_gas_change),
396        fmt_pct_change(overall_gas_diff)
397    )?;
398    Ok(())
399}
400
401fn fmt_pct_change(change: f64) -> String {
402    let change_pct = change * 100.0;
403    match change.total_cmp(&0.0) {
404        Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
405        Ordering::Equal => {
406            format!("{change_pct:.3}%")
407        }
408        Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
409    }
410}
411
412fn fmt_change(change: i128) -> String {
413    match change.cmp(&0) {
414        Ordering::Less => format!("{change}").green().to_string(),
415        Ordering::Equal => {
416            format!("{change}")
417        }
418        Ordering::Greater => format!("{change}").red().to_string(),
419    }
420}
421
422/// Returns true of the difference between the gas values exceeds the tolerance
423///
424/// If `tolerance` is `None`, then this returns `true` if both gas values are equal
425fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
426    if let Some(tolerance) = tolerance_pct {
427        let (hi, lo) = if source_gas > target_gas {
428            (source_gas, target_gas)
429        } else {
430            (target_gas, source_gas)
431        };
432        let diff = (1. - (lo as f64 / hi as f64)) * 100.;
433        diff < tolerance as f64
434    } else {
435        source_gas == target_gas
436    }
437}
438
439#[cfg(test)]
440mod tests {
441    use super::*;
442
443    #[test]
444    fn test_tolerance() {
445        assert!(within_tolerance(100, 105, Some(5)));
446        assert!(within_tolerance(105, 100, Some(5)));
447        assert!(!within_tolerance(100, 106, Some(5)));
448        assert!(!within_tolerance(106, 100, Some(5)));
449        assert!(within_tolerance(100, 100, None));
450    }
451
452    #[test]
453    fn can_parse_basic_gas_snapshot_entry() {
454        let s = "Test:deposit() (gas: 7222)";
455        let entry = GasSnapshotEntry::from_str(s).unwrap();
456        assert_eq!(
457            entry,
458            GasSnapshotEntry {
459                contract_name: "Test".to_string(),
460                signature: "deposit()".to_string(),
461                gas_used: TestKindReport::Unit { gas: 7222 }
462            }
463        );
464    }
465
466    #[test]
467    fn can_parse_fuzz_gas_snapshot_entry() {
468        let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
469        let entry = GasSnapshotEntry::from_str(s).unwrap();
470        assert_eq!(
471            entry,
472            GasSnapshotEntry {
473                contract_name: "Test".to_string(),
474                signature: "deposit()".to_string(),
475                gas_used: TestKindReport::Fuzz {
476                    runs: 256,
477                    median_gas: 200,
478                    mean_gas: 100,
479                    failed_corpus_replays: 0
480                }
481            }
482        );
483    }
484
485    #[test]
486    fn can_parse_invariant_gas_snapshot_entry() {
487        let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
488        let entry = GasSnapshotEntry::from_str(s).unwrap();
489        assert_eq!(
490            entry,
491            GasSnapshotEntry {
492                contract_name: "Test".to_string(),
493                signature: "deposit()".to_string(),
494                gas_used: TestKindReport::Invariant {
495                    runs: 256,
496                    calls: 100,
497                    reverts: 200,
498                    metrics: HashMap::default(),
499                    failed_corpus_replays: 0,
500                }
501            }
502        );
503    }
504
505    #[test]
506    fn can_parse_invariant_gas_snapshot_entry2() {
507        let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
508        let entry = GasSnapshotEntry::from_str(s).unwrap();
509        assert_eq!(
510            entry,
511            GasSnapshotEntry {
512                contract_name: "ERC20Invariants".to_string(),
513                signature: "invariantBalanceSum()".to_string(),
514                gas_used: TestKindReport::Invariant {
515                    runs: 256,
516                    calls: 3840,
517                    reverts: 2388,
518                    metrics: HashMap::default(),
519                    failed_corpus_replays: 0,
520                }
521            }
522        );
523    }
524}