Skip to main content

forge/
multi_runner.rs

1//! Forge test runner for multiple contracts.
2
3use crate::{
4    ContractRunner, TestFilter, progress::TestsProgress, result::SuiteResult,
5    runner::LIBRARY_DEPLOYER,
6};
7use alloy_json_abi::{Function, JsonAbi};
8use alloy_primitives::{Address, Bytes, U256};
9use eyre::Result;
10use foundry_cli::opts::configure_pcx_from_compile_output;
11use foundry_common::{
12    ContractsByArtifact, ContractsByArtifactBuilder, TestFunctionExt, get_contract_name,
13};
14use foundry_compilers::{
15    Artifact, ArtifactId, Compiler, ProjectCompileOutput,
16    artifacts::{Contract, Libraries},
17};
18use foundry_config::{Config, InlineConfig};
19use foundry_evm::{
20    backend::Backend,
21    core::evm::{EvmEnvFor, FoundryEvmNetwork, SpecFor, TxEnvFor},
22    decode::RevertDecoder,
23    executors::{EarlyExit, Executor, ExecutorBuilder},
24    fork::CreateFork,
25    fuzz::strategies::LiteralsDictionary,
26    inspectors::CheatsConfig,
27    opts::EvmOpts,
28    traces::{InternalTraceMode, TraceMode},
29};
30use foundry_evm_networks::NetworkVariant;
31
32use foundry_linking::{LinkOutput, Linker};
33use rayon::prelude::*;
34use std::{
35    borrow::Borrow,
36    collections::BTreeMap,
37    ops::{Deref, DerefMut},
38    path::Path,
39    sync::{Arc, mpsc},
40    time::Instant,
41};
42
43#[derive(Debug, Clone)]
44pub struct TestContract {
45    pub abi: JsonAbi,
46    pub bytecode: Bytes,
47}
48
49pub type DeployableContracts = BTreeMap<ArtifactId, TestContract>;
50
51/// A multi contract runner receives a set of contracts deployed in an EVM instance and proceeds
52/// to run all test functions in these contracts.
53#[derive(Clone, Debug)]
54pub struct MultiContractRunner<FEN: FoundryEvmNetwork> {
55    /// Mapping of contract name to JsonAbi, creation bytecode and library bytecode which
56    /// needs to be deployed & linked against
57    pub contracts: DeployableContracts,
58    /// Known contracts linked with computed library addresses.
59    pub known_contracts: ContractsByArtifact,
60    /// Revert decoder. Contains all known errors and their selectors.
61    pub revert_decoder: RevertDecoder,
62    /// Libraries to deploy.
63    pub libs_to_deploy: Vec<Bytes>,
64    /// Library addresses used to link contracts.
65    pub libraries: Libraries,
66    /// Solar compiler instance, to grant syntactic and semantic analysis capabilities
67    pub analysis: Arc<solar::sema::Compiler>,
68    /// Literals dictionary for fuzzing.
69    pub fuzz_literals: LiteralsDictionary,
70
71    /// The fork to use at launch
72    pub fork: Option<CreateFork>,
73
74    /// The base configuration for the test runner.
75    pub tcfg: TestRunnerConfig<FEN>,
76}
77
78impl<FEN: FoundryEvmNetwork> Deref for MultiContractRunner<FEN> {
79    type Target = TestRunnerConfig<FEN>;
80
81    fn deref(&self) -> &Self::Target {
82        &self.tcfg
83    }
84}
85
86impl<FEN: FoundryEvmNetwork> DerefMut for MultiContractRunner<FEN> {
87    fn deref_mut(&mut self) -> &mut Self::Target {
88        &mut self.tcfg
89    }
90}
91
92impl<FEN: FoundryEvmNetwork> MultiContractRunner<FEN> {
93    /// Returns an iterator over all contracts that match the filter.
94    pub fn matching_contracts<'a: 'b, 'b>(
95        &'a self,
96        filter: &'b dyn TestFilter,
97    ) -> impl Iterator<Item = (&'a ArtifactId, &'a TestContract)> + 'b {
98        self.contracts.iter().filter(|&(id, c)| matches_artifact(filter, id, &c.abi))
99    }
100
101    /// Returns an iterator over all test functions that match the filter.
102    pub fn matching_test_functions<'a: 'b, 'b>(
103        &'a self,
104        filter: &'b dyn TestFilter,
105    ) -> impl Iterator<Item = &'a Function> + 'b {
106        self.matching_contracts(filter)
107            .flat_map(|(_, c)| c.abi.functions())
108            .filter(|func| filter.matches_test_function(func))
109    }
110
111    /// Returns an iterator over all test functions in contracts that match the filter.
112    pub fn all_test_functions<'a: 'b, 'b>(
113        &'a self,
114        filter: &'b dyn TestFilter,
115    ) -> impl Iterator<Item = &'a Function> + 'b {
116        self.contracts
117            .iter()
118            .filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
119            .flat_map(|(_, c)| c.abi.functions())
120            .filter(|func| func.is_any_test())
121    }
122
123    /// Returns all matching tests grouped by contract grouped by file (file -> (contract -> tests))
124    pub fn list(&self, filter: &dyn TestFilter) -> BTreeMap<String, BTreeMap<String, Vec<String>>> {
125        self.matching_contracts(filter)
126            .map(|(id, c)| {
127                let source = id.source.as_path().display().to_string();
128                let name = id.name.clone();
129                let tests = c
130                    .abi
131                    .functions()
132                    .filter(|func| filter.matches_test_function(func))
133                    .map(|func| func.name.clone())
134                    .collect::<Vec<_>>();
135                (source, name, tests)
136            })
137            .fold(BTreeMap::new(), |mut acc, (source, name, tests)| {
138                acc.entry(source).or_default().insert(name, tests);
139                acc
140            })
141    }
142
143    /// Executes _all_ tests that match the given `filter`.
144    ///
145    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
146    ///
147    /// Note that this method returns only when all tests have been executed.
148    pub fn test_collect(
149        &mut self,
150        filter: &dyn TestFilter,
151    ) -> Result<BTreeMap<String, SuiteResult>> {
152        Ok(self.test_iter(filter)?.collect())
153    }
154
155    /// Executes _all_ tests that match the given `filter`.
156    ///
157    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
158    ///
159    /// Note that this method returns only when all tests have been executed.
160    pub fn test_iter(
161        &mut self,
162        filter: &dyn TestFilter,
163    ) -> Result<impl Iterator<Item = (String, SuiteResult)>> {
164        let (tx, rx) = mpsc::channel();
165        self.test(filter, tx, false)?;
166        Ok(rx.into_iter())
167    }
168
169    /// Executes _all_ tests that match the given `filter`.
170    ///
171    /// This will create the runtime based on the configured `evm` ops and create the `Backend`
172    /// before executing all contracts and their tests in _parallel_.
173    ///
174    /// Each Executor gets its own instance of the `Backend`.
175    pub fn test(
176        &mut self,
177        filter: &dyn TestFilter,
178        tx: mpsc::Sender<(String, SuiteResult)>,
179        show_progress: bool,
180    ) -> Result<()> {
181        let tokio_handle = tokio::runtime::Handle::current();
182        trace!("running all tests");
183
184        // The DB backend that serves all the data.
185        let db = Backend::spawn(self.fork.take())?;
186
187        let find_timer = Instant::now();
188        let contracts = self.matching_contracts(filter).collect::<Vec<_>>();
189        let find_time = find_timer.elapsed();
190        debug!(
191            "Found {} test contracts out of {} in {:?}",
192            contracts.len(),
193            self.contracts.len(),
194            find_time,
195        );
196
197        if show_progress {
198            let tests_progress = TestsProgress::new(contracts.len(), rayon::current_num_threads());
199            // Collect test suite results to stream at the end of test run.
200            let results: Vec<(String, SuiteResult)> = contracts
201                .par_iter()
202                .map(|&(id, contract)| {
203                    let _guard = tokio_handle.enter();
204                    tests_progress.inner.lock().start_suite_progress(&id.identifier());
205
206                    let result = self.run_test_suite(
207                        id,
208                        contract,
209                        &db,
210                        filter,
211                        &tokio_handle,
212                        Some(&tests_progress),
213                    );
214
215                    tests_progress
216                        .inner
217                        .lock()
218                        .end_suite_progress(&id.identifier(), result.summary());
219
220                    (id.identifier(), result)
221                })
222                .collect();
223
224            tests_progress.inner.lock().clear();
225
226            for result in &results {
227                let _ = tx.send(result.to_owned());
228            }
229        } else {
230            contracts.par_iter().for_each(|&(id, contract)| {
231                let _guard = tokio_handle.enter();
232                let result = self.run_test_suite(id, contract, &db, filter, &tokio_handle, None);
233                let _ = tx.send((id.identifier(), result));
234            })
235        }
236
237        Ok(())
238    }
239
240    fn run_test_suite(
241        &self,
242        artifact_id: &ArtifactId,
243        contract: &TestContract,
244        db: &Backend<FEN>,
245        filter: &dyn TestFilter,
246        tokio_handle: &tokio::runtime::Handle,
247        progress: Option<&TestsProgress>,
248    ) -> SuiteResult {
249        let identifier = artifact_id.identifier();
250        let span_name = if enabled!(tracing::Level::TRACE) {
251            identifier.as_str()
252        } else {
253            get_contract_name(&identifier)
254        };
255        let span = debug_span!("suite", name = %span_name);
256        let span_local = span.clone();
257        let _guard = span_local.enter();
258
259        debug!("start executing all tests in contract");
260
261        let executor = self.tcfg.executor(
262            self.known_contracts.clone(),
263            self.analysis.clone(),
264            artifact_id,
265            db.clone(),
266        );
267        let runner = ContractRunner::new(
268            &identifier,
269            contract,
270            executor,
271            progress,
272            tokio_handle,
273            span,
274            self,
275        );
276        let r = runner.run_tests(filter);
277
278        debug!(duration=?r.duration, "executed all tests in contract");
279
280        r
281    }
282}
283
284/// Tracks network assignment across a multi-network test run.
285///
286/// When inline config specifies different networks for different tests, the runner performs one
287/// pass per distinct network. This struct encodes which pass we're in so each `ContractRunner`
288/// can skip tests that belong to a different pass.
289///
290/// Default (empty `all_override_networks`, `None` pass) = single-pass mode, every test runs.
291#[derive(Clone, Debug, Default)]
292pub struct MultiNetworkConfig {
293    /// All networks explicitly referenced in inline config annotations across the whole suite.
294    /// Empty means single-pass mode (no per-test network overrides present).
295    pub all_override_networks: Vec<NetworkVariant>,
296    /// The network this pass is responsible for.
297    /// `None` = default pass: runs tests *without* an explicit network annotation (or annotated
298    /// with a network not in `all_override_networks`).
299    /// `Some(v)` = override pass: runs only tests annotated with exactly `v`.
300    pub pass_network: Option<NetworkVariant>,
301}
302
303/// Configuration for the test runner.
304///
305/// This is modified after instantiation through inline config.
306#[derive(Clone, Debug)]
307pub struct TestRunnerConfig<FEN: FoundryEvmNetwork> {
308    /// Project config.
309    pub config: Arc<Config>,
310    /// Inline configuration.
311    pub inline_config: Arc<InlineConfig>,
312
313    /// EVM configuration.
314    pub evm_opts: EvmOpts,
315    /// EVM environment.
316    pub evm_env: EvmEnvFor<FEN>,
317    /// Transaction environment.
318    pub tx_env: TxEnvFor<FEN>,
319    /// EVM version.
320    pub spec_id: SpecFor<FEN>,
321    /// The address which will be used to deploy the initial contracts and send all transactions.
322    pub sender: Address,
323
324    /// Whether to collect line coverage info
325    pub line_coverage: bool,
326    /// Whether to collect debug info
327    pub debug: bool,
328    /// Whether to enable steps tracking in the tracer.
329    pub decode_internal: InternalTraceMode,
330    /// Whether to enable call isolation.
331    pub isolation: bool,
332    /// Whether to exit early on test failure or if test run interrupted.
333    pub early_exit: EarlyExit,
334
335    /// Multi-network pass configuration. Default = single-pass mode.
336    pub multi_network: MultiNetworkConfig,
337}
338
339impl<FEN: FoundryEvmNetwork> TestRunnerConfig<FEN> {
340    /// Reconfigures all fields using the given `config`.
341    /// This is for example used to override the configuration with inline config.
342    pub fn reconfigure_with(&mut self, config: Arc<Config>) {
343        debug_assert!(!Arc::ptr_eq(&self.config, &config));
344
345        self.spec_id = config.evm_spec_id();
346        self.sender = config.sender;
347        self.evm_opts.networks = config.networks;
348        self.isolation = config.isolate;
349
350        // Specific to Forge, not present in config.
351        // self.line_coverage = N/A;
352        // self.debug = N/A;
353        // self.decode_internal = N/A;
354
355        // TODO: self.evm_opts
356        self.evm_opts.always_use_create_2_factory = config.always_use_create_2_factory;
357
358        // TODO: self.env
359
360        self.config = config;
361    }
362
363    /// Configures the given executor with this configuration.
364    pub fn configure_executor(&self, executor: &mut Executor<FEN>) {
365        // TODO: See above
366
367        let inspector = executor.inspector_mut();
368        // inspector.set_env(&self.env);
369        if let Some(cheatcodes) = inspector.cheatcodes.as_mut() {
370            cheatcodes.config =
371                Arc::new(cheatcodes.config.clone_with(&self.config, self.evm_opts.clone()));
372        }
373        inspector.tracing(self.trace_mode());
374        inspector.collect_line_coverage(self.line_coverage);
375        inspector.enable_isolation(self.isolation);
376        inspector.networks(self.evm_opts.networks);
377        // inspector.set_create2_deployer(self.evm_opts.create2_deployer);
378
379        // executor.env_mut().clone_from(&self.env);
380        executor.set_spec_id(self.spec_id);
381        // executor.set_gas_limit(self.evm_opts.gas_limit());
382        executor.set_legacy_assertions(self.config.legacy_assertions);
383    }
384
385    /// Creates a new executor with this configuration.
386    pub fn executor(
387        &self,
388        known_contracts: ContractsByArtifact,
389        analysis: Arc<solar::sema::Compiler>,
390        artifact_id: &ArtifactId,
391        db: Backend<FEN>,
392    ) -> Executor<FEN> {
393        let cheats_config = Arc::new(CheatsConfig::new(
394            &self.config,
395            self.evm_opts.clone(),
396            Some(known_contracts),
397            Some(artifact_id.clone()),
398            None,
399        ));
400        ExecutorBuilder::default()
401            .inspectors(|stack| {
402                stack
403                    .logs(self.config.live_logs)
404                    .cheatcodes(cheats_config)
405                    .trace_mode(self.trace_mode())
406                    .line_coverage(self.line_coverage)
407                    .enable_isolation(self.isolation)
408                    .networks(self.evm_opts.networks)
409                    .create2_deployer(self.evm_opts.create2_deployer)
410                    .set_analysis(analysis)
411            })
412            .spec_id(self.spec_id)
413            .gas_limit(self.evm_opts.gas_limit())
414            .legacy_assertions(self.config.legacy_assertions)
415            .build(self.evm_env.clone(), self.tx_env.clone(), db)
416    }
417
418    fn trace_mode(&self) -> TraceMode {
419        TraceMode::default()
420            .with_debug(self.debug)
421            .with_decode_internal(self.decode_internal)
422            .with_verbosity(self.evm_opts.verbosity)
423    }
424}
425
426/// Builder used for instantiating the multi-contract runner
427#[derive(Clone)]
428#[must_use = "builders do nothing unless you call `build` on them"]
429pub struct MultiContractRunnerBuilder {
430    /// The address which will be used to deploy the initial contracts and send all
431    /// transactions
432    pub sender: Option<Address>,
433    /// The initial balance for each one of the deployed smart contracts
434    pub initial_balance: U256,
435    /// The fork to use at launch
436    pub fork: Option<CreateFork>,
437    /// Project config.
438    pub config: Arc<Config>,
439    /// Whether or not to collect line coverage info
440    pub line_coverage: bool,
441    /// Whether or not to collect debug info
442    pub debug: bool,
443    /// Whether to enable steps tracking in the tracer.
444    pub decode_internal: InternalTraceMode,
445    /// Whether to enable call isolation
446    pub isolation: bool,
447    /// Whether to exit early on test failure.
448    pub fail_fast: bool,
449    /// Multi-network pass configuration.
450    pub multi_network: MultiNetworkConfig,
451}
452
453impl MultiContractRunnerBuilder {
454    pub fn new(config: Arc<Config>) -> Self {
455        Self {
456            config,
457            sender: Default::default(),
458            initial_balance: Default::default(),
459            fork: Default::default(),
460            line_coverage: Default::default(),
461            debug: Default::default(),
462            isolation: Default::default(),
463            decode_internal: Default::default(),
464            fail_fast: false,
465            multi_network: Default::default(),
466        }
467    }
468
469    pub const fn sender(mut self, sender: Address) -> Self {
470        self.sender = Some(sender);
471        self
472    }
473
474    pub const fn initial_balance(mut self, initial_balance: U256) -> Self {
475        self.initial_balance = initial_balance;
476        self
477    }
478
479    pub fn with_fork(mut self, fork: Option<CreateFork>) -> Self {
480        self.fork = fork;
481        self
482    }
483
484    pub const fn set_coverage(mut self, enable: bool) -> Self {
485        self.line_coverage = enable;
486        self
487    }
488
489    pub const fn set_debug(mut self, enable: bool) -> Self {
490        self.debug = enable;
491        self
492    }
493
494    pub const fn set_decode_internal(mut self, mode: InternalTraceMode) -> Self {
495        self.decode_internal = mode;
496        self
497    }
498
499    pub fn with_multi_network(mut self, multi_network: MultiNetworkConfig) -> Self {
500        self.multi_network = multi_network;
501        self
502    }
503
504    pub const fn fail_fast(mut self, fail_fast: bool) -> Self {
505        self.fail_fast = fail_fast;
506        self
507    }
508
509    pub const fn enable_isolation(mut self, enable: bool) -> Self {
510        self.isolation = enable;
511        self
512    }
513
514    /// Given an EVM, proceeds to return a runner which is able to execute all tests
515    /// against that evm
516    pub fn build<FEN: FoundryEvmNetwork, C: Compiler<CompilerContract = Contract>>(
517        self,
518        output: &ProjectCompileOutput,
519        evm_env: EvmEnvFor<FEN>,
520        tx_env: TxEnvFor<FEN>,
521        evm_opts: EvmOpts,
522    ) -> Result<MultiContractRunner<FEN>> {
523        let root = &self.config.root;
524        let contracts = output
525            .artifact_ids()
526            .map(|(id, v)| (id.with_stripped_file_prefixes(root), v))
527            .collect();
528        let linker = Linker::new(root, contracts);
529
530        // Build revert decoder from ABIs of all artifacts.
531        let abis = linker
532            .contracts
533            .values()
534            .filter_map(|contract| contract.abi.as_ref().map(|abi| abi.borrow()));
535        let revert_decoder = RevertDecoder::new().with_abis(abis);
536
537        let LinkOutput { libraries, libs_to_deploy } = linker.link_with_nonce_or_address(
538            Default::default(),
539            LIBRARY_DEPLOYER,
540            0,
541            linker.contracts.keys(),
542        )?;
543
544        let linked_contracts = linker.get_linked_artifacts_cow(&libraries)?;
545
546        // Create a mapping of name => (abi, deployment code, Vec<library deployment code>)
547        let mut deployable_contracts = DeployableContracts::default();
548
549        for (id, contract) in linked_contracts.iter() {
550            let Some(abi) = contract.abi.as_ref() else { continue };
551
552            // if it's a test, link it and add to deployable contracts
553            if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true)
554                && abi.functions().any(|func| func.name.is_any_test())
555            {
556                linker.ensure_linked(contract, id)?;
557
558                let Some(bytecode) =
559                    contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
560                else {
561                    continue;
562                };
563
564                deployable_contracts
565                    .insert(id.clone(), TestContract { abi: abi.clone().into_owned(), bytecode });
566            }
567        }
568
569        // Create known contracts from linked contracts and storage layout information (if any).
570        let known_contracts =
571            ContractsByArtifactBuilder::new(linked_contracts).with_output(output, root).build();
572
573        // Initialize and configure the solar compiler.
574        let mut analysis = solar::sema::Compiler::new(
575            solar::interface::Session::builder().with_stderr_emitter().build(),
576        );
577        let dcx = analysis.dcx_mut();
578        dcx.set_emitter(Box::new(
579            solar::interface::diagnostics::HumanEmitter::stderr(Default::default())
580                .source_map(Some(dcx.source_map().unwrap())),
581        ));
582        dcx.set_flags_mut(|f| f.track_diagnostics = false);
583
584        // Populate solar's global context by parsing and lowering the sources.
585        let files: Vec<_> = output.output().sources.as_ref().keys().cloned().collect();
586
587        analysis.enter_mut(|compiler| -> Result<()> {
588            let mut pcx = compiler.parse();
589            configure_pcx_from_compile_output(
590                &mut pcx,
591                &self.config,
592                output,
593                if files.is_empty() { None } else { Some(&files) },
594            )?;
595            pcx.parse();
596            let _ = compiler.lower_asts();
597            Ok(())
598        })?;
599
600        let analysis = Arc::new(analysis);
601        let fuzz_literals = LiteralsDictionary::new(
602            Some(analysis.clone()),
603            Some(self.config.project_paths()),
604            self.config.fuzz.dictionary.max_fuzz_dictionary_literals,
605        );
606
607        Ok(MultiContractRunner {
608            contracts: deployable_contracts,
609            revert_decoder,
610            known_contracts,
611            libs_to_deploy,
612            libraries,
613            analysis,
614            fuzz_literals,
615
616            tcfg: TestRunnerConfig {
617                evm_opts,
618                evm_env,
619                tx_env,
620                spec_id: self.config.evm_spec_id(),
621                sender: self.sender.unwrap_or(self.config.sender),
622                line_coverage: self.line_coverage,
623                debug: self.debug,
624                decode_internal: self.decode_internal,
625                inline_config: Arc::new(InlineConfig::new_parsed(output, &self.config)?),
626                isolation: self.isolation,
627                early_exit: EarlyExit::new(self.fail_fast),
628                multi_network: self.multi_network,
629                config: self.config,
630            },
631
632            fork: self.fork,
633        })
634    }
635}
636
637pub fn matches_artifact(filter: &dyn TestFilter, id: &ArtifactId, abi: &JsonAbi) -> bool {
638    matches_contract(filter, &id.source, &id.name, abi.functions())
639}
640
641pub(crate) fn matches_contract(
642    filter: &dyn TestFilter,
643    path: &Path,
644    contract_name: &str,
645    functions: impl IntoIterator<Item = impl std::borrow::Borrow<Function>>,
646) -> bool {
647    (filter.matches_path(path) && filter.matches_contract(contract_name))
648        && functions.into_iter().any(|func| filter.matches_test_function(func.borrow()))
649}