foundry_evm_fuzz/invariant/
mod.rs1use alloy_json_abi::{Function, JsonAbi};
2use alloy_primitives::{Address, Selector, map::HashMap};
3use foundry_compilers::artifacts::StorageLayout;
4use itertools::Either;
5use serde::{Deserialize, Serialize};
6use std::{
7 cell::{Ref, RefCell},
8 collections::BTreeMap,
9 fmt,
10 rc::Rc,
11 sync::Arc,
12};
13
14mod call_override;
15pub use call_override::RandomCallGenerator;
16
17mod filters;
18use crate::BasicTxDetails;
19pub use filters::{ArtifactFilters, SenderFilters};
20use foundry_common::{ContractsByAddress, ContractsByArtifact};
21use foundry_evm_core::utils::{StateChangeset, get_function};
22
23pub fn is_optimization_invariant(func: &Function) -> bool {
26 func.outputs.len() == 1 && func.outputs[0].ty == "int256"
27}
28
29#[derive(Clone, Debug)]
34pub struct FuzzRunIdentifiedContracts {
35 pub targets: Rc<RefCell<TargetedContracts>>,
37 pub is_updatable: bool,
39}
40
41impl FuzzRunIdentifiedContracts {
42 pub fn new(targets: TargetedContracts, is_updatable: bool) -> Self {
44 Self { targets: Rc::new(RefCell::new(targets)), is_updatable }
45 }
46
47 pub fn targets(&self) -> Ref<'_, TargetedContracts> {
49 self.targets.borrow()
50 }
51
52 pub fn collect_created_contracts(
55 &self,
56 state_changeset: &StateChangeset,
57 project_contracts: &ContractsByArtifact,
58 setup_contracts: &ContractsByAddress,
59 artifact_filters: &ArtifactFilters,
60 created_contracts: &mut Vec<Address>,
61 ) -> eyre::Result<()> {
62 if !self.is_updatable {
63 return Ok(());
64 }
65
66 let mut targets = self.targets.borrow_mut();
67 for (address, account) in state_changeset {
68 if setup_contracts.contains_key(address) {
69 continue;
70 }
71 if !account.is_touched() {
72 continue;
73 }
74 let Some(code) = &account.info.code else {
75 continue;
76 };
77 if code.is_empty() {
78 continue;
79 }
80 let Some((artifact, contract)) =
81 project_contracts.find_by_deployed_code(code.original_byte_slice())
82 else {
83 continue;
84 };
85 let Some(functions) =
86 artifact_filters.get_targeted_functions(artifact, &contract.abi)?
87 else {
88 continue;
89 };
90 created_contracts.push(*address);
91 let contract = TargetedContract {
92 identifier: artifact.name.clone(),
93 abi: contract.abi.clone(),
94 targeted_functions: functions,
95 excluded_functions: Vec::new(),
96 storage_layout: contract.storage_layout.as_ref().map(Arc::clone),
97 };
98 targets.insert(*address, contract);
99 }
100 Ok(())
101 }
102
103 pub fn clear_created_contracts(&self, created_contracts: Vec<Address>) {
105 if !created_contracts.is_empty() {
106 let mut targets = self.targets.borrow_mut();
107 for addr in &created_contracts {
108 targets.remove(addr);
109 }
110 }
111 }
112}
113
114#[derive(Clone, Debug, Default)]
116pub struct TargetedContracts {
117 pub inner: BTreeMap<Address, TargetedContract>,
119}
120
121impl TargetedContracts {
122 pub fn new() -> Self {
124 Self::default()
125 }
126
127 pub fn fuzzed_artifacts(&self, tx: &BasicTxDetails) -> (Option<&JsonAbi>, Option<&Function>) {
131 match self.inner.get(&tx.call_details.target) {
132 Some(c) => (
133 Some(&c.abi),
134 c.abi.functions().find(|f| f.selector() == tx.call_details.calldata[..4]),
135 ),
136 None => (None, None),
137 }
138 }
139
140 pub fn fuzzed_functions(&self) -> impl Iterator<Item = (&Address, &Function)> {
143 self.inner
144 .iter()
145 .filter(|(_, c)| !c.abi.functions.is_empty())
146 .flat_map(|(contract, c)| c.abi_fuzzed_functions().map(move |f| (contract, f)))
147 }
148
149 pub fn can_replay(&self, tx: &BasicTxDetails) -> bool {
151 match self.inner.get(&tx.call_details.target) {
152 Some(c) => c.abi.functions().any(|f| f.selector() == tx.call_details.calldata[..4]),
153 None => false,
154 }
155 }
156
157 pub fn fuzzed_metric_key(&self, tx: &BasicTxDetails) -> Option<String> {
160 self.inner.get(&tx.call_details.target).and_then(|contract| {
161 contract
162 .abi
163 .functions()
164 .find(|f| f.selector() == tx.call_details.calldata[..4])
165 .map(|function| format!("{}.{}", contract.identifier.clone(), function.name))
166 })
167 }
168
169 pub fn get_storage_layouts(&self) -> HashMap<Address, Arc<StorageLayout>> {
171 self.inner
172 .iter()
173 .filter_map(|(addr, c)| {
174 c.storage_layout.as_ref().map(|layout| (*addr, Arc::clone(layout)))
175 })
176 .collect()
177 }
178}
179
180impl std::ops::Deref for TargetedContracts {
181 type Target = BTreeMap<Address, TargetedContract>;
182
183 fn deref(&self) -> &Self::Target {
184 &self.inner
185 }
186}
187
188impl std::ops::DerefMut for TargetedContracts {
189 fn deref_mut(&mut self) -> &mut Self::Target {
190 &mut self.inner
191 }
192}
193
194#[derive(Clone, Debug)]
196pub struct TargetedContract {
197 pub identifier: String,
199 pub abi: JsonAbi,
201 pub targeted_functions: Vec<Function>,
203 pub excluded_functions: Vec<Function>,
205 pub storage_layout: Option<Arc<StorageLayout>>,
207}
208
209impl TargetedContract {
210 pub const fn new(identifier: String, abi: JsonAbi) -> Self {
212 Self {
213 identifier,
214 abi,
215 targeted_functions: Vec::new(),
216 excluded_functions: Vec::new(),
217 storage_layout: None,
218 }
219 }
220
221 pub fn with_project_contracts(mut self, project_contracts: &ContractsByArtifact) -> Self {
224 if let Some((src, name)) = self.identifier.split_once(':')
225 && let Some((_, contract_data)) = project_contracts.iter().find(|(artifact, _)| {
226 artifact.name == name && artifact.source.as_path().ends_with(src)
227 })
228 {
229 self.storage_layout = contract_data.storage_layout.as_ref().map(Arc::clone);
230 }
231 self
232 }
233
234 pub fn abi_fuzzed_functions(&self) -> impl Iterator<Item = &Function> {
238 if self.targeted_functions.is_empty() {
239 Either::Right(self.abi.functions().filter(|&func| {
240 !matches!(
241 func.state_mutability,
242 alloy_json_abi::StateMutability::Pure | alloy_json_abi::StateMutability::View
243 ) && !self.excluded_functions.contains(func)
244 }))
245 } else {
246 Either::Left(self.targeted_functions.iter())
247 }
248 }
249
250 pub fn get_function(&self, selector: Selector) -> eyre::Result<&Function> {
252 get_function(&self.identifier, selector, &self.abi)
253 }
254
255 pub fn add_selectors(
257 &mut self,
258 selectors: impl IntoIterator<Item = Selector>,
259 should_exclude: bool,
260 ) -> eyre::Result<()> {
261 for selector in selectors {
262 if should_exclude {
263 self.excluded_functions.push(self.get_function(selector)?.clone());
264 } else {
265 self.targeted_functions.push(self.get_function(selector)?.clone());
266 }
267 }
268 Ok(())
269 }
270}
271
272#[derive(Clone, Debug)]
274pub struct InvariantContract<'a> {
275 pub address: Address,
277 pub name: &'a str,
279 pub invariant_fns: Vec<(&'a Function, bool)>,
283 pub anchor_idx: usize,
287 pub call_after_invariant: bool,
289 pub abi: &'a JsonAbi,
291}
292
293impl<'a> InvariantContract<'a> {
294 pub const fn new(
298 address: Address,
299 name: &'a str,
300 invariant_fns: Vec<(&'a Function, bool)>,
301 anchor_idx: usize,
302 call_after_invariant: bool,
303 abi: &'a JsonAbi,
304 ) -> Self {
305 Self { address, name, invariant_fns, anchor_idx, call_after_invariant, abi }
306 }
307
308 pub fn anchor(&self) -> &'a Function {
311 self.invariant_fns[self.anchor_idx].0
312 }
313
314 pub fn is_optimization(&self) -> bool {
316 is_optimization_invariant(self.anchor())
317 }
318}
319
320#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
326pub struct InvariantSettings {
327 pub target_contracts: BTreeMap<Address, String>,
329 pub target_selectors: BTreeMap<Address, Vec<Selector>>,
331 pub target_senders: Vec<Address>,
333 pub excluded_senders: Vec<Address>,
335 pub fail_on_revert: bool,
337}
338
339impl InvariantSettings {
340 pub fn new(
342 targeted_contracts: &TargetedContracts,
343 sender_filters: &SenderFilters,
344 fail_on_revert: bool,
345 ) -> Self {
346 let target_contracts = targeted_contracts
347 .inner
348 .iter()
349 .map(|(addr, contract)| (*addr, contract.identifier.clone()))
350 .collect();
351
352 let target_selectors = targeted_contracts
353 .inner
354 .iter()
355 .map(|(addr, contract)| {
356 let selectors: Vec<Selector> =
357 contract.abi_fuzzed_functions().map(|f| f.selector()).collect();
358 (*addr, selectors)
359 })
360 .collect();
361
362 let mut target_senders = sender_filters.targeted.clone();
363 target_senders.sort();
364
365 let mut excluded_senders = sender_filters.excluded.clone();
366 excluded_senders.sort();
367
368 Self {
369 target_contracts,
370 target_selectors,
371 target_senders,
372 excluded_senders,
373 fail_on_revert,
374 }
375 }
376
377 pub fn diff(&self, other: &Self) -> Option<String> {
380 let mut changes = Vec::new();
381
382 if self.target_contracts != other.target_contracts {
383 let added: Vec<_> = other
384 .target_contracts
385 .iter()
386 .filter(|(addr, _)| !self.target_contracts.contains_key(*addr))
387 .map(|(_, name)| name.as_str())
388 .collect();
389 let removed: Vec<_> = self
390 .target_contracts
391 .iter()
392 .filter(|(addr, _)| !other.target_contracts.contains_key(*addr))
393 .map(|(_, name)| name.as_str())
394 .collect();
395
396 if !added.is_empty() {
397 changes.push(format!("added target contracts: {}", added.join(", ")));
398 }
399 if !removed.is_empty() {
400 changes.push(format!("removed target contracts: {}", removed.join(", ")));
401 }
402 }
403
404 if self.target_selectors != other.target_selectors {
405 changes.push("target selectors changed".to_string());
406 }
407
408 if self.target_senders != other.target_senders {
409 changes.push("target senders changed".to_string());
410 }
411
412 if self.excluded_senders != other.excluded_senders {
413 changes.push("excluded senders changed".to_string());
414 }
415
416 if self.fail_on_revert != other.fail_on_revert {
417 changes.push(format!(
418 "fail_on_revert changed from {} to {}",
419 self.fail_on_revert, other.fail_on_revert
420 ));
421 }
422
423 if changes.is_empty() { None } else { Some(changes.join(", ")) }
424 }
425}
426
427impl fmt::Display for InvariantSettings {
428 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
429 write!(
430 f,
431 "targets: {}, selectors: {}, senders: {}, excluded: {}, fail_on_revert: {}",
432 self.target_contracts.len(),
433 self.target_selectors.values().map(|v| v.len()).sum::<usize>(),
434 self.target_senders.len(),
435 self.excluded_senders.len(),
436 self.fail_on_revert,
437 )
438 }
439}