Skip to main content

foundry_evm/executors/invariant/
shrink.rs

1use crate::executors::{
2    EarlyExit, EvmError, Executor, RawCallResult,
3    invariant::{
4        call_after_invariant_function, call_invariant_function, execute_tx,
5        result::did_fail_on_assert,
6    },
7};
8use alloy_primitives::{Address, Bytes, I256, U256};
9use foundry_config::InvariantConfig;
10use foundry_evm_core::{
11    FoundryBlock, constants::MAGIC_ASSUME, decode::RevertDecoder, evm::FoundryEvmNetwork,
12};
13use foundry_evm_fuzz::{BasicTxDetails, invariant::InvariantContract};
14use indicatif::ProgressBar;
15use proptest::bits::{BitSetLike, VarBitSet};
16use revm::context::Block;
17
18/// Shrinker for a call sequence failure.
19/// Iterates sequence call sequence top down and removes calls one by one.
20/// If the failure is still reproducible with removed call then moves to the next one.
21/// If the failure is not reproducible then restore removed call and moves to next one.
22#[derive(Debug)]
23struct CallSequenceShrinker {
24    /// Length of call sequence to be shrunk.
25    call_sequence_len: usize,
26    /// Call ids contained in current shrunk sequence.
27    included_calls: VarBitSet,
28}
29
30impl CallSequenceShrinker {
31    fn new(call_sequence_len: usize) -> Self {
32        Self { call_sequence_len, included_calls: VarBitSet::saturated(call_sequence_len) }
33    }
34
35    /// Return candidate shrink sequence to be tested, by removing ids from original sequence.
36    fn current(&self) -> impl Iterator<Item = usize> + '_ {
37        (0..self.call_sequence_len).filter(|&call_id| self.included_calls.test(call_id))
38    }
39
40    /// Advance to the next call index, wrapping around to 0 at the end.
41    const fn next_index(&self, call_idx: usize) -> usize {
42        if call_idx + 1 == self.call_sequence_len { 0 } else { call_idx + 1 }
43    }
44}
45
46/// Resets the progress bar for shrinking.
47fn reset_shrink_progress(config: &InvariantConfig, progress: Option<&ProgressBar>) {
48    if let Some(progress) = progress {
49        progress.set_length(config.shrink_run_limit as u64);
50        progress.reset();
51        progress.set_message(" Shrink");
52    }
53}
54
55/// Applies accumulated warp/roll to a call, returning a modified copy.
56fn apply_warp_roll(call: &BasicTxDetails, warp: U256, roll: U256) -> BasicTxDetails {
57    let mut result = call.clone();
58    if warp > U256::ZERO {
59        result.warp = Some(warp);
60    }
61    if roll > U256::ZERO {
62        result.roll = Some(roll);
63    }
64    result
65}
66
67/// Applies warp/roll adjustments directly to the executor's environment.
68fn apply_warp_roll_to_env<FEN: FoundryEvmNetwork>(
69    executor: &mut Executor<FEN>,
70    warp: U256,
71    roll: U256,
72) {
73    if warp > U256::ZERO || roll > U256::ZERO {
74        let ts = executor.evm_env().block_env.timestamp();
75        let num = executor.evm_env().block_env.number();
76        executor.evm_env_mut().block_env.set_timestamp(ts + warp);
77        executor.evm_env_mut().block_env.set_number(num + roll);
78
79        let block_env = executor.evm_env().block_env.clone();
80        if let Some(cheatcodes) = executor.inspector_mut().cheatcodes.as_mut() {
81            if let Some(block) = cheatcodes.block.as_mut() {
82                let bts = block.timestamp();
83                let bnum = block.number();
84                block.set_timestamp(bts + warp);
85                block.set_number(bnum + roll);
86            } else {
87                cheatcodes.block = Some(block_env);
88            }
89        }
90    }
91}
92
93/// Builds the final shrunk sequence from the shrinker state.
94///
95/// When `accumulate_warp_roll` is enabled, warp/roll from removed calls is folded into the next
96/// kept call so the final sequence remains reproducible.
97fn build_shrunk_sequence(
98    calls: &[BasicTxDetails],
99    shrinker: &CallSequenceShrinker,
100    accumulate_warp_roll: bool,
101) -> Vec<BasicTxDetails> {
102    if !accumulate_warp_roll {
103        return shrinker.current().map(|idx| calls[idx].clone()).collect();
104    }
105
106    let mut result = Vec::new();
107    let mut accumulated_warp = U256::ZERO;
108    let mut accumulated_roll = U256::ZERO;
109
110    for (idx, call) in calls.iter().enumerate() {
111        accumulated_warp += call.warp.unwrap_or(U256::ZERO);
112        accumulated_roll += call.roll.unwrap_or(U256::ZERO);
113
114        if shrinker.included_calls.test(idx) {
115            result.push(apply_warp_roll(call, accumulated_warp, accumulated_roll));
116            accumulated_warp = U256::ZERO;
117            accumulated_roll = U256::ZERO;
118        }
119    }
120
121    result
122}
123
124pub(crate) fn shrink_sequence<FEN: FoundryEvmNetwork>(
125    config: &InvariantConfig,
126    invariant_contract: &InvariantContract<'_>,
127    calls: &[BasicTxDetails],
128    expect_assertion_failure: bool,
129    executor: &Executor<FEN>,
130    progress: Option<&ProgressBar>,
131    early_exit: &EarlyExit,
132) -> eyre::Result<Vec<BasicTxDetails>> {
133    trace!(target: "forge::test", "Shrinking sequence of {} calls.", calls.len());
134
135    reset_shrink_progress(config, progress);
136
137    let target_address = invariant_contract.address;
138    let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
139    // Special case test: the invariant is *unsatisfiable* - it took 0 calls to
140    // break the invariant -- consider emitting a warning.
141    let (_, success) = call_invariant_function(executor, target_address, calldata.clone())?;
142    if !success {
143        return Ok(vec![]);
144    }
145
146    let accumulate_warp_roll = config.has_delay();
147    let mut call_idx = 0;
148    let mut shrinker = CallSequenceShrinker::new(calls.len());
149
150    for _ in 0..config.shrink_run_limit {
151        if early_exit.should_stop() {
152            break;
153        }
154
155        shrinker.included_calls.clear(call_idx);
156
157        match check_sequence(
158            executor.clone(),
159            calls,
160            shrinker.current().collect(),
161            target_address,
162            calldata.clone(),
163            CheckSequenceOptions {
164                accumulate_warp_roll,
165                fail_on_revert: config.fail_on_revert,
166                expect_assertion_failure,
167                call_after_invariant: invariant_contract.call_after_invariant,
168                rd: None,
169            },
170        ) {
171            // If candidate sequence still fails, shrink until shortest possible.
172            Ok((false, _, _)) if shrinker.included_calls.count() == 1 => break,
173            // Restore last removed call as it caused sequence to pass invariant.
174            Ok((true, _, _)) => shrinker.included_calls.set(call_idx),
175            _ => {}
176        }
177
178        if let Some(progress) = progress {
179            progress.inc(1);
180        }
181
182        call_idx = shrinker.next_index(call_idx);
183    }
184
185    Ok(build_shrunk_sequence(calls, &shrinker, accumulate_warp_roll))
186}
187
188/// Checks if the given call sequence breaks the invariant.
189///
190/// Used in shrinking phase for checking candidate sequences and in replay failures phase to test
191/// persisted failures.
192/// Returns the result of invariant check (and afterInvariant call if needed) and if sequence was
193/// entirely applied.
194///
195/// When `options.accumulate_warp_roll` is enabled, warp/roll from removed calls is folded into the
196/// next kept call so the candidate sequence stays representable as a concrete counterexample.
197pub fn check_sequence<FEN: FoundryEvmNetwork>(
198    executor: Executor<FEN>,
199    calls: &[BasicTxDetails],
200    sequence: Vec<usize>,
201    test_address: Address,
202    calldata: Bytes,
203    options: CheckSequenceOptions<'_>,
204) -> eyre::Result<(bool, bool, Option<String>)> {
205    if options.accumulate_warp_roll {
206        check_sequence_with_accumulation(executor, calls, sequence, test_address, calldata, options)
207    } else {
208        check_sequence_simple(executor, calls, sequence, test_address, calldata, options)
209    }
210}
211
212fn check_sequence_simple<FEN: FoundryEvmNetwork>(
213    mut executor: Executor<FEN>,
214    calls: &[BasicTxDetails],
215    sequence: Vec<usize>,
216    test_address: Address,
217    calldata: Bytes,
218    options: CheckSequenceOptions<'_>,
219) -> eyre::Result<(bool, bool, Option<String>)> {
220    // Apply the call sequence.
221    for call_index in sequence {
222        let tx = &calls[call_index];
223        let mut call_result = execute_tx(&mut executor, tx)?;
224        let assertion_failure = did_fail_on_assert(&call_result, &call_result.state_changeset);
225        // Ignore calls reverted with `MAGIC_ASSUME`. This is needed to handle failed scenarios that
226        // are replayed with a modified version of test driver (that use new `vm.assume`
227        // cheatcodes).
228        if call_result.result.as_ref() != MAGIC_ASSUME {
229            if assertion_failure {
230                return Ok((false, false, assertion_failure_reason(call_result, options.rd)));
231            }
232
233            if call_result.reverted && options.fail_on_revert {
234                if options.expect_assertion_failure {
235                    return Ok((true, false, None));
236                }
237                return Ok((false, false, call_failure_reason(call_result, options.rd)));
238            }
239        }
240
241        if !call_result.reverted {
242            executor.commit(&mut call_result);
243        }
244    }
245
246    finish_sequence_check(&executor, test_address, calldata, &options)
247}
248
249fn check_sequence_with_accumulation<FEN: FoundryEvmNetwork>(
250    mut executor: Executor<FEN>,
251    calls: &[BasicTxDetails],
252    sequence: Vec<usize>,
253    test_address: Address,
254    calldata: Bytes,
255    options: CheckSequenceOptions<'_>,
256) -> eyre::Result<(bool, bool, Option<String>)> {
257    let mut accumulated_warp = U256::ZERO;
258    let mut accumulated_roll = U256::ZERO;
259    let mut seq_iter = sequence.iter().peekable();
260
261    for (idx, tx) in calls.iter().enumerate() {
262        accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
263        accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
264
265        if seq_iter.peek() != Some(&&idx) {
266            continue;
267        }
268
269        seq_iter.next();
270
271        let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
272        let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
273        let assertion_failure = did_fail_on_assert(&call_result, &call_result.state_changeset);
274
275        if call_result.result.as_ref() != MAGIC_ASSUME {
276            if assertion_failure {
277                return Ok((false, false, assertion_failure_reason(call_result, options.rd)));
278            }
279
280            if call_result.reverted && options.fail_on_revert {
281                if options.expect_assertion_failure {
282                    return Ok((true, false, None));
283                }
284                return Ok((false, false, call_failure_reason(call_result, options.rd)));
285            }
286        }
287
288        if !call_result.reverted {
289            executor.commit(&mut call_result);
290        }
291
292        accumulated_warp = U256::ZERO;
293        accumulated_roll = U256::ZERO;
294    }
295
296    // Unlike optimization mode we intentionally do not apply trailing warp/roll before the
297    // invariant call: those delays would not be representable in the final shrunk sequence.
298    finish_sequence_check(&executor, test_address, calldata, &options)
299}
300
301fn finish_sequence_check<FEN: FoundryEvmNetwork>(
302    executor: &Executor<FEN>,
303    test_address: Address,
304    calldata: Bytes,
305    options: &CheckSequenceOptions<'_>,
306) -> eyre::Result<(bool, bool, Option<String>)> {
307    let handle_terminal_failure = |call_result: RawCallResult<FEN>| {
308        let should_ignore_failure = options.expect_assertion_failure
309            && !executor.has_global_failure(&call_result.state_changeset)
310            && !did_fail_on_assert(&call_result, &call_result.state_changeset);
311
312        if should_ignore_failure {
313            return (true, true, None);
314        }
315
316        let reason = if options.expect_assertion_failure {
317            assertion_failure_reason(call_result, options.rd)
318        } else {
319            call_failure_reason(call_result, options.rd)
320        };
321
322        (false, true, reason)
323    };
324
325    let (invariant_result, mut success) =
326        call_invariant_function(executor, test_address, calldata)?;
327    if !success {
328        return Ok(handle_terminal_failure(invariant_result));
329    }
330
331    // Check after invariant result if invariant is success and `afterInvariant` function is
332    // declared.
333    if success && options.call_after_invariant {
334        let (after_invariant_result, after_invariant_success) =
335            call_after_invariant_function(executor, test_address)?;
336        success = after_invariant_success;
337        if !success {
338            return Ok(handle_terminal_failure(after_invariant_result));
339        }
340    }
341
342    Ok((success, true, None))
343}
344
345pub struct CheckSequenceOptions<'a> {
346    pub accumulate_warp_roll: bool,
347    pub fail_on_revert: bool,
348    pub expect_assertion_failure: bool,
349    pub call_after_invariant: bool,
350    pub rd: Option<&'a RevertDecoder>,
351}
352
353fn call_failure_reason<FEN: FoundryEvmNetwork>(
354    call_result: RawCallResult<FEN>,
355    rd: Option<&RevertDecoder>,
356) -> Option<String> {
357    match call_result.into_evm_error(rd) {
358        EvmError::Execution(err) => Some(err.reason),
359        _ => None,
360    }
361}
362
363fn assertion_failure_reason<FEN: FoundryEvmNetwork>(
364    call_result: RawCallResult<FEN>,
365    rd: Option<&RevertDecoder>,
366) -> Option<String> {
367    call_failure_reason(call_result, rd).or_else(|| Some("assertion failed".to_string()))
368}
369
370/// Shrinks a call sequence to the shortest sequence that still produces the target optimization
371/// value. This is specifically for optimization mode where we want to find the minimal sequence
372/// that achieves the maximum value.
373///
374/// Unlike `shrink_sequence` (for check mode), this function:
375/// - Accumulates warp/roll values from removed calls into the next kept call
376/// - Checks for target value equality rather than invariant failure
377pub(crate) fn shrink_sequence_value<FEN: FoundryEvmNetwork>(
378    config: &InvariantConfig,
379    invariant_contract: &InvariantContract<'_>,
380    calls: &[BasicTxDetails],
381    executor: &Executor<FEN>,
382    target_value: I256,
383    progress: Option<&ProgressBar>,
384    early_exit: &EarlyExit,
385) -> eyre::Result<Vec<BasicTxDetails>> {
386    trace!(target: "forge::test", "Shrinking optimization sequence of {} calls for target value {}.", calls.len(), target_value);
387
388    reset_shrink_progress(config, progress);
389
390    let target_address = invariant_contract.address;
391    let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
392
393    // Special case: check if target value is achieved with 0 calls.
394    if check_sequence_value(executor.clone(), calls, vec![], target_address, calldata.clone())?
395        == Some(target_value)
396    {
397        return Ok(vec![]);
398    }
399
400    let mut call_idx = 0;
401    let mut shrinker = CallSequenceShrinker::new(calls.len());
402
403    for _ in 0..config.shrink_run_limit {
404        if early_exit.should_stop() {
405            break;
406        }
407
408        shrinker.included_calls.clear(call_idx);
409
410        let keeps_target = check_sequence_value(
411            executor.clone(),
412            calls,
413            shrinker.current().collect(),
414            target_address,
415            calldata.clone(),
416        )? == Some(target_value);
417
418        if keeps_target {
419            if shrinker.included_calls.count() == 1 {
420                break;
421            }
422        } else {
423            shrinker.included_calls.set(call_idx);
424        }
425
426        if let Some(progress) = progress {
427            progress.inc(1);
428        }
429
430        call_idx = shrinker.next_index(call_idx);
431    }
432
433    Ok(build_shrunk_sequence(calls, &shrinker, true))
434}
435
436/// Executes a call sequence and returns the optimization value (int256) from the invariant
437/// function. Used during shrinking for optimization mode.
438///
439/// Returns `None` if the invariant call fails or doesn't return a valid int256.
440/// Unlike `check_sequence`, this applies warp/roll from ALL calls (including removed ones).
441pub fn check_sequence_value<FEN: FoundryEvmNetwork>(
442    mut executor: Executor<FEN>,
443    calls: &[BasicTxDetails],
444    sequence: Vec<usize>,
445    test_address: Address,
446    calldata: Bytes,
447) -> eyre::Result<Option<I256>> {
448    let mut accumulated_warp = U256::ZERO;
449    let mut accumulated_roll = U256::ZERO;
450    let mut seq_iter = sequence.iter().peekable();
451
452    for (idx, tx) in calls.iter().enumerate() {
453        accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
454        accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
455
456        if seq_iter.peek() == Some(&&idx) {
457            seq_iter.next();
458
459            let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
460            let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
461
462            if !call_result.reverted {
463                executor.commit(&mut call_result);
464            }
465
466            accumulated_warp = U256::ZERO;
467            accumulated_roll = U256::ZERO;
468        }
469    }
470
471    // Apply any remaining accumulated warp/roll before calling invariant.
472    apply_warp_roll_to_env(&mut executor, accumulated_warp, accumulated_roll);
473
474    let (inv_result, success) = call_invariant_function(&executor, test_address, calldata)?;
475
476    if success
477        && inv_result.result.len() >= 32
478        && let Some(value) = I256::try_from_be_slice(&inv_result.result[..32])
479    {
480        return Ok(Some(value));
481    }
482
483    Ok(None)
484}
485
486#[cfg(test)]
487mod tests {
488    use super::{CallSequenceShrinker, build_shrunk_sequence};
489    use alloy_primitives::{Address, Bytes, U256};
490    use foundry_evm_fuzz::{BasicTxDetails, CallDetails};
491    use proptest::bits::BitSetLike;
492
493    fn tx(warp: Option<u64>, roll: Option<u64>) -> BasicTxDetails {
494        BasicTxDetails {
495            warp: warp.map(U256::from),
496            roll: roll.map(U256::from),
497            sender: Address::ZERO,
498            call_details: CallDetails { target: Address::ZERO, calldata: Bytes::new() },
499        }
500    }
501
502    #[test]
503    fn build_shrunk_sequence_accumulates_removed_delay_into_next_kept_call() {
504        let calls = vec![tx(Some(3), Some(5)), tx(Some(7), Some(11)), tx(Some(13), Some(17))];
505        let mut shrinker = CallSequenceShrinker::new(calls.len());
506        shrinker.included_calls.clear(0);
507
508        let shrunk = build_shrunk_sequence(&calls, &shrinker, true);
509
510        assert_eq!(shrunk.len(), 2);
511        assert_eq!(shrunk[0].warp, Some(U256::from(10)));
512        assert_eq!(shrunk[0].roll, Some(U256::from(16)));
513        assert_eq!(shrunk[1].warp, Some(U256::from(13)));
514        assert_eq!(shrunk[1].roll, Some(U256::from(17)));
515    }
516
517    #[test]
518    fn build_shrunk_sequence_does_not_move_trailing_delay_backward() {
519        let calls = vec![tx(Some(3), Some(5)), tx(Some(7), Some(11))];
520        let mut shrinker = CallSequenceShrinker::new(calls.len());
521        shrinker.included_calls.clear(1);
522
523        let shrunk = build_shrunk_sequence(&calls, &shrinker, true);
524
525        assert_eq!(shrunk.len(), 1);
526        assert_eq!(shrunk[0].warp, Some(U256::from(3)));
527        assert_eq!(shrunk[0].roll, Some(U256::from(5)));
528    }
529}