Skip to main content

foundry_evm_fuzz/invariant/
mod.rs

1use alloy_json_abi::{Function, JsonAbi};
2use alloy_primitives::{Address, Selector, map::HashMap};
3use foundry_compilers::artifacts::StorageLayout;
4use itertools::Either;
5use parking_lot::Mutex;
6use serde::{Deserialize, Serialize};
7use std::{collections::BTreeMap, fmt, sync::Arc};
8
9mod call_override;
10pub use call_override::RandomCallGenerator;
11
12mod filters;
13use crate::BasicTxDetails;
14pub use filters::{ArtifactFilters, SenderFilters};
15use foundry_common::{ContractsByAddress, ContractsByArtifact};
16use foundry_evm_core::utils::{StateChangeset, get_function};
17
18/// Returns true if the function returns `int256`, indicating optimization mode.
19/// In optimization mode, the fuzzer maximizes the return value instead of checking invariants.
20pub fn is_optimization_invariant(func: &Function) -> bool {
21    func.outputs.len() == 1 && func.outputs[0].ty == "int256"
22}
23
24/// Contracts identified as targets during a fuzz run.
25///
26/// During execution, any newly created contract is added as target and used through the rest of
27/// the fuzz run if the collection is updatable (no `targetContract` specified in `setUp`).
28#[derive(Clone, Debug)]
29pub struct FuzzRunIdentifiedContracts {
30    /// Contracts identified as targets during a fuzz run.
31    pub targets: Arc<Mutex<TargetedContracts>>,
32    /// Whether target contracts are updatable or not.
33    pub is_updatable: bool,
34}
35
36impl FuzzRunIdentifiedContracts {
37    /// Creates a new `FuzzRunIdentifiedContracts` instance.
38    pub fn new(targets: TargetedContracts, is_updatable: bool) -> Self {
39        Self { targets: Arc::new(Mutex::new(targets)), is_updatable }
40    }
41
42    /// If targets are updatable, collect all contracts created during an invariant run (which
43    /// haven't been discovered yet).
44    pub fn collect_created_contracts(
45        &self,
46        state_changeset: &StateChangeset,
47        project_contracts: &ContractsByArtifact,
48        setup_contracts: &ContractsByAddress,
49        artifact_filters: &ArtifactFilters,
50        created_contracts: &mut Vec<Address>,
51    ) -> eyre::Result<()> {
52        if !self.is_updatable {
53            return Ok(());
54        }
55
56        let mut targets = self.targets.lock();
57        for (address, account) in state_changeset {
58            if setup_contracts.contains_key(address) {
59                continue;
60            }
61            if !account.is_touched() {
62                continue;
63            }
64            let Some(code) = &account.info.code else {
65                continue;
66            };
67            if code.is_empty() {
68                continue;
69            }
70            let Some((artifact, contract)) =
71                project_contracts.find_by_deployed_code(code.original_byte_slice())
72            else {
73                continue;
74            };
75            let Some(functions) =
76                artifact_filters.get_targeted_functions(artifact, &contract.abi)?
77            else {
78                continue;
79            };
80            created_contracts.push(*address);
81            let contract = TargetedContract {
82                identifier: artifact.name.clone(),
83                abi: contract.abi.clone(),
84                targeted_functions: functions,
85                excluded_functions: Vec::new(),
86                storage_layout: contract.storage_layout.as_ref().map(Arc::clone),
87            };
88            targets.insert(*address, contract);
89        }
90        Ok(())
91    }
92
93    /// Clears targeted contracts created during an invariant run.
94    pub fn clear_created_contracts(&self, created_contracts: Vec<Address>) {
95        if !created_contracts.is_empty() {
96            let mut targets = self.targets.lock();
97            for addr in &created_contracts {
98                targets.remove(addr);
99            }
100        }
101    }
102}
103
104/// A collection of contracts identified as targets for invariant testing.
105#[derive(Clone, Debug, Default)]
106pub struct TargetedContracts {
107    /// The inner map of targeted contracts.
108    pub inner: BTreeMap<Address, TargetedContract>,
109}
110
111impl TargetedContracts {
112    /// Returns a new `TargetedContracts` instance.
113    pub fn new() -> Self {
114        Self::default()
115    }
116
117    /// Returns fuzzed contract abi and fuzzed function from address and provided calldata.
118    ///
119    /// Used to decode return values and logs in order to add values into fuzz dictionary.
120    pub fn fuzzed_artifacts(&self, tx: &BasicTxDetails) -> (Option<&JsonAbi>, Option<&Function>) {
121        match self.inner.get(&tx.call_details.target) {
122            Some(c) => (
123                Some(&c.abi),
124                c.abi.functions().find(|f| f.selector() == tx.call_details.calldata[..4]),
125            ),
126            None => (None, None),
127        }
128    }
129
130    /// Returns flatten target contract address and functions to be fuzzed.
131    /// Includes contract targeted functions if specified, else all mutable contract functions.
132    pub fn fuzzed_functions(&self) -> impl Iterator<Item = (&Address, &Function)> {
133        self.inner
134            .iter()
135            .filter(|(_, c)| !c.abi.functions.is_empty())
136            .flat_map(|(contract, c)| c.abi_fuzzed_functions().map(move |f| (contract, f)))
137    }
138
139    /// Returns whether the given transaction can be replayed or not with known contracts.
140    pub fn can_replay(&self, tx: &BasicTxDetails) -> bool {
141        match self.inner.get(&tx.call_details.target) {
142            Some(c) => c.abi.functions().any(|f| f.selector() == tx.call_details.calldata[..4]),
143            None => false,
144        }
145    }
146
147    /// Identifies fuzzed contract and function based on given tx details and returns unique metric
148    /// key composed from contract identifier and function name.
149    pub fn fuzzed_metric_key(&self, tx: &BasicTxDetails) -> Option<String> {
150        self.inner.get(&tx.call_details.target).and_then(|contract| {
151            contract
152                .abi
153                .functions()
154                .find(|f| f.selector() == tx.call_details.calldata[..4])
155                .map(|function| format!("{}.{}", contract.identifier.clone(), function.name))
156        })
157    }
158
159    /// Returns a map of contract addresses to their storage layouts.
160    pub fn get_storage_layouts(&self) -> HashMap<Address, Arc<StorageLayout>> {
161        self.inner
162            .iter()
163            .filter_map(|(addr, c)| {
164                c.storage_layout.as_ref().map(|layout| (*addr, Arc::clone(layout)))
165            })
166            .collect()
167    }
168}
169
170impl std::ops::Deref for TargetedContracts {
171    type Target = BTreeMap<Address, TargetedContract>;
172
173    fn deref(&self) -> &Self::Target {
174        &self.inner
175    }
176}
177
178impl std::ops::DerefMut for TargetedContracts {
179    fn deref_mut(&mut self) -> &mut Self::Target {
180        &mut self.inner
181    }
182}
183
184/// A contract identified as target for invariant testing.
185#[derive(Clone, Debug)]
186pub struct TargetedContract {
187    /// The contract identifier. This is only used in error messages.
188    pub identifier: String,
189    /// The contract's ABI.
190    pub abi: JsonAbi,
191    /// The targeted functions of the contract.
192    pub targeted_functions: Vec<Function>,
193    /// The excluded functions of the contract.
194    pub excluded_functions: Vec<Function>,
195    /// The contract's storage layout, if available.
196    pub storage_layout: Option<Arc<StorageLayout>>,
197}
198
199impl TargetedContract {
200    /// Returns a new `TargetedContract` instance.
201    pub const fn new(identifier: String, abi: JsonAbi) -> Self {
202        Self {
203            identifier,
204            abi,
205            targeted_functions: Vec::new(),
206            excluded_functions: Vec::new(),
207            storage_layout: None,
208        }
209    }
210
211    /// Determines contract storage layout from project contracts. Needs `storageLayout` to be
212    /// enabled as extra output in project configuration.
213    pub fn with_project_contracts(mut self, project_contracts: &ContractsByArtifact) -> Self {
214        if let Some((src, name)) = self.identifier.split_once(':')
215            && let Some((_, contract_data)) = project_contracts.iter().find(|(artifact, _)| {
216                artifact.name == name && artifact.source.as_path().ends_with(src)
217            })
218        {
219            self.storage_layout = contract_data.storage_layout.as_ref().map(Arc::clone);
220        }
221        self
222    }
223
224    /// Helper to retrieve functions to fuzz for specified abi.
225    /// Returns specified targeted functions if any, else mutable abi functions that are not
226    /// marked as excluded.
227    pub fn abi_fuzzed_functions(&self) -> impl Iterator<Item = &Function> {
228        if self.targeted_functions.is_empty() {
229            Either::Right(self.abi.functions().filter(|&func| {
230                !matches!(
231                    func.state_mutability,
232                    alloy_json_abi::StateMutability::Pure | alloy_json_abi::StateMutability::View
233                ) && !self.excluded_functions.contains(func)
234            }))
235        } else {
236            Either::Left(self.targeted_functions.iter())
237        }
238    }
239
240    /// Returns the function for the given selector.
241    pub fn get_function(&self, selector: Selector) -> eyre::Result<&Function> {
242        get_function(&self.identifier, selector, &self.abi)
243    }
244
245    /// Adds the specified selectors to the targeted functions.
246    pub fn add_selectors(
247        &mut self,
248        selectors: impl IntoIterator<Item = Selector>,
249        should_exclude: bool,
250    ) -> eyre::Result<()> {
251        for selector in selectors {
252            if should_exclude {
253                self.excluded_functions.push(self.get_function(selector)?.clone());
254            } else {
255                self.targeted_functions.push(self.get_function(selector)?.clone());
256            }
257        }
258        Ok(())
259    }
260}
261
262/// Test contract which is testing its invariants.
263#[derive(Clone, Debug)]
264pub struct InvariantContract<'a> {
265    /// Address of the test contract.
266    pub address: Address,
267    /// Name of the test contract.
268    pub name: &'a str,
269    /// Invariant function present in the test contract.
270    pub invariant_function: &'a Function,
271    /// If true, `afterInvariant` function is called after each invariant run.
272    pub call_after_invariant: bool,
273    /// ABI of the test contract.
274    pub abi: &'a JsonAbi,
275}
276
277impl<'a> InvariantContract<'a> {
278    /// Creates a new invariant contract.
279    pub const fn new(
280        address: Address,
281        name: &'a str,
282        invariant_function: &'a Function,
283        call_after_invariant: bool,
284        abi: &'a JsonAbi,
285    ) -> Self {
286        Self { address, name, invariant_function, call_after_invariant, abi }
287    }
288
289    /// Returns true if this is an optimization mode invariant (returns int256).
290    pub fn is_optimization(&self) -> bool {
291        is_optimization_invariant(self.invariant_function)
292    }
293}
294
295/// Settings that determine the validity of a persisted invariant counterexample.
296///
297/// When a counterexample is replayed, it's only valid if the same contracts, selectors,
298/// senders, and fail_on_revert settings are used. Changes to unrelated code (e.g., adding
299/// a log statement) should not invalidate the counterexample.
300#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
301pub struct InvariantSettings {
302    /// Target contracts with their addresses and identifiers.
303    pub target_contracts: BTreeMap<Address, String>,
304    /// Target selectors per contract address.
305    pub target_selectors: BTreeMap<Address, Vec<Selector>>,
306    /// Target senders for the invariant test.
307    pub target_senders: Vec<Address>,
308    /// Excluded senders for the invariant test.
309    pub excluded_senders: Vec<Address>,
310    /// Whether the test should fail on any revert.
311    pub fail_on_revert: bool,
312}
313
314impl InvariantSettings {
315    /// Creates new invariant settings from the given components.
316    pub fn new(
317        targeted_contracts: &TargetedContracts,
318        sender_filters: &SenderFilters,
319        fail_on_revert: bool,
320    ) -> Self {
321        let target_contracts = targeted_contracts
322            .inner
323            .iter()
324            .map(|(addr, contract)| (*addr, contract.identifier.clone()))
325            .collect();
326
327        let target_selectors = targeted_contracts
328            .inner
329            .iter()
330            .map(|(addr, contract)| {
331                let selectors: Vec<Selector> =
332                    contract.abi_fuzzed_functions().map(|f| f.selector()).collect();
333                (*addr, selectors)
334            })
335            .collect();
336
337        let mut target_senders = sender_filters.targeted.clone();
338        target_senders.sort();
339
340        let mut excluded_senders = sender_filters.excluded.clone();
341        excluded_senders.sort();
342
343        Self {
344            target_contracts,
345            target_selectors,
346            target_senders,
347            excluded_senders,
348            fail_on_revert,
349        }
350    }
351
352    /// Compares these settings with another and returns a description of what changed.
353    /// Returns `None` if the settings are equivalent.
354    pub fn diff(&self, other: &Self) -> Option<String> {
355        let mut changes = Vec::new();
356
357        if self.target_contracts != other.target_contracts {
358            let added: Vec<_> = other
359                .target_contracts
360                .iter()
361                .filter(|(addr, _)| !self.target_contracts.contains_key(*addr))
362                .map(|(_, name)| name.as_str())
363                .collect();
364            let removed: Vec<_> = self
365                .target_contracts
366                .iter()
367                .filter(|(addr, _)| !other.target_contracts.contains_key(*addr))
368                .map(|(_, name)| name.as_str())
369                .collect();
370
371            if !added.is_empty() {
372                changes.push(format!("added target contracts: {}", added.join(", ")));
373            }
374            if !removed.is_empty() {
375                changes.push(format!("removed target contracts: {}", removed.join(", ")));
376            }
377        }
378
379        if self.target_selectors != other.target_selectors {
380            changes.push("target selectors changed".to_string());
381        }
382
383        if self.target_senders != other.target_senders {
384            changes.push("target senders changed".to_string());
385        }
386
387        if self.excluded_senders != other.excluded_senders {
388            changes.push("excluded senders changed".to_string());
389        }
390
391        if self.fail_on_revert != other.fail_on_revert {
392            changes.push(format!(
393                "fail_on_revert changed from {} to {}",
394                self.fail_on_revert, other.fail_on_revert
395            ));
396        }
397
398        if changes.is_empty() { None } else { Some(changes.join(", ")) }
399    }
400}
401
402impl fmt::Display for InvariantSettings {
403    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
404        write!(
405            f,
406            "targets: {}, selectors: {}, senders: {}, excluded: {}, fail_on_revert: {}",
407            self.target_contracts.len(),
408            self.target_selectors.values().map(|v| v.len()).sum::<usize>(),
409            self.target_senders.len(),
410            self.excluded_senders.len(),
411            self.fail_on_revert,
412        )
413    }
414}