forge/
multi_runner.rs

1//! Forge test runner for multiple contracts.
2
3use crate::{
4    ContractRunner, TestFilter, progress::TestsProgress, result::SuiteResult,
5    runner::LIBRARY_DEPLOYER,
6};
7use alloy_json_abi::{Function, JsonAbi};
8use alloy_primitives::{Address, Bytes, U256};
9use eyre::Result;
10use foundry_cli::opts::configure_pcx_from_compile_output;
11use foundry_common::{
12    ContractsByArtifact, ContractsByArtifactBuilder, TestFunctionExt, get_contract_name,
13    shell::verbosity,
14};
15use foundry_compilers::{
16    Artifact, ArtifactId, Compiler, ProjectCompileOutput,
17    artifacts::{Contract, Libraries},
18};
19use foundry_config::{Config, InlineConfig};
20use foundry_evm::{
21    Env,
22    backend::Backend,
23    decode::RevertDecoder,
24    executors::{EarlyExit, Executor, ExecutorBuilder},
25    fork::CreateFork,
26    inspectors::CheatsConfig,
27    opts::EvmOpts,
28    traces::{InternalTraceMode, TraceMode},
29};
30use foundry_evm_networks::NetworkConfigs;
31use foundry_linking::{LinkOutput, Linker};
32use rayon::prelude::*;
33use revm::primitives::hardfork::SpecId;
34use std::{
35    borrow::Borrow,
36    collections::BTreeMap,
37    path::Path,
38    sync::{Arc, mpsc},
39    time::Instant,
40};
41
42#[derive(Debug, Clone)]
43pub struct TestContract {
44    pub abi: JsonAbi,
45    pub bytecode: Bytes,
46}
47
48pub type DeployableContracts = BTreeMap<ArtifactId, TestContract>;
49
50/// A multi contract runner receives a set of contracts deployed in an EVM instance and proceeds
51/// to run all test functions in these contracts.
52#[derive(Clone, Debug)]
53pub struct MultiContractRunner {
54    /// Mapping of contract name to JsonAbi, creation bytecode and library bytecode which
55    /// needs to be deployed & linked against
56    pub contracts: DeployableContracts,
57    /// Known contracts linked with computed library addresses.
58    pub known_contracts: ContractsByArtifact,
59    /// Revert decoder. Contains all known errors and their selectors.
60    pub revert_decoder: RevertDecoder,
61    /// Libraries to deploy.
62    pub libs_to_deploy: Vec<Bytes>,
63    /// Library addresses used to link contracts.
64    pub libraries: Libraries,
65    /// Solar compiler instance, to grant syntactic and semantic analysis capabilities
66    pub analysis: Arc<solar::sema::Compiler>,
67
68    /// The fork to use at launch
69    pub fork: Option<CreateFork>,
70
71    /// The base configuration for the test runner.
72    pub tcfg: TestRunnerConfig,
73}
74
75impl std::ops::Deref for MultiContractRunner {
76    type Target = TestRunnerConfig;
77
78    fn deref(&self) -> &Self::Target {
79        &self.tcfg
80    }
81}
82
83impl std::ops::DerefMut for MultiContractRunner {
84    fn deref_mut(&mut self) -> &mut Self::Target {
85        &mut self.tcfg
86    }
87}
88
89impl MultiContractRunner {
90    /// Returns an iterator over all contracts that match the filter.
91    pub fn matching_contracts<'a: 'b, 'b>(
92        &'a self,
93        filter: &'b dyn TestFilter,
94    ) -> impl Iterator<Item = (&'a ArtifactId, &'a TestContract)> + 'b {
95        self.contracts.iter().filter(|&(id, c)| matches_artifact(filter, id, &c.abi))
96    }
97
98    /// Returns an iterator over all test functions that match the filter.
99    pub fn matching_test_functions<'a: 'b, 'b>(
100        &'a self,
101        filter: &'b dyn TestFilter,
102    ) -> impl Iterator<Item = &'a Function> + 'b {
103        self.matching_contracts(filter)
104            .flat_map(|(_, c)| c.abi.functions())
105            .filter(|func| filter.matches_test_function(func))
106    }
107
108    /// Returns an iterator over all test functions in contracts that match the filter.
109    pub fn all_test_functions<'a: 'b, 'b>(
110        &'a self,
111        filter: &'b dyn TestFilter,
112    ) -> impl Iterator<Item = &'a Function> + 'b {
113        self.contracts
114            .iter()
115            .filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
116            .flat_map(|(_, c)| c.abi.functions())
117            .filter(|func| func.is_any_test())
118    }
119
120    /// Returns all matching tests grouped by contract grouped by file (file -> (contract -> tests))
121    pub fn list(&self, filter: &dyn TestFilter) -> BTreeMap<String, BTreeMap<String, Vec<String>>> {
122        self.matching_contracts(filter)
123            .map(|(id, c)| {
124                let source = id.source.as_path().display().to_string();
125                let name = id.name.clone();
126                let tests = c
127                    .abi
128                    .functions()
129                    .filter(|func| filter.matches_test_function(func))
130                    .map(|func| func.name.clone())
131                    .collect::<Vec<_>>();
132                (source, name, tests)
133            })
134            .fold(BTreeMap::new(), |mut acc, (source, name, tests)| {
135                acc.entry(source).or_default().insert(name, tests);
136                acc
137            })
138    }
139
140    /// Executes _all_ tests that match the given `filter`.
141    ///
142    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
143    ///
144    /// Note that this method returns only when all tests have been executed.
145    pub fn test_collect(
146        &mut self,
147        filter: &dyn TestFilter,
148    ) -> Result<BTreeMap<String, SuiteResult>> {
149        Ok(self.test_iter(filter)?.collect())
150    }
151
152    /// Executes _all_ tests that match the given `filter`.
153    ///
154    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
155    ///
156    /// Note that this method returns only when all tests have been executed.
157    pub fn test_iter(
158        &mut self,
159        filter: &dyn TestFilter,
160    ) -> Result<impl Iterator<Item = (String, SuiteResult)>> {
161        let (tx, rx) = mpsc::channel();
162        self.test(filter, tx, false)?;
163        Ok(rx.into_iter())
164    }
165
166    /// Executes _all_ tests that match the given `filter`.
167    ///
168    /// This will create the runtime based on the configured `evm` ops and create the `Backend`
169    /// before executing all contracts and their tests in _parallel_.
170    ///
171    /// Each Executor gets its own instance of the `Backend`.
172    pub fn test(
173        &mut self,
174        filter: &dyn TestFilter,
175        tx: mpsc::Sender<(String, SuiteResult)>,
176        show_progress: bool,
177    ) -> Result<()> {
178        let tokio_handle = tokio::runtime::Handle::current();
179        trace!("running all tests");
180
181        // The DB backend that serves all the data.
182        let db = Backend::spawn(self.fork.take())?;
183
184        let find_timer = Instant::now();
185        let contracts = self.matching_contracts(filter).collect::<Vec<_>>();
186        let find_time = find_timer.elapsed();
187        debug!(
188            "Found {} test contracts out of {} in {:?}",
189            contracts.len(),
190            self.contracts.len(),
191            find_time,
192        );
193
194        if show_progress {
195            let tests_progress = TestsProgress::new(contracts.len(), rayon::current_num_threads());
196            // Collect test suite results to stream at the end of test run.
197            let results: Vec<(String, SuiteResult)> = contracts
198                .par_iter()
199                .map(|&(id, contract)| {
200                    let _guard = tokio_handle.enter();
201                    tests_progress.inner.lock().start_suite_progress(&id.identifier());
202
203                    let result = self.run_test_suite(
204                        id,
205                        contract,
206                        &db,
207                        filter,
208                        &tokio_handle,
209                        Some(&tests_progress),
210                    );
211
212                    tests_progress
213                        .inner
214                        .lock()
215                        .end_suite_progress(&id.identifier(), result.summary());
216
217                    (id.identifier(), result)
218                })
219                .collect();
220
221            tests_progress.inner.lock().clear();
222
223            results.iter().for_each(|result| {
224                let _ = tx.send(result.to_owned());
225            });
226        } else {
227            contracts.par_iter().for_each(|&(id, contract)| {
228                let _guard = tokio_handle.enter();
229                let result = self.run_test_suite(id, contract, &db, filter, &tokio_handle, None);
230                let _ = tx.send((id.identifier(), result));
231            })
232        }
233
234        Ok(())
235    }
236
237    fn run_test_suite(
238        &self,
239        artifact_id: &ArtifactId,
240        contract: &TestContract,
241        db: &Backend,
242        filter: &dyn TestFilter,
243        tokio_handle: &tokio::runtime::Handle,
244        progress: Option<&TestsProgress>,
245    ) -> SuiteResult {
246        let identifier = artifact_id.identifier();
247        let mut span_name = identifier.as_str();
248
249        if !enabled!(tracing::Level::TRACE) {
250            span_name = get_contract_name(&identifier);
251        }
252        let span = debug_span!("suite", name = %span_name);
253        let span_local = span.clone();
254        let _guard = span_local.enter();
255
256        debug!("start executing all tests in contract");
257
258        let executor = self.tcfg.executor(
259            self.known_contracts.clone(),
260            self.analysis.clone(),
261            artifact_id,
262            db.clone(),
263        );
264        let runner = ContractRunner::new(
265            &identifier,
266            contract,
267            executor,
268            progress,
269            tokio_handle,
270            span,
271            self,
272        );
273        let r = runner.run_tests(filter);
274
275        debug!(duration=?r.duration, "executed all tests in contract");
276
277        r
278    }
279}
280
281/// Configuration for the test runner.
282///
283/// This is modified after instantiation through inline config.
284#[derive(Clone, Debug)]
285pub struct TestRunnerConfig {
286    /// Project config.
287    pub config: Arc<Config>,
288    /// Inline configuration.
289    pub inline_config: Arc<InlineConfig>,
290
291    /// EVM configuration.
292    pub evm_opts: EvmOpts,
293    /// EVM environment.
294    pub env: Env,
295    /// EVM version.
296    pub spec_id: SpecId,
297    /// The address which will be used to deploy the initial contracts and send all transactions.
298    pub sender: Address,
299
300    /// Whether to collect line coverage info
301    pub line_coverage: bool,
302    /// Whether to collect debug info
303    pub debug: bool,
304    /// Whether to enable steps tracking in the tracer.
305    pub decode_internal: InternalTraceMode,
306    /// Whether to enable call isolation.
307    pub isolation: bool,
308    /// Networks with enabled features.
309    pub networks: NetworkConfigs,
310    /// Whether to exit early on test failure or if test run interrupted.
311    pub early_exit: EarlyExit,
312}
313
314impl TestRunnerConfig {
315    /// Reconfigures all fields using the given `config`.
316    /// This is for example used to override the configuration with inline config.
317    pub fn reconfigure_with(&mut self, config: Arc<Config>) {
318        debug_assert!(!Arc::ptr_eq(&self.config, &config));
319
320        self.spec_id = config.evm_spec_id();
321        self.sender = config.sender;
322        self.networks = config.networks;
323        self.isolation = config.isolate;
324
325        // Specific to Forge, not present in config.
326        // self.line_coverage = N/A;
327        // self.debug = N/A;
328        // self.decode_internal = N/A;
329
330        // TODO: self.evm_opts
331        self.evm_opts.always_use_create_2_factory = config.always_use_create_2_factory;
332
333        // TODO: self.env
334
335        self.config = config;
336    }
337
338    /// Configures the given executor with this configuration.
339    pub fn configure_executor(&self, executor: &mut Executor) {
340        // TODO: See above
341
342        let inspector = executor.inspector_mut();
343        // inspector.set_env(&self.env);
344        if let Some(cheatcodes) = inspector.cheatcodes.as_mut() {
345            cheatcodes.config =
346                Arc::new(cheatcodes.config.clone_with(&self.config, self.evm_opts.clone()));
347        }
348        inspector.tracing(self.trace_mode());
349        inspector.collect_line_coverage(self.line_coverage);
350        inspector.enable_isolation(self.isolation);
351        inspector.networks(self.networks);
352        // inspector.set_create2_deployer(self.evm_opts.create2_deployer);
353
354        // executor.env_mut().clone_from(&self.env);
355        executor.set_spec_id(self.spec_id);
356        // executor.set_gas_limit(self.evm_opts.gas_limit());
357        executor.set_legacy_assertions(self.config.legacy_assertions);
358    }
359
360    /// Creates a new executor with this configuration.
361    pub fn executor(
362        &self,
363        known_contracts: ContractsByArtifact,
364        analysis: Arc<solar::sema::Compiler>,
365        artifact_id: &ArtifactId,
366        db: Backend,
367    ) -> Executor {
368        let cheats_config = Arc::new(CheatsConfig::new(
369            &self.config,
370            self.evm_opts.clone(),
371            Some(known_contracts),
372            Some(artifact_id.clone()),
373        ));
374        ExecutorBuilder::new()
375            .inspectors(|stack| {
376                stack
377                    .cheatcodes(cheats_config)
378                    .trace_mode(self.trace_mode())
379                    .line_coverage(self.line_coverage)
380                    .enable_isolation(self.isolation)
381                    .networks(self.networks)
382                    .create2_deployer(self.evm_opts.create2_deployer)
383                    .set_analysis(analysis)
384            })
385            .spec_id(self.spec_id)
386            .gas_limit(self.evm_opts.gas_limit())
387            .legacy_assertions(self.config.legacy_assertions)
388            .build(self.env.clone(), db)
389    }
390
391    fn trace_mode(&self) -> TraceMode {
392        TraceMode::default()
393            .with_debug(self.debug)
394            .with_decode_internal(self.decode_internal)
395            .with_verbosity(self.evm_opts.verbosity)
396            .with_state_changes(verbosity() > 4)
397    }
398}
399
400/// Builder used for instantiating the multi-contract runner
401#[derive(Clone)]
402#[must_use = "builders do nothing unless you call `build` on them"]
403pub struct MultiContractRunnerBuilder {
404    /// The address which will be used to deploy the initial contracts and send all
405    /// transactions
406    pub sender: Option<Address>,
407    /// The initial balance for each one of the deployed smart contracts
408    pub initial_balance: U256,
409    /// The EVM spec to use
410    pub evm_spec: Option<SpecId>,
411    /// The fork to use at launch
412    pub fork: Option<CreateFork>,
413    /// Project config.
414    pub config: Arc<Config>,
415    /// Whether or not to collect line coverage info
416    pub line_coverage: bool,
417    /// Whether or not to collect debug info
418    pub debug: bool,
419    /// Whether to enable steps tracking in the tracer.
420    pub decode_internal: InternalTraceMode,
421    /// Whether to enable call isolation
422    pub isolation: bool,
423    /// Networks with enabled features.
424    pub networks: NetworkConfigs,
425    /// Whether to exit early on test failure.
426    pub fail_fast: bool,
427}
428
429impl MultiContractRunnerBuilder {
430    pub fn new(config: Arc<Config>) -> Self {
431        Self {
432            config,
433            sender: Default::default(),
434            initial_balance: Default::default(),
435            evm_spec: Default::default(),
436            fork: Default::default(),
437            line_coverage: Default::default(),
438            debug: Default::default(),
439            isolation: Default::default(),
440            decode_internal: Default::default(),
441            networks: Default::default(),
442            fail_fast: false,
443        }
444    }
445
446    pub fn sender(mut self, sender: Address) -> Self {
447        self.sender = Some(sender);
448        self
449    }
450
451    pub fn initial_balance(mut self, initial_balance: U256) -> Self {
452        self.initial_balance = initial_balance;
453        self
454    }
455
456    pub fn evm_spec(mut self, spec: SpecId) -> Self {
457        self.evm_spec = Some(spec);
458        self
459    }
460
461    pub fn with_fork(mut self, fork: Option<CreateFork>) -> Self {
462        self.fork = fork;
463        self
464    }
465
466    pub fn set_coverage(mut self, enable: bool) -> Self {
467        self.line_coverage = enable;
468        self
469    }
470
471    pub fn set_debug(mut self, enable: bool) -> Self {
472        self.debug = enable;
473        self
474    }
475
476    pub fn set_decode_internal(mut self, mode: InternalTraceMode) -> Self {
477        self.decode_internal = mode;
478        self
479    }
480
481    pub fn fail_fast(mut self, fail_fast: bool) -> Self {
482        self.fail_fast = fail_fast;
483        self
484    }
485
486    pub fn enable_isolation(mut self, enable: bool) -> Self {
487        self.isolation = enable;
488        self
489    }
490
491    pub fn networks(mut self, networks: NetworkConfigs) -> Self {
492        self.networks = networks;
493        self
494    }
495
496    /// Given an EVM, proceeds to return a runner which is able to execute all tests
497    /// against that evm
498    pub fn build<C: Compiler<CompilerContract = Contract>>(
499        self,
500        output: &ProjectCompileOutput,
501        env: Env,
502        evm_opts: EvmOpts,
503    ) -> Result<MultiContractRunner> {
504        let root = &self.config.root;
505        let contracts = output
506            .artifact_ids()
507            .map(|(id, v)| (id.with_stripped_file_prefixes(root), v))
508            .collect();
509        let linker = Linker::new(root, contracts);
510
511        // Build revert decoder from ABIs of all artifacts.
512        let abis = linker
513            .contracts
514            .iter()
515            .filter_map(|(_, contract)| contract.abi.as_ref().map(|abi| abi.borrow()));
516        let revert_decoder = RevertDecoder::new().with_abis(abis);
517
518        let LinkOutput { libraries, libs_to_deploy } = linker.link_with_nonce_or_address(
519            Default::default(),
520            LIBRARY_DEPLOYER,
521            0,
522            linker.contracts.keys(),
523        )?;
524
525        let linked_contracts = linker.get_linked_artifacts_cow(&libraries)?;
526
527        // Create a mapping of name => (abi, deployment code, Vec<library deployment code>)
528        let mut deployable_contracts = DeployableContracts::default();
529
530        for (id, contract) in linked_contracts.iter() {
531            let Some(abi) = contract.abi.as_ref() else { continue };
532
533            // if it's a test, link it and add to deployable contracts
534            if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true)
535                && abi.functions().any(|func| func.name.is_any_test())
536            {
537                linker.ensure_linked(contract, id)?;
538
539                let Some(bytecode) =
540                    contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
541                else {
542                    continue;
543                };
544
545                deployable_contracts
546                    .insert(id.clone(), TestContract { abi: abi.clone().into_owned(), bytecode });
547            }
548        }
549
550        // Create known contracts from linked contracts and storage layout information (if any).
551        let known_contracts =
552            ContractsByArtifactBuilder::new(linked_contracts).with_output(output, root).build();
553
554        // Initialize and configure the solar compiler.
555        let mut analysis = solar::sema::Compiler::new(
556            solar::interface::Session::builder().with_stderr_emitter().build(),
557        );
558        let dcx = analysis.dcx_mut();
559        dcx.set_emitter(Box::new(
560            solar::interface::diagnostics::HumanEmitter::stderr(Default::default())
561                .source_map(Some(dcx.source_map().unwrap())),
562        ));
563        dcx.set_flags_mut(|f| f.track_diagnostics = false);
564
565        // Populate solar's global context by parsing and lowering the sources.
566        let files: Vec<_> =
567            output.output().sources.as_ref().keys().map(|path| path.to_path_buf()).collect();
568
569        analysis.enter_mut(|compiler| -> Result<()> {
570            let mut pcx = compiler.parse();
571            configure_pcx_from_compile_output(
572                &mut pcx,
573                &self.config,
574                output,
575                if files.is_empty() { None } else { Some(&files) },
576            )?;
577            pcx.parse();
578            let _ = compiler.lower_asts();
579            Ok(())
580        })?;
581
582        Ok(MultiContractRunner {
583            contracts: deployable_contracts,
584            revert_decoder,
585            known_contracts,
586            libs_to_deploy,
587            libraries,
588            analysis: Arc::new(analysis),
589
590            tcfg: TestRunnerConfig {
591                evm_opts,
592                env,
593                spec_id: self.evm_spec.unwrap_or_else(|| self.config.evm_spec_id()),
594                sender: self.sender.unwrap_or(self.config.sender),
595                line_coverage: self.line_coverage,
596                debug: self.debug,
597                decode_internal: self.decode_internal,
598                inline_config: Arc::new(InlineConfig::new_parsed(output, &self.config)?),
599                isolation: self.isolation,
600                networks: self.networks,
601                early_exit: EarlyExit::new(self.fail_fast || self.config.show_progress),
602                config: self.config,
603            },
604
605            fork: self.fork,
606        })
607    }
608}
609
610pub fn matches_artifact(filter: &dyn TestFilter, id: &ArtifactId, abi: &JsonAbi) -> bool {
611    matches_contract(filter, &id.source, &id.name, abi.functions())
612}
613
614pub(crate) fn matches_contract(
615    filter: &dyn TestFilter,
616    path: &Path,
617    contract_name: &str,
618    functions: impl IntoIterator<Item = impl std::borrow::Borrow<Function>>,
619) -> bool {
620    (filter.matches_path(path) && filter.matches_contract(contract_name))
621        && functions.into_iter().any(|func| filter.matches_test_function(func.borrow()))
622}