Skip to main content

forge/
gas_report.rs

1//! Gas reports.
2
3use crate::{
4    constants::{CHEATCODE_ADDRESS, HARDHAT_CONSOLE_ADDRESS},
5    traces::{CallTraceArena, CallTraceDecoder, CallTraceNode, DecodedCallData},
6};
7use alloy_primitives::map::HashSet;
8use comfy_table::{
9    Cell, CellAlignment, Color, Table, modifiers::UTF8_ROUND_CORNERS, presets::ASCII_MARKDOWN,
10};
11use foundry_common::{TestFunctionExt, calc, shell};
12use foundry_evm::traces::CallKind;
13
14use serde::{Deserialize, Serialize};
15use serde_json::json;
16use std::{collections::BTreeMap, fmt::Display};
17
18/// Represents the gas report for a set of contracts.
19#[derive(Clone, Debug, Default, Serialize, Deserialize)]
20pub struct GasReport {
21    /// Whether to report any contracts.
22    report_any: bool,
23    /// Contracts to generate the report for.
24    report_for: HashSet<String>,
25    /// Contracts to ignore when generating the report.
26    ignore: HashSet<String>,
27    /// Whether to include gas reports for tests.
28    include_tests: bool,
29    /// All contracts that were analyzed grouped by their identifier
30    /// ``test/Counter.t.sol:CounterTest
31    pub contracts: BTreeMap<String, ContractInfo>,
32}
33
34impl GasReport {
35    pub fn new(
36        report_for: impl IntoIterator<Item = String>,
37        ignore: impl IntoIterator<Item = String>,
38        include_tests: bool,
39    ) -> Self {
40        let report_for = report_for.into_iter().collect::<HashSet<_>>();
41        let ignore = ignore.into_iter().collect::<HashSet<_>>();
42        let report_any = report_for.is_empty() || report_for.contains("*");
43        Self { report_any, report_for, ignore, include_tests, ..Default::default() }
44    }
45
46    /// Whether the given contract should be reported.
47    #[instrument(level = "trace", skip(self), ret)]
48    fn should_report(&self, contract_name: &str) -> bool {
49        if self.ignore.contains(contract_name) {
50            let contains_anyway = self.report_for.contains(contract_name);
51            if contains_anyway {
52                // If the user listed the contract in 'gas_reports' (the foundry.toml field) a
53                // report for the contract is generated even if it's listed in the ignore
54                // list. This is addressed this way because getting a report you don't expect is
55                // preferable than not getting one you expect. A warning is printed to stderr
56                // indicating the "double listing".
57                let _ = sh_warn!(
58                    "{contract_name} is listed in both 'gas_reports' and 'gas_reports_ignore'."
59                );
60            }
61            return contains_anyway;
62        }
63        self.report_any || self.report_for.contains(contract_name)
64    }
65
66    /// Analyzes the given traces and generates a gas report.
67    pub async fn analyze(
68        &mut self,
69        arenas: impl IntoIterator<Item = &CallTraceArena>,
70        decoder: &CallTraceDecoder,
71    ) {
72        for node in arenas.into_iter().flat_map(|arena| arena.nodes()) {
73            self.analyze_node(node, decoder).await;
74        }
75    }
76
77    async fn analyze_node(&mut self, node: &CallTraceNode, decoder: &CallTraceDecoder) {
78        let trace = &node.trace;
79
80        if trace.address == CHEATCODE_ADDRESS || trace.address == HARDHAT_CONSOLE_ADDRESS {
81            return;
82        }
83
84        let Some(name) = decoder.contracts.get(&node.trace.address) else { return };
85        let contract_name = name.rsplit(':').next().unwrap_or(name);
86
87        if !self.should_report(contract_name) {
88            return;
89        }
90        let contract_info = self.contracts.entry(name.to_string()).or_default();
91        let is_create_call = trace.kind.is_any_create();
92
93        // Record contract deployment size.
94        if is_create_call {
95            trace!(contract_name, "adding create size info");
96            contract_info.size = trace.data.len();
97        }
98
99        // Only include top-level calls which account for calldata and base (21.000) cost.
100        // Only include Calls and Creates as only these calls are isolated in inspector.
101        if trace.depth > 1 && (trace.kind == CallKind::Call || is_create_call) {
102            return;
103        }
104
105        let decoded = || decoder.decode_function(&node.trace);
106
107        if is_create_call {
108            trace!(contract_name, "adding create gas info");
109            contract_info.gas = trace.gas_used;
110        } else if let Some(DecodedCallData { signature, .. }) = decoded().await.call_data {
111            let name = signature.split('(').next().unwrap();
112            // ignore any test/setup functions
113            if self.include_tests || !name.test_function_kind().is_known() {
114                trace!(contract_name, signature, "adding gas info");
115                let gas_info = contract_info
116                    .functions
117                    .entry(name.to_string())
118                    .or_default()
119                    .entry(signature.clone())
120                    .or_default();
121                gas_info.frames.push(trace.gas_used);
122            }
123        }
124    }
125
126    /// Finalizes the gas report by calculating the min, max, mean, and median for each function.
127    #[must_use]
128    pub fn finalize(mut self) -> Self {
129        trace!("finalizing gas report");
130        for contract in self.contracts.values_mut() {
131            for sigs in contract.functions.values_mut() {
132                for func in sigs.values_mut() {
133                    func.frames.sort_unstable();
134                    func.min = func.frames.first().copied().unwrap_or_default();
135                    func.max = func.frames.last().copied().unwrap_or_default();
136                    func.mean = calc::mean(&func.frames);
137                    func.median = calc::median_sorted(&func.frames);
138                    func.calls = func.frames.len() as u64;
139                }
140            }
141        }
142        self
143    }
144}
145
146impl Display for GasReport {
147    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
148        if shell::is_json() {
149            writeln!(f, "{}", &self.format_json_output())?;
150        } else {
151            for (name, contract) in &self.contracts {
152                if contract.functions.is_empty() {
153                    trace!(name, "gas report contract without functions");
154                    continue;
155                }
156
157                let table = self.format_table_output(contract, name);
158                writeln!(f, "\n{table}")?;
159            }
160        }
161
162        Ok(())
163    }
164}
165
166impl GasReport {
167    fn format_json_output(&self) -> String {
168        serde_json::to_string(
169            &self
170                .contracts
171                .iter()
172                .filter_map(|(name, contract)| {
173                    if contract.functions.is_empty() {
174                        trace!(name, "gas report contract without functions");
175                        return None;
176                    }
177
178                    let functions = contract
179                        .functions
180                        .iter()
181                        .flat_map(|(_, sigs)| {
182                            sigs.iter().map(|(sig, gas_info)| {
183                                let display_name = sig.replace(':', "");
184                                (display_name, gas_info)
185                            })
186                        })
187                        .collect::<BTreeMap<_, _>>();
188
189                    Some(json!({
190                        "contract": name,
191                        "deployment": {
192                            "gas": contract.gas,
193                            "size": contract.size,
194                        },
195                        "functions": functions,
196                    }))
197                })
198                .collect::<Vec<_>>(),
199        )
200        .unwrap()
201    }
202
203    fn format_table_output(&self, contract: &ContractInfo, name: &str) -> Table {
204        let mut table = Table::new();
205        if shell::is_markdown() {
206            table.load_preset(ASCII_MARKDOWN);
207        } else {
208            table.apply_modifier(UTF8_ROUND_CORNERS);
209        }
210
211        table.set_header(vec![Cell::new(format!("{name} Contract")).fg(Color::Magenta)]);
212
213        table.add_row(vec![
214            Cell::new("Deployment Cost").fg(Color::Cyan),
215            Cell::new("Deployment Size").fg(Color::Cyan),
216        ]);
217        table.add_row(vec![
218            Cell::new(contract.gas.to_string()).set_alignment(CellAlignment::Right),
219            Cell::new(contract.size.to_string()).set_alignment(CellAlignment::Right),
220        ]);
221
222        // Add a blank row to separate deployment info from function info.
223        table.add_row(vec![Cell::new("")]);
224
225        table.add_row(vec![
226            Cell::new("Function Name"),
227            Cell::new("Min").fg(Color::Green),
228            Cell::new("Avg").fg(Color::Yellow),
229            Cell::new("Median").fg(Color::Yellow),
230            Cell::new("Max").fg(Color::Red),
231            Cell::new("# Calls").fg(Color::Cyan),
232        ]);
233
234        contract.functions.iter().for_each(|(fname, sigs)| {
235            sigs.iter().for_each(|(sig, gas_info)| {
236                // Show function signature if overloaded else display function name.
237                let display_name =
238                    if sigs.len() == 1 { fname.to_string() } else { sig.replace(':', "") };
239
240                table.add_row(vec![
241                    Cell::new(display_name),
242                    Cell::new(gas_info.min.to_string())
243                        .fg(Color::Green)
244                        .set_alignment(CellAlignment::Right),
245                    Cell::new(gas_info.mean.to_string())
246                        .fg(Color::Yellow)
247                        .set_alignment(CellAlignment::Right),
248                    Cell::new(gas_info.median.to_string())
249                        .fg(Color::Yellow)
250                        .set_alignment(CellAlignment::Right),
251                    Cell::new(gas_info.max.to_string())
252                        .fg(Color::Red)
253                        .set_alignment(CellAlignment::Right),
254                    Cell::new(gas_info.calls.to_string()).set_alignment(CellAlignment::Right),
255                ]);
256            })
257        });
258
259        table
260    }
261}
262
263#[derive(Clone, Debug, Default, Serialize, Deserialize)]
264pub struct ContractInfo {
265    pub gas: u64,
266    pub size: usize,
267    /// Function name -> Function signature -> GasInfo
268    pub functions: BTreeMap<String, BTreeMap<String, GasInfo>>,
269}
270
271#[derive(Clone, Debug, Default, Serialize, Deserialize)]
272pub struct GasInfo {
273    pub calls: u64,
274    pub min: u64,
275    pub mean: u64,
276    pub median: u64,
277    pub max: u64,
278
279    #[serde(skip)]
280    pub frames: Vec<u64>,
281}