foundry_evm/executors/invariant/
shrink.rs1use crate::executors::{
2 EarlyExit, Executor,
3 invariant::{call_after_invariant_function, call_invariant_function, execute_tx},
4};
5use alloy_primitives::{Address, Bytes, I256, U256};
6use foundry_config::InvariantConfig;
7use foundry_evm_core::constants::MAGIC_ASSUME;
8use foundry_evm_fuzz::{BasicTxDetails, invariant::InvariantContract};
9use indicatif::ProgressBar;
10use proptest::bits::{BitSetLike, VarBitSet};
11
12#[derive(Debug)]
17struct CallSequenceShrinker {
18 call_sequence_len: usize,
20 included_calls: VarBitSet,
22}
23
24impl CallSequenceShrinker {
25 fn new(call_sequence_len: usize) -> Self {
26 Self { call_sequence_len, included_calls: VarBitSet::saturated(call_sequence_len) }
27 }
28
29 fn current(&self) -> impl Iterator<Item = usize> + '_ {
31 (0..self.call_sequence_len).filter(|&call_id| self.included_calls.test(call_id))
32 }
33
34 fn next_index(&self, call_idx: usize) -> usize {
36 if call_idx + 1 == self.call_sequence_len { 0 } else { call_idx + 1 }
37 }
38}
39
40fn reset_shrink_progress(config: &InvariantConfig, progress: Option<&ProgressBar>) {
42 if let Some(progress) = progress {
43 progress.set_length(config.shrink_run_limit as u64);
44 progress.reset();
45 progress.set_message(" Shrink");
46 }
47}
48
49fn apply_warp_roll(call: &BasicTxDetails, warp: U256, roll: U256) -> BasicTxDetails {
51 let mut result = call.clone();
52 if warp > U256::ZERO {
53 result.warp = Some(warp);
54 }
55 if roll > U256::ZERO {
56 result.roll = Some(roll);
57 }
58 result
59}
60
61fn apply_warp_roll_to_env(executor: &mut Executor, warp: U256, roll: U256) {
63 if warp > U256::ZERO || roll > U256::ZERO {
64 executor.env_mut().evm_env.block_env.timestamp += warp;
65 executor.env_mut().evm_env.block_env.number += roll;
66
67 let block_env = executor.env().evm_env.block_env.clone();
68 if let Some(cheatcodes) = executor.inspector_mut().cheatcodes.as_mut() {
69 if let Some(block) = cheatcodes.block.as_mut() {
70 block.timestamp += warp;
71 block.number += roll;
72 } else {
73 cheatcodes.block = Some(block_env);
74 }
75 }
76 }
77}
78
79pub(crate) fn shrink_sequence(
80 config: &InvariantConfig,
81 invariant_contract: &InvariantContract<'_>,
82 calls: &[BasicTxDetails],
83 executor: &Executor,
84 progress: Option<&ProgressBar>,
85 early_exit: &EarlyExit,
86) -> eyre::Result<Vec<BasicTxDetails>> {
87 trace!(target: "forge::test", "Shrinking sequence of {} calls.", calls.len());
88
89 reset_shrink_progress(config, progress);
90
91 let target_address = invariant_contract.address;
92 let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
93 let (_, success) = call_invariant_function(executor, target_address, calldata.clone())?;
96 if !success {
97 return Ok(vec![]);
98 }
99
100 let mut call_idx = 0;
101 let mut shrinker = CallSequenceShrinker::new(calls.len());
102
103 for _ in 0..config.shrink_run_limit {
104 if early_exit.should_stop() {
105 break;
106 }
107
108 shrinker.included_calls.clear(call_idx);
109
110 match check_sequence(
111 executor.clone(),
112 calls,
113 shrinker.current().collect(),
114 target_address,
115 calldata.clone(),
116 config.fail_on_revert,
117 invariant_contract.call_after_invariant,
118 ) {
119 Ok((false, _)) if shrinker.included_calls.count() == 1 => break,
121 Ok((true, _)) => shrinker.included_calls.set(call_idx),
123 _ => {}
124 }
125
126 if let Some(progress) = progress {
127 progress.inc(1);
128 }
129
130 call_idx = shrinker.next_index(call_idx);
131 }
132
133 Ok(shrinker.current().map(|idx| &calls[idx]).cloned().collect())
134}
135
136pub fn check_sequence(
143 mut executor: Executor,
144 calls: &[BasicTxDetails],
145 sequence: Vec<usize>,
146 test_address: Address,
147 calldata: Bytes,
148 fail_on_revert: bool,
149 call_after_invariant: bool,
150) -> eyre::Result<(bool, bool)> {
151 for call_index in sequence {
153 let tx = &calls[call_index];
154 let mut call_result = execute_tx(&mut executor, tx)?;
155 executor.commit(&mut call_result);
156 if call_result.reverted && fail_on_revert && call_result.result.as_ref() != MAGIC_ASSUME {
160 return Ok((false, false));
163 }
164 }
165
166 let (_, mut success) = call_invariant_function(&executor, test_address, calldata)?;
168 if success && call_after_invariant {
171 (_, success) = call_after_invariant_function(&executor, test_address)?;
172 }
173
174 Ok((success, true))
175}
176
177pub(crate) fn shrink_sequence_value(
185 config: &InvariantConfig,
186 invariant_contract: &InvariantContract<'_>,
187 calls: &[BasicTxDetails],
188 executor: &Executor,
189 target_value: I256,
190 progress: Option<&ProgressBar>,
191 early_exit: &EarlyExit,
192) -> eyre::Result<Vec<BasicTxDetails>> {
193 trace!(target: "forge::test", "Shrinking optimization sequence of {} calls for target value {}.", calls.len(), target_value);
194
195 reset_shrink_progress(config, progress);
196
197 let target_address = invariant_contract.address;
198 let calldata: Bytes = invariant_contract.invariant_function.selector().to_vec().into();
199
200 if check_sequence_value(executor.clone(), calls, vec![], target_address, calldata.clone())?
202 == Some(target_value)
203 {
204 return Ok(vec![]);
205 }
206
207 let mut call_idx = 0;
208 let mut shrinker = CallSequenceShrinker::new(calls.len());
209
210 for _ in 0..config.shrink_run_limit {
211 if early_exit.should_stop() {
212 break;
213 }
214
215 shrinker.included_calls.clear(call_idx);
216
217 let keeps_target = check_sequence_value(
218 executor.clone(),
219 calls,
220 shrinker.current().collect(),
221 target_address,
222 calldata.clone(),
223 )? == Some(target_value);
224
225 if keeps_target {
226 if shrinker.included_calls.count() == 1 {
227 break;
228 }
229 } else {
230 shrinker.included_calls.set(call_idx);
231 }
232
233 if let Some(progress) = progress {
234 progress.inc(1);
235 }
236
237 call_idx = shrinker.next_index(call_idx);
238 }
239
240 let mut result = Vec::new();
242 let mut accumulated_warp = U256::ZERO;
243 let mut accumulated_roll = U256::ZERO;
244
245 for (idx, call) in calls.iter().enumerate() {
246 accumulated_warp += call.warp.unwrap_or(U256::ZERO);
247 accumulated_roll += call.roll.unwrap_or(U256::ZERO);
248
249 if shrinker.included_calls.test(idx) {
250 result.push(apply_warp_roll(call, accumulated_warp, accumulated_roll));
251 accumulated_warp = U256::ZERO;
252 accumulated_roll = U256::ZERO;
253 }
254 }
255
256 Ok(result)
257}
258
259pub fn check_sequence_value(
265 mut executor: Executor,
266 calls: &[BasicTxDetails],
267 sequence: Vec<usize>,
268 test_address: Address,
269 calldata: Bytes,
270) -> eyre::Result<Option<I256>> {
271 let mut accumulated_warp = U256::ZERO;
272 let mut accumulated_roll = U256::ZERO;
273 let mut seq_iter = sequence.iter().peekable();
274
275 for (idx, tx) in calls.iter().enumerate() {
276 accumulated_warp += tx.warp.unwrap_or(U256::ZERO);
277 accumulated_roll += tx.roll.unwrap_or(U256::ZERO);
278
279 if seq_iter.peek() == Some(&&idx) {
280 seq_iter.next();
281
282 let tx_with_accumulated = apply_warp_roll(tx, accumulated_warp, accumulated_roll);
283 let mut call_result = execute_tx(&mut executor, &tx_with_accumulated)?;
284
285 if !call_result.reverted {
286 executor.commit(&mut call_result);
287 }
288
289 accumulated_warp = U256::ZERO;
290 accumulated_roll = U256::ZERO;
291 }
292 }
293
294 apply_warp_roll_to_env(&mut executor, accumulated_warp, accumulated_roll);
296
297 let (inv_result, success) = call_invariant_function(&executor, test_address, calldata)?;
298
299 if success
300 && inv_result.result.len() >= 32
301 && let Some(value) = I256::try_from_be_slice(&inv_result.result[..32])
302 {
303 return Ok(Some(value));
304 }
305
306 Ok(None)
307}