Skip to main content

forge/
multi_runner.rs

1//! Forge test runner for multiple contracts.
2
3use crate::{
4    ContractRunner, TestFilter, progress::TestsProgress, result::SuiteResult,
5    runner::LIBRARY_DEPLOYER,
6};
7use alloy_json_abi::{Function, JsonAbi};
8use alloy_primitives::{Address, Bytes, U256};
9use eyre::Result;
10use foundry_cli::opts::configure_pcx_from_compile_output;
11use foundry_common::{
12    ContractsByArtifact, ContractsByArtifactBuilder, TestFunctionExt, get_contract_name,
13};
14use foundry_compilers::{
15    Artifact, ArtifactId, Compiler, ProjectCompileOutput,
16    artifacts::{Contract, Libraries},
17};
18use foundry_config::{Config, InlineConfig};
19use foundry_evm::{
20    backend::Backend,
21    core::evm::{EvmEnvFor, FoundryEvmNetwork, SpecFor, TxEnvFor},
22    decode::RevertDecoder,
23    executors::{EarlyExit, Executor, ExecutorBuilder},
24    fork::CreateFork,
25    fuzz::strategies::LiteralsDictionary,
26    inspectors::CheatsConfig,
27    opts::EvmOpts,
28    traces::{InternalTraceMode, TraceMode},
29};
30
31use foundry_linking::{LinkOutput, Linker};
32use rayon::prelude::*;
33use std::{
34    borrow::Borrow,
35    collections::BTreeMap,
36    ops::{Deref, DerefMut},
37    path::Path,
38    sync::{Arc, mpsc},
39    time::Instant,
40};
41
42#[derive(Debug, Clone)]
43pub struct TestContract {
44    pub abi: JsonAbi,
45    pub bytecode: Bytes,
46}
47
48pub type DeployableContracts = BTreeMap<ArtifactId, TestContract>;
49
50/// A multi contract runner receives a set of contracts deployed in an EVM instance and proceeds
51/// to run all test functions in these contracts.
52#[derive(Clone, Debug)]
53pub struct MultiContractRunner<FEN: FoundryEvmNetwork> {
54    /// Mapping of contract name to JsonAbi, creation bytecode and library bytecode which
55    /// needs to be deployed & linked against
56    pub contracts: DeployableContracts,
57    /// Known contracts linked with computed library addresses.
58    pub known_contracts: ContractsByArtifact,
59    /// Revert decoder. Contains all known errors and their selectors.
60    pub revert_decoder: RevertDecoder,
61    /// Libraries to deploy.
62    pub libs_to_deploy: Vec<Bytes>,
63    /// Library addresses used to link contracts.
64    pub libraries: Libraries,
65    /// Solar compiler instance, to grant syntactic and semantic analysis capabilities
66    pub analysis: Arc<solar::sema::Compiler>,
67    /// Literals dictionary for fuzzing.
68    pub fuzz_literals: LiteralsDictionary,
69
70    /// The fork to use at launch
71    pub fork: Option<CreateFork>,
72
73    /// The base configuration for the test runner.
74    pub tcfg: TestRunnerConfig<FEN>,
75}
76
77impl<FEN: FoundryEvmNetwork> Deref for MultiContractRunner<FEN> {
78    type Target = TestRunnerConfig<FEN>;
79
80    fn deref(&self) -> &Self::Target {
81        &self.tcfg
82    }
83}
84
85impl<FEN: FoundryEvmNetwork> DerefMut for MultiContractRunner<FEN> {
86    fn deref_mut(&mut self) -> &mut Self::Target {
87        &mut self.tcfg
88    }
89}
90
91impl<FEN: FoundryEvmNetwork> MultiContractRunner<FEN> {
92    /// Returns an iterator over all contracts that match the filter.
93    pub fn matching_contracts<'a: 'b, 'b>(
94        &'a self,
95        filter: &'b dyn TestFilter,
96    ) -> impl Iterator<Item = (&'a ArtifactId, &'a TestContract)> + 'b {
97        self.contracts.iter().filter(|&(id, c)| matches_artifact(filter, id, &c.abi))
98    }
99
100    /// Returns an iterator over all test functions that match the filter.
101    pub fn matching_test_functions<'a: 'b, 'b>(
102        &'a self,
103        filter: &'b dyn TestFilter,
104    ) -> impl Iterator<Item = &'a Function> + 'b {
105        self.matching_contracts(filter)
106            .flat_map(|(_, c)| c.abi.functions())
107            .filter(|func| filter.matches_test_function(func))
108    }
109
110    /// Returns an iterator over all test functions in contracts that match the filter.
111    pub fn all_test_functions<'a: 'b, 'b>(
112        &'a self,
113        filter: &'b dyn TestFilter,
114    ) -> impl Iterator<Item = &'a Function> + 'b {
115        self.contracts
116            .iter()
117            .filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
118            .flat_map(|(_, c)| c.abi.functions())
119            .filter(|func| func.is_any_test())
120    }
121
122    /// Returns all matching tests grouped by contract grouped by file (file -> (contract -> tests))
123    pub fn list(&self, filter: &dyn TestFilter) -> BTreeMap<String, BTreeMap<String, Vec<String>>> {
124        self.matching_contracts(filter)
125            .map(|(id, c)| {
126                let source = id.source.as_path().display().to_string();
127                let name = id.name.clone();
128                let tests = c
129                    .abi
130                    .functions()
131                    .filter(|func| filter.matches_test_function(func))
132                    .map(|func| func.name.clone())
133                    .collect::<Vec<_>>();
134                (source, name, tests)
135            })
136            .fold(BTreeMap::new(), |mut acc, (source, name, tests)| {
137                acc.entry(source).or_default().insert(name, tests);
138                acc
139            })
140    }
141
142    /// Executes _all_ tests that match the given `filter`.
143    ///
144    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
145    ///
146    /// Note that this method returns only when all tests have been executed.
147    pub fn test_collect(
148        &mut self,
149        filter: &dyn TestFilter,
150    ) -> Result<BTreeMap<String, SuiteResult>> {
151        Ok(self.test_iter(filter)?.collect())
152    }
153
154    /// Executes _all_ tests that match the given `filter`.
155    ///
156    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
157    ///
158    /// Note that this method returns only when all tests have been executed.
159    pub fn test_iter(
160        &mut self,
161        filter: &dyn TestFilter,
162    ) -> Result<impl Iterator<Item = (String, SuiteResult)>> {
163        let (tx, rx) = mpsc::channel();
164        self.test(filter, tx, false)?;
165        Ok(rx.into_iter())
166    }
167
168    /// Executes _all_ tests that match the given `filter`.
169    ///
170    /// This will create the runtime based on the configured `evm` ops and create the `Backend`
171    /// before executing all contracts and their tests in _parallel_.
172    ///
173    /// Each Executor gets its own instance of the `Backend`.
174    pub fn test(
175        &mut self,
176        filter: &dyn TestFilter,
177        tx: mpsc::Sender<(String, SuiteResult)>,
178        show_progress: bool,
179    ) -> Result<()> {
180        let tokio_handle = tokio::runtime::Handle::current();
181        trace!("running all tests");
182
183        // The DB backend that serves all the data.
184        let db = Backend::spawn(self.fork.take())?;
185
186        let find_timer = Instant::now();
187        let contracts = self.matching_contracts(filter).collect::<Vec<_>>();
188        let find_time = find_timer.elapsed();
189        debug!(
190            "Found {} test contracts out of {} in {:?}",
191            contracts.len(),
192            self.contracts.len(),
193            find_time,
194        );
195
196        if show_progress {
197            let tests_progress = TestsProgress::new(contracts.len(), rayon::current_num_threads());
198            // Collect test suite results to stream at the end of test run.
199            let results: Vec<(String, SuiteResult)> = contracts
200                .par_iter()
201                .map(|&(id, contract)| {
202                    let _guard = tokio_handle.enter();
203                    tests_progress.inner.lock().start_suite_progress(&id.identifier());
204
205                    let result = self.run_test_suite(
206                        id,
207                        contract,
208                        &db,
209                        filter,
210                        &tokio_handle,
211                        Some(&tests_progress),
212                    );
213
214                    tests_progress
215                        .inner
216                        .lock()
217                        .end_suite_progress(&id.identifier(), result.summary());
218
219                    (id.identifier(), result)
220                })
221                .collect();
222
223            tests_progress.inner.lock().clear();
224
225            for result in &results {
226                let _ = tx.send(result.to_owned());
227            }
228        } else {
229            contracts.par_iter().for_each(|&(id, contract)| {
230                let _guard = tokio_handle.enter();
231                let result = self.run_test_suite(id, contract, &db, filter, &tokio_handle, None);
232                let _ = tx.send((id.identifier(), result));
233            })
234        }
235
236        Ok(())
237    }
238
239    fn run_test_suite(
240        &self,
241        artifact_id: &ArtifactId,
242        contract: &TestContract,
243        db: &Backend<FEN>,
244        filter: &dyn TestFilter,
245        tokio_handle: &tokio::runtime::Handle,
246        progress: Option<&TestsProgress>,
247    ) -> SuiteResult {
248        let identifier = artifact_id.identifier();
249        let span_name = if enabled!(tracing::Level::TRACE) {
250            identifier.as_str()
251        } else {
252            get_contract_name(&identifier)
253        };
254        let span = debug_span!("suite", name = %span_name);
255        let span_local = span.clone();
256        let _guard = span_local.enter();
257
258        debug!("start executing all tests in contract");
259
260        let executor = self.tcfg.executor(
261            self.known_contracts.clone(),
262            self.analysis.clone(),
263            artifact_id,
264            db.clone(),
265        );
266        let runner = ContractRunner::new(
267            &identifier,
268            contract,
269            executor,
270            progress,
271            tokio_handle,
272            span,
273            self,
274        );
275        let r = runner.run_tests(filter);
276
277        debug!(duration=?r.duration, "executed all tests in contract");
278
279        r
280    }
281}
282
283/// Configuration for the test runner.
284///
285/// This is modified after instantiation through inline config.
286#[derive(Clone, Debug)]
287pub struct TestRunnerConfig<FEN: FoundryEvmNetwork> {
288    /// Project config.
289    pub config: Arc<Config>,
290    /// Inline configuration.
291    pub inline_config: Arc<InlineConfig>,
292
293    /// EVM configuration.
294    pub evm_opts: EvmOpts,
295    /// EVM environment.
296    pub evm_env: EvmEnvFor<FEN>,
297    /// Transaction environment.
298    pub tx_env: TxEnvFor<FEN>,
299    /// EVM version.
300    pub spec_id: SpecFor<FEN>,
301    /// The address which will be used to deploy the initial contracts and send all transactions.
302    pub sender: Address,
303
304    /// Whether to collect line coverage info
305    pub line_coverage: bool,
306    /// Whether to collect debug info
307    pub debug: bool,
308    /// Whether to enable steps tracking in the tracer.
309    pub decode_internal: InternalTraceMode,
310    /// Whether to enable call isolation.
311    pub isolation: bool,
312    /// Whether to exit early on test failure or if test run interrupted.
313    pub early_exit: EarlyExit,
314}
315
316impl<FEN: FoundryEvmNetwork> TestRunnerConfig<FEN> {
317    /// Reconfigures all fields using the given `config`.
318    /// This is for example used to override the configuration with inline config.
319    pub fn reconfigure_with(&mut self, config: Arc<Config>) {
320        debug_assert!(!Arc::ptr_eq(&self.config, &config));
321
322        self.spec_id = config.evm_spec_id();
323        self.sender = config.sender;
324        self.evm_opts.networks = config.networks;
325        self.isolation = config.isolate;
326
327        // Specific to Forge, not present in config.
328        // self.line_coverage = N/A;
329        // self.debug = N/A;
330        // self.decode_internal = N/A;
331
332        // TODO: self.evm_opts
333        self.evm_opts.always_use_create_2_factory = config.always_use_create_2_factory;
334
335        // TODO: self.env
336
337        self.config = config;
338    }
339
340    /// Configures the given executor with this configuration.
341    pub fn configure_executor(&self, executor: &mut Executor<FEN>) {
342        // TODO: See above
343
344        let inspector = executor.inspector_mut();
345        // inspector.set_env(&self.env);
346        if let Some(cheatcodes) = inspector.cheatcodes.as_mut() {
347            cheatcodes.config =
348                Arc::new(cheatcodes.config.clone_with(&self.config, self.evm_opts.clone()));
349        }
350        inspector.tracing(self.trace_mode());
351        inspector.collect_line_coverage(self.line_coverage);
352        inspector.enable_isolation(self.isolation);
353        inspector.networks(self.evm_opts.networks);
354        // inspector.set_create2_deployer(self.evm_opts.create2_deployer);
355
356        // executor.env_mut().clone_from(&self.env);
357        executor.set_spec_id(self.spec_id);
358        // executor.set_gas_limit(self.evm_opts.gas_limit());
359        executor.set_legacy_assertions(self.config.legacy_assertions);
360    }
361
362    /// Creates a new executor with this configuration.
363    pub fn executor(
364        &self,
365        known_contracts: ContractsByArtifact,
366        analysis: Arc<solar::sema::Compiler>,
367        artifact_id: &ArtifactId,
368        db: Backend<FEN>,
369    ) -> Executor<FEN> {
370        let cheats_config = Arc::new(CheatsConfig::new(
371            &self.config,
372            self.evm_opts.clone(),
373            Some(known_contracts),
374            Some(artifact_id.clone()),
375            None,
376        ));
377        ExecutorBuilder::default()
378            .inspectors(|stack| {
379                stack
380                    .logs(self.config.live_logs)
381                    .cheatcodes(cheats_config)
382                    .trace_mode(self.trace_mode())
383                    .line_coverage(self.line_coverage)
384                    .enable_isolation(self.isolation)
385                    .networks(self.evm_opts.networks)
386                    .create2_deployer(self.evm_opts.create2_deployer)
387                    .set_analysis(analysis)
388            })
389            .spec_id(self.spec_id)
390            .gas_limit(self.evm_opts.gas_limit())
391            .legacy_assertions(self.config.legacy_assertions)
392            .build(self.evm_env.clone(), self.tx_env.clone(), db)
393    }
394
395    fn trace_mode(&self) -> TraceMode {
396        TraceMode::default()
397            .with_debug(self.debug)
398            .with_decode_internal(self.decode_internal)
399            .with_verbosity(self.evm_opts.verbosity)
400    }
401}
402
403/// Builder used for instantiating the multi-contract runner
404#[derive(Clone)]
405#[must_use = "builders do nothing unless you call `build` on them"]
406pub struct MultiContractRunnerBuilder {
407    /// The address which will be used to deploy the initial contracts and send all
408    /// transactions
409    pub sender: Option<Address>,
410    /// The initial balance for each one of the deployed smart contracts
411    pub initial_balance: U256,
412    /// The fork to use at launch
413    pub fork: Option<CreateFork>,
414    /// Project config.
415    pub config: Arc<Config>,
416    /// Whether or not to collect line coverage info
417    pub line_coverage: bool,
418    /// Whether or not to collect debug info
419    pub debug: bool,
420    /// Whether to enable steps tracking in the tracer.
421    pub decode_internal: InternalTraceMode,
422    /// Whether to enable call isolation
423    pub isolation: bool,
424    /// Whether to exit early on test failure.
425    pub fail_fast: bool,
426}
427
428impl MultiContractRunnerBuilder {
429    pub fn new(config: Arc<Config>) -> Self {
430        Self {
431            config,
432            sender: Default::default(),
433            initial_balance: Default::default(),
434            fork: Default::default(),
435            line_coverage: Default::default(),
436            debug: Default::default(),
437            isolation: Default::default(),
438            decode_internal: Default::default(),
439            fail_fast: false,
440        }
441    }
442
443    pub const fn sender(mut self, sender: Address) -> Self {
444        self.sender = Some(sender);
445        self
446    }
447
448    pub const fn initial_balance(mut self, initial_balance: U256) -> Self {
449        self.initial_balance = initial_balance;
450        self
451    }
452
453    pub fn with_fork(mut self, fork: Option<CreateFork>) -> Self {
454        self.fork = fork;
455        self
456    }
457
458    pub const fn set_coverage(mut self, enable: bool) -> Self {
459        self.line_coverage = enable;
460        self
461    }
462
463    pub const fn set_debug(mut self, enable: bool) -> Self {
464        self.debug = enable;
465        self
466    }
467
468    pub const fn set_decode_internal(mut self, mode: InternalTraceMode) -> Self {
469        self.decode_internal = mode;
470        self
471    }
472
473    pub const fn fail_fast(mut self, fail_fast: bool) -> Self {
474        self.fail_fast = fail_fast;
475        self
476    }
477
478    pub const fn enable_isolation(mut self, enable: bool) -> Self {
479        self.isolation = enable;
480        self
481    }
482
483    /// Given an EVM, proceeds to return a runner which is able to execute all tests
484    /// against that evm
485    pub fn build<FEN: FoundryEvmNetwork, C: Compiler<CompilerContract = Contract>>(
486        self,
487        output: &ProjectCompileOutput,
488        evm_env: EvmEnvFor<FEN>,
489        tx_env: TxEnvFor<FEN>,
490        evm_opts: EvmOpts,
491    ) -> Result<MultiContractRunner<FEN>> {
492        let root = &self.config.root;
493        let contracts = output
494            .artifact_ids()
495            .map(|(id, v)| (id.with_stripped_file_prefixes(root), v))
496            .collect();
497        let linker = Linker::new(root, contracts);
498
499        // Build revert decoder from ABIs of all artifacts.
500        let abis = linker
501            .contracts
502            .values()
503            .filter_map(|contract| contract.abi.as_ref().map(|abi| abi.borrow()));
504        let revert_decoder = RevertDecoder::new().with_abis(abis);
505
506        let LinkOutput { libraries, libs_to_deploy } = linker.link_with_nonce_or_address(
507            Default::default(),
508            LIBRARY_DEPLOYER,
509            0,
510            linker.contracts.keys(),
511        )?;
512
513        let linked_contracts = linker.get_linked_artifacts_cow(&libraries)?;
514
515        // Create a mapping of name => (abi, deployment code, Vec<library deployment code>)
516        let mut deployable_contracts = DeployableContracts::default();
517
518        for (id, contract) in linked_contracts.iter() {
519            let Some(abi) = contract.abi.as_ref() else { continue };
520
521            // if it's a test, link it and add to deployable contracts
522            if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true)
523                && abi.functions().any(|func| func.name.is_any_test())
524            {
525                linker.ensure_linked(contract, id)?;
526
527                let Some(bytecode) =
528                    contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
529                else {
530                    continue;
531                };
532
533                deployable_contracts
534                    .insert(id.clone(), TestContract { abi: abi.clone().into_owned(), bytecode });
535            }
536        }
537
538        // Create known contracts from linked contracts and storage layout information (if any).
539        let known_contracts =
540            ContractsByArtifactBuilder::new(linked_contracts).with_output(output, root).build();
541
542        // Initialize and configure the solar compiler.
543        let mut analysis = solar::sema::Compiler::new(
544            solar::interface::Session::builder().with_stderr_emitter().build(),
545        );
546        let dcx = analysis.dcx_mut();
547        dcx.set_emitter(Box::new(
548            solar::interface::diagnostics::HumanEmitter::stderr(Default::default())
549                .source_map(Some(dcx.source_map().unwrap())),
550        ));
551        dcx.set_flags_mut(|f| f.track_diagnostics = false);
552
553        // Populate solar's global context by parsing and lowering the sources.
554        let files: Vec<_> = output.output().sources.as_ref().keys().cloned().collect();
555
556        analysis.enter_mut(|compiler| -> Result<()> {
557            let mut pcx = compiler.parse();
558            configure_pcx_from_compile_output(
559                &mut pcx,
560                &self.config,
561                output,
562                if files.is_empty() { None } else { Some(&files) },
563            )?;
564            pcx.parse();
565            let _ = compiler.lower_asts();
566            Ok(())
567        })?;
568
569        let analysis = Arc::new(analysis);
570        let fuzz_literals = LiteralsDictionary::new(
571            Some(analysis.clone()),
572            Some(self.config.project_paths()),
573            self.config.fuzz.dictionary.max_fuzz_dictionary_literals,
574        );
575
576        Ok(MultiContractRunner {
577            contracts: deployable_contracts,
578            revert_decoder,
579            known_contracts,
580            libs_to_deploy,
581            libraries,
582            analysis,
583            fuzz_literals,
584
585            tcfg: TestRunnerConfig {
586                evm_opts,
587                evm_env,
588                tx_env,
589                spec_id: self.config.evm_spec_id(),
590                sender: self.sender.unwrap_or(self.config.sender),
591                line_coverage: self.line_coverage,
592                debug: self.debug,
593                decode_internal: self.decode_internal,
594                inline_config: Arc::new(InlineConfig::new_parsed(output, &self.config)?),
595                isolation: self.isolation,
596                early_exit: EarlyExit::new(self.fail_fast),
597                config: self.config,
598            },
599
600            fork: self.fork,
601        })
602    }
603}
604
605pub fn matches_artifact(filter: &dyn TestFilter, id: &ArtifactId, abi: &JsonAbi) -> bool {
606    matches_contract(filter, &id.source, &id.name, abi.functions())
607}
608
609pub(crate) fn matches_contract(
610    filter: &dyn TestFilter,
611    path: &Path,
612    contract_name: &str,
613    functions: impl IntoIterator<Item = impl std::borrow::Borrow<Function>>,
614) -> bool {
615    (filter.matches_path(path) && filter.matches_contract(contract_name))
616        && functions.into_iter().any(|func| filter.matches_test_function(func.borrow()))
617}