1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{map::HashMap, U256};
4use clap::{builder::RangedU64ValueParser, Parser, ValueHint};
5use eyre::{Context, Result};
6use foundry_cli::utils::STATIC_FUZZ_SEED;
7use regex::Regex;
8use std::{
9 cmp::Ordering,
10 fs,
11 io::{self, BufRead},
12 path::{Path, PathBuf},
13 str::FromStr,
14 sync::LazyLock,
15};
16use yansi::Paint;
17
18pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
21 Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
22});
23
24#[derive(Clone, Debug, Parser)]
26pub struct GasSnapshotArgs {
27 #[arg(
31 conflicts_with = "snap",
32 long,
33 value_hint = ValueHint::FilePath,
34 value_name = "SNAPSHOT_FILE",
35 )]
36 diff: Option<Option<PathBuf>>,
37
38 #[arg(
44 conflicts_with = "diff",
45 long,
46 value_hint = ValueHint::FilePath,
47 value_name = "SNAPSHOT_FILE",
48 )]
49 check: Option<Option<PathBuf>>,
50
51 #[arg(long, hide(true))]
54 format: Option<Format>,
55
56 #[arg(
58 long,
59 default_value = ".gas-snapshot",
60 value_hint = ValueHint::FilePath,
61 value_name = "FILE",
62 )]
63 snap: PathBuf,
64
65 #[arg(
67 long,
68 value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
69 value_name = "SNAPSHOT_THRESHOLD"
70 )]
71 tolerance: Option<u32>,
72
73 #[command(flatten)]
75 pub(crate) test: test::TestArgs,
76
77 #[command(flatten)]
79 config: GasSnapshotConfig,
80}
81
82impl GasSnapshotArgs {
83 pub fn is_watch(&self) -> bool {
85 self.test.is_watch()
86 }
87
88 pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
90 self.test.watchexec_config()
91 }
92
93 pub async fn run(mut self) -> Result<()> {
94 self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
96
97 let outcome = self.test.execute_tests().await?;
98 outcome.ensure_ok(false)?;
99 let tests = self.config.apply(outcome);
100
101 if let Some(path) = self.diff {
102 let snap = path.as_ref().unwrap_or(&self.snap);
103 let snaps = read_gas_snapshot(snap)?;
104 diff(tests, snaps)?;
105 } else if let Some(path) = self.check {
106 let snap = path.as_ref().unwrap_or(&self.snap);
107 let snaps = read_gas_snapshot(snap)?;
108 if check(tests, snaps, self.tolerance) {
109 std::process::exit(0)
110 } else {
111 std::process::exit(1)
112 }
113 } else {
114 write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
115 }
116 Ok(())
117 }
118}
119
120#[derive(Clone, Debug)]
122pub enum Format {
123 Table,
124}
125
126impl FromStr for Format {
127 type Err = String;
128
129 fn from_str(s: &str) -> Result<Self, Self::Err> {
130 match s {
131 "t" | "table" => Ok(Self::Table),
132 _ => Err(format!("Unrecognized format `{s}`")),
133 }
134 }
135}
136
137#[derive(Clone, Debug, Default, Parser)]
139struct GasSnapshotConfig {
140 #[arg(long)]
142 asc: bool,
143
144 #[arg(conflicts_with = "asc", long)]
146 desc: bool,
147
148 #[arg(long, value_name = "MIN_GAS")]
150 min: Option<u64>,
151
152 #[arg(long, value_name = "MAX_GAS")]
154 max: Option<u64>,
155}
156
157impl GasSnapshotConfig {
158 fn is_in_gas_range(&self, gas_used: u64) -> bool {
159 if let Some(min) = self.min {
160 if gas_used < min {
161 return false
162 }
163 }
164 if let Some(max) = self.max {
165 if gas_used > max {
166 return false
167 }
168 }
169 true
170 }
171
172 fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
173 let mut tests = outcome
174 .into_tests()
175 .filter(|test| self.is_in_gas_range(test.gas_used()))
176 .collect::<Vec<_>>();
177
178 if self.asc {
179 tests.sort_by_key(|a| a.gas_used());
180 } else if self.desc {
181 tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
182 }
183
184 tests
185 }
186}
187
188#[derive(Clone, Debug, PartialEq, Eq)]
195pub struct GasSnapshotEntry {
196 pub contract_name: String,
197 pub signature: String,
198 pub gas_used: TestKindReport,
199}
200
201impl FromStr for GasSnapshotEntry {
202 type Err = String;
203
204 fn from_str(s: &str) -> Result<Self, Self::Err> {
205 RE_BASIC_SNAPSHOT_ENTRY
206 .captures(s)
207 .and_then(|cap| {
208 cap.name("file").and_then(|file| {
209 cap.name("sig").and_then(|sig| {
210 if let Some(gas) = cap.name("gas") {
211 Some(Self {
212 contract_name: file.as_str().to_string(),
213 signature: sig.as_str().to_string(),
214 gas_used: TestKindReport::Unit {
215 gas: gas.as_str().parse().unwrap(),
216 },
217 })
218 } else if let Some(runs) = cap.name("runs") {
219 cap.name("avg")
220 .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
221 .map(|(runs, avg, med)| Self {
222 contract_name: file.as_str().to_string(),
223 signature: sig.as_str().to_string(),
224 gas_used: TestKindReport::Fuzz {
225 runs: runs.as_str().parse().unwrap(),
226 median_gas: med.as_str().parse().unwrap(),
227 mean_gas: avg.as_str().parse().unwrap(),
228 },
229 })
230 } else {
231 cap.name("invruns")
232 .and_then(|runs| {
233 cap.name("calls").and_then(|avg| {
234 cap.name("reverts").map(|med| (runs, avg, med))
235 })
236 })
237 .map(|(runs, calls, reverts)| Self {
238 contract_name: file.as_str().to_string(),
239 signature: sig.as_str().to_string(),
240 gas_used: TestKindReport::Invariant {
241 runs: runs.as_str().parse().unwrap(),
242 calls: calls.as_str().parse().unwrap(),
243 reverts: reverts.as_str().parse().unwrap(),
244 metrics: HashMap::default(),
245 },
246 })
247 }
248 })
249 })
250 })
251 .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
252 }
253}
254
255fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
257 let path = path.as_ref();
258 let mut entries = Vec::new();
259 for line in io::BufReader::new(
260 fs::File::open(path)
261 .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
262 )
263 .lines()
264 {
265 entries
266 .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
267 }
268 Ok(entries)
269}
270
271fn write_to_gas_snapshot_file(
273 tests: &[SuiteTestResult],
274 path: impl AsRef<Path>,
275 _format: Option<Format>,
276) -> Result<()> {
277 let mut reports = tests
278 .iter()
279 .map(|test| {
280 format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
281 })
282 .collect::<Vec<_>>();
283
284 reports.sort();
286
287 let content = reports.join("\n");
288 Ok(fs::write(path, content)?)
289}
290
291#[derive(Clone, Debug, PartialEq, Eq)]
293pub struct GasSnapshotDiff {
294 pub signature: String,
295 pub source_gas_used: TestKindReport,
296 pub target_gas_used: TestKindReport,
297}
298
299impl GasSnapshotDiff {
300 fn gas_change(&self) -> i128 {
305 self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
306 }
307
308 fn gas_diff(&self) -> f64 {
310 self.gas_change() as f64 / self.target_gas_used.gas() as f64
311 }
312}
313
314fn check(
318 tests: Vec<SuiteTestResult>,
319 snaps: Vec<GasSnapshotEntry>,
320 tolerance: Option<u32>,
321) -> bool {
322 let snaps = snaps
323 .into_iter()
324 .map(|s| ((s.contract_name, s.signature), s.gas_used))
325 .collect::<HashMap<_, _>>();
326 let mut has_diff = false;
327 for test in tests {
328 if let Some(target_gas) =
329 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
330 {
331 let source_gas = test.result.kind.report();
332 if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
333 let _ = sh_println!(
334 "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
335 test.contract_name(),
336 test.signature,
337 source_gas,
338 target_gas
339 );
340 has_diff = true;
341 }
342 } else {
343 let _ = sh_println!(
344 "No matching snapshot entry found for \"{}::{}\" in snapshot file",
345 test.contract_name(),
346 test.signature
347 );
348 has_diff = true;
349 }
350 }
351 !has_diff
352}
353
354fn diff(tests: Vec<SuiteTestResult>, snaps: Vec<GasSnapshotEntry>) -> Result<()> {
356 let snaps = snaps
357 .into_iter()
358 .map(|s| ((s.contract_name, s.signature), s.gas_used))
359 .collect::<HashMap<_, _>>();
360 let mut diffs = Vec::with_capacity(tests.len());
361 for test in tests.into_iter() {
362 if let Some(target_gas_used) =
363 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
364 {
365 diffs.push(GasSnapshotDiff {
366 source_gas_used: test.result.kind.report(),
367 signature: test.signature,
368 target_gas_used,
369 });
370 }
371 }
372 let mut overall_gas_change = 0i128;
373 let mut overall_gas_used = 0i128;
374
375 diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
376
377 for diff in diffs {
378 let gas_change = diff.gas_change();
379 overall_gas_change += gas_change;
380 overall_gas_used += diff.target_gas_used.gas() as i128;
381 let gas_diff = diff.gas_diff();
382 sh_println!(
383 "{} (gas: {} ({})) ",
384 diff.signature,
385 fmt_change(gas_change),
386 fmt_pct_change(gas_diff)
387 )?;
388 }
389
390 let overall_gas_diff = overall_gas_change as f64 / overall_gas_used as f64;
391 sh_println!(
392 "Overall gas change: {} ({})",
393 fmt_change(overall_gas_change),
394 fmt_pct_change(overall_gas_diff)
395 )?;
396 Ok(())
397}
398
399fn fmt_pct_change(change: f64) -> String {
400 let change_pct = change * 100.0;
401 match change.total_cmp(&0.0) {
402 Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
403 Ordering::Equal => {
404 format!("{change_pct:.3}%")
405 }
406 Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
407 }
408}
409
410fn fmt_change(change: i128) -> String {
411 match change.cmp(&0) {
412 Ordering::Less => format!("{change}").green().to_string(),
413 Ordering::Equal => {
414 format!("{change}")
415 }
416 Ordering::Greater => format!("{change}").red().to_string(),
417 }
418}
419
420fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
424 if let Some(tolerance) = tolerance_pct {
425 let (hi, lo) = if source_gas > target_gas {
426 (source_gas, target_gas)
427 } else {
428 (target_gas, source_gas)
429 };
430 let diff = (1. - (lo as f64 / hi as f64)) * 100.;
431 diff < tolerance as f64
432 } else {
433 source_gas == target_gas
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440
441 #[test]
442 fn test_tolerance() {
443 assert!(within_tolerance(100, 105, Some(5)));
444 assert!(within_tolerance(105, 100, Some(5)));
445 assert!(!within_tolerance(100, 106, Some(5)));
446 assert!(!within_tolerance(106, 100, Some(5)));
447 assert!(within_tolerance(100, 100, None));
448 }
449
450 #[test]
451 fn can_parse_basic_gas_snapshot_entry() {
452 let s = "Test:deposit() (gas: 7222)";
453 let entry = GasSnapshotEntry::from_str(s).unwrap();
454 assert_eq!(
455 entry,
456 GasSnapshotEntry {
457 contract_name: "Test".to_string(),
458 signature: "deposit()".to_string(),
459 gas_used: TestKindReport::Unit { gas: 7222 }
460 }
461 );
462 }
463
464 #[test]
465 fn can_parse_fuzz_gas_snapshot_entry() {
466 let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
467 let entry = GasSnapshotEntry::from_str(s).unwrap();
468 assert_eq!(
469 entry,
470 GasSnapshotEntry {
471 contract_name: "Test".to_string(),
472 signature: "deposit()".to_string(),
473 gas_used: TestKindReport::Fuzz { runs: 256, median_gas: 200, mean_gas: 100 }
474 }
475 );
476 }
477
478 #[test]
479 fn can_parse_invariant_gas_snapshot_entry() {
480 let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
481 let entry = GasSnapshotEntry::from_str(s).unwrap();
482 assert_eq!(
483 entry,
484 GasSnapshotEntry {
485 contract_name: "Test".to_string(),
486 signature: "deposit()".to_string(),
487 gas_used: TestKindReport::Invariant {
488 runs: 256,
489 calls: 100,
490 reverts: 200,
491 metrics: HashMap::default()
492 }
493 }
494 );
495 }
496
497 #[test]
498 fn can_parse_invariant_gas_snapshot_entry2() {
499 let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
500 let entry = GasSnapshotEntry::from_str(s).unwrap();
501 assert_eq!(
502 entry,
503 GasSnapshotEntry {
504 contract_name: "ERC20Invariants".to_string(),
505 signature: "invariantBalanceSum()".to_string(),
506 gas_used: TestKindReport::Invariant {
507 runs: 256,
508 calls: 3840,
509 reverts: 2388,
510 metrics: HashMap::default()
511 }
512 }
513 );
514 }
515}