foundry_evm_fuzz/invariant/
call_override.rs

1use crate::{BasicTxDetails, CallDetails};
2use alloy_primitives::Address;
3use parking_lot::{Mutex, RwLock};
4use proptest::{
5    option::weighted,
6    strategy::{SBoxedStrategy, Strategy, ValueTree},
7    test_runner::TestRunner,
8};
9use std::{collections::HashSet, sync::Arc};
10
11/// Given a TestRunner and a strategy, it generates calls. Used inside the Fuzzer inspector to
12/// override external calls to test for potential reentrancy vulnerabilities.
13///
14/// The key insight is that we only override calls TO handler contracts (targeted contracts).
15/// This simulates a malicious contract that reenters when receiving ETH via its receive() function.
16#[derive(Clone, Debug)]
17pub struct RandomCallGenerator {
18    /// Address of the test contract.
19    pub test_address: Address,
20    /// Addresses of handler contracts that can be reentered.
21    /// We only inject callbacks when the call target is one of these.
22    pub handler_addresses: Arc<RwLock<HashSet<Address>>>,
23    /// Runner that will generate the call from the strategy.
24    pub runner: Arc<Mutex<TestRunner>>,
25    /// Strategy to be used to generate calls from `target_reference`.
26    pub strategy: SBoxedStrategy<Option<CallDetails>>,
27    /// Reference to which contract we want a fuzzed calldata from.
28    pub target_reference: Arc<RwLock<Address>>,
29    /// Tracks the call depth when an override is active. When > 0, we're inside an overridden
30    /// call and should not override nested calls. Incremented when we override a call,
31    /// decremented when any call ends while inside an override.
32    pub override_depth: usize,
33    /// If set to `true`, consumes the next call from `last_sequence`, otherwise queries it from
34    /// the strategy.
35    pub replay: bool,
36    /// Saves the sequence of generated calls that can be replayed later on.
37    pub last_sequence: Arc<RwLock<Vec<Option<BasicTxDetails>>>>,
38}
39
40impl RandomCallGenerator {
41    pub fn new(
42        test_address: Address,
43        handler_addresses: HashSet<Address>,
44        runner: TestRunner,
45        strategy: impl Strategy<Value = CallDetails> + Send + Sync + 'static,
46        target_reference: Arc<RwLock<Address>>,
47    ) -> Self {
48        Self {
49            test_address,
50            handler_addresses: Arc::new(RwLock::new(handler_addresses)),
51            runner: Arc::new(Mutex::new(runner)),
52            strategy: weighted(0.9, strategy).sboxed(),
53            target_reference,
54            last_sequence: Arc::default(),
55            replay: false,
56            override_depth: 0,
57        }
58    }
59
60    /// Check if the given address is a handler that can be reentered.
61    pub fn is_handler(&self, address: Address) -> bool {
62        self.handler_addresses.read().contains(&address)
63    }
64
65    /// All `self.next()` calls will now pop `self.last_sequence`. Used to replay an invariant
66    /// failure.
67    pub fn set_replay(&mut self, status: bool) {
68        self.replay = status;
69        if status {
70            // So it can later be popped.
71            self.last_sequence.write().reverse();
72        }
73    }
74
75    /// Gets the next call. Random if replay is not set. Otherwise, it pops from `last_sequence`.
76    pub fn next(
77        &mut self,
78        original_caller: Address,
79        original_target: Address,
80    ) -> Option<BasicTxDetails> {
81        if self.replay {
82            self.last_sequence.write().pop().expect(
83                "to have same size as the number of (unsafe) external calls of the sequence.",
84            )
85        } else {
86            // TODO: Do we want it to be 80% chance only too ?
87            let sender = original_target;
88
89            // Set which contract we mostly (80% chance) want to generate calldata from.
90            *self.target_reference.write() = original_caller;
91
92            // `original_caller` has a 80% chance of being the `new_target`.
93            let choice = self.strategy.new_tree(&mut self.runner.lock()).unwrap().current().map(
94                |call_details| BasicTxDetails { warp: None, roll: None, sender, call_details },
95            );
96
97            self.last_sequence.write().push(choice.clone());
98            choice
99        }
100    }
101}