Skip to main content

foundry_evm/executors/invariant/
shrink.rs

1use crate::executors::{
2    EarlyExit, EvmError, Executor, RawCallResult,
3    invariant::{call_after_invariant_function, call_invariant_function, execute_tx},
4};
5use alloy_primitives::{Address, Bytes, I256, U256};
6use foundry_config::InvariantConfig;
7use foundry_evm_core::{
8    FoundryBlock, constants::MAGIC_ASSUME, decode::RevertDecoder, evm::FoundryEvmNetwork,
9};
10use foundry_evm_fuzz::{BasicTxDetails, invariant::InvariantContract};
11use indicatif::ProgressBar;
12use proptest::bits::{BitSetLike, VarBitSet};
13use revm::context::Block;
14
15/// Shrinker for a call sequence failure.
16/// Iterates sequence call sequence top down and removes calls one by one.
17/// If the failure is still reproducible with removed call then moves to the next one.
18/// If the failure is not reproducible then restore removed call and moves to next one.
19#[derive(Debug)]
20struct CallSequenceShrinker {
21    /// Length of call sequence to be shrunk.
22    call_sequence_len: usize,
23    /// Call ids contained in current shrunk sequence.
24    included_calls: VarBitSet,
25}
26
27impl CallSequenceShrinker {
28    fn new(call_sequence_len: usize) -> Self {
29        Self { call_sequence_len, included_calls: VarBitSet::saturated(call_sequence_len) }
30    }
31
32    /// Return candidate shrink sequence to be tested, by removing ids from original sequence.
33    fn current(&self) -> impl Iterator<Item = usize> + '_ {
34        (0..self.call_sequence_len).filter(|&call_id| self.included_calls.test(call_id))
35    }
36
37    /// Advance to the next call index, wrapping around to 0 at the end.
38    fn next_index(&self, call_idx: usize) -> usize {
39        if call_idx + 1 == self.call_sequence_len { 0 } else { call_idx + 1 }
40    }
41}
42
43/// Resets the progress bar for shrinking.
44fn reset_shrink_progress(config: &InvariantConfig, progress: Option<&ProgressBar>) {
45    if let Some(progress) = progress {
46        progress.set_length(config.shrink_run_limit as u64);
47        progress.reset();
48        progress.set_message(" Shrink");
49    }
50}
51
52/// Applies accumulated warp/roll to a call, returning a modified copy.
53fn apply_warp_roll(call: &BasicTxDetails, warp: U256, roll: U256) -> BasicTxDetails {
54    let mut result = call.clone();
55    if warp > U256::ZERO {
56        result.warp = Some(warp);
57    }
58    if roll > U256::ZERO {
59        result.roll = Some(roll);
60    }
61    result
62}
63
64/// Applies warp/roll adjustments directly to the executor's environment.
65fn apply_warp_roll_to_env<FEN: FoundryEvmNetwork>(
66    executor: &mut Executor<FEN>,
67    warp: U256,
68    roll: U256,
69) {
70    if warp > U256::ZERO || roll > U256::ZERO {
71        let ts = executor.evm_env().block_env.timestamp();
72        let num = executor.evm_env().block_env.number();
73        executor.evm_env_mut().block_env.set_timestamp(ts + warp);
74        executor.evm_env_mut().block_env.set_number(num + roll);
75
76        let block_env = executor.evm_env().block_env.clone();
77        if let Some(cheatcodes) = executor.inspector_mut().cheatcodes.as_mut() {
78            if let Some(block) = cheatcodes.block.as_mut() {
79                let bts = block.timestamp();
80                let bnum = block.number();
81                block.set_timestamp(bts + warp);
82                block.set_number(bnum + roll);
83            } else {
84                cheatcodes.block = Some(block_env);
85            }
86        }
87    }
88}
89
90/// Builds the final shrunk sequence from the shrinker state.
91///
92/// When `accumulate_warp_roll` is enabled, warp/roll from removed calls is folded into the next
93/// kept call so the final sequence remains reproducible.
94fn build_shrunk_sequence(
95    calls: &[BasicTxDetails],
96    shrinker: &CallSequenceShrinker,
97    accumulate_warp_roll: bool,
98) -> Vec<BasicTxDetails> {
99    if !accumulate_warp_roll {
100        return shrinker.current().map(|idx| calls[idx].clone()).collect();
101    }
102
103    let mut result = Vec::new();
104    let mut accumulated_warp = U256::ZERO;
105    let mut accumulated_roll = U256::ZERO;
106
107    for (idx, call) in calls.iter().enumerate() {
108        accumulated_warp += call.warp.unwrap_or(U256::ZERO);
109        accumulated_roll += call.roll.unwrap_or(U256::ZERO);
110
111        if shrinker.included_calls.test(idx) {
112            result.push(apply_warp_roll(call, accumulated_warp, accumulated_roll));
113            accumulated_warp = U256::ZERO;
114            accumulated_roll = U256::ZERO;
115        }
116    }
117
118    result
119}
120
121pub(crate) fn shrink_sequence<FEN: FoundryEvmNetwork>(
122    config: &InvariantConfig,
123    invariant_contract: &InvariantContract<'_>,
124    calls: &[BasicTxDetails],
125    executor: &Executor<FEN>,
126    progress: Option<&ProgressBar>,
127    early_exit: &EarlyExit,
128) -> eyre::Result<Vec<BasicTxDetails>> {
129    trace!(target: "forge::test", "Shrinking sequence of {} calls.", calls.len());
130
131    reset_shrink_progress(config, progress);
132
133    let target_address = invariant_contract.address;
134    let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
135    // Special case test: the invariant is *unsatisfiable* - it took 0 calls to
136    // break the invariant -- consider emitting a warning.
137    let (_, success) = call_invariant_function(executor, target_address, calldata.clone())?;
138    if !success {
139        return Ok(vec![]);
140    }
141
142    let accumulate_warp_roll = config.has_delay();
143    let mut call_idx = 0;
144    let mut shrinker = CallSequenceShrinker::new(calls.len());
145
146    for _ in 0..config.shrink_run_limit {
147        if early_exit.should_stop() {
148            break;
149        }
150
151        shrinker.included_calls.clear(call_idx);
152
153        match check_sequence(
154            executor.clone(),
155            calls,
156            shrinker.current().collect(),
157            target_address,
158            calldata.clone(),
159            CheckSequenceOptions {
160                accumulate_warp_roll,
161                fail_on_revert: config.fail_on_revert,
162                call_after_invariant: invariant_contract.call_after_invariant,
163                rd: None,
164            },
165        ) {
166            // If candidate sequence still fails, shrink until shortest possible.
167            Ok((false, _, _)) if shrinker.included_calls.count() == 1 => break,
168            // Restore last removed call as it caused sequence to pass invariant.
169            Ok((true, _, _)) => shrinker.included_calls.set(call_idx),
170            _ => {}
171        }
172
173        if let Some(progress) = progress {
174            progress.inc(1);
175        }
176
177        call_idx = shrinker.next_index(call_idx);
178    }
179
180    Ok(build_shrunk_sequence(calls, &shrinker, accumulate_warp_roll))
181}
182
183/// Checks if the given call sequence breaks the invariant.
184///
185/// Used in shrinking phase for checking candidate sequences and in replay failures phase to test
186/// persisted failures.
187/// Returns the result of invariant check (and afterInvariant call if needed) and if sequence was
188/// entirely applied.
189///
190/// When `options.accumulate_warp_roll` is enabled, warp/roll from removed calls is folded into the
191/// next kept call so the candidate sequence stays representable as a concrete counterexample.
192pub fn check_sequence<FEN: FoundryEvmNetwork>(
193    executor: Executor<FEN>,
194    calls: &[BasicTxDetails],
195    sequence: Vec<usize>,
196    test_address: Address,
197    calldata: Bytes,
198    options: CheckSequenceOptions<'_>,
199) -> eyre::Result<(bool, bool, Option<String>)> {
200    if options.accumulate_warp_roll {
201        check_sequence_with_accumulation(executor, calls, sequence, test_address, calldata, options)
202    } else {
203        check_sequence_simple(executor, calls, sequence, test_address, calldata, options)
204    }
205}
206
207fn check_sequence_simple<FEN: FoundryEvmNetwork>(
208    mut executor: Executor<FEN>,
209    calls: &[BasicTxDetails],
210    sequence: Vec<usize>,
211    test_address: Address,
212    calldata: Bytes,
213    options: CheckSequenceOptions<'_>,
214) -> eyre::Result<(bool, bool, Option<String>)> {
215    // Apply the call sequence.
216    for call_index in sequence {
217        let tx = &calls[call_index];
218        let mut call_result = execute_tx(&mut executor, tx)?;
219        executor.commit(&mut call_result);
220        // Ignore calls reverted with `MAGIC_ASSUME`. This is needed to handle failed scenarios that
221        // are replayed with a modified version of test driver (that use new `vm.assume`
222        // cheatcodes).
223        if call_result.reverted
224            && options.fail_on_revert
225            && call_result.result.as_ref() != MAGIC_ASSUME
226        {
227            // Candidate sequence fails test.
228            // We don't have to apply remaining calls to check sequence.
229            return Ok((false, false, call_failure_reason(call_result, options.rd)));
230        }
231    }
232
233    finish_sequence_check(&executor, test_address, calldata, &options)
234}
235
236fn check_sequence_with_accumulation<FEN: FoundryEvmNetwork>(
237    mut executor: Executor<FEN>,
238    calls: &[BasicTxDetails],
239    sequence: Vec<usize>,
240    test_address: Address,
241    calldata: Bytes,
242    options: CheckSequenceOptions<'_>,
243) -> eyre::Result<(bool, bool, Option<String>)> {
244    let mut accumulated_warp = U256::ZERO;
245    let mut accumulated_roll = U256::ZERO;
246    let mut seq_iter = sequence.iter().peekable();
247
248    for (idx, tx) in calls.iter().enumerate() {
249        accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
250        accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
251
252        if seq_iter.peek() != Some(&&idx) {
253            continue;
254        }
255
256        seq_iter.next();
257
258        let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
259        let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
260
261        if call_result.reverted {
262            if options.fail_on_revert && call_result.result.as_ref() != MAGIC_ASSUME {
263                return Ok((false, false, call_failure_reason(call_result, options.rd)));
264            }
265        } else {
266            executor.commit(&mut call_result);
267        }
268
269        accumulated_warp = U256::ZERO;
270        accumulated_roll = U256::ZERO;
271    }
272
273    // Unlike optimization mode we intentionally do not apply trailing warp/roll before the
274    // invariant call: those delays would not be representable in the final shrunk sequence.
275    finish_sequence_check(&executor, test_address, calldata, &options)
276}
277
278fn finish_sequence_check<FEN: FoundryEvmNetwork>(
279    executor: &Executor<FEN>,
280    test_address: Address,
281    calldata: Bytes,
282    options: &CheckSequenceOptions<'_>,
283) -> eyre::Result<(bool, bool, Option<String>)> {
284    let (invariant_result, mut success) =
285        call_invariant_function(executor, test_address, calldata)?;
286    if !success {
287        return Ok((false, true, call_failure_reason(invariant_result, options.rd)));
288    }
289
290    // Check after invariant result if invariant is success and `afterInvariant` function is
291    // declared.
292    if success && options.call_after_invariant {
293        let (after_invariant_result, after_invariant_success) =
294            call_after_invariant_function(executor, test_address)?;
295        success = after_invariant_success;
296        if !success {
297            return Ok((false, true, call_failure_reason(after_invariant_result, options.rd)));
298        }
299    }
300
301    Ok((success, true, None))
302}
303
304pub struct CheckSequenceOptions<'a> {
305    pub accumulate_warp_roll: bool,
306    pub fail_on_revert: bool,
307    pub call_after_invariant: bool,
308    pub rd: Option<&'a RevertDecoder>,
309}
310
311fn call_failure_reason<FEN: FoundryEvmNetwork>(
312    call_result: RawCallResult<FEN>,
313    rd: Option<&RevertDecoder>,
314) -> Option<String> {
315    match call_result.into_evm_error(rd) {
316        EvmError::Execution(err) => Some(err.reason),
317        _ => None,
318    }
319}
320
321/// Shrinks a call sequence to the shortest sequence that still produces the target optimization
322/// value. This is specifically for optimization mode where we want to find the minimal sequence
323/// that achieves the maximum value.
324///
325/// Unlike `shrink_sequence` (for check mode), this function:
326/// - Accumulates warp/roll values from removed calls into the next kept call
327/// - Checks for target value equality rather than invariant failure
328pub(crate) fn shrink_sequence_value<FEN: FoundryEvmNetwork>(
329    config: &InvariantConfig,
330    invariant_contract: &InvariantContract<'_>,
331    calls: &[BasicTxDetails],
332    executor: &Executor<FEN>,
333    target_value: I256,
334    progress: Option<&ProgressBar>,
335    early_exit: &EarlyExit,
336) -> eyre::Result<Vec<BasicTxDetails>> {
337    trace!(target: "forge::test", "Shrinking optimization sequence of {} calls for target value {}.", calls.len(), target_value);
338
339    reset_shrink_progress(config, progress);
340
341    let target_address = invariant_contract.address;
342    let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
343
344    // Special case: check if target value is achieved with 0 calls.
345    if check_sequence_value(executor.clone(), calls, vec![], target_address, calldata.clone())?
346        == Some(target_value)
347    {
348        return Ok(vec![]);
349    }
350
351    let mut call_idx = 0;
352    let mut shrinker = CallSequenceShrinker::new(calls.len());
353
354    for _ in 0..config.shrink_run_limit {
355        if early_exit.should_stop() {
356            break;
357        }
358
359        shrinker.included_calls.clear(call_idx);
360
361        let keeps_target = check_sequence_value(
362            executor.clone(),
363            calls,
364            shrinker.current().collect(),
365            target_address,
366            calldata.clone(),
367        )? == Some(target_value);
368
369        if keeps_target {
370            if shrinker.included_calls.count() == 1 {
371                break;
372            }
373        } else {
374            shrinker.included_calls.set(call_idx);
375        }
376
377        if let Some(progress) = progress {
378            progress.inc(1);
379        }
380
381        call_idx = shrinker.next_index(call_idx);
382    }
383
384    Ok(build_shrunk_sequence(calls, &shrinker, true))
385}
386
387/// Executes a call sequence and returns the optimization value (int256) from the invariant
388/// function. Used during shrinking for optimization mode.
389///
390/// Returns `None` if the invariant call fails or doesn't return a valid int256.
391/// Unlike `check_sequence`, this applies warp/roll from ALL calls (including removed ones).
392pub fn check_sequence_value<FEN: FoundryEvmNetwork>(
393    mut executor: Executor<FEN>,
394    calls: &[BasicTxDetails],
395    sequence: Vec<usize>,
396    test_address: Address,
397    calldata: Bytes,
398) -> eyre::Result<Option<I256>> {
399    let mut accumulated_warp = U256::ZERO;
400    let mut accumulated_roll = U256::ZERO;
401    let mut seq_iter = sequence.iter().peekable();
402
403    for (idx, tx) in calls.iter().enumerate() {
404        accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
405        accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
406
407        if seq_iter.peek() == Some(&&idx) {
408            seq_iter.next();
409
410            let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
411            let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
412
413            if !call_result.reverted {
414                executor.commit(&mut call_result);
415            }
416
417            accumulated_warp = U256::ZERO;
418            accumulated_roll = U256::ZERO;
419        }
420    }
421
422    // Apply any remaining accumulated warp/roll before calling invariant.
423    apply_warp_roll_to_env(&mut executor, accumulated_warp, accumulated_roll);
424
425    let (inv_result, success) = call_invariant_function(&executor, test_address, calldata)?;
426
427    if success
428        && inv_result.result.len() >= 32
429        && let Some(value) = I256::try_from_be_slice(&inv_result.result[..32])
430    {
431        return Ok(Some(value));
432    }
433
434    Ok(None)
435}
436
437#[cfg(test)]
438mod tests {
439    use super::{CallSequenceShrinker, build_shrunk_sequence};
440    use alloy_primitives::{Address, Bytes, U256};
441    use foundry_evm_fuzz::{BasicTxDetails, CallDetails};
442    use proptest::bits::BitSetLike;
443
444    fn tx(warp: Option<u64>, roll: Option<u64>) -> BasicTxDetails {
445        BasicTxDetails {
446            warp: warp.map(U256::from),
447            roll: roll.map(U256::from),
448            sender: Address::ZERO,
449            call_details: CallDetails { target: Address::ZERO, calldata: Bytes::new() },
450        }
451    }
452
453    #[test]
454    fn build_shrunk_sequence_accumulates_removed_delay_into_next_kept_call() {
455        let calls = vec![tx(Some(3), Some(5)), tx(Some(7), Some(11)), tx(Some(13), Some(17))];
456        let mut shrinker = CallSequenceShrinker::new(calls.len());
457        shrinker.included_calls.clear(0);
458
459        let shrunk = build_shrunk_sequence(&calls, &shrinker, true);
460
461        assert_eq!(shrunk.len(), 2);
462        assert_eq!(shrunk[0].warp, Some(U256::from(10)));
463        assert_eq!(shrunk[0].roll, Some(U256::from(16)));
464        assert_eq!(shrunk[1].warp, Some(U256::from(13)));
465        assert_eq!(shrunk[1].roll, Some(U256::from(17)));
466    }
467
468    #[test]
469    fn build_shrunk_sequence_does_not_move_trailing_delay_backward() {
470        let calls = vec![tx(Some(3), Some(5)), tx(Some(7), Some(11))];
471        let mut shrinker = CallSequenceShrinker::new(calls.len());
472        shrinker.included_calls.clear(1);
473
474        let shrunk = build_shrunk_sequence(&calls, &shrinker, true);
475
476        assert_eq!(shrunk.len(), 1);
477        assert_eq!(shrunk[0].warp, Some(U256::from(3)));
478        assert_eq!(shrunk[0].roll, Some(U256::from(5)));
479    }
480}