1use crate::executors::{
2 EarlyExit, EvmError, Executor, RawCallResult,
3 invariant::{
4 call_after_invariant_function, call_invariant_function, execute_tx,
5 result::did_fail_on_assert,
6 },
7};
8use alloy_primitives::{Address, Bytes, I256, U256};
9use foundry_config::InvariantConfig;
10use foundry_evm_core::{
11 FoundryBlock, constants::MAGIC_ASSUME, decode::RevertDecoder, evm::FoundryEvmNetwork,
12};
13use foundry_evm_fuzz::{BasicTxDetails, invariant::InvariantContract};
14use indicatif::ProgressBar;
15use proptest::bits::{BitSetLike, VarBitSet};
16use revm::context::Block;
17
18#[derive(Debug)]
23struct CallSequenceShrinker {
24 call_sequence_len: usize,
26 included_calls: VarBitSet,
28}
29
30impl CallSequenceShrinker {
31 fn new(call_sequence_len: usize) -> Self {
32 Self { call_sequence_len, included_calls: VarBitSet::saturated(call_sequence_len) }
33 }
34
35 fn current(&self) -> impl Iterator<Item = usize> + '_ {
37 (0..self.call_sequence_len).filter(|&call_id| self.included_calls.test(call_id))
38 }
39
40 const fn next_index(&self, call_idx: usize) -> usize {
42 if call_idx + 1 == self.call_sequence_len { 0 } else { call_idx + 1 }
43 }
44}
45
46fn reset_shrink_progress(config: &InvariantConfig, progress: Option<&ProgressBar>) {
48 if let Some(progress) = progress {
49 progress.set_length(config.shrink_run_limit as u64);
50 progress.reset();
51 progress.set_message(" Shrink");
52 }
53}
54
55fn apply_warp_roll(call: &BasicTxDetails, warp: U256, roll: U256) -> BasicTxDetails {
57 let mut result = call.clone();
58 if warp > U256::ZERO {
59 result.warp = Some(warp);
60 }
61 if roll > U256::ZERO {
62 result.roll = Some(roll);
63 }
64 result
65}
66
67fn apply_warp_roll_to_env<FEN: FoundryEvmNetwork>(
69 executor: &mut Executor<FEN>,
70 warp: U256,
71 roll: U256,
72) {
73 if warp > U256::ZERO || roll > U256::ZERO {
74 let ts = executor.evm_env().block_env.timestamp();
75 let num = executor.evm_env().block_env.number();
76 executor.evm_env_mut().block_env.set_timestamp(ts + warp);
77 executor.evm_env_mut().block_env.set_number(num + roll);
78
79 let block_env = executor.evm_env().block_env.clone();
80 if let Some(cheatcodes) = executor.inspector_mut().cheatcodes.as_mut() {
81 if let Some(block) = cheatcodes.block.as_mut() {
82 let bts = block.timestamp();
83 let bnum = block.number();
84 block.set_timestamp(bts + warp);
85 block.set_number(bnum + roll);
86 } else {
87 cheatcodes.block = Some(block_env);
88 }
89 }
90 }
91}
92
93fn build_shrunk_sequence(
98 calls: &[BasicTxDetails],
99 shrinker: &CallSequenceShrinker,
100 accumulate_warp_roll: bool,
101) -> Vec<BasicTxDetails> {
102 if !accumulate_warp_roll {
103 return shrinker.current().map(|idx| calls[idx].clone()).collect();
104 }
105
106 let mut result = Vec::new();
107 let mut accumulated_warp = U256::ZERO;
108 let mut accumulated_roll = U256::ZERO;
109
110 for (idx, call) in calls.iter().enumerate() {
111 accumulated_warp += call.warp.unwrap_or(U256::ZERO);
112 accumulated_roll += call.roll.unwrap_or(U256::ZERO);
113
114 if shrinker.included_calls.test(idx) {
115 result.push(apply_warp_roll(call, accumulated_warp, accumulated_roll));
116 accumulated_warp = U256::ZERO;
117 accumulated_roll = U256::ZERO;
118 }
119 }
120
121 result
122}
123
124pub(crate) fn shrink_sequence<FEN: FoundryEvmNetwork>(
125 config: &InvariantConfig,
126 invariant_contract: &InvariantContract<'_>,
127 calls: &[BasicTxDetails],
128 expect_assertion_failure: bool,
129 executor: &Executor<FEN>,
130 progress: Option<&ProgressBar>,
131 early_exit: &EarlyExit,
132) -> eyre::Result<Vec<BasicTxDetails>> {
133 trace!(target: "forge::test", "Shrinking sequence of {} calls.", calls.len());
134
135 reset_shrink_progress(config, progress);
136
137 let target_address = invariant_contract.address;
138 let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
139 let (_, success) = call_invariant_function(executor, target_address, calldata.clone())?;
142 if !success {
143 return Ok(vec![]);
144 }
145
146 let accumulate_warp_roll = config.has_delay();
147 let mut call_idx = 0;
148 let mut shrinker = CallSequenceShrinker::new(calls.len());
149
150 for _ in 0..config.shrink_run_limit {
151 if early_exit.should_stop() {
152 break;
153 }
154
155 shrinker.included_calls.clear(call_idx);
156
157 match check_sequence(
158 executor.clone(),
159 calls,
160 shrinker.current().collect(),
161 target_address,
162 calldata.clone(),
163 CheckSequenceOptions {
164 accumulate_warp_roll,
165 fail_on_revert: config.fail_on_revert,
166 expect_assertion_failure,
167 call_after_invariant: invariant_contract.call_after_invariant,
168 rd: None,
169 },
170 ) {
171 Ok((false, _, _)) if shrinker.included_calls.count() == 1 => break,
173 Ok((true, _, _)) => shrinker.included_calls.set(call_idx),
175 _ => {}
176 }
177
178 if let Some(progress) = progress {
179 progress.inc(1);
180 }
181
182 call_idx = shrinker.next_index(call_idx);
183 }
184
185 Ok(build_shrunk_sequence(calls, &shrinker, accumulate_warp_roll))
186}
187
188pub fn check_sequence<FEN: FoundryEvmNetwork>(
198 executor: Executor<FEN>,
199 calls: &[BasicTxDetails],
200 sequence: Vec<usize>,
201 test_address: Address,
202 calldata: Bytes,
203 options: CheckSequenceOptions<'_>,
204) -> eyre::Result<(bool, bool, Option<String>)> {
205 if options.accumulate_warp_roll {
206 check_sequence_with_accumulation(executor, calls, sequence, test_address, calldata, options)
207 } else {
208 check_sequence_simple(executor, calls, sequence, test_address, calldata, options)
209 }
210}
211
212fn check_sequence_simple<FEN: FoundryEvmNetwork>(
213 mut executor: Executor<FEN>,
214 calls: &[BasicTxDetails],
215 sequence: Vec<usize>,
216 test_address: Address,
217 calldata: Bytes,
218 options: CheckSequenceOptions<'_>,
219) -> eyre::Result<(bool, bool, Option<String>)> {
220 for call_index in sequence {
222 let tx = &calls[call_index];
223 let mut call_result = execute_tx(&mut executor, tx)?;
224 let assertion_failure = did_fail_on_assert(&call_result, &call_result.state_changeset);
225 if call_result.result.as_ref() != MAGIC_ASSUME {
229 if assertion_failure {
230 return Ok((false, false, assertion_failure_reason(call_result, options.rd)));
231 }
232
233 if call_result.reverted && options.fail_on_revert {
234 if options.expect_assertion_failure {
235 return Ok((true, false, None));
236 }
237 return Ok((false, false, call_failure_reason(call_result, options.rd)));
238 }
239 }
240
241 if !call_result.reverted {
242 executor.commit(&mut call_result);
243 }
244 }
245
246 finish_sequence_check(&executor, test_address, calldata, &options)
247}
248
249fn check_sequence_with_accumulation<FEN: FoundryEvmNetwork>(
250 mut executor: Executor<FEN>,
251 calls: &[BasicTxDetails],
252 sequence: Vec<usize>,
253 test_address: Address,
254 calldata: Bytes,
255 options: CheckSequenceOptions<'_>,
256) -> eyre::Result<(bool, bool, Option<String>)> {
257 let mut accumulated_warp = U256::ZERO;
258 let mut accumulated_roll = U256::ZERO;
259 let mut seq_iter = sequence.iter().peekable();
260
261 for (idx, tx) in calls.iter().enumerate() {
262 accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
263 accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
264
265 if seq_iter.peek() != Some(&&idx) {
266 continue;
267 }
268
269 seq_iter.next();
270
271 let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
272 let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
273 let assertion_failure = did_fail_on_assert(&call_result, &call_result.state_changeset);
274
275 if call_result.result.as_ref() != MAGIC_ASSUME {
276 if assertion_failure {
277 return Ok((false, false, assertion_failure_reason(call_result, options.rd)));
278 }
279
280 if call_result.reverted && options.fail_on_revert {
281 if options.expect_assertion_failure {
282 return Ok((true, false, None));
283 }
284 return Ok((false, false, call_failure_reason(call_result, options.rd)));
285 }
286 }
287
288 if !call_result.reverted {
289 executor.commit(&mut call_result);
290 }
291
292 accumulated_warp = U256::ZERO;
293 accumulated_roll = U256::ZERO;
294 }
295
296 finish_sequence_check(&executor, test_address, calldata, &options)
299}
300
301fn finish_sequence_check<FEN: FoundryEvmNetwork>(
302 executor: &Executor<FEN>,
303 test_address: Address,
304 calldata: Bytes,
305 options: &CheckSequenceOptions<'_>,
306) -> eyre::Result<(bool, bool, Option<String>)> {
307 let handle_terminal_failure = |call_result: RawCallResult<FEN>| {
308 let should_ignore_failure = options.expect_assertion_failure
309 && !executor.has_global_failure(&call_result.state_changeset)
310 && !did_fail_on_assert(&call_result, &call_result.state_changeset);
311
312 if should_ignore_failure {
313 return (true, true, None);
314 }
315
316 let reason = if options.expect_assertion_failure {
317 assertion_failure_reason(call_result, options.rd)
318 } else {
319 call_failure_reason(call_result, options.rd)
320 };
321
322 (false, true, reason)
323 };
324
325 let (invariant_result, mut success) =
326 call_invariant_function(executor, test_address, calldata)?;
327 if !success {
328 return Ok(handle_terminal_failure(invariant_result));
329 }
330
331 if success && options.call_after_invariant {
334 let (after_invariant_result, after_invariant_success) =
335 call_after_invariant_function(executor, test_address)?;
336 success = after_invariant_success;
337 if !success {
338 return Ok(handle_terminal_failure(after_invariant_result));
339 }
340 }
341
342 Ok((success, true, None))
343}
344
345pub struct CheckSequenceOptions<'a> {
346 pub accumulate_warp_roll: bool,
347 pub fail_on_revert: bool,
348 pub expect_assertion_failure: bool,
349 pub call_after_invariant: bool,
350 pub rd: Option<&'a RevertDecoder>,
351}
352
353fn call_failure_reason<FEN: FoundryEvmNetwork>(
354 call_result: RawCallResult<FEN>,
355 rd: Option<&RevertDecoder>,
356) -> Option<String> {
357 match call_result.into_evm_error(rd) {
358 EvmError::Execution(err) => Some(err.reason),
359 _ => None,
360 }
361}
362
363fn assertion_failure_reason<FEN: FoundryEvmNetwork>(
364 call_result: RawCallResult<FEN>,
365 rd: Option<&RevertDecoder>,
366) -> Option<String> {
367 call_failure_reason(call_result, rd).or_else(|| Some("assertion failed".to_string()))
368}
369
370pub(crate) fn shrink_sequence_value<FEN: FoundryEvmNetwork>(
378 config: &InvariantConfig,
379 invariant_contract: &InvariantContract<'_>,
380 calls: &[BasicTxDetails],
381 executor: &Executor<FEN>,
382 target_value: I256,
383 progress: Option<&ProgressBar>,
384 early_exit: &EarlyExit,
385) -> eyre::Result<Vec<BasicTxDetails>> {
386 trace!(target: "forge::test", "Shrinking optimization sequence of {} calls for target value {}.", calls.len(), target_value);
387
388 reset_shrink_progress(config, progress);
389
390 let target_address = invariant_contract.address;
391 let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
392
393 if check_sequence_value(executor.clone(), calls, vec![], target_address, calldata.clone())?
395 == Some(target_value)
396 {
397 return Ok(vec![]);
398 }
399
400 let mut call_idx = 0;
401 let mut shrinker = CallSequenceShrinker::new(calls.len());
402
403 for _ in 0..config.shrink_run_limit {
404 if early_exit.should_stop() {
405 break;
406 }
407
408 shrinker.included_calls.clear(call_idx);
409
410 let keeps_target = check_sequence_value(
411 executor.clone(),
412 calls,
413 shrinker.current().collect(),
414 target_address,
415 calldata.clone(),
416 )? == Some(target_value);
417
418 if keeps_target {
419 if shrinker.included_calls.count() == 1 {
420 break;
421 }
422 } else {
423 shrinker.included_calls.set(call_idx);
424 }
425
426 if let Some(progress) = progress {
427 progress.inc(1);
428 }
429
430 call_idx = shrinker.next_index(call_idx);
431 }
432
433 Ok(build_shrunk_sequence(calls, &shrinker, true))
434}
435
436pub fn check_sequence_value<FEN: FoundryEvmNetwork>(
442 mut executor: Executor<FEN>,
443 calls: &[BasicTxDetails],
444 sequence: Vec<usize>,
445 test_address: Address,
446 calldata: Bytes,
447) -> eyre::Result<Option<I256>> {
448 let mut accumulated_warp = U256::ZERO;
449 let mut accumulated_roll = U256::ZERO;
450 let mut seq_iter = sequence.iter().peekable();
451
452 for (idx, tx) in calls.iter().enumerate() {
453 accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
454 accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
455
456 if seq_iter.peek() == Some(&&idx) {
457 seq_iter.next();
458
459 let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
460 let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
461
462 if !call_result.reverted {
463 executor.commit(&mut call_result);
464 }
465
466 accumulated_warp = U256::ZERO;
467 accumulated_roll = U256::ZERO;
468 }
469 }
470
471 apply_warp_roll_to_env(&mut executor, accumulated_warp, accumulated_roll);
473
474 let (inv_result, success) = call_invariant_function(&executor, test_address, calldata)?;
475
476 if success
477 && inv_result.result.len() >= 32
478 && let Some(value) = I256::try_from_be_slice(&inv_result.result[..32])
479 {
480 return Ok(Some(value));
481 }
482
483 Ok(None)
484}
485
486#[cfg(test)]
487mod tests {
488 use super::{CallSequenceShrinker, build_shrunk_sequence};
489 use alloy_primitives::{Address, Bytes, U256};
490 use foundry_evm_fuzz::{BasicTxDetails, CallDetails};
491 use proptest::bits::BitSetLike;
492
493 fn tx(warp: Option<u64>, roll: Option<u64>) -> BasicTxDetails {
494 BasicTxDetails {
495 warp: warp.map(U256::from),
496 roll: roll.map(U256::from),
497 sender: Address::ZERO,
498 call_details: CallDetails { target: Address::ZERO, calldata: Bytes::new() },
499 }
500 }
501
502 #[test]
503 fn build_shrunk_sequence_accumulates_removed_delay_into_next_kept_call() {
504 let calls = vec![tx(Some(3), Some(5)), tx(Some(7), Some(11)), tx(Some(13), Some(17))];
505 let mut shrinker = CallSequenceShrinker::new(calls.len());
506 shrinker.included_calls.clear(0);
507
508 let shrunk = build_shrunk_sequence(&calls, &shrinker, true);
509
510 assert_eq!(shrunk.len(), 2);
511 assert_eq!(shrunk[0].warp, Some(U256::from(10)));
512 assert_eq!(shrunk[0].roll, Some(U256::from(16)));
513 assert_eq!(shrunk[1].warp, Some(U256::from(13)));
514 assert_eq!(shrunk[1].roll, Some(U256::from(17)));
515 }
516
517 #[test]
518 fn build_shrunk_sequence_does_not_move_trailing_delay_backward() {
519 let calls = vec![tx(Some(3), Some(5)), tx(Some(7), Some(11))];
520 let mut shrinker = CallSequenceShrinker::new(calls.len());
521 shrinker.included_calls.clear(1);
522
523 let shrunk = build_shrunk_sequence(&calls, &shrinker, true);
524
525 assert_eq!(shrunk.len(), 1);
526 assert_eq!(shrunk[0].warp, Some(U256::from(3)));
527 assert_eq!(shrunk[0].roll, Some(U256::from(5)));
528 }
529}