foundry_evm/executors/invariant/
shrink.rs1use crate::executors::{
2 EarlyExit, EvmError, Executor, RawCallResult,
3 invariant::{call_after_invariant_function, call_invariant_function, execute_tx},
4};
5use alloy_primitives::{Address, Bytes, I256, U256};
6use foundry_config::InvariantConfig;
7use foundry_evm_core::{
8 FoundryBlock, constants::MAGIC_ASSUME, decode::RevertDecoder, evm::FoundryEvmNetwork,
9};
10use foundry_evm_fuzz::{BasicTxDetails, invariant::InvariantContract};
11use indicatif::ProgressBar;
12use proptest::bits::{BitSetLike, VarBitSet};
13use revm::context::Block;
14
15#[derive(Debug)]
20struct CallSequenceShrinker {
21 call_sequence_len: usize,
23 included_calls: VarBitSet,
25}
26
27impl CallSequenceShrinker {
28 fn new(call_sequence_len: usize) -> Self {
29 Self { call_sequence_len, included_calls: VarBitSet::saturated(call_sequence_len) }
30 }
31
32 fn current(&self) -> impl Iterator<Item = usize> + '_ {
34 (0..self.call_sequence_len).filter(|&call_id| self.included_calls.test(call_id))
35 }
36
37 fn next_index(&self, call_idx: usize) -> usize {
39 if call_idx + 1 == self.call_sequence_len { 0 } else { call_idx + 1 }
40 }
41}
42
43fn reset_shrink_progress(config: &InvariantConfig, progress: Option<&ProgressBar>) {
45 if let Some(progress) = progress {
46 progress.set_length(config.shrink_run_limit as u64);
47 progress.reset();
48 progress.set_message(" Shrink");
49 }
50}
51
52fn apply_warp_roll(call: &BasicTxDetails, warp: U256, roll: U256) -> BasicTxDetails {
54 let mut result = call.clone();
55 if warp > U256::ZERO {
56 result.warp = Some(warp);
57 }
58 if roll > U256::ZERO {
59 result.roll = Some(roll);
60 }
61 result
62}
63
64fn apply_warp_roll_to_env<FEN: FoundryEvmNetwork>(
66 executor: &mut Executor<FEN>,
67 warp: U256,
68 roll: U256,
69) {
70 if warp > U256::ZERO || roll > U256::ZERO {
71 let ts = executor.evm_env().block_env.timestamp();
72 let num = executor.evm_env().block_env.number();
73 executor.evm_env_mut().block_env.set_timestamp(ts + warp);
74 executor.evm_env_mut().block_env.set_number(num + roll);
75
76 let block_env = executor.evm_env().block_env.clone();
77 if let Some(cheatcodes) = executor.inspector_mut().cheatcodes.as_mut() {
78 if let Some(block) = cheatcodes.block.as_mut() {
79 let bts = block.timestamp();
80 let bnum = block.number();
81 block.set_timestamp(bts + warp);
82 block.set_number(bnum + roll);
83 } else {
84 cheatcodes.block = Some(block_env);
85 }
86 }
87 }
88}
89
90fn build_shrunk_sequence(
95 calls: &[BasicTxDetails],
96 shrinker: &CallSequenceShrinker,
97 accumulate_warp_roll: bool,
98) -> Vec<BasicTxDetails> {
99 if !accumulate_warp_roll {
100 return shrinker.current().map(|idx| calls[idx].clone()).collect();
101 }
102
103 let mut result = Vec::new();
104 let mut accumulated_warp = U256::ZERO;
105 let mut accumulated_roll = U256::ZERO;
106
107 for (idx, call) in calls.iter().enumerate() {
108 accumulated_warp += call.warp.unwrap_or(U256::ZERO);
109 accumulated_roll += call.roll.unwrap_or(U256::ZERO);
110
111 if shrinker.included_calls.test(idx) {
112 result.push(apply_warp_roll(call, accumulated_warp, accumulated_roll));
113 accumulated_warp = U256::ZERO;
114 accumulated_roll = U256::ZERO;
115 }
116 }
117
118 result
119}
120
121pub(crate) fn shrink_sequence<FEN: FoundryEvmNetwork>(
122 config: &InvariantConfig,
123 invariant_contract: &InvariantContract<'_>,
124 calls: &[BasicTxDetails],
125 executor: &Executor<FEN>,
126 progress: Option<&ProgressBar>,
127 early_exit: &EarlyExit,
128) -> eyre::Result<Vec<BasicTxDetails>> {
129 trace!(target: "forge::test", "Shrinking sequence of {} calls.", calls.len());
130
131 reset_shrink_progress(config, progress);
132
133 let target_address = invariant_contract.address;
134 let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
135 let (_, success) = call_invariant_function(executor, target_address, calldata.clone())?;
138 if !success {
139 return Ok(vec![]);
140 }
141
142 let accumulate_warp_roll = config.has_delay();
143 let mut call_idx = 0;
144 let mut shrinker = CallSequenceShrinker::new(calls.len());
145
146 for _ in 0..config.shrink_run_limit {
147 if early_exit.should_stop() {
148 break;
149 }
150
151 shrinker.included_calls.clear(call_idx);
152
153 match check_sequence(
154 executor.clone(),
155 calls,
156 shrinker.current().collect(),
157 target_address,
158 calldata.clone(),
159 CheckSequenceOptions {
160 accumulate_warp_roll,
161 fail_on_revert: config.fail_on_revert,
162 call_after_invariant: invariant_contract.call_after_invariant,
163 rd: None,
164 },
165 ) {
166 Ok((false, _, _)) if shrinker.included_calls.count() == 1 => break,
168 Ok((true, _, _)) => shrinker.included_calls.set(call_idx),
170 _ => {}
171 }
172
173 if let Some(progress) = progress {
174 progress.inc(1);
175 }
176
177 call_idx = shrinker.next_index(call_idx);
178 }
179
180 Ok(build_shrunk_sequence(calls, &shrinker, accumulate_warp_roll))
181}
182
183pub fn check_sequence<FEN: FoundryEvmNetwork>(
193 executor: Executor<FEN>,
194 calls: &[BasicTxDetails],
195 sequence: Vec<usize>,
196 test_address: Address,
197 calldata: Bytes,
198 options: CheckSequenceOptions<'_>,
199) -> eyre::Result<(bool, bool, Option<String>)> {
200 if options.accumulate_warp_roll {
201 check_sequence_with_accumulation(executor, calls, sequence, test_address, calldata, options)
202 } else {
203 check_sequence_simple(executor, calls, sequence, test_address, calldata, options)
204 }
205}
206
207fn check_sequence_simple<FEN: FoundryEvmNetwork>(
208 mut executor: Executor<FEN>,
209 calls: &[BasicTxDetails],
210 sequence: Vec<usize>,
211 test_address: Address,
212 calldata: Bytes,
213 options: CheckSequenceOptions<'_>,
214) -> eyre::Result<(bool, bool, Option<String>)> {
215 for call_index in sequence {
217 let tx = &calls[call_index];
218 let mut call_result = execute_tx(&mut executor, tx)?;
219 executor.commit(&mut call_result);
220 if call_result.reverted
224 && options.fail_on_revert
225 && call_result.result.as_ref() != MAGIC_ASSUME
226 {
227 return Ok((false, false, call_failure_reason(call_result, options.rd)));
230 }
231 }
232
233 finish_sequence_check(&executor, test_address, calldata, &options)
234}
235
236fn check_sequence_with_accumulation<FEN: FoundryEvmNetwork>(
237 mut executor: Executor<FEN>,
238 calls: &[BasicTxDetails],
239 sequence: Vec<usize>,
240 test_address: Address,
241 calldata: Bytes,
242 options: CheckSequenceOptions<'_>,
243) -> eyre::Result<(bool, bool, Option<String>)> {
244 let mut accumulated_warp = U256::ZERO;
245 let mut accumulated_roll = U256::ZERO;
246 let mut seq_iter = sequence.iter().peekable();
247
248 for (idx, tx) in calls.iter().enumerate() {
249 accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
250 accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
251
252 if seq_iter.peek() != Some(&&idx) {
253 continue;
254 }
255
256 seq_iter.next();
257
258 let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
259 let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
260
261 if call_result.reverted {
262 if options.fail_on_revert && call_result.result.as_ref() != MAGIC_ASSUME {
263 return Ok((false, false, call_failure_reason(call_result, options.rd)));
264 }
265 } else {
266 executor.commit(&mut call_result);
267 }
268
269 accumulated_warp = U256::ZERO;
270 accumulated_roll = U256::ZERO;
271 }
272
273 finish_sequence_check(&executor, test_address, calldata, &options)
276}
277
278fn finish_sequence_check<FEN: FoundryEvmNetwork>(
279 executor: &Executor<FEN>,
280 test_address: Address,
281 calldata: Bytes,
282 options: &CheckSequenceOptions<'_>,
283) -> eyre::Result<(bool, bool, Option<String>)> {
284 let (invariant_result, mut success) =
285 call_invariant_function(executor, test_address, calldata)?;
286 if !success {
287 return Ok((false, true, call_failure_reason(invariant_result, options.rd)));
288 }
289
290 if success && options.call_after_invariant {
293 let (after_invariant_result, after_invariant_success) =
294 call_after_invariant_function(executor, test_address)?;
295 success = after_invariant_success;
296 if !success {
297 return Ok((false, true, call_failure_reason(after_invariant_result, options.rd)));
298 }
299 }
300
301 Ok((success, true, None))
302}
303
304pub struct CheckSequenceOptions<'a> {
305 pub accumulate_warp_roll: bool,
306 pub fail_on_revert: bool,
307 pub call_after_invariant: bool,
308 pub rd: Option<&'a RevertDecoder>,
309}
310
311fn call_failure_reason<FEN: FoundryEvmNetwork>(
312 call_result: RawCallResult<FEN>,
313 rd: Option<&RevertDecoder>,
314) -> Option<String> {
315 match call_result.into_evm_error(rd) {
316 EvmError::Execution(err) => Some(err.reason),
317 _ => None,
318 }
319}
320
321pub(crate) fn shrink_sequence_value<FEN: FoundryEvmNetwork>(
329 config: &InvariantConfig,
330 invariant_contract: &InvariantContract<'_>,
331 calls: &[BasicTxDetails],
332 executor: &Executor<FEN>,
333 target_value: I256,
334 progress: Option<&ProgressBar>,
335 early_exit: &EarlyExit,
336) -> eyre::Result<Vec<BasicTxDetails>> {
337 trace!(target: "forge::test", "Shrinking optimization sequence of {} calls for target value {}.", calls.len(), target_value);
338
339 reset_shrink_progress(config, progress);
340
341 let target_address = invariant_contract.address;
342 let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
343
344 if check_sequence_value(executor.clone(), calls, vec![], target_address, calldata.clone())?
346 == Some(target_value)
347 {
348 return Ok(vec![]);
349 }
350
351 let mut call_idx = 0;
352 let mut shrinker = CallSequenceShrinker::new(calls.len());
353
354 for _ in 0..config.shrink_run_limit {
355 if early_exit.should_stop() {
356 break;
357 }
358
359 shrinker.included_calls.clear(call_idx);
360
361 let keeps_target = check_sequence_value(
362 executor.clone(),
363 calls,
364 shrinker.current().collect(),
365 target_address,
366 calldata.clone(),
367 )? == Some(target_value);
368
369 if keeps_target {
370 if shrinker.included_calls.count() == 1 {
371 break;
372 }
373 } else {
374 shrinker.included_calls.set(call_idx);
375 }
376
377 if let Some(progress) = progress {
378 progress.inc(1);
379 }
380
381 call_idx = shrinker.next_index(call_idx);
382 }
383
384 Ok(build_shrunk_sequence(calls, &shrinker, true))
385}
386
387pub fn check_sequence_value<FEN: FoundryEvmNetwork>(
393 mut executor: Executor<FEN>,
394 calls: &[BasicTxDetails],
395 sequence: Vec<usize>,
396 test_address: Address,
397 calldata: Bytes,
398) -> eyre::Result<Option<I256>> {
399 let mut accumulated_warp = U256::ZERO;
400 let mut accumulated_roll = U256::ZERO;
401 let mut seq_iter = sequence.iter().peekable();
402
403 for (idx, tx) in calls.iter().enumerate() {
404 accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
405 accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
406
407 if seq_iter.peek() == Some(&&idx) {
408 seq_iter.next();
409
410 let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
411 let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
412
413 if !call_result.reverted {
414 executor.commit(&mut call_result);
415 }
416
417 accumulated_warp = U256::ZERO;
418 accumulated_roll = U256::ZERO;
419 }
420 }
421
422 apply_warp_roll_to_env(&mut executor, accumulated_warp, accumulated_roll);
424
425 let (inv_result, success) = call_invariant_function(&executor, test_address, calldata)?;
426
427 if success
428 && inv_result.result.len() >= 32
429 && let Some(value) = I256::try_from_be_slice(&inv_result.result[..32])
430 {
431 return Ok(Some(value));
432 }
433
434 Ok(None)
435}
436
437#[cfg(test)]
438mod tests {
439 use super::{CallSequenceShrinker, build_shrunk_sequence};
440 use alloy_primitives::{Address, Bytes, U256};
441 use foundry_evm_fuzz::{BasicTxDetails, CallDetails};
442 use proptest::bits::BitSetLike;
443
444 fn tx(warp: Option<u64>, roll: Option<u64>) -> BasicTxDetails {
445 BasicTxDetails {
446 warp: warp.map(U256::from),
447 roll: roll.map(U256::from),
448 sender: Address::ZERO,
449 call_details: CallDetails { target: Address::ZERO, calldata: Bytes::new() },
450 }
451 }
452
453 #[test]
454 fn build_shrunk_sequence_accumulates_removed_delay_into_next_kept_call() {
455 let calls = vec![tx(Some(3), Some(5)), tx(Some(7), Some(11)), tx(Some(13), Some(17))];
456 let mut shrinker = CallSequenceShrinker::new(calls.len());
457 shrinker.included_calls.clear(0);
458
459 let shrunk = build_shrunk_sequence(&calls, &shrinker, true);
460
461 assert_eq!(shrunk.len(), 2);
462 assert_eq!(shrunk[0].warp, Some(U256::from(10)));
463 assert_eq!(shrunk[0].roll, Some(U256::from(16)));
464 assert_eq!(shrunk[1].warp, Some(U256::from(13)));
465 assert_eq!(shrunk[1].roll, Some(U256::from(17)));
466 }
467
468 #[test]
469 fn build_shrunk_sequence_does_not_move_trailing_delay_backward() {
470 let calls = vec![tx(Some(3), Some(5)), tx(Some(7), Some(11))];
471 let mut shrinker = CallSequenceShrinker::new(calls.len());
472 shrinker.included_calls.clear(1);
473
474 let shrunk = build_shrunk_sequence(&calls, &shrinker, true);
475
476 assert_eq!(shrunk.len(), 1);
477 assert_eq!(shrunk[0].warp, Some(U256::from(3)));
478 assert_eq!(shrunk[0].roll, Some(U256::from(5)));
479 }
480}