forge/
multi_runner.rs

1//! Forge test runner for multiple contracts.
2
3use crate::{
4    ContractRunner, TestFilter, progress::TestsProgress, result::SuiteResult,
5    runner::LIBRARY_DEPLOYER,
6};
7use alloy_json_abi::{Function, JsonAbi};
8use alloy_primitives::{Address, Bytes, U256};
9use eyre::Result;
10use foundry_common::{
11    ContractsByArtifact, ContractsByArtifactBuilder, TestFunctionExt, get_contract_name,
12    shell::verbosity,
13};
14use foundry_compilers::{
15    Artifact, ArtifactId, ProjectCompileOutput,
16    artifacts::{Contract, Libraries},
17    compilers::Compiler,
18};
19use foundry_config::{Config, InlineConfig};
20use foundry_evm::{
21    Env,
22    backend::Backend,
23    decode::RevertDecoder,
24    executors::{Executor, ExecutorBuilder, FailFast},
25    fork::CreateFork,
26    inspectors::CheatsConfig,
27    opts::EvmOpts,
28    traces::{InternalTraceMode, TraceMode},
29};
30use foundry_evm_networks::NetworkConfigs;
31use foundry_linking::{LinkOutput, Linker};
32use rayon::prelude::*;
33use revm::primitives::hardfork::SpecId;
34use std::{
35    borrow::Borrow,
36    collections::BTreeMap,
37    fmt::Debug,
38    path::Path,
39    sync::{Arc, mpsc},
40    time::Instant,
41};
42
43#[derive(Debug, Clone)]
44pub struct TestContract {
45    pub abi: JsonAbi,
46    pub bytecode: Bytes,
47}
48
49pub type DeployableContracts = BTreeMap<ArtifactId, TestContract>;
50
51/// A multi contract runner receives a set of contracts deployed in an EVM instance and proceeds
52/// to run all test functions in these contracts.
53pub struct MultiContractRunner {
54    /// Mapping of contract name to JsonAbi, creation bytecode and library bytecode which
55    /// needs to be deployed & linked against
56    pub contracts: DeployableContracts,
57    /// Known contracts linked with computed library addresses.
58    pub known_contracts: ContractsByArtifact,
59    /// Revert decoder. Contains all known errors and their selectors.
60    pub revert_decoder: RevertDecoder,
61    /// Libraries to deploy.
62    pub libs_to_deploy: Vec<Bytes>,
63    /// Library addresses used to link contracts.
64    pub libraries: Libraries,
65
66    /// The fork to use at launch
67    pub fork: Option<CreateFork>,
68
69    /// The base configuration for the test runner.
70    pub tcfg: TestRunnerConfig,
71}
72
73impl std::ops::Deref for MultiContractRunner {
74    type Target = TestRunnerConfig;
75
76    fn deref(&self) -> &Self::Target {
77        &self.tcfg
78    }
79}
80
81impl std::ops::DerefMut for MultiContractRunner {
82    fn deref_mut(&mut self) -> &mut Self::Target {
83        &mut self.tcfg
84    }
85}
86
87impl MultiContractRunner {
88    /// Returns an iterator over all contracts that match the filter.
89    pub fn matching_contracts<'a: 'b, 'b>(
90        &'a self,
91        filter: &'b dyn TestFilter,
92    ) -> impl Iterator<Item = (&'a ArtifactId, &'a TestContract)> + 'b {
93        self.contracts.iter().filter(|&(id, c)| matches_artifact(filter, id, &c.abi))
94    }
95
96    /// Returns an iterator over all test functions that match the filter.
97    pub fn matching_test_functions<'a: 'b, 'b>(
98        &'a self,
99        filter: &'b dyn TestFilter,
100    ) -> impl Iterator<Item = &'a Function> + 'b {
101        self.matching_contracts(filter)
102            .flat_map(|(_, c)| c.abi.functions())
103            .filter(|func| filter.matches_test_function(func))
104    }
105
106    /// Returns an iterator over all test functions in contracts that match the filter.
107    pub fn all_test_functions<'a: 'b, 'b>(
108        &'a self,
109        filter: &'b dyn TestFilter,
110    ) -> impl Iterator<Item = &'a Function> + 'b {
111        self.contracts
112            .iter()
113            .filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
114            .flat_map(|(_, c)| c.abi.functions())
115            .filter(|func| func.is_any_test())
116    }
117
118    /// Returns all matching tests grouped by contract grouped by file (file -> (contract -> tests))
119    pub fn list(&self, filter: &dyn TestFilter) -> BTreeMap<String, BTreeMap<String, Vec<String>>> {
120        self.matching_contracts(filter)
121            .map(|(id, c)| {
122                let source = id.source.as_path().display().to_string();
123                let name = id.name.clone();
124                let tests = c
125                    .abi
126                    .functions()
127                    .filter(|func| filter.matches_test_function(func))
128                    .map(|func| func.name.clone())
129                    .collect::<Vec<_>>();
130                (source, name, tests)
131            })
132            .fold(BTreeMap::new(), |mut acc, (source, name, tests)| {
133                acc.entry(source).or_default().insert(name, tests);
134                acc
135            })
136    }
137
138    /// Executes _all_ tests that match the given `filter`.
139    ///
140    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
141    ///
142    /// Note that this method returns only when all tests have been executed.
143    pub fn test_collect(
144        &mut self,
145        filter: &dyn TestFilter,
146    ) -> Result<BTreeMap<String, SuiteResult>> {
147        Ok(self.test_iter(filter)?.collect())
148    }
149
150    /// Executes _all_ tests that match the given `filter`.
151    ///
152    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
153    ///
154    /// Note that this method returns only when all tests have been executed.
155    pub fn test_iter(
156        &mut self,
157        filter: &dyn TestFilter,
158    ) -> Result<impl Iterator<Item = (String, SuiteResult)>> {
159        let (tx, rx) = mpsc::channel();
160        self.test(filter, tx, false)?;
161        Ok(rx.into_iter())
162    }
163
164    /// Executes _all_ tests that match the given `filter`.
165    ///
166    /// This will create the runtime based on the configured `evm` ops and create the `Backend`
167    /// before executing all contracts and their tests in _parallel_.
168    ///
169    /// Each Executor gets its own instance of the `Backend`.
170    pub fn test(
171        &mut self,
172        filter: &dyn TestFilter,
173        tx: mpsc::Sender<(String, SuiteResult)>,
174        show_progress: bool,
175    ) -> Result<()> {
176        let tokio_handle = tokio::runtime::Handle::current();
177        trace!("running all tests");
178
179        // The DB backend that serves all the data.
180        let db = Backend::spawn(self.fork.take())?;
181
182        let find_timer = Instant::now();
183        let contracts = self.matching_contracts(filter).collect::<Vec<_>>();
184        let find_time = find_timer.elapsed();
185        debug!(
186            "Found {} test contracts out of {} in {:?}",
187            contracts.len(),
188            self.contracts.len(),
189            find_time,
190        );
191
192        if show_progress {
193            let tests_progress = TestsProgress::new(contracts.len(), rayon::current_num_threads());
194            // Collect test suite results to stream at the end of test run.
195            let results: Vec<(String, SuiteResult)> = contracts
196                .par_iter()
197                .map(|&(id, contract)| {
198                    let _guard = tokio_handle.enter();
199                    tests_progress.inner.lock().start_suite_progress(&id.identifier());
200
201                    let result = self.run_test_suite(
202                        id,
203                        contract,
204                        &db,
205                        filter,
206                        &tokio_handle,
207                        Some(&tests_progress),
208                    );
209
210                    tests_progress
211                        .inner
212                        .lock()
213                        .end_suite_progress(&id.identifier(), result.summary());
214
215                    (id.identifier(), result)
216                })
217                .collect();
218
219            tests_progress.inner.lock().clear();
220
221            results.iter().for_each(|result| {
222                let _ = tx.send(result.to_owned());
223            });
224        } else {
225            contracts.par_iter().for_each(|&(id, contract)| {
226                let _guard = tokio_handle.enter();
227                let result = self.run_test_suite(id, contract, &db, filter, &tokio_handle, None);
228                let _ = tx.send((id.identifier(), result));
229            })
230        }
231
232        Ok(())
233    }
234
235    fn run_test_suite(
236        &self,
237        artifact_id: &ArtifactId,
238        contract: &TestContract,
239        db: &Backend,
240        filter: &dyn TestFilter,
241        tokio_handle: &tokio::runtime::Handle,
242        progress: Option<&TestsProgress>,
243    ) -> SuiteResult {
244        let identifier = artifact_id.identifier();
245        let mut span_name = identifier.as_str();
246
247        if !enabled!(tracing::Level::TRACE) {
248            span_name = get_contract_name(&identifier);
249        }
250        let span = debug_span!("suite", name = %span_name);
251        let span_local = span.clone();
252        let _guard = span_local.enter();
253
254        debug!("start executing all tests in contract");
255
256        let executor = self.tcfg.executor(self.known_contracts.clone(), artifact_id, db.clone());
257        let runner = ContractRunner::new(
258            &identifier,
259            contract,
260            executor,
261            progress,
262            tokio_handle,
263            span,
264            self,
265        );
266        let r = runner.run_tests(filter);
267
268        debug!(duration=?r.duration, "executed all tests in contract");
269
270        r
271    }
272}
273
274/// Configuration for the test runner.
275///
276/// This is modified after instantiation through inline config.
277#[derive(Clone)]
278pub struct TestRunnerConfig {
279    /// Project config.
280    pub config: Arc<Config>,
281    /// Inline configuration.
282    pub inline_config: Arc<InlineConfig>,
283
284    /// EVM configuration.
285    pub evm_opts: EvmOpts,
286    /// EVM environment.
287    pub env: Env,
288    /// EVM version.
289    pub spec_id: SpecId,
290    /// The address which will be used to deploy the initial contracts and send all transactions.
291    pub sender: Address,
292
293    /// Whether to collect line coverage info
294    pub line_coverage: bool,
295    /// Whether to collect debug info
296    pub debug: bool,
297    /// Whether to enable steps tracking in the tracer.
298    pub decode_internal: InternalTraceMode,
299    /// Whether to enable call isolation.
300    pub isolation: bool,
301    /// Networks with enabled features.
302    pub networks: NetworkConfigs,
303    /// Whether to exit early on test failure.
304    pub fail_fast: FailFast,
305}
306
307impl TestRunnerConfig {
308    /// Reconfigures all fields using the given `config`.
309    /// This is for example used to override the configuration with inline config.
310    pub fn reconfigure_with(&mut self, config: Arc<Config>) {
311        debug_assert!(!Arc::ptr_eq(&self.config, &config));
312
313        self.spec_id = config.evm_spec_id();
314        self.sender = config.sender;
315        self.networks.celo = config.celo;
316        self.isolation = config.isolate;
317
318        // Specific to Forge, not present in config.
319        // TODO: self.evm_opts
320        // TODO: self.env
321        // self.coverage = N/A;
322        // self.debug = N/A;
323        // self.decode_internal = N/A;
324
325        self.config = config;
326    }
327
328    /// Configures the given executor with this configuration.
329    pub fn configure_executor(&self, executor: &mut Executor) {
330        // TODO: See above
331
332        let inspector = executor.inspector_mut();
333        // inspector.set_env(&self.env);
334        if let Some(cheatcodes) = inspector.cheatcodes.as_mut() {
335            cheatcodes.config =
336                Arc::new(cheatcodes.config.clone_with(&self.config, self.evm_opts.clone()));
337        }
338        inspector.tracing(self.trace_mode());
339        inspector.collect_line_coverage(self.line_coverage);
340        inspector.enable_isolation(self.isolation);
341        inspector.networks(self.networks);
342        // inspector.set_create2_deployer(self.evm_opts.create2_deployer);
343
344        // executor.env_mut().clone_from(&self.env);
345        executor.set_spec_id(self.spec_id);
346        // executor.set_gas_limit(self.evm_opts.gas_limit());
347        executor.set_legacy_assertions(self.config.legacy_assertions);
348    }
349
350    /// Creates a new executor with this configuration.
351    pub fn executor(
352        &self,
353        known_contracts: ContractsByArtifact,
354        artifact_id: &ArtifactId,
355        db: Backend,
356    ) -> Executor {
357        let cheats_config = Arc::new(CheatsConfig::new(
358            &self.config,
359            self.evm_opts.clone(),
360            Some(known_contracts),
361            Some(artifact_id.clone()),
362        ));
363        ExecutorBuilder::new()
364            .inspectors(|stack| {
365                stack
366                    .cheatcodes(cheats_config)
367                    .trace_mode(self.trace_mode())
368                    .line_coverage(self.line_coverage)
369                    .enable_isolation(self.isolation)
370                    .networks(self.networks)
371                    .create2_deployer(self.evm_opts.create2_deployer)
372            })
373            .spec_id(self.spec_id)
374            .gas_limit(self.evm_opts.gas_limit())
375            .legacy_assertions(self.config.legacy_assertions)
376            .build(self.env.clone(), db)
377    }
378
379    fn trace_mode(&self) -> TraceMode {
380        TraceMode::default()
381            .with_debug(self.debug)
382            .with_decode_internal(self.decode_internal)
383            .with_verbosity(self.evm_opts.verbosity)
384            .with_state_changes(verbosity() > 4)
385    }
386}
387
388/// Builder used for instantiating the multi-contract runner
389#[derive(Clone, Debug)]
390#[must_use = "builders do nothing unless you call `build` on them"]
391pub struct MultiContractRunnerBuilder {
392    /// The address which will be used to deploy the initial contracts and send all
393    /// transactions
394    pub sender: Option<Address>,
395    /// The initial balance for each one of the deployed smart contracts
396    pub initial_balance: U256,
397    /// The EVM spec to use
398    pub evm_spec: Option<SpecId>,
399    /// The fork to use at launch
400    pub fork: Option<CreateFork>,
401    /// Project config.
402    pub config: Arc<Config>,
403    /// Whether or not to collect line coverage info
404    pub line_coverage: bool,
405    /// Whether or not to collect debug info
406    pub debug: bool,
407    /// Whether to enable steps tracking in the tracer.
408    pub decode_internal: InternalTraceMode,
409    /// Whether to enable call isolation
410    pub isolation: bool,
411    /// Networks with enabled features.
412    pub networks: NetworkConfigs,
413    /// Whether to exit early on test failure.
414    pub fail_fast: bool,
415}
416
417impl MultiContractRunnerBuilder {
418    pub fn new(config: Arc<Config>) -> Self {
419        Self {
420            config,
421            sender: Default::default(),
422            initial_balance: Default::default(),
423            evm_spec: Default::default(),
424            fork: Default::default(),
425            line_coverage: Default::default(),
426            debug: Default::default(),
427            isolation: Default::default(),
428            decode_internal: Default::default(),
429            networks: Default::default(),
430            fail_fast: false,
431        }
432    }
433
434    pub fn sender(mut self, sender: Address) -> Self {
435        self.sender = Some(sender);
436        self
437    }
438
439    pub fn initial_balance(mut self, initial_balance: U256) -> Self {
440        self.initial_balance = initial_balance;
441        self
442    }
443
444    pub fn evm_spec(mut self, spec: SpecId) -> Self {
445        self.evm_spec = Some(spec);
446        self
447    }
448
449    pub fn with_fork(mut self, fork: Option<CreateFork>) -> Self {
450        self.fork = fork;
451        self
452    }
453
454    pub fn set_coverage(mut self, enable: bool) -> Self {
455        self.line_coverage = enable;
456        self
457    }
458
459    pub fn set_debug(mut self, enable: bool) -> Self {
460        self.debug = enable;
461        self
462    }
463
464    pub fn set_decode_internal(mut self, mode: InternalTraceMode) -> Self {
465        self.decode_internal = mode;
466        self
467    }
468
469    pub fn fail_fast(mut self, fail_fast: bool) -> Self {
470        self.fail_fast = fail_fast;
471        self
472    }
473
474    pub fn enable_isolation(mut self, enable: bool) -> Self {
475        self.isolation = enable;
476        self
477    }
478
479    pub fn networks(mut self, networks: NetworkConfigs) -> Self {
480        self.networks = networks;
481        self
482    }
483
484    /// Given an EVM, proceeds to return a runner which is able to execute all tests
485    /// against that evm
486    pub fn build<C: Compiler<CompilerContract = Contract>>(
487        self,
488        root: &Path,
489        output: &ProjectCompileOutput,
490        env: Env,
491        evm_opts: EvmOpts,
492    ) -> Result<MultiContractRunner> {
493        let contracts = output
494            .artifact_ids()
495            .map(|(id, v)| (id.with_stripped_file_prefixes(root), v))
496            .collect();
497        let linker = Linker::new(root, contracts);
498
499        // Build revert decoder from ABIs of all artifacts.
500        let abis = linker
501            .contracts
502            .iter()
503            .filter_map(|(_, contract)| contract.abi.as_ref().map(|abi| abi.borrow()));
504        let revert_decoder = RevertDecoder::new().with_abis(abis);
505
506        let LinkOutput { libraries, libs_to_deploy } = linker.link_with_nonce_or_address(
507            Default::default(),
508            LIBRARY_DEPLOYER,
509            0,
510            linker.contracts.keys(),
511        )?;
512
513        let linked_contracts = linker.get_linked_artifacts_cow(&libraries)?;
514
515        // Create a mapping of name => (abi, deployment code, Vec<library deployment code>)
516        let mut deployable_contracts = DeployableContracts::default();
517
518        for (id, contract) in linked_contracts.iter() {
519            let Some(abi) = contract.abi.as_ref() else { continue };
520
521            // if it's a test, link it and add to deployable contracts
522            if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true)
523                && abi.functions().any(|func| func.name.is_any_test())
524            {
525                linker.ensure_linked(contract, id)?;
526
527                let Some(bytecode) =
528                    contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
529                else {
530                    continue;
531                };
532
533                deployable_contracts
534                    .insert(id.clone(), TestContract { abi: abi.clone().into_owned(), bytecode });
535            }
536        }
537
538        // Create known contracts from linked contracts and storage layout information (if any).
539        let known_contracts =
540            ContractsByArtifactBuilder::new(linked_contracts).with_output(output, root).build();
541
542        Ok(MultiContractRunner {
543            contracts: deployable_contracts,
544            revert_decoder,
545            known_contracts,
546            libs_to_deploy,
547            libraries,
548
549            fork: self.fork,
550
551            tcfg: TestRunnerConfig {
552                evm_opts,
553                env,
554                spec_id: self.evm_spec.unwrap_or_else(|| self.config.evm_spec_id()),
555                sender: self.sender.unwrap_or(self.config.sender),
556                line_coverage: self.line_coverage,
557                debug: self.debug,
558                decode_internal: self.decode_internal,
559                inline_config: Arc::new(InlineConfig::new_parsed(output, &self.config)?),
560                isolation: self.isolation,
561                networks: self.networks,
562                config: self.config,
563                fail_fast: FailFast::new(self.fail_fast),
564            },
565        })
566    }
567}
568
569pub fn matches_artifact(filter: &dyn TestFilter, id: &ArtifactId, abi: &JsonAbi) -> bool {
570    matches_contract(filter, &id.source, &id.name, abi.functions())
571}
572
573pub(crate) fn matches_contract(
574    filter: &dyn TestFilter,
575    path: &Path,
576    contract_name: &str,
577    functions: impl IntoIterator<Item = impl std::borrow::Borrow<Function>>,
578) -> bool {
579    (filter.matches_path(path) && filter.matches_contract(contract_name))
580        && functions.into_iter().any(|func| filter.matches_test_function(func.borrow()))
581}