1use crate::{
4 ContractRunner, TestFilter, progress::TestsProgress, result::SuiteResult,
5 runner::LIBRARY_DEPLOYER,
6};
7use alloy_json_abi::{Function, JsonAbi};
8use alloy_primitives::{Address, Bytes, U256};
9use eyre::Result;
10use foundry_common::{
11 ContractsByArtifact, ContractsByArtifactBuilder, TestFunctionExt, get_contract_name,
12 shell::verbosity,
13};
14use foundry_compilers::{
15 Artifact, ArtifactId, ProjectCompileOutput,
16 artifacts::{Contract, Libraries},
17 compilers::Compiler,
18};
19use foundry_config::{Config, InlineConfig};
20use foundry_evm::{
21 Env,
22 backend::Backend,
23 decode::RevertDecoder,
24 executors::{Executor, ExecutorBuilder, FailFast},
25 fork::CreateFork,
26 inspectors::CheatsConfig,
27 opts::EvmOpts,
28 traces::{InternalTraceMode, TraceMode},
29};
30use foundry_linking::{LinkOutput, Linker};
31use rayon::prelude::*;
32use revm::primitives::hardfork::SpecId;
33use std::{
34 borrow::Borrow,
35 collections::BTreeMap,
36 fmt::Debug,
37 path::Path,
38 sync::{Arc, mpsc},
39 time::Instant,
40};
41
42#[derive(Debug, Clone)]
43pub struct TestContract {
44 pub abi: JsonAbi,
45 pub bytecode: Bytes,
46}
47
48pub type DeployableContracts = BTreeMap<ArtifactId, TestContract>;
49
50pub struct MultiContractRunner {
53 pub contracts: DeployableContracts,
56 pub known_contracts: ContractsByArtifact,
58 pub revert_decoder: RevertDecoder,
60 pub libs_to_deploy: Vec<Bytes>,
62 pub libraries: Libraries,
64
65 pub fork: Option<CreateFork>,
67
68 pub tcfg: TestRunnerConfig,
70}
71
72impl std::ops::Deref for MultiContractRunner {
73 type Target = TestRunnerConfig;
74
75 fn deref(&self) -> &Self::Target {
76 &self.tcfg
77 }
78}
79
80impl std::ops::DerefMut for MultiContractRunner {
81 fn deref_mut(&mut self) -> &mut Self::Target {
82 &mut self.tcfg
83 }
84}
85
86impl MultiContractRunner {
87 pub fn matching_contracts<'a: 'b, 'b>(
89 &'a self,
90 filter: &'b dyn TestFilter,
91 ) -> impl Iterator<Item = (&'a ArtifactId, &'a TestContract)> + 'b {
92 self.contracts.iter().filter(|&(id, c)| matches_artifact(filter, id, &c.abi))
93 }
94
95 pub fn matching_test_functions<'a: 'b, 'b>(
97 &'a self,
98 filter: &'b dyn TestFilter,
99 ) -> impl Iterator<Item = &'a Function> + 'b {
100 self.matching_contracts(filter)
101 .flat_map(|(_, c)| c.abi.functions())
102 .filter(|func| filter.matches_test_function(func))
103 }
104
105 pub fn all_test_functions<'a: 'b, 'b>(
107 &'a self,
108 filter: &'b dyn TestFilter,
109 ) -> impl Iterator<Item = &'a Function> + 'b {
110 self.contracts
111 .iter()
112 .filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
113 .flat_map(|(_, c)| c.abi.functions())
114 .filter(|func| func.is_any_test())
115 }
116
117 pub fn list(&self, filter: &dyn TestFilter) -> BTreeMap<String, BTreeMap<String, Vec<String>>> {
119 self.matching_contracts(filter)
120 .map(|(id, c)| {
121 let source = id.source.as_path().display().to_string();
122 let name = id.name.clone();
123 let tests = c
124 .abi
125 .functions()
126 .filter(|func| filter.matches_test_function(func))
127 .map(|func| func.name.clone())
128 .collect::<Vec<_>>();
129 (source, name, tests)
130 })
131 .fold(BTreeMap::new(), |mut acc, (source, name, tests)| {
132 acc.entry(source).or_default().insert(name, tests);
133 acc
134 })
135 }
136
137 pub fn test_collect(
143 &mut self,
144 filter: &dyn TestFilter,
145 ) -> Result<BTreeMap<String, SuiteResult>> {
146 Ok(self.test_iter(filter)?.collect())
147 }
148
149 pub fn test_iter(
155 &mut self,
156 filter: &dyn TestFilter,
157 ) -> Result<impl Iterator<Item = (String, SuiteResult)>> {
158 let (tx, rx) = mpsc::channel();
159 self.test(filter, tx, false)?;
160 Ok(rx.into_iter())
161 }
162
163 pub fn test(
170 &mut self,
171 filter: &dyn TestFilter,
172 tx: mpsc::Sender<(String, SuiteResult)>,
173 show_progress: bool,
174 ) -> Result<()> {
175 let tokio_handle = tokio::runtime::Handle::current();
176 trace!("running all tests");
177
178 let db = Backend::spawn(self.fork.take())?;
180
181 let find_timer = Instant::now();
182 let contracts = self.matching_contracts(filter).collect::<Vec<_>>();
183 let find_time = find_timer.elapsed();
184 debug!(
185 "Found {} test contracts out of {} in {:?}",
186 contracts.len(),
187 self.contracts.len(),
188 find_time,
189 );
190
191 if show_progress {
192 let tests_progress = TestsProgress::new(contracts.len(), rayon::current_num_threads());
193 let results: Vec<(String, SuiteResult)> = contracts
195 .par_iter()
196 .map(|&(id, contract)| {
197 let _guard = tokio_handle.enter();
198 tests_progress.inner.lock().start_suite_progress(&id.identifier());
199
200 let result = self.run_test_suite(
201 id,
202 contract,
203 &db,
204 filter,
205 &tokio_handle,
206 Some(&tests_progress),
207 );
208
209 tests_progress
210 .inner
211 .lock()
212 .end_suite_progress(&id.identifier(), result.summary());
213
214 (id.identifier(), result)
215 })
216 .collect();
217
218 tests_progress.inner.lock().clear();
219
220 results.iter().for_each(|result| {
221 let _ = tx.send(result.to_owned());
222 });
223 } else {
224 contracts.par_iter().for_each(|&(id, contract)| {
225 let _guard = tokio_handle.enter();
226 let result = self.run_test_suite(id, contract, &db, filter, &tokio_handle, None);
227 let _ = tx.send((id.identifier(), result));
228 })
229 }
230
231 Ok(())
232 }
233
234 fn run_test_suite(
235 &self,
236 artifact_id: &ArtifactId,
237 contract: &TestContract,
238 db: &Backend,
239 filter: &dyn TestFilter,
240 tokio_handle: &tokio::runtime::Handle,
241 progress: Option<&TestsProgress>,
242 ) -> SuiteResult {
243 let identifier = artifact_id.identifier();
244 let mut span_name = identifier.as_str();
245
246 if !enabled!(tracing::Level::TRACE) {
247 span_name = get_contract_name(&identifier);
248 }
249 let span = debug_span!("suite", name = %span_name);
250 let span_local = span.clone();
251 let _guard = span_local.enter();
252
253 debug!("start executing all tests in contract");
254
255 let executor = self.tcfg.executor(self.known_contracts.clone(), artifact_id, db.clone());
256 let runner = ContractRunner::new(
257 &identifier,
258 contract,
259 executor,
260 progress,
261 tokio_handle,
262 span,
263 self,
264 );
265 let r = runner.run_tests(filter);
266
267 debug!(duration=?r.duration, "executed all tests in contract");
268
269 r
270 }
271}
272
273#[derive(Clone)]
277pub struct TestRunnerConfig {
278 pub config: Arc<Config>,
280 pub inline_config: Arc<InlineConfig>,
282
283 pub evm_opts: EvmOpts,
285 pub env: Env,
287 pub spec_id: SpecId,
289 pub sender: Address,
291
292 pub line_coverage: bool,
294 pub debug: bool,
296 pub decode_internal: InternalTraceMode,
298 pub isolation: bool,
300 pub odyssey: bool,
302 pub fail_fast: FailFast,
304}
305
306impl TestRunnerConfig {
307 pub fn reconfigure_with(&mut self, config: Arc<Config>) {
310 debug_assert!(!Arc::ptr_eq(&self.config, &config));
311
312 self.spec_id = config.evm_spec_id();
313 self.sender = config.sender;
314 self.odyssey = config.odyssey;
315 self.isolation = config.isolate;
316
317 self.config = config;
325 }
326
327 pub fn configure_executor(&self, executor: &mut Executor) {
329 let inspector = executor.inspector_mut();
332 if let Some(cheatcodes) = inspector.cheatcodes.as_mut() {
334 cheatcodes.config =
335 Arc::new(cheatcodes.config.clone_with(&self.config, self.evm_opts.clone()));
336 }
337 inspector.tracing(self.trace_mode());
338 inspector.collect_line_coverage(self.line_coverage);
339 inspector.enable_isolation(self.isolation);
340 inspector.odyssey(self.odyssey);
341 executor.set_spec_id(self.spec_id);
345 executor.set_legacy_assertions(self.config.legacy_assertions);
347 }
348
349 pub fn executor(
351 &self,
352 known_contracts: ContractsByArtifact,
353 artifact_id: &ArtifactId,
354 db: Backend,
355 ) -> Executor {
356 let cheats_config = Arc::new(CheatsConfig::new(
357 &self.config,
358 self.evm_opts.clone(),
359 Some(known_contracts),
360 Some(artifact_id.clone()),
361 ));
362 ExecutorBuilder::new()
363 .inspectors(|stack| {
364 stack
365 .cheatcodes(cheats_config)
366 .trace_mode(self.trace_mode())
367 .line_coverage(self.line_coverage)
368 .enable_isolation(self.isolation)
369 .odyssey(self.odyssey)
370 .create2_deployer(self.evm_opts.create2_deployer)
371 })
372 .spec_id(self.spec_id)
373 .gas_limit(self.evm_opts.gas_limit())
374 .legacy_assertions(self.config.legacy_assertions)
375 .build(self.env.clone(), db)
376 }
377
378 fn trace_mode(&self) -> TraceMode {
379 TraceMode::default()
380 .with_debug(self.debug)
381 .with_decode_internal(self.decode_internal)
382 .with_verbosity(self.evm_opts.verbosity)
383 .with_state_changes(verbosity() > 4)
384 }
385}
386
387#[derive(Clone, Debug)]
389#[must_use = "builders do nothing unless you call `build` on them"]
390pub struct MultiContractRunnerBuilder {
391 pub sender: Option<Address>,
394 pub initial_balance: U256,
396 pub evm_spec: Option<SpecId>,
398 pub fork: Option<CreateFork>,
400 pub config: Arc<Config>,
402 pub line_coverage: bool,
404 pub debug: bool,
406 pub decode_internal: InternalTraceMode,
408 pub isolation: bool,
410 pub odyssey: bool,
412 pub fail_fast: bool,
414}
415
416impl MultiContractRunnerBuilder {
417 pub fn new(config: Arc<Config>) -> Self {
418 Self {
419 config,
420 sender: Default::default(),
421 initial_balance: Default::default(),
422 evm_spec: Default::default(),
423 fork: Default::default(),
424 line_coverage: Default::default(),
425 debug: Default::default(),
426 isolation: Default::default(),
427 decode_internal: Default::default(),
428 odyssey: Default::default(),
429 fail_fast: false,
430 }
431 }
432
433 pub fn sender(mut self, sender: Address) -> Self {
434 self.sender = Some(sender);
435 self
436 }
437
438 pub fn initial_balance(mut self, initial_balance: U256) -> Self {
439 self.initial_balance = initial_balance;
440 self
441 }
442
443 pub fn evm_spec(mut self, spec: SpecId) -> Self {
444 self.evm_spec = Some(spec);
445 self
446 }
447
448 pub fn with_fork(mut self, fork: Option<CreateFork>) -> Self {
449 self.fork = fork;
450 self
451 }
452
453 pub fn set_coverage(mut self, enable: bool) -> Self {
454 self.line_coverage = enable;
455 self
456 }
457
458 pub fn set_debug(mut self, enable: bool) -> Self {
459 self.debug = enable;
460 self
461 }
462
463 pub fn set_decode_internal(mut self, mode: InternalTraceMode) -> Self {
464 self.decode_internal = mode;
465 self
466 }
467
468 pub fn fail_fast(mut self, fail_fast: bool) -> Self {
469 self.fail_fast = fail_fast;
470 self
471 }
472
473 pub fn enable_isolation(mut self, enable: bool) -> Self {
474 self.isolation = enable;
475 self
476 }
477
478 pub fn odyssey(mut self, enable: bool) -> Self {
479 self.odyssey = enable;
480 self
481 }
482
483 pub fn build<C: Compiler<CompilerContract = Contract>>(
486 self,
487 root: &Path,
488 output: &ProjectCompileOutput,
489 env: Env,
490 evm_opts: EvmOpts,
491 ) -> Result<MultiContractRunner> {
492 let contracts = output
493 .artifact_ids()
494 .map(|(id, v)| (id.with_stripped_file_prefixes(root), v))
495 .collect();
496 let linker = Linker::new(root, contracts);
497
498 let abis = linker
500 .contracts
501 .iter()
502 .filter_map(|(_, contract)| contract.abi.as_ref().map(|abi| abi.borrow()));
503 let revert_decoder = RevertDecoder::new().with_abis(abis);
504
505 let LinkOutput { libraries, libs_to_deploy } = linker.link_with_nonce_or_address(
506 Default::default(),
507 LIBRARY_DEPLOYER,
508 0,
509 linker.contracts.keys(),
510 )?;
511
512 let linked_contracts = linker.get_linked_artifacts_cow(&libraries)?;
513
514 let mut deployable_contracts = DeployableContracts::default();
516
517 for (id, contract) in linked_contracts.iter() {
518 let Some(abi) = contract.abi.as_ref() else { continue };
519
520 if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true)
522 && abi.functions().any(|func| func.name.is_any_test())
523 {
524 linker.ensure_linked(contract, id)?;
525
526 let Some(bytecode) =
527 contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
528 else {
529 continue;
530 };
531
532 deployable_contracts
533 .insert(id.clone(), TestContract { abi: abi.clone().into_owned(), bytecode });
534 }
535 }
536
537 let known_contracts = ContractsByArtifactBuilder::new(linked_contracts)
539 .with_storage_layouts(output.clone().with_stripped_file_prefixes(root))
540 .build();
541
542 Ok(MultiContractRunner {
543 contracts: deployable_contracts,
544 revert_decoder,
545 known_contracts,
546 libs_to_deploy,
547 libraries,
548
549 fork: self.fork,
550
551 tcfg: TestRunnerConfig {
552 evm_opts,
553 env,
554 spec_id: self.evm_spec.unwrap_or_else(|| self.config.evm_spec_id()),
555 sender: self.sender.unwrap_or(self.config.sender),
556 line_coverage: self.line_coverage,
557 debug: self.debug,
558 decode_internal: self.decode_internal,
559 inline_config: Arc::new(InlineConfig::new_parsed(output, &self.config)?),
560 isolation: self.isolation,
561 odyssey: self.odyssey,
562 config: self.config,
563 fail_fast: FailFast::new(self.fail_fast),
564 },
565 })
566 }
567}
568
569pub fn matches_artifact(filter: &dyn TestFilter, id: &ArtifactId, abi: &JsonAbi) -> bool {
570 matches_contract(filter, &id.source, &id.name, abi.functions())
571}
572
573pub(crate) fn matches_contract(
574 filter: &dyn TestFilter,
575 path: &Path,
576 contract_name: &str,
577 functions: impl IntoIterator<Item = impl std::borrow::Borrow<Function>>,
578) -> bool {
579 (filter.matches_path(path) && filter.matches_contract(contract_name))
580 && functions.into_iter().any(|func| filter.matches_test_function(func.borrow()))
581}
582
583pub(crate) fn is_test_contract(
585 functions: impl IntoIterator<Item = impl std::borrow::Borrow<Function>>,
586) -> bool {
587 functions.into_iter().any(|func| func.borrow().is_any_test())
588}