1use super::test;
2use crate::result::{SuiteTestResult, TestKindReport, TestOutcome};
3use alloy_primitives::{U256, map::HashMap};
4use clap::{Parser, ValueHint, builder::RangedU64ValueParser};
5use eyre::{Context, Result};
6use foundry_cli::utils::STATIC_FUZZ_SEED;
7use regex::Regex;
8use std::{
9 cmp::Ordering,
10 fs,
11 io::{self, BufRead},
12 path::{Path, PathBuf},
13 str::FromStr,
14 sync::LazyLock,
15};
16use yansi::Paint;
17
18pub static RE_BASIC_SNAPSHOT_ENTRY: LazyLock<Regex> = LazyLock::new(|| {
21 Regex::new(r"(?P<file>(.*?)):(?P<sig>(\w+)\s*\((.*?)\))\s*\(((gas:)?\s*(?P<gas>\d+)|(runs:\s*(?P<runs>\d+),\s*μ:\s*(?P<avg>\d+),\s*~:\s*(?P<med>\d+))|(runs:\s*(?P<invruns>\d+),\s*calls:\s*(?P<calls>\d+),\s*reverts:\s*(?P<reverts>\d+)))\)").unwrap()
22});
23
24#[derive(Clone, Debug, Parser)]
26pub struct GasSnapshotArgs {
27 #[arg(
31 conflicts_with = "snap",
32 long,
33 value_hint = ValueHint::FilePath,
34 value_name = "SNAPSHOT_FILE",
35 )]
36 diff: Option<Option<PathBuf>>,
37
38 #[arg(
44 conflicts_with = "diff",
45 long,
46 value_hint = ValueHint::FilePath,
47 value_name = "SNAPSHOT_FILE",
48 )]
49 check: Option<Option<PathBuf>>,
50
51 #[arg(long, hide(true))]
54 format: Option<Format>,
55
56 #[arg(
58 long,
59 default_value = ".gas-snapshot",
60 value_hint = ValueHint::FilePath,
61 value_name = "FILE",
62 )]
63 snap: PathBuf,
64
65 #[arg(
67 long,
68 value_parser = RangedU64ValueParser::<u32>::new().range(0..100),
69 value_name = "SNAPSHOT_THRESHOLD"
70 )]
71 tolerance: Option<u32>,
72
73 #[command(flatten)]
75 pub(crate) test: test::TestArgs,
76
77 #[command(flatten)]
79 config: GasSnapshotConfig,
80}
81
82impl GasSnapshotArgs {
83 pub fn is_watch(&self) -> bool {
85 self.test.is_watch()
86 }
87
88 pub(crate) fn watchexec_config(&self) -> Result<watchexec::Config> {
90 self.test.watchexec_config()
91 }
92
93 pub async fn run(mut self) -> Result<()> {
94 self.test.fuzz_seed = Some(U256::from_be_bytes(STATIC_FUZZ_SEED));
96
97 let outcome = self.test.execute_tests().await?;
98 outcome.ensure_ok(false)?;
99 let tests = self.config.apply(outcome);
100
101 if let Some(path) = self.diff {
102 let snap = path.as_ref().unwrap_or(&self.snap);
103 let snaps = read_gas_snapshot(snap)?;
104 diff(tests, snaps)?;
105 } else if let Some(path) = self.check {
106 let snap = path.as_ref().unwrap_or(&self.snap);
107 let snaps = read_gas_snapshot(snap)?;
108 if check(tests, snaps, self.tolerance) {
109 std::process::exit(0)
110 } else {
111 std::process::exit(1)
112 }
113 } else {
114 write_to_gas_snapshot_file(&tests, self.snap, self.format)?;
115 }
116 Ok(())
117 }
118}
119
120#[derive(Clone, Debug)]
122pub enum Format {
123 Table,
124}
125
126impl FromStr for Format {
127 type Err = String;
128
129 fn from_str(s: &str) -> Result<Self, Self::Err> {
130 match s {
131 "t" | "table" => Ok(Self::Table),
132 _ => Err(format!("Unrecognized format `{s}`")),
133 }
134 }
135}
136
137#[derive(Clone, Debug, Default, Parser)]
139struct GasSnapshotConfig {
140 #[arg(long)]
142 asc: bool,
143
144 #[arg(conflicts_with = "asc", long)]
146 desc: bool,
147
148 #[arg(long, value_name = "MIN_GAS")]
150 min: Option<u64>,
151
152 #[arg(long, value_name = "MAX_GAS")]
154 max: Option<u64>,
155}
156
157impl GasSnapshotConfig {
158 fn is_in_gas_range(&self, gas_used: u64) -> bool {
159 if let Some(min) = self.min
160 && gas_used < min
161 {
162 return false;
163 }
164 if let Some(max) = self.max
165 && gas_used > max
166 {
167 return false;
168 }
169 true
170 }
171
172 fn apply(&self, outcome: TestOutcome) -> Vec<SuiteTestResult> {
173 let mut tests = outcome
174 .into_tests()
175 .filter(|test| self.is_in_gas_range(test.gas_used()))
176 .collect::<Vec<_>>();
177
178 if self.asc {
179 tests.sort_by_key(|a| a.gas_used());
180 } else if self.desc {
181 tests.sort_by_key(|b| std::cmp::Reverse(b.gas_used()))
182 }
183
184 tests
185 }
186}
187
188#[derive(Clone, Debug, PartialEq, Eq)]
195pub struct GasSnapshotEntry {
196 pub contract_name: String,
197 pub signature: String,
198 pub gas_used: TestKindReport,
199}
200
201impl FromStr for GasSnapshotEntry {
202 type Err = String;
203
204 fn from_str(s: &str) -> Result<Self, Self::Err> {
205 RE_BASIC_SNAPSHOT_ENTRY
206 .captures(s)
207 .and_then(|cap| {
208 cap.name("file").and_then(|file| {
209 cap.name("sig").and_then(|sig| {
210 if let Some(gas) = cap.name("gas") {
211 Some(Self {
212 contract_name: file.as_str().to_string(),
213 signature: sig.as_str().to_string(),
214 gas_used: TestKindReport::Unit {
215 gas: gas.as_str().parse().unwrap(),
216 },
217 })
218 } else if let Some(runs) = cap.name("runs") {
219 cap.name("avg")
220 .and_then(|avg| cap.name("med").map(|med| (runs, avg, med)))
221 .map(|(runs, avg, med)| Self {
222 contract_name: file.as_str().to_string(),
223 signature: sig.as_str().to_string(),
224 gas_used: TestKindReport::Fuzz {
225 runs: runs.as_str().parse().unwrap(),
226 median_gas: med.as_str().parse().unwrap(),
227 mean_gas: avg.as_str().parse().unwrap(),
228 failed_corpus_replays: 0,
229 },
230 })
231 } else {
232 cap.name("invruns")
233 .and_then(|runs| {
234 cap.name("calls").and_then(|avg| {
235 cap.name("reverts").map(|med| (runs, avg, med))
236 })
237 })
238 .map(|(runs, calls, reverts)| Self {
239 contract_name: file.as_str().to_string(),
240 signature: sig.as_str().to_string(),
241 gas_used: TestKindReport::Invariant {
242 runs: runs.as_str().parse().unwrap(),
243 calls: calls.as_str().parse().unwrap(),
244 reverts: reverts.as_str().parse().unwrap(),
245 metrics: HashMap::default(),
246 failed_corpus_replays: 0,
247 },
248 })
249 }
250 })
251 })
252 })
253 .ok_or_else(|| format!("Could not extract Snapshot Entry for {s}"))
254 }
255}
256
257fn read_gas_snapshot(path: impl AsRef<Path>) -> Result<Vec<GasSnapshotEntry>> {
259 let path = path.as_ref();
260 let mut entries = Vec::new();
261 for line in io::BufReader::new(
262 fs::File::open(path)
263 .wrap_err(format!("failed to read snapshot file \"{}\"", path.display()))?,
264 )
265 .lines()
266 {
267 entries
268 .push(GasSnapshotEntry::from_str(line?.as_str()).map_err(|err| eyre::eyre!("{err}"))?);
269 }
270 Ok(entries)
271}
272
273fn write_to_gas_snapshot_file(
275 tests: &[SuiteTestResult],
276 path: impl AsRef<Path>,
277 _format: Option<Format>,
278) -> Result<()> {
279 let mut reports = tests
280 .iter()
281 .map(|test| {
282 format!("{}:{} {}", test.contract_name(), test.signature, test.result.kind.report())
283 })
284 .collect::<Vec<_>>();
285
286 reports.sort();
288
289 let content = reports.join("\n");
290 Ok(fs::write(path, content)?)
291}
292
293#[derive(Clone, Debug, PartialEq, Eq)]
295pub struct GasSnapshotDiff {
296 pub signature: String,
297 pub source_gas_used: TestKindReport,
298 pub target_gas_used: TestKindReport,
299}
300
301impl GasSnapshotDiff {
302 fn gas_change(&self) -> i128 {
307 self.source_gas_used.gas() as i128 - self.target_gas_used.gas() as i128
308 }
309
310 fn gas_diff(&self) -> f64 {
312 self.gas_change() as f64 / self.target_gas_used.gas() as f64
313 }
314}
315
316fn check(
320 tests: Vec<SuiteTestResult>,
321 snaps: Vec<GasSnapshotEntry>,
322 tolerance: Option<u32>,
323) -> bool {
324 let snaps = snaps
325 .into_iter()
326 .map(|s| ((s.contract_name, s.signature), s.gas_used))
327 .collect::<HashMap<_, _>>();
328 let mut has_diff = false;
329 for test in tests {
330 if let Some(target_gas) =
331 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
332 {
333 let source_gas = test.result.kind.report();
334 if !within_tolerance(source_gas.gas(), target_gas.gas(), tolerance) {
335 let _ = sh_println!(
336 "Diff in \"{}::{}\": consumed \"{}\" gas, expected \"{}\" gas ",
337 test.contract_name(),
338 test.signature,
339 source_gas,
340 target_gas
341 );
342 has_diff = true;
343 }
344 } else {
345 let _ = sh_println!(
346 "No matching snapshot entry found for \"{}::{}\" in snapshot file",
347 test.contract_name(),
348 test.signature
349 );
350 has_diff = true;
351 }
352 }
353 !has_diff
354}
355
356fn diff(tests: Vec<SuiteTestResult>, snaps: Vec<GasSnapshotEntry>) -> Result<()> {
358 let snaps = snaps
359 .into_iter()
360 .map(|s| ((s.contract_name, s.signature), s.gas_used))
361 .collect::<HashMap<_, _>>();
362 let mut diffs = Vec::with_capacity(tests.len());
363 for test in tests.into_iter() {
364 if let Some(target_gas_used) =
365 snaps.get(&(test.contract_name().to_string(), test.signature.clone())).cloned()
366 {
367 diffs.push(GasSnapshotDiff {
368 source_gas_used: test.result.kind.report(),
369 signature: test.signature,
370 target_gas_used,
371 });
372 }
373 }
374 let mut overall_gas_change = 0i128;
375 let mut overall_gas_used = 0i128;
376
377 diffs.sort_by(|a, b| a.gas_diff().abs().total_cmp(&b.gas_diff().abs()));
378
379 for diff in diffs {
380 let gas_change = diff.gas_change();
381 overall_gas_change += gas_change;
382 overall_gas_used += diff.target_gas_used.gas() as i128;
383 let gas_diff = diff.gas_diff();
384 sh_println!(
385 "{} (gas: {} ({})) ",
386 diff.signature,
387 fmt_change(gas_change),
388 fmt_pct_change(gas_diff)
389 )?;
390 }
391
392 let overall_gas_diff = overall_gas_change as f64 / overall_gas_used as f64;
393 sh_println!(
394 "Overall gas change: {} ({})",
395 fmt_change(overall_gas_change),
396 fmt_pct_change(overall_gas_diff)
397 )?;
398 Ok(())
399}
400
401fn fmt_pct_change(change: f64) -> String {
402 let change_pct = change * 100.0;
403 match change.total_cmp(&0.0) {
404 Ordering::Less => format!("{change_pct:.3}%").green().to_string(),
405 Ordering::Equal => {
406 format!("{change_pct:.3}%")
407 }
408 Ordering::Greater => format!("{change_pct:.3}%").red().to_string(),
409 }
410}
411
412fn fmt_change(change: i128) -> String {
413 match change.cmp(&0) {
414 Ordering::Less => format!("{change}").green().to_string(),
415 Ordering::Equal => {
416 format!("{change}")
417 }
418 Ordering::Greater => format!("{change}").red().to_string(),
419 }
420}
421
422fn within_tolerance(source_gas: u64, target_gas: u64, tolerance_pct: Option<u32>) -> bool {
426 if let Some(tolerance) = tolerance_pct {
427 let (hi, lo) = if source_gas > target_gas {
428 (source_gas, target_gas)
429 } else {
430 (target_gas, source_gas)
431 };
432 let diff = (1. - (lo as f64 / hi as f64)) * 100.;
433 diff < tolerance as f64
434 } else {
435 source_gas == target_gas
436 }
437}
438
439#[cfg(test)]
440mod tests {
441 use super::*;
442
443 #[test]
444 fn test_tolerance() {
445 assert!(within_tolerance(100, 105, Some(5)));
446 assert!(within_tolerance(105, 100, Some(5)));
447 assert!(!within_tolerance(100, 106, Some(5)));
448 assert!(!within_tolerance(106, 100, Some(5)));
449 assert!(within_tolerance(100, 100, None));
450 }
451
452 #[test]
453 fn can_parse_basic_gas_snapshot_entry() {
454 let s = "Test:deposit() (gas: 7222)";
455 let entry = GasSnapshotEntry::from_str(s).unwrap();
456 assert_eq!(
457 entry,
458 GasSnapshotEntry {
459 contract_name: "Test".to_string(),
460 signature: "deposit()".to_string(),
461 gas_used: TestKindReport::Unit { gas: 7222 }
462 }
463 );
464 }
465
466 #[test]
467 fn can_parse_fuzz_gas_snapshot_entry() {
468 let s = "Test:deposit() (runs: 256, μ: 100, ~:200)";
469 let entry = GasSnapshotEntry::from_str(s).unwrap();
470 assert_eq!(
471 entry,
472 GasSnapshotEntry {
473 contract_name: "Test".to_string(),
474 signature: "deposit()".to_string(),
475 gas_used: TestKindReport::Fuzz {
476 runs: 256,
477 median_gas: 200,
478 mean_gas: 100,
479 failed_corpus_replays: 0
480 }
481 }
482 );
483 }
484
485 #[test]
486 fn can_parse_invariant_gas_snapshot_entry() {
487 let s = "Test:deposit() (runs: 256, calls: 100, reverts: 200)";
488 let entry = GasSnapshotEntry::from_str(s).unwrap();
489 assert_eq!(
490 entry,
491 GasSnapshotEntry {
492 contract_name: "Test".to_string(),
493 signature: "deposit()".to_string(),
494 gas_used: TestKindReport::Invariant {
495 runs: 256,
496 calls: 100,
497 reverts: 200,
498 metrics: HashMap::default(),
499 failed_corpus_replays: 0,
500 }
501 }
502 );
503 }
504
505 #[test]
506 fn can_parse_invariant_gas_snapshot_entry2() {
507 let s = "ERC20Invariants:invariantBalanceSum() (runs: 256, calls: 3840, reverts: 2388)";
508 let entry = GasSnapshotEntry::from_str(s).unwrap();
509 assert_eq!(
510 entry,
511 GasSnapshotEntry {
512 contract_name: "ERC20Invariants".to_string(),
513 signature: "invariantBalanceSum()".to_string(),
514 gas_used: TestKindReport::Invariant {
515 runs: 256,
516 calls: 3840,
517 reverts: 2388,
518 metrics: HashMap::default(),
519 failed_corpus_replays: 0,
520 }
521 }
522 );
523 }
524}