1use crate::{
4 ContractRunner, TestFilter, progress::TestsProgress, result::SuiteResult,
5 runner::LIBRARY_DEPLOYER,
6};
7use alloy_json_abi::{Function, JsonAbi};
8use alloy_primitives::{Address, Bytes, U256};
9use eyre::Result;
10use foundry_common::{
11 ContractsByArtifact, ContractsByArtifactBuilder, TestFunctionExt, get_contract_name,
12 shell::verbosity,
13};
14use foundry_compilers::{
15 Artifact, ArtifactId, ProjectCompileOutput,
16 artifacts::{Contract, Libraries},
17 compilers::Compiler,
18};
19use foundry_config::{Config, InlineConfig};
20use foundry_evm::{
21 Env,
22 backend::Backend,
23 decode::RevertDecoder,
24 executors::{Executor, ExecutorBuilder, FailFast},
25 fork::CreateFork,
26 inspectors::CheatsConfig,
27 opts::EvmOpts,
28 traces::{InternalTraceMode, TraceMode},
29};
30use foundry_evm_networks::NetworkConfigs;
31use foundry_linking::{LinkOutput, Linker};
32use rayon::prelude::*;
33use revm::primitives::hardfork::SpecId;
34use std::{
35 borrow::Borrow,
36 collections::BTreeMap,
37 fmt::Debug,
38 path::Path,
39 sync::{Arc, mpsc},
40 time::Instant,
41};
42
43#[derive(Debug, Clone)]
44pub struct TestContract {
45 pub abi: JsonAbi,
46 pub bytecode: Bytes,
47}
48
49pub type DeployableContracts = BTreeMap<ArtifactId, TestContract>;
50
51pub struct MultiContractRunner {
54 pub contracts: DeployableContracts,
57 pub known_contracts: ContractsByArtifact,
59 pub revert_decoder: RevertDecoder,
61 pub libs_to_deploy: Vec<Bytes>,
63 pub libraries: Libraries,
65
66 pub fork: Option<CreateFork>,
68
69 pub tcfg: TestRunnerConfig,
71}
72
73impl std::ops::Deref for MultiContractRunner {
74 type Target = TestRunnerConfig;
75
76 fn deref(&self) -> &Self::Target {
77 &self.tcfg
78 }
79}
80
81impl std::ops::DerefMut for MultiContractRunner {
82 fn deref_mut(&mut self) -> &mut Self::Target {
83 &mut self.tcfg
84 }
85}
86
87impl MultiContractRunner {
88 pub fn matching_contracts<'a: 'b, 'b>(
90 &'a self,
91 filter: &'b dyn TestFilter,
92 ) -> impl Iterator<Item = (&'a ArtifactId, &'a TestContract)> + 'b {
93 self.contracts.iter().filter(|&(id, c)| matches_artifact(filter, id, &c.abi))
94 }
95
96 pub fn matching_test_functions<'a: 'b, 'b>(
98 &'a self,
99 filter: &'b dyn TestFilter,
100 ) -> impl Iterator<Item = &'a Function> + 'b {
101 self.matching_contracts(filter)
102 .flat_map(|(_, c)| c.abi.functions())
103 .filter(|func| filter.matches_test_function(func))
104 }
105
106 pub fn all_test_functions<'a: 'b, 'b>(
108 &'a self,
109 filter: &'b dyn TestFilter,
110 ) -> impl Iterator<Item = &'a Function> + 'b {
111 self.contracts
112 .iter()
113 .filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
114 .flat_map(|(_, c)| c.abi.functions())
115 .filter(|func| func.is_any_test())
116 }
117
118 pub fn list(&self, filter: &dyn TestFilter) -> BTreeMap<String, BTreeMap<String, Vec<String>>> {
120 self.matching_contracts(filter)
121 .map(|(id, c)| {
122 let source = id.source.as_path().display().to_string();
123 let name = id.name.clone();
124 let tests = c
125 .abi
126 .functions()
127 .filter(|func| filter.matches_test_function(func))
128 .map(|func| func.name.clone())
129 .collect::<Vec<_>>();
130 (source, name, tests)
131 })
132 .fold(BTreeMap::new(), |mut acc, (source, name, tests)| {
133 acc.entry(source).or_default().insert(name, tests);
134 acc
135 })
136 }
137
138 pub fn test_collect(
144 &mut self,
145 filter: &dyn TestFilter,
146 ) -> Result<BTreeMap<String, SuiteResult>> {
147 Ok(self.test_iter(filter)?.collect())
148 }
149
150 pub fn test_iter(
156 &mut self,
157 filter: &dyn TestFilter,
158 ) -> Result<impl Iterator<Item = (String, SuiteResult)>> {
159 let (tx, rx) = mpsc::channel();
160 self.test(filter, tx, false)?;
161 Ok(rx.into_iter())
162 }
163
164 pub fn test(
171 &mut self,
172 filter: &dyn TestFilter,
173 tx: mpsc::Sender<(String, SuiteResult)>,
174 show_progress: bool,
175 ) -> Result<()> {
176 let tokio_handle = tokio::runtime::Handle::current();
177 trace!("running all tests");
178
179 let db = Backend::spawn(self.fork.take())?;
181
182 let find_timer = Instant::now();
183 let contracts = self.matching_contracts(filter).collect::<Vec<_>>();
184 let find_time = find_timer.elapsed();
185 debug!(
186 "Found {} test contracts out of {} in {:?}",
187 contracts.len(),
188 self.contracts.len(),
189 find_time,
190 );
191
192 if show_progress {
193 let tests_progress = TestsProgress::new(contracts.len(), rayon::current_num_threads());
194 let results: Vec<(String, SuiteResult)> = contracts
196 .par_iter()
197 .map(|&(id, contract)| {
198 let _guard = tokio_handle.enter();
199 tests_progress.inner.lock().start_suite_progress(&id.identifier());
200
201 let result = self.run_test_suite(
202 id,
203 contract,
204 &db,
205 filter,
206 &tokio_handle,
207 Some(&tests_progress),
208 );
209
210 tests_progress
211 .inner
212 .lock()
213 .end_suite_progress(&id.identifier(), result.summary());
214
215 (id.identifier(), result)
216 })
217 .collect();
218
219 tests_progress.inner.lock().clear();
220
221 results.iter().for_each(|result| {
222 let _ = tx.send(result.to_owned());
223 });
224 } else {
225 contracts.par_iter().for_each(|&(id, contract)| {
226 let _guard = tokio_handle.enter();
227 let result = self.run_test_suite(id, contract, &db, filter, &tokio_handle, None);
228 let _ = tx.send((id.identifier(), result));
229 })
230 }
231
232 Ok(())
233 }
234
235 fn run_test_suite(
236 &self,
237 artifact_id: &ArtifactId,
238 contract: &TestContract,
239 db: &Backend,
240 filter: &dyn TestFilter,
241 tokio_handle: &tokio::runtime::Handle,
242 progress: Option<&TestsProgress>,
243 ) -> SuiteResult {
244 let identifier = artifact_id.identifier();
245 let mut span_name = identifier.as_str();
246
247 if !enabled!(tracing::Level::TRACE) {
248 span_name = get_contract_name(&identifier);
249 }
250 let span = debug_span!("suite", name = %span_name);
251 let span_local = span.clone();
252 let _guard = span_local.enter();
253
254 debug!("start executing all tests in contract");
255
256 let executor = self.tcfg.executor(self.known_contracts.clone(), artifact_id, db.clone());
257 let runner = ContractRunner::new(
258 &identifier,
259 contract,
260 executor,
261 progress,
262 tokio_handle,
263 span,
264 self,
265 );
266 let r = runner.run_tests(filter);
267
268 debug!(duration=?r.duration, "executed all tests in contract");
269
270 r
271 }
272}
273
274#[derive(Clone)]
278pub struct TestRunnerConfig {
279 pub config: Arc<Config>,
281 pub inline_config: Arc<InlineConfig>,
283
284 pub evm_opts: EvmOpts,
286 pub env: Env,
288 pub spec_id: SpecId,
290 pub sender: Address,
292
293 pub line_coverage: bool,
295 pub debug: bool,
297 pub decode_internal: InternalTraceMode,
299 pub isolation: bool,
301 pub networks: NetworkConfigs,
303 pub fail_fast: FailFast,
305}
306
307impl TestRunnerConfig {
308 pub fn reconfigure_with(&mut self, config: Arc<Config>) {
311 debug_assert!(!Arc::ptr_eq(&self.config, &config));
312
313 self.spec_id = config.evm_spec_id();
314 self.sender = config.sender;
315 self.networks.celo = config.celo;
316 self.isolation = config.isolate;
317
318 self.config = config;
326 }
327
328 pub fn configure_executor(&self, executor: &mut Executor) {
330 let inspector = executor.inspector_mut();
333 if let Some(cheatcodes) = inspector.cheatcodes.as_mut() {
335 cheatcodes.config =
336 Arc::new(cheatcodes.config.clone_with(&self.config, self.evm_opts.clone()));
337 }
338 inspector.tracing(self.trace_mode());
339 inspector.collect_line_coverage(self.line_coverage);
340 inspector.enable_isolation(self.isolation);
341 inspector.networks(self.networks);
342 executor.set_spec_id(self.spec_id);
346 executor.set_legacy_assertions(self.config.legacy_assertions);
348 }
349
350 pub fn executor(
352 &self,
353 known_contracts: ContractsByArtifact,
354 artifact_id: &ArtifactId,
355 db: Backend,
356 ) -> Executor {
357 let cheats_config = Arc::new(CheatsConfig::new(
358 &self.config,
359 self.evm_opts.clone(),
360 Some(known_contracts),
361 Some(artifact_id.clone()),
362 ));
363 ExecutorBuilder::new()
364 .inspectors(|stack| {
365 stack
366 .cheatcodes(cheats_config)
367 .trace_mode(self.trace_mode())
368 .line_coverage(self.line_coverage)
369 .enable_isolation(self.isolation)
370 .networks(self.networks)
371 .create2_deployer(self.evm_opts.create2_deployer)
372 })
373 .spec_id(self.spec_id)
374 .gas_limit(self.evm_opts.gas_limit())
375 .legacy_assertions(self.config.legacy_assertions)
376 .build(self.env.clone(), db)
377 }
378
379 fn trace_mode(&self) -> TraceMode {
380 TraceMode::default()
381 .with_debug(self.debug)
382 .with_decode_internal(self.decode_internal)
383 .with_verbosity(self.evm_opts.verbosity)
384 .with_state_changes(verbosity() > 4)
385 }
386}
387
388#[derive(Clone, Debug)]
390#[must_use = "builders do nothing unless you call `build` on them"]
391pub struct MultiContractRunnerBuilder {
392 pub sender: Option<Address>,
395 pub initial_balance: U256,
397 pub evm_spec: Option<SpecId>,
399 pub fork: Option<CreateFork>,
401 pub config: Arc<Config>,
403 pub line_coverage: bool,
405 pub debug: bool,
407 pub decode_internal: InternalTraceMode,
409 pub isolation: bool,
411 pub networks: NetworkConfigs,
413 pub fail_fast: bool,
415}
416
417impl MultiContractRunnerBuilder {
418 pub fn new(config: Arc<Config>) -> Self {
419 Self {
420 config,
421 sender: Default::default(),
422 initial_balance: Default::default(),
423 evm_spec: Default::default(),
424 fork: Default::default(),
425 line_coverage: Default::default(),
426 debug: Default::default(),
427 isolation: Default::default(),
428 decode_internal: Default::default(),
429 networks: Default::default(),
430 fail_fast: false,
431 }
432 }
433
434 pub fn sender(mut self, sender: Address) -> Self {
435 self.sender = Some(sender);
436 self
437 }
438
439 pub fn initial_balance(mut self, initial_balance: U256) -> Self {
440 self.initial_balance = initial_balance;
441 self
442 }
443
444 pub fn evm_spec(mut self, spec: SpecId) -> Self {
445 self.evm_spec = Some(spec);
446 self
447 }
448
449 pub fn with_fork(mut self, fork: Option<CreateFork>) -> Self {
450 self.fork = fork;
451 self
452 }
453
454 pub fn set_coverage(mut self, enable: bool) -> Self {
455 self.line_coverage = enable;
456 self
457 }
458
459 pub fn set_debug(mut self, enable: bool) -> Self {
460 self.debug = enable;
461 self
462 }
463
464 pub fn set_decode_internal(mut self, mode: InternalTraceMode) -> Self {
465 self.decode_internal = mode;
466 self
467 }
468
469 pub fn fail_fast(mut self, fail_fast: bool) -> Self {
470 self.fail_fast = fail_fast;
471 self
472 }
473
474 pub fn enable_isolation(mut self, enable: bool) -> Self {
475 self.isolation = enable;
476 self
477 }
478
479 pub fn networks(mut self, networks: NetworkConfigs) -> Self {
480 self.networks = networks;
481 self
482 }
483
484 pub fn build<C: Compiler<CompilerContract = Contract>>(
487 self,
488 root: &Path,
489 output: &ProjectCompileOutput,
490 env: Env,
491 evm_opts: EvmOpts,
492 ) -> Result<MultiContractRunner> {
493 let contracts = output
494 .artifact_ids()
495 .map(|(id, v)| (id.with_stripped_file_prefixes(root), v))
496 .collect();
497 let linker = Linker::new(root, contracts);
498
499 let abis = linker
501 .contracts
502 .iter()
503 .filter_map(|(_, contract)| contract.abi.as_ref().map(|abi| abi.borrow()));
504 let revert_decoder = RevertDecoder::new().with_abis(abis);
505
506 let LinkOutput { libraries, libs_to_deploy } = linker.link_with_nonce_or_address(
507 Default::default(),
508 LIBRARY_DEPLOYER,
509 0,
510 linker.contracts.keys(),
511 )?;
512
513 let linked_contracts = linker.get_linked_artifacts_cow(&libraries)?;
514
515 let mut deployable_contracts = DeployableContracts::default();
517
518 for (id, contract) in linked_contracts.iter() {
519 let Some(abi) = contract.abi.as_ref() else { continue };
520
521 if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true)
523 && abi.functions().any(|func| func.name.is_any_test())
524 {
525 linker.ensure_linked(contract, id)?;
526
527 let Some(bytecode) =
528 contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
529 else {
530 continue;
531 };
532
533 deployable_contracts
534 .insert(id.clone(), TestContract { abi: abi.clone().into_owned(), bytecode });
535 }
536 }
537
538 let known_contracts =
540 ContractsByArtifactBuilder::new(linked_contracts).with_output(output, root).build();
541
542 Ok(MultiContractRunner {
543 contracts: deployable_contracts,
544 revert_decoder,
545 known_contracts,
546 libs_to_deploy,
547 libraries,
548
549 fork: self.fork,
550
551 tcfg: TestRunnerConfig {
552 evm_opts,
553 env,
554 spec_id: self.evm_spec.unwrap_or_else(|| self.config.evm_spec_id()),
555 sender: self.sender.unwrap_or(self.config.sender),
556 line_coverage: self.line_coverage,
557 debug: self.debug,
558 decode_internal: self.decode_internal,
559 inline_config: Arc::new(InlineConfig::new_parsed(output, &self.config)?),
560 isolation: self.isolation,
561 networks: self.networks,
562 config: self.config,
563 fail_fast: FailFast::new(self.fail_fast),
564 },
565 })
566 }
567}
568
569pub fn matches_artifact(filter: &dyn TestFilter, id: &ArtifactId, abi: &JsonAbi) -> bool {
570 matches_contract(filter, &id.source, &id.name, abi.functions())
571}
572
573pub(crate) fn matches_contract(
574 filter: &dyn TestFilter,
575 path: &Path,
576 contract_name: &str,
577 functions: impl IntoIterator<Item = impl std::borrow::Borrow<Function>>,
578) -> bool {
579 (filter.matches_path(path) && filter.matches_contract(contract_name))
580 && functions.into_iter().any(|func| filter.matches_test_function(func.borrow()))
581}