Skip to main content

foundry_evm/executors/invariant/
shrink.rs

1use crate::executors::{
2    EarlyExit, Executor,
3    invariant::{call_after_invariant_function, call_invariant_function, execute_tx},
4};
5use alloy_primitives::{Address, Bytes, I256, U256};
6use foundry_config::InvariantConfig;
7use foundry_evm_core::constants::MAGIC_ASSUME;
8use foundry_evm_fuzz::{BasicTxDetails, invariant::InvariantContract};
9use indicatif::ProgressBar;
10use proptest::bits::{BitSetLike, VarBitSet};
11
12/// Shrinker for a call sequence failure.
13/// Iterates sequence call sequence top down and removes calls one by one.
14/// If the failure is still reproducible with removed call then moves to the next one.
15/// If the failure is not reproducible then restore removed call and moves to next one.
16#[derive(Debug)]
17struct CallSequenceShrinker {
18    /// Length of call sequence to be shrunk.
19    call_sequence_len: usize,
20    /// Call ids contained in current shrunk sequence.
21    included_calls: VarBitSet,
22}
23
24impl CallSequenceShrinker {
25    fn new(call_sequence_len: usize) -> Self {
26        Self { call_sequence_len, included_calls: VarBitSet::saturated(call_sequence_len) }
27    }
28
29    /// Return candidate shrink sequence to be tested, by removing ids from original sequence.
30    fn current(&self) -> impl Iterator<Item = usize> + '_ {
31        (0..self.call_sequence_len).filter(|&call_id| self.included_calls.test(call_id))
32    }
33
34    /// Advance to the next call index, wrapping around to 0 at the end.
35    fn next_index(&self, call_idx: usize) -> usize {
36        if call_idx + 1 == self.call_sequence_len { 0 } else { call_idx + 1 }
37    }
38}
39
40/// Resets the progress bar for shrinking.
41fn reset_shrink_progress(config: &InvariantConfig, progress: Option<&ProgressBar>) {
42    if let Some(progress) = progress {
43        progress.set_length(config.shrink_run_limit as u64);
44        progress.reset();
45        progress.set_message(" Shrink");
46    }
47}
48
49/// Applies accumulated warp/roll to a call, returning a modified copy.
50fn apply_warp_roll(call: &BasicTxDetails, warp: U256, roll: U256) -> BasicTxDetails {
51    let mut result = call.clone();
52    if warp > U256::ZERO {
53        result.warp = Some(warp);
54    }
55    if roll > U256::ZERO {
56        result.roll = Some(roll);
57    }
58    result
59}
60
61/// Applies warp/roll adjustments directly to the executor's environment.
62fn apply_warp_roll_to_env(executor: &mut Executor, warp: U256, roll: U256) {
63    if warp > U256::ZERO || roll > U256::ZERO {
64        executor.env_mut().evm_env.block_env.timestamp += warp;
65        executor.env_mut().evm_env.block_env.number += roll;
66
67        let block_env = executor.env().evm_env.block_env.clone();
68        if let Some(cheatcodes) = executor.inspector_mut().cheatcodes.as_mut() {
69            if let Some(block) = cheatcodes.block.as_mut() {
70                block.timestamp += warp;
71                block.number += roll;
72            } else {
73                cheatcodes.block = Some(block_env);
74            }
75        }
76    }
77}
78
79pub(crate) fn shrink_sequence(
80    config: &InvariantConfig,
81    invariant_contract: &InvariantContract<'_>,
82    calls: &[BasicTxDetails],
83    executor: &Executor,
84    progress: Option<&ProgressBar>,
85    early_exit: &EarlyExit,
86) -> eyre::Result<Vec<BasicTxDetails>> {
87    trace!(target: "forge::test", "Shrinking sequence of {} calls.", calls.len());
88
89    reset_shrink_progress(config, progress);
90
91    let target_address = invariant_contract.address;
92    let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
93    // Special case test: the invariant is *unsatisfiable* - it took 0 calls to
94    // break the invariant -- consider emitting a warning.
95    let (_, success) = call_invariant_function(executor, target_address, calldata.clone())?;
96    if !success {
97        return Ok(vec![]);
98    }
99
100    let mut call_idx = 0;
101    let mut shrinker = CallSequenceShrinker::new(calls.len());
102
103    for _ in 0..config.shrink_run_limit {
104        if early_exit.should_stop() {
105            break;
106        }
107
108        shrinker.included_calls.clear(call_idx);
109
110        match check_sequence(
111            executor.clone(),
112            calls,
113            shrinker.current().collect(),
114            target_address,
115            calldata.clone(),
116            config.fail_on_revert,
117            invariant_contract.call_after_invariant,
118        ) {
119            // If candidate sequence still fails, shrink until shortest possible.
120            Ok((false, _)) if shrinker.included_calls.count() == 1 => break,
121            // Restore last removed call as it caused sequence to pass invariant.
122            Ok((true, _)) => shrinker.included_calls.set(call_idx),
123            _ => {}
124        }
125
126        if let Some(progress) = progress {
127            progress.inc(1);
128        }
129
130        call_idx = shrinker.next_index(call_idx);
131    }
132
133    Ok(shrinker.current().map(|idx| &calls[idx]).cloned().collect())
134}
135
136/// Checks if the given call sequence breaks the invariant.
137///
138/// Used in shrinking phase for checking candidate sequences and in replay failures phase to test
139/// persisted failures.
140/// Returns the result of invariant check (and afterInvariant call if needed) and if sequence was
141/// entirely applied.
142pub fn check_sequence(
143    mut executor: Executor,
144    calls: &[BasicTxDetails],
145    sequence: Vec<usize>,
146    test_address: Address,
147    calldata: Bytes,
148    fail_on_revert: bool,
149    call_after_invariant: bool,
150) -> eyre::Result<(bool, bool)> {
151    // Apply the call sequence.
152    for call_index in sequence {
153        let tx = &calls[call_index];
154        let mut call_result = execute_tx(&mut executor, tx)?;
155        executor.commit(&mut call_result);
156        // Ignore calls reverted with `MAGIC_ASSUME`. This is needed to handle failed scenarios that
157        // are replayed with a modified version of test driver (that use new `vm.assume`
158        // cheatcodes).
159        if call_result.reverted && fail_on_revert && call_result.result.as_ref() != MAGIC_ASSUME {
160            // Candidate sequence fails test.
161            // We don't have to apply remaining calls to check sequence.
162            return Ok((false, false));
163        }
164    }
165
166    // Check the invariant for call sequence.
167    let (_, mut success) = call_invariant_function(&executor, test_address, calldata)?;
168    // Check after invariant result if invariant is success and `afterInvariant` function is
169    // declared.
170    if success && call_after_invariant {
171        (_, success) = call_after_invariant_function(&executor, test_address)?;
172    }
173
174    Ok((success, true))
175}
176
177/// Shrinks a call sequence to the shortest sequence that still produces the target optimization
178/// value. This is specifically for optimization mode where we want to find the minimal sequence
179/// that achieves the maximum value.
180///
181/// Unlike `shrink_sequence` (for check mode), this function:
182/// - Accumulates warp/roll values from removed calls into the next kept call
183/// - Checks for target value equality rather than invariant failure
184pub(crate) fn shrink_sequence_value(
185    config: &InvariantConfig,
186    invariant_contract: &InvariantContract<'_>,
187    calls: &[BasicTxDetails],
188    executor: &Executor,
189    target_value: I256,
190    progress: Option<&ProgressBar>,
191    early_exit: &EarlyExit,
192) -> eyre::Result<Vec<BasicTxDetails>> {
193    trace!(target: "forge::test", "Shrinking optimization sequence of {} calls for target value {}.", calls.len(), target_value);
194
195    reset_shrink_progress(config, progress);
196
197    let target_address = invariant_contract.address;
198    let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
199
200    // Special case: check if target value is achieved with 0 calls.
201    if check_sequence_value(executor.clone(), calls, vec![], target_address, calldata.clone())?
202        == Some(target_value)
203    {
204        return Ok(vec![]);
205    }
206
207    let mut call_idx = 0;
208    let mut shrinker = CallSequenceShrinker::new(calls.len());
209
210    for _ in 0..config.shrink_run_limit {
211        if early_exit.should_stop() {
212            break;
213        }
214
215        shrinker.included_calls.clear(call_idx);
216
217        let keeps_target = check_sequence_value(
218            executor.clone(),
219            calls,
220            shrinker.current().collect(),
221            target_address,
222            calldata.clone(),
223        )? == Some(target_value);
224
225        if keeps_target {
226            if shrinker.included_calls.count() == 1 {
227                break;
228            }
229        } else {
230            shrinker.included_calls.set(call_idx);
231        }
232
233        if let Some(progress) = progress {
234            progress.inc(1);
235        }
236
237        call_idx = shrinker.next_index(call_idx);
238    }
239
240    // Build the final shrunk sequence, accumulating warp/roll from removed calls.
241    let mut result = Vec::new();
242    let mut accumulated_warp = U256::ZERO;
243    let mut accumulated_roll = U256::ZERO;
244
245    for (idx, call) in calls.iter().enumerate() {
246        accumulated_warp += call.warp.unwrap_or(U256::ZERO);
247        accumulated_roll += call.roll.unwrap_or(U256::ZERO);
248
249        if shrinker.included_calls.test(idx) {
250            result.push(apply_warp_roll(call, accumulated_warp, accumulated_roll));
251            accumulated_warp = U256::ZERO;
252            accumulated_roll = U256::ZERO;
253        }
254    }
255
256    Ok(result)
257}
258
259/// Executes a call sequence and returns the optimization value (int256) from the invariant
260/// function. Used during shrinking for optimization mode.
261///
262/// Returns `None` if the invariant call fails or doesn't return a valid int256.
263/// Unlike `check_sequence`, this applies warp/roll from ALL calls (including removed ones).
264pub fn check_sequence_value(
265    mut executor: Executor,
266    calls: &[BasicTxDetails],
267    sequence: Vec<usize>,
268    test_address: Address,
269    calldata: Bytes,
270) -> eyre::Result<Option<I256>> {
271    let mut accumulated_warp = U256::ZERO;
272    let mut accumulated_roll = U256::ZERO;
273    let mut seq_iter = sequence.iter().peekable();
274
275    for (idx, tx) in calls.iter().enumerate() {
276        accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
277        accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
278
279        if seq_iter.peek() == Some(&&idx) {
280            seq_iter.next();
281
282            let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
283            let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
284
285            if !call_result.reverted {
286                executor.commit(&mut call_result);
287            }
288
289            accumulated_warp = U256::ZERO;
290            accumulated_roll = U256::ZERO;
291        }
292    }
293
294    // Apply any remaining accumulated warp/roll before calling invariant.
295    apply_warp_roll_to_env(&mut executor, accumulated_warp, accumulated_roll);
296
297    let (inv_result, success) = call_invariant_function(&executor, test_address, calldata)?;
298
299    if success
300        && inv_result.result.len() >= 32
301        && let Some(value) = I256::try_from_be_slice(&inv_result.result[..32])
302    {
303        return Ok(Some(value));
304    }
305
306    Ok(None)
307}