forge/
multi_runner.rs

1//! Forge test runner for multiple contracts.
2
3use crate::{
4    ContractRunner, TestFilter, progress::TestsProgress, result::SuiteResult,
5    runner::LIBRARY_DEPLOYER,
6};
7use alloy_json_abi::{Function, JsonAbi};
8use alloy_primitives::{Address, Bytes, U256};
9use eyre::Result;
10use foundry_common::{
11    ContractsByArtifact, ContractsByArtifactBuilder, TestFunctionExt, get_contract_name,
12    shell::verbosity,
13};
14use foundry_compilers::{
15    Artifact, ArtifactId, ProjectCompileOutput,
16    artifacts::{Contract, Libraries},
17    compilers::Compiler,
18};
19use foundry_config::{Config, InlineConfig};
20use foundry_evm::{
21    Env,
22    backend::Backend,
23    decode::RevertDecoder,
24    executors::{Executor, ExecutorBuilder, FailFast},
25    fork::CreateFork,
26    inspectors::CheatsConfig,
27    opts::EvmOpts,
28    traces::{InternalTraceMode, TraceMode},
29};
30use foundry_linking::{LinkOutput, Linker};
31use rayon::prelude::*;
32use revm::primitives::hardfork::SpecId;
33use std::{
34    borrow::Borrow,
35    collections::BTreeMap,
36    fmt::Debug,
37    path::Path,
38    sync::{Arc, mpsc},
39    time::Instant,
40};
41
42#[derive(Debug, Clone)]
43pub struct TestContract {
44    pub abi: JsonAbi,
45    pub bytecode: Bytes,
46}
47
48pub type DeployableContracts = BTreeMap<ArtifactId, TestContract>;
49
50/// A multi contract runner receives a set of contracts deployed in an EVM instance and proceeds
51/// to run all test functions in these contracts.
52pub struct MultiContractRunner {
53    /// Mapping of contract name to JsonAbi, creation bytecode and library bytecode which
54    /// needs to be deployed & linked against
55    pub contracts: DeployableContracts,
56    /// Known contracts linked with computed library addresses.
57    pub known_contracts: ContractsByArtifact,
58    /// Revert decoder. Contains all known errors and their selectors.
59    pub revert_decoder: RevertDecoder,
60    /// Libraries to deploy.
61    pub libs_to_deploy: Vec<Bytes>,
62    /// Library addresses used to link contracts.
63    pub libraries: Libraries,
64
65    /// The fork to use at launch
66    pub fork: Option<CreateFork>,
67
68    /// The base configuration for the test runner.
69    pub tcfg: TestRunnerConfig,
70}
71
72impl std::ops::Deref for MultiContractRunner {
73    type Target = TestRunnerConfig;
74
75    fn deref(&self) -> &Self::Target {
76        &self.tcfg
77    }
78}
79
80impl std::ops::DerefMut for MultiContractRunner {
81    fn deref_mut(&mut self) -> &mut Self::Target {
82        &mut self.tcfg
83    }
84}
85
86impl MultiContractRunner {
87    /// Returns an iterator over all contracts that match the filter.
88    pub fn matching_contracts<'a: 'b, 'b>(
89        &'a self,
90        filter: &'b dyn TestFilter,
91    ) -> impl Iterator<Item = (&'a ArtifactId, &'a TestContract)> + 'b {
92        self.contracts.iter().filter(|&(id, c)| matches_artifact(filter, id, &c.abi))
93    }
94
95    /// Returns an iterator over all test functions that match the filter.
96    pub fn matching_test_functions<'a: 'b, 'b>(
97        &'a self,
98        filter: &'b dyn TestFilter,
99    ) -> impl Iterator<Item = &'a Function> + 'b {
100        self.matching_contracts(filter)
101            .flat_map(|(_, c)| c.abi.functions())
102            .filter(|func| filter.matches_test_function(func))
103    }
104
105    /// Returns an iterator over all test functions in contracts that match the filter.
106    pub fn all_test_functions<'a: 'b, 'b>(
107        &'a self,
108        filter: &'b dyn TestFilter,
109    ) -> impl Iterator<Item = &'a Function> + 'b {
110        self.contracts
111            .iter()
112            .filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
113            .flat_map(|(_, c)| c.abi.functions())
114            .filter(|func| func.is_any_test())
115    }
116
117    /// Returns all matching tests grouped by contract grouped by file (file -> (contract -> tests))
118    pub fn list(&self, filter: &dyn TestFilter) -> BTreeMap<String, BTreeMap<String, Vec<String>>> {
119        self.matching_contracts(filter)
120            .map(|(id, c)| {
121                let source = id.source.as_path().display().to_string();
122                let name = id.name.clone();
123                let tests = c
124                    .abi
125                    .functions()
126                    .filter(|func| filter.matches_test_function(func))
127                    .map(|func| func.name.clone())
128                    .collect::<Vec<_>>();
129                (source, name, tests)
130            })
131            .fold(BTreeMap::new(), |mut acc, (source, name, tests)| {
132                acc.entry(source).or_default().insert(name, tests);
133                acc
134            })
135    }
136
137    /// Executes _all_ tests that match the given `filter`.
138    ///
139    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
140    ///
141    /// Note that this method returns only when all tests have been executed.
142    pub fn test_collect(
143        &mut self,
144        filter: &dyn TestFilter,
145    ) -> Result<BTreeMap<String, SuiteResult>> {
146        Ok(self.test_iter(filter)?.collect())
147    }
148
149    /// Executes _all_ tests that match the given `filter`.
150    ///
151    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
152    ///
153    /// Note that this method returns only when all tests have been executed.
154    pub fn test_iter(
155        &mut self,
156        filter: &dyn TestFilter,
157    ) -> Result<impl Iterator<Item = (String, SuiteResult)>> {
158        let (tx, rx) = mpsc::channel();
159        self.test(filter, tx, false)?;
160        Ok(rx.into_iter())
161    }
162
163    /// Executes _all_ tests that match the given `filter`.
164    ///
165    /// This will create the runtime based on the configured `evm` ops and create the `Backend`
166    /// before executing all contracts and their tests in _parallel_.
167    ///
168    /// Each Executor gets its own instance of the `Backend`.
169    pub fn test(
170        &mut self,
171        filter: &dyn TestFilter,
172        tx: mpsc::Sender<(String, SuiteResult)>,
173        show_progress: bool,
174    ) -> Result<()> {
175        let tokio_handle = tokio::runtime::Handle::current();
176        trace!("running all tests");
177
178        // The DB backend that serves all the data.
179        let db = Backend::spawn(self.fork.take())?;
180
181        let find_timer = Instant::now();
182        let contracts = self.matching_contracts(filter).collect::<Vec<_>>();
183        let find_time = find_timer.elapsed();
184        debug!(
185            "Found {} test contracts out of {} in {:?}",
186            contracts.len(),
187            self.contracts.len(),
188            find_time,
189        );
190
191        if show_progress {
192            let tests_progress = TestsProgress::new(contracts.len(), rayon::current_num_threads());
193            // Collect test suite results to stream at the end of test run.
194            let results: Vec<(String, SuiteResult)> = contracts
195                .par_iter()
196                .map(|&(id, contract)| {
197                    let _guard = tokio_handle.enter();
198                    tests_progress.inner.lock().start_suite_progress(&id.identifier());
199
200                    let result = self.run_test_suite(
201                        id,
202                        contract,
203                        &db,
204                        filter,
205                        &tokio_handle,
206                        Some(&tests_progress),
207                    );
208
209                    tests_progress
210                        .inner
211                        .lock()
212                        .end_suite_progress(&id.identifier(), result.summary());
213
214                    (id.identifier(), result)
215                })
216                .collect();
217
218            tests_progress.inner.lock().clear();
219
220            results.iter().for_each(|result| {
221                let _ = tx.send(result.to_owned());
222            });
223        } else {
224            contracts.par_iter().for_each(|&(id, contract)| {
225                let _guard = tokio_handle.enter();
226                let result = self.run_test_suite(id, contract, &db, filter, &tokio_handle, None);
227                let _ = tx.send((id.identifier(), result));
228            })
229        }
230
231        Ok(())
232    }
233
234    fn run_test_suite(
235        &self,
236        artifact_id: &ArtifactId,
237        contract: &TestContract,
238        db: &Backend,
239        filter: &dyn TestFilter,
240        tokio_handle: &tokio::runtime::Handle,
241        progress: Option<&TestsProgress>,
242    ) -> SuiteResult {
243        let identifier = artifact_id.identifier();
244        let mut span_name = identifier.as_str();
245
246        if !enabled!(tracing::Level::TRACE) {
247            span_name = get_contract_name(&identifier);
248        }
249        let span = debug_span!("suite", name = %span_name);
250        let span_local = span.clone();
251        let _guard = span_local.enter();
252
253        debug!("start executing all tests in contract");
254
255        let executor = self.tcfg.executor(self.known_contracts.clone(), artifact_id, db.clone());
256        let runner = ContractRunner::new(
257            &identifier,
258            contract,
259            executor,
260            progress,
261            tokio_handle,
262            span,
263            self,
264        );
265        let r = runner.run_tests(filter);
266
267        debug!(duration=?r.duration, "executed all tests in contract");
268
269        r
270    }
271}
272
273/// Configuration for the test runner.
274///
275/// This is modified after instantiation through inline config.
276#[derive(Clone)]
277pub struct TestRunnerConfig {
278    /// Project config.
279    pub config: Arc<Config>,
280    /// Inline configuration.
281    pub inline_config: Arc<InlineConfig>,
282
283    /// EVM configuration.
284    pub evm_opts: EvmOpts,
285    /// EVM environment.
286    pub env: Env,
287    /// EVM version.
288    pub spec_id: SpecId,
289    /// The address which will be used to deploy the initial contracts and send all transactions.
290    pub sender: Address,
291
292    /// Whether to collect line coverage info
293    pub line_coverage: bool,
294    /// Whether to collect debug info
295    pub debug: bool,
296    /// Whether to enable steps tracking in the tracer.
297    pub decode_internal: InternalTraceMode,
298    /// Whether to enable call isolation.
299    pub isolation: bool,
300    /// Whether to enable Odyssey features.
301    pub odyssey: bool,
302    /// Whether to exit early on test failure.
303    pub fail_fast: FailFast,
304}
305
306impl TestRunnerConfig {
307    /// Reconfigures all fields using the given `config`.
308    /// This is for example used to override the configuration with inline config.
309    pub fn reconfigure_with(&mut self, config: Arc<Config>) {
310        debug_assert!(!Arc::ptr_eq(&self.config, &config));
311
312        self.spec_id = config.evm_spec_id();
313        self.sender = config.sender;
314        self.odyssey = config.odyssey;
315        self.isolation = config.isolate;
316
317        // Specific to Forge, not present in config.
318        // TODO: self.evm_opts
319        // TODO: self.env
320        // self.coverage = N/A;
321        // self.debug = N/A;
322        // self.decode_internal = N/A;
323
324        self.config = config;
325    }
326
327    /// Configures the given executor with this configuration.
328    pub fn configure_executor(&self, executor: &mut Executor) {
329        // TODO: See above
330
331        let inspector = executor.inspector_mut();
332        // inspector.set_env(&self.env);
333        if let Some(cheatcodes) = inspector.cheatcodes.as_mut() {
334            cheatcodes.config =
335                Arc::new(cheatcodes.config.clone_with(&self.config, self.evm_opts.clone()));
336        }
337        inspector.tracing(self.trace_mode());
338        inspector.collect_line_coverage(self.line_coverage);
339        inspector.enable_isolation(self.isolation);
340        inspector.odyssey(self.odyssey);
341        // inspector.set_create2_deployer(self.evm_opts.create2_deployer);
342
343        // executor.env_mut().clone_from(&self.env);
344        executor.set_spec_id(self.spec_id);
345        // executor.set_gas_limit(self.evm_opts.gas_limit());
346        executor.set_legacy_assertions(self.config.legacy_assertions);
347    }
348
349    /// Creates a new executor with this configuration.
350    pub fn executor(
351        &self,
352        known_contracts: ContractsByArtifact,
353        artifact_id: &ArtifactId,
354        db: Backend,
355    ) -> Executor {
356        let cheats_config = Arc::new(CheatsConfig::new(
357            &self.config,
358            self.evm_opts.clone(),
359            Some(known_contracts),
360            Some(artifact_id.clone()),
361        ));
362        ExecutorBuilder::new()
363            .inspectors(|stack| {
364                stack
365                    .cheatcodes(cheats_config)
366                    .trace_mode(self.trace_mode())
367                    .line_coverage(self.line_coverage)
368                    .enable_isolation(self.isolation)
369                    .odyssey(self.odyssey)
370                    .create2_deployer(self.evm_opts.create2_deployer)
371            })
372            .spec_id(self.spec_id)
373            .gas_limit(self.evm_opts.gas_limit())
374            .legacy_assertions(self.config.legacy_assertions)
375            .build(self.env.clone(), db)
376    }
377
378    fn trace_mode(&self) -> TraceMode {
379        TraceMode::default()
380            .with_debug(self.debug)
381            .with_decode_internal(self.decode_internal)
382            .with_verbosity(self.evm_opts.verbosity)
383            .with_state_changes(verbosity() > 4)
384    }
385}
386
387/// Builder used for instantiating the multi-contract runner
388#[derive(Clone, Debug)]
389#[must_use = "builders do nothing unless you call `build` on them"]
390pub struct MultiContractRunnerBuilder {
391    /// The address which will be used to deploy the initial contracts and send all
392    /// transactions
393    pub sender: Option<Address>,
394    /// The initial balance for each one of the deployed smart contracts
395    pub initial_balance: U256,
396    /// The EVM spec to use
397    pub evm_spec: Option<SpecId>,
398    /// The fork to use at launch
399    pub fork: Option<CreateFork>,
400    /// Project config.
401    pub config: Arc<Config>,
402    /// Whether or not to collect line coverage info
403    pub line_coverage: bool,
404    /// Whether or not to collect debug info
405    pub debug: bool,
406    /// Whether to enable steps tracking in the tracer.
407    pub decode_internal: InternalTraceMode,
408    /// Whether to enable call isolation
409    pub isolation: bool,
410    /// Whether to enable Odyssey features.
411    pub odyssey: bool,
412    /// Whether to exit early on test failure.
413    pub fail_fast: bool,
414}
415
416impl MultiContractRunnerBuilder {
417    pub fn new(config: Arc<Config>) -> Self {
418        Self {
419            config,
420            sender: Default::default(),
421            initial_balance: Default::default(),
422            evm_spec: Default::default(),
423            fork: Default::default(),
424            line_coverage: Default::default(),
425            debug: Default::default(),
426            isolation: Default::default(),
427            decode_internal: Default::default(),
428            odyssey: Default::default(),
429            fail_fast: false,
430        }
431    }
432
433    pub fn sender(mut self, sender: Address) -> Self {
434        self.sender = Some(sender);
435        self
436    }
437
438    pub fn initial_balance(mut self, initial_balance: U256) -> Self {
439        self.initial_balance = initial_balance;
440        self
441    }
442
443    pub fn evm_spec(mut self, spec: SpecId) -> Self {
444        self.evm_spec = Some(spec);
445        self
446    }
447
448    pub fn with_fork(mut self, fork: Option<CreateFork>) -> Self {
449        self.fork = fork;
450        self
451    }
452
453    pub fn set_coverage(mut self, enable: bool) -> Self {
454        self.line_coverage = enable;
455        self
456    }
457
458    pub fn set_debug(mut self, enable: bool) -> Self {
459        self.debug = enable;
460        self
461    }
462
463    pub fn set_decode_internal(mut self, mode: InternalTraceMode) -> Self {
464        self.decode_internal = mode;
465        self
466    }
467
468    pub fn fail_fast(mut self, fail_fast: bool) -> Self {
469        self.fail_fast = fail_fast;
470        self
471    }
472
473    pub fn enable_isolation(mut self, enable: bool) -> Self {
474        self.isolation = enable;
475        self
476    }
477
478    pub fn odyssey(mut self, enable: bool) -> Self {
479        self.odyssey = enable;
480        self
481    }
482
483    /// Given an EVM, proceeds to return a runner which is able to execute all tests
484    /// against that evm
485    pub fn build<C: Compiler<CompilerContract = Contract>>(
486        self,
487        root: &Path,
488        output: &ProjectCompileOutput,
489        env: Env,
490        evm_opts: EvmOpts,
491    ) -> Result<MultiContractRunner> {
492        let contracts = output
493            .artifact_ids()
494            .map(|(id, v)| (id.with_stripped_file_prefixes(root), v))
495            .collect();
496        let linker = Linker::new(root, contracts);
497
498        // Build revert decoder from ABIs of all artifacts.
499        let abis = linker
500            .contracts
501            .iter()
502            .filter_map(|(_, contract)| contract.abi.as_ref().map(|abi| abi.borrow()));
503        let revert_decoder = RevertDecoder::new().with_abis(abis);
504
505        let LinkOutput { libraries, libs_to_deploy } = linker.link_with_nonce_or_address(
506            Default::default(),
507            LIBRARY_DEPLOYER,
508            0,
509            linker.contracts.keys(),
510        )?;
511
512        let linked_contracts = linker.get_linked_artifacts_cow(&libraries)?;
513
514        // Create a mapping of name => (abi, deployment code, Vec<library deployment code>)
515        let mut deployable_contracts = DeployableContracts::default();
516
517        for (id, contract) in linked_contracts.iter() {
518            let Some(abi) = contract.abi.as_ref() else { continue };
519
520            // if it's a test, link it and add to deployable contracts
521            if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true)
522                && abi.functions().any(|func| func.name.is_any_test())
523            {
524                linker.ensure_linked(contract, id)?;
525
526                let Some(bytecode) =
527                    contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
528                else {
529                    continue;
530                };
531
532                deployable_contracts
533                    .insert(id.clone(), TestContract { abi: abi.clone().into_owned(), bytecode });
534            }
535        }
536
537        // Create known contracts from linked contracts and storage layout information (if any).
538        let known_contracts = ContractsByArtifactBuilder::new(linked_contracts)
539            .with_storage_layouts(output.clone().with_stripped_file_prefixes(root))
540            .build();
541
542        Ok(MultiContractRunner {
543            contracts: deployable_contracts,
544            revert_decoder,
545            known_contracts,
546            libs_to_deploy,
547            libraries,
548
549            fork: self.fork,
550
551            tcfg: TestRunnerConfig {
552                evm_opts,
553                env,
554                spec_id: self.evm_spec.unwrap_or_else(|| self.config.evm_spec_id()),
555                sender: self.sender.unwrap_or(self.config.sender),
556                line_coverage: self.line_coverage,
557                debug: self.debug,
558                decode_internal: self.decode_internal,
559                inline_config: Arc::new(InlineConfig::new_parsed(output, &self.config)?),
560                isolation: self.isolation,
561                odyssey: self.odyssey,
562                config: self.config,
563                fail_fast: FailFast::new(self.fail_fast),
564            },
565        })
566    }
567}
568
569pub fn matches_artifact(filter: &dyn TestFilter, id: &ArtifactId, abi: &JsonAbi) -> bool {
570    matches_contract(filter, &id.source, &id.name, abi.functions())
571}
572
573pub(crate) fn matches_contract(
574    filter: &dyn TestFilter,
575    path: &Path,
576    contract_name: &str,
577    functions: impl IntoIterator<Item = impl std::borrow::Borrow<Function>>,
578) -> bool {
579    (filter.matches_path(path) && filter.matches_contract(contract_name))
580        && functions.into_iter().any(|func| filter.matches_test_function(func.borrow()))
581}
582
583/// Returns `true` if the given contract is a test contract.
584pub(crate) fn is_test_contract(
585    functions: impl IntoIterator<Item = impl std::borrow::Borrow<Function>>,
586) -> bool {
587    functions.into_iter().any(|func| func.borrow().is_any_test())
588}