forge/cmd/
snapshot.rs

1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{map::HashMap, U256};
4use clap::{builder::RangedU64ValueParser, Parser, ValueHint};
5use eyre::{Context, Result};
6use foundry_cli::utils::STATIC_FUZZ_SEED;
7use regex::Regex;
8use std::{
9    cmp::Ordering,
10    fs,
11    io::{self, BufRead},
12    path::{Path, PathBuf},
13    str::FromStr,
14    sync::LazyLock,
15};
16use yansi::Paint;
17
18/// A regex that matches a basic snapshot entry like
19/// `Test:testDeposit() (gas: 58804)`
20pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
21    Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
22});
23
24/// CLI arguments for `forge snapshot`.
25#[derive(Clone, Debug, Parser)]
26pub struct GasSnapshotArgs {
27    /// Output a diff against a pre-existing gas snapshot.
28    ///
29    /// By default, the comparison is done with .gas-snapshot.
30    #[arg(
31        conflicts_with = "snap",
32        long,
33        value_hint = ValueHint::FilePath,
34        value_name = "SNAPSHOT_FILE",
35    )]
36    diff: Option<Option<PathBuf>>,
37
38    /// Compare against a pre-existing gas snapshot, exiting with code 1 if they do not match.
39    ///
40    /// Outputs a diff if the gas snapshots do not match.
41    ///
42    /// By default, the comparison is done with .gas-snapshot.
43    #[arg(
44        conflicts_with = "diff",
45        long,
46        value_hint = ValueHint::FilePath,
47        value_name = "SNAPSHOT_FILE",
48    )]
49    check: Option<Option<PathBuf>>,
50
51    // Hidden because there is only one option
52    /// How to format the output.
53    #[arg(long, hide(true))]
54    format: Option<Format>,
55
56    /// Output file for the gas snapshot.
57    #[arg(
58        long,
59        default_value = ".gas-snapshot",
60        value_hint = ValueHint::FilePath,
61        value_name = "FILE",
62    )]
63    snap: PathBuf,
64
65    /// Tolerates gas deviations up to the specified percentage.
66    #[arg(
67        long,
68        value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
69        value_name = "SNAPSHOT_THRESHOLD"
70    )]
71    tolerance: Option<u32>,
72
73    /// All test arguments are supported
74    #[command(flatten)]
75    pub(crate) test: test::TestArgs,
76
77    /// Additional configs for test results
78    #[command(flatten)]
79    config: GasSnapshotConfig,
80}
81
82impl GasSnapshotArgs {
83    /// Returns whether `GasSnapshotArgs` was configured with `--watch`
84    pub fn is_watch(&self) -> bool {
85        self.test.is_watch()
86    }
87
88    /// Returns the [`watchexec::Config`] necessary to bootstrap a new watch loop.
89    pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
90        self.test.watchexec_config()
91    }
92
93    pub async fn run(mut self) -> Result<()> {
94        // Set fuzz seed so gas snapshots are deterministic
95        self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
96
97        let outcome = self.test.execute_tests().await?;
98        outcome.ensure_ok(false)?;
99        let tests = self.config.apply(outcome);
100
101        if let Some(path) = self.diff {
102            let snap = path.as_ref().unwrap_or(&self.snap);
103            let snaps = read_gas_snapshot(snap)?;
104            diff(tests, snaps)?;
105        } else if let Some(path) = self.check {
106            let snap = path.as_ref().unwrap_or(&self.snap);
107            let snaps = read_gas_snapshot(snap)?;
108            if check(tests, snaps, self.tolerance) {
109                std::process::exit(0)
110            } else {
111                std::process::exit(1)
112            }
113        } else {
114            write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
115        }
116        Ok(())
117    }
118}
119
120// TODO implement pretty tables
121#[derive(Clone, Debug)]
122pub enum Format {
123    Table,
124}
125
126impl FromStr for Format {
127    type Err = String;
128
129    fn from_str(s: &str) -> Result<Self, Self::Err> {
130        match s {
131            "t" | "table" => Ok(Self::Table),
132            _ => Err(format!("Unrecognized format `{s}`")),
133        }
134    }
135}
136
137/// Additional filters that can be applied on the test results
138#[derive(Clone, Debug, Default, Parser)]
139struct GasSnapshotConfig {
140    /// Sort results by gas used (ascending).
141    #[arg(long)]
142    asc: bool,
143
144    /// Sort results by gas used (descending).
145    #[arg(conflicts_with = "asc", long)]
146    desc: bool,
147
148    /// Only include tests that used more gas that the given amount.
149    #[arg(long, value_name = "MIN_GAS")]
150    min: Option<u64>,
151
152    /// Only include tests that used less gas that the given amount.
153    #[arg(long, value_name = "MAX_GAS")]
154    max: Option<u64>,
155}
156
157impl GasSnapshotConfig {
158    fn is_in_gas_range(&self, gas_used: u64) -> bool {
159        if let Some(min) = self.min {
160            if gas_used < min {
161                return false
162            }
163        }
164        if let Some(max) = self.max {
165            if gas_used > max {
166                return false
167            }
168        }
169        true
170    }
171
172    fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
173        let mut tests = outcome
174            .into_tests()
175            .filter(|test| self.is_in_gas_range(test.gas_used()))
176            .collect::<Vec<_>>();
177
178        if self.asc {
179            tests.sort_by_key(|a| a.gas_used());
180        } else if self.desc {
181            tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
182        }
183
184        tests
185    }
186}
187
188/// A general entry in a gas snapshot file
189///
190/// Has the form:
191///   `<signature>(gas:? 40181)` for normal tests
192///   `<signature>(runs: 256, μ: 40181, ~: 40181)` for fuzz tests
193///   `<signature>(runs: 256, calls: 40181, reverts: 40181)` for invariant tests
194#[derive(Clone, Debug, PartialEq, Eq)]
195pub struct GasSnapshotEntry {
196    pub contract_name: String,
197    pub signature: String,
198    pub gas_used: TestKindReport,
199}
200
201impl FromStr for GasSnapshotEntry {
202    type Err = String;
203
204    fn from_str(s: &str) -> Result<Self, Self::Err> {
205        RE_BASIC_SNAPSHOT_ENTRY
206            .captures(s)
207            .and_then(|cap| {
208                cap.name("file").and_then(|file| {
209                    cap.name("sig").and_then(|sig| {
210                        if let Some(gas) = cap.name("gas") {
211                            Some(Self {
212                                contract_name: file.as_str().to_string(),
213                                signature: sig.as_str().to_string(),
214                                gas_used: TestKindReport::Unit {
215                                    gas: gas.as_str().parse().unwrap(),
216                                },
217                            })
218                        } else if let Some(runs) = cap.name("runs") {
219                            cap.name("avg")
220                                .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
221                                .map(|(runs, avg, med)| Self {
222                                    contract_name: file.as_str().to_string(),
223                                    signature: sig.as_str().to_string(),
224                                    gas_used: TestKindReport::Fuzz {
225                                        runs: runs.as_str().parse().unwrap(),
226                                        median_gas: med.as_str().parse().unwrap(),
227                                        mean_gas: avg.as_str().parse().unwrap(),
228                                    },
229                                })
230                        } else {
231                            cap.name("invruns")
232                                .and_then(|runs| {
233                                    cap.name("calls").and_then(|avg| {
234                                        cap.name("reverts").map(|med| (runs, avg, med))
235                                    })
236                                })
237                                .map(|(runs, calls, reverts)| Self {
238                                    contract_name: file.as_str().to_string(),
239                                    signature: sig.as_str().to_string(),
240                                    gas_used: TestKindReport::Invariant {
241                                        runs: runs.as_str().parse().unwrap(),
242                                        calls: calls.as_str().parse().unwrap(),
243                                        reverts: reverts.as_str().parse().unwrap(),
244                                        metrics: HashMap::default(),
245                                    },
246                                })
247                        }
248                    })
249                })
250            })
251            .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
252    }
253}
254
255/// Reads a list of gas snapshot entries from a gas snapshot file.
256fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
257    let path = path.as_ref();
258    let mut entries = Vec::new();
259    for line in io::BufReader::new(
260        fs::File::open(path)
261            .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
262    )
263    .lines()
264    {
265        entries
266            .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
267    }
268    Ok(entries)
269}
270
271/// Writes a series of tests to a gas snapshot file after sorting them.
272fn write_to_gas_snapshot_file(
273    tests: &[SuiteTestResult],
274    path: impl AsRef<Path>,
275    _format: Option<Format>,
276) -> Result<()> {
277    let mut reports = tests
278        .iter()
279        .map(|test| {
280            format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
281        })
282        .collect::<Vec<_>>();
283
284    // sort all reports
285    reports.sort();
286
287    let content = reports.join("\n");
288    Ok(fs::write(path, content)?)
289}
290
291/// A Gas snapshot entry diff.
292#[derive(Clone, Debug, PartialEq, Eq)]
293pub struct GasSnapshotDiff {
294    pub signature: String,
295    pub source_gas_used: TestKindReport,
296    pub target_gas_used: TestKindReport,
297}
298
299impl GasSnapshotDiff {
300    /// Returns the gas diff
301    ///
302    /// `> 0` if the source used more gas
303    /// `< 0` if the target used more gas
304    fn gas_change(&self) -> i128 {
305        self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
306    }
307
308    /// Determines the percentage change
309    fn gas_diff(&self) -> f64 {
310        self.gas_change() as f64 / self.target_gas_used.gas() as f64
311    }
312}
313
314/// Compares the set of tests with an existing gas snapshot.
315///
316/// Returns true all tests match
317fn check(
318    tests: Vec<SuiteTestResult>,
319    snaps: Vec<GasSnapshotEntry>,
320    tolerance: Option<u32>,
321) -> bool {
322    let snaps = snaps
323        .into_iter()
324        .map(|s| ((s.contract_name, s.signature), s.gas_used))
325        .collect::<HashMap<_, _>>();
326    let mut has_diff = false;
327    for test in tests {
328        if let Some(target_gas) =
329            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
330        {
331            let source_gas = test.result.kind.report();
332            if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
333                let _ = sh_println!(
334                    "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
335                    test.contract_name(),
336                    test.signature,
337                    source_gas,
338                    target_gas
339                );
340                has_diff = true;
341            }
342        } else {
343            let _ = sh_println!(
344                "No matching snapshot entry found for \"{}::{}\" in snapshot file",
345                test.contract_name(),
346                test.signature
347            );
348            has_diff = true;
349        }
350    }
351    !has_diff
352}
353
354/// Compare the set of tests with an existing gas snapshot.
355fn diff(tests: Vec<SuiteTestResult>, snaps: Vec<GasSnapshotEntry>) -> Result<()> {
356    let snaps = snaps
357        .into_iter()
358        .map(|s| ((s.contract_name, s.signature), s.gas_used))
359        .collect::<HashMap<_, _>>();
360    let mut diffs = Vec::with_capacity(tests.len());
361    for test in tests.into_iter() {
362        if let Some(target_gas_used) =
363            snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
364        {
365            diffs.push(GasSnapshotDiff {
366                source_gas_used: test.result.kind.report(),
367                signature: test.signature,
368                target_gas_used,
369            });
370        }
371    }
372    let mut overall_gas_change = 0i128;
373    let mut overall_gas_used = 0i128;
374
375    diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
376
377    for diff in diffs {
378        let gas_change = diff.gas_change();
379        overall_gas_change += gas_change;
380        overall_gas_used += diff.target_gas_used.gas() as i128;
381        let gas_diff = diff.gas_diff();
382        sh_println!(
383            "{} (gas: {} ({})) ",
384            diff.signature,
385            fmt_change(gas_change),
386            fmt_pct_change(gas_diff)
387        )?;
388    }
389
390    let overall_gas_diff = overall_gas_change as f64 / overall_gas_used as f64;
391    sh_println!(
392        "Overall gas change: {} ({})",
393        fmt_change(overall_gas_change),
394        fmt_pct_change(overall_gas_diff)
395    )?;
396    Ok(())
397}
398
399fn fmt_pct_change(change: f64) -> String {
400    let change_pct = change * 100.0;
401    match change.total_cmp(&0.0) {
402        Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
403        Ordering::Equal => {
404            format!("{change_pct:.3}%")
405        }
406        Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
407    }
408}
409
410fn fmt_change(change: i128) -> String {
411    match change.cmp(&0) {
412        Ordering::Less => format!("{change}").green().to_string(),
413        Ordering::Equal => {
414            format!("{change}")
415        }
416        Ordering::Greater => format!("{change}").red().to_string(),
417    }
418}
419
420/// Returns true of the difference between the gas values exceeds the tolerance
421///
422/// If `tolerance` is `None`, then this returns `true` if both gas values are equal
423fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
424    if let Some(tolerance) = tolerance_pct {
425        let (hi, lo) = if source_gas > target_gas {
426            (source_gas, target_gas)
427        } else {
428            (target_gas, source_gas)
429        };
430        let diff = (1. - (lo as f64 / hi as f64)) * 100.;
431        diff < tolerance as f64
432    } else {
433        source_gas == target_gas
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440
441    #[test]
442    fn test_tolerance() {
443        assert!(within_tolerance(100, 105, Some(5)));
444        assert!(within_tolerance(105, 100, Some(5)));
445        assert!(!within_tolerance(100, 106, Some(5)));
446        assert!(!within_tolerance(106, 100, Some(5)));
447        assert!(within_tolerance(100, 100, None));
448    }
449
450    #[test]
451    fn can_parse_basic_gas_snapshot_entry() {
452        let s = "Test:deposit() (gas: 7222)";
453        let entry = GasSnapshotEntry::from_str(s).unwrap();
454        assert_eq!(
455            entry,
456            GasSnapshotEntry {
457                contract_name: "Test".to_string(),
458                signature: "deposit()".to_string(),
459                gas_used: TestKindReport::Unit { gas: 7222 }
460            }
461        );
462    }
463
464    #[test]
465    fn can_parse_fuzz_gas_snapshot_entry() {
466        let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
467        let entry = GasSnapshotEntry::from_str(s).unwrap();
468        assert_eq!(
469            entry,
470            GasSnapshotEntry {
471                contract_name: "Test".to_string(),
472                signature: "deposit()".to_string(),
473                gas_used: TestKindReport::Fuzz { runs: 256, median_gas: 200, mean_gas: 100 }
474            }
475        );
476    }
477
478    #[test]
479    fn can_parse_invariant_gas_snapshot_entry() {
480        let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
481        let entry = GasSnapshotEntry::from_str(s).unwrap();
482        assert_eq!(
483            entry,
484            GasSnapshotEntry {
485                contract_name: "Test".to_string(),
486                signature: "deposit()".to_string(),
487                gas_used: TestKindReport::Invariant {
488                    runs: 256,
489                    calls: 100,
490                    reverts: 200,
491                    metrics: HashMap::default()
492                }
493            }
494        );
495    }
496
497    #[test]
498    fn can_parse_invariant_gas_snapshot_entry2() {
499        let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
500        let entry = GasSnapshotEntry::from_str(s).unwrap();
501        assert_eq!(
502            entry,
503            GasSnapshotEntry {
504                contract_name: "ERC20Invariants".to_string(),
505                signature: "invariantBalanceSum()".to_string(),
506                gas_used: TestKindReport::Invariant {
507                    runs: 256,
508                    calls: 3840,
509                    reverts: 2388,
510                    metrics: HashMap::default()
511                }
512            }
513        );
514    }
515}