forge/
multi_runner.rs

1//! Forge test runner for multiple contracts.
2
3use crate::{
4    progress::TestsProgress, result::SuiteResult, runner::LIBRARY_DEPLOYER, ContractRunner,
5    TestFilter,
6};
7use alloy_json_abi::{Function, JsonAbi};
8use alloy_primitives::{Address, Bytes, U256};
9use eyre::Result;
10use foundry_common::{get_contract_name, shell::verbosity, ContractsByArtifact, TestFunctionExt};
11use foundry_compilers::{
12    artifacts::{Contract, Libraries},
13    compilers::Compiler,
14    Artifact, ArtifactId, ProjectCompileOutput,
15};
16use foundry_config::{Config, InlineConfig};
17use foundry_evm::{
18    backend::Backend,
19    decode::RevertDecoder,
20    executors::{Executor, ExecutorBuilder},
21    fork::CreateFork,
22    inspectors::CheatsConfig,
23    opts::EvmOpts,
24    revm,
25    traces::{InternalTraceMode, TraceMode},
26};
27use foundry_linking::{LinkOutput, Linker};
28use rayon::prelude::*;
29use revm::primitives::SpecId;
30use std::{
31    borrow::Borrow,
32    collections::BTreeMap,
33    fmt::Debug,
34    path::Path,
35    sync::{mpsc, Arc},
36    time::Instant,
37};
38
39#[derive(Debug, Clone)]
40pub struct TestContract {
41    pub abi: JsonAbi,
42    pub bytecode: Bytes,
43}
44
45pub type DeployableContracts = BTreeMap<ArtifactId, TestContract>;
46
47/// A multi contract runner receives a set of contracts deployed in an EVM instance and proceeds
48/// to run all test functions in these contracts.
49pub struct MultiContractRunner {
50    /// Mapping of contract name to JsonAbi, creation bytecode and library bytecode which
51    /// needs to be deployed & linked against
52    pub contracts: DeployableContracts,
53    /// Known contracts linked with computed library addresses.
54    pub known_contracts: ContractsByArtifact,
55    /// Revert decoder. Contains all known errors and their selectors.
56    pub revert_decoder: RevertDecoder,
57    /// Libraries to deploy.
58    pub libs_to_deploy: Vec<Bytes>,
59    /// Library addresses used to link contracts.
60    pub libraries: Libraries,
61
62    /// The fork to use at launch
63    pub fork: Option<CreateFork>,
64
65    /// The base configuration for the test runner.
66    pub tcfg: TestRunnerConfig,
67}
68
69impl std::ops::Deref for MultiContractRunner {
70    type Target = TestRunnerConfig;
71
72    fn deref(&self) -> &Self::Target {
73        &self.tcfg
74    }
75}
76
77impl std::ops::DerefMut for MultiContractRunner {
78    fn deref_mut(&mut self) -> &mut Self::Target {
79        &mut self.tcfg
80    }
81}
82
83impl MultiContractRunner {
84    /// Returns an iterator over all contracts that match the filter.
85    pub fn matching_contracts<'a: 'b, 'b>(
86        &'a self,
87        filter: &'b dyn TestFilter,
88    ) -> impl Iterator<Item = (&'a ArtifactId, &'a TestContract)> + 'b {
89        self.contracts.iter().filter(|&(id, c)| matches_contract(id, &c.abi, filter))
90    }
91
92    /// Returns an iterator over all test functions that match the filter.
93    pub fn matching_test_functions<'a: 'b, 'b>(
94        &'a self,
95        filter: &'b dyn TestFilter,
96    ) -> impl Iterator<Item = &'a Function> + 'b {
97        self.matching_contracts(filter)
98            .flat_map(|(_, c)| c.abi.functions())
99            .filter(|func| is_matching_test(func, filter))
100    }
101
102    /// Returns an iterator over all test functions in contracts that match the filter.
103    pub fn all_test_functions<'a: 'b, 'b>(
104        &'a self,
105        filter: &'b dyn TestFilter,
106    ) -> impl Iterator<Item = &'a Function> + 'b {
107        self.contracts
108            .iter()
109            .filter(|(id, _)| filter.matches_path(&id.source) && filter.matches_contract(&id.name))
110            .flat_map(|(_, c)| c.abi.functions())
111            .filter(|func| func.is_any_test())
112    }
113
114    /// Returns all matching tests grouped by contract grouped by file (file -> (contract -> tests))
115    pub fn list(&self, filter: &dyn TestFilter) -> BTreeMap<String, BTreeMap<String, Vec<String>>> {
116        self.matching_contracts(filter)
117            .map(|(id, c)| {
118                let source = id.source.as_path().display().to_string();
119                let name = id.name.clone();
120                let tests = c
121                    .abi
122                    .functions()
123                    .filter(|func| is_matching_test(func, filter))
124                    .map(|func| func.name.clone())
125                    .collect::<Vec<_>>();
126                (source, name, tests)
127            })
128            .fold(BTreeMap::new(), |mut acc, (source, name, tests)| {
129                acc.entry(source).or_default().insert(name, tests);
130                acc
131            })
132    }
133
134    /// Executes _all_ tests that match the given `filter`.
135    ///
136    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
137    ///
138    /// Note that this method returns only when all tests have been executed.
139    pub fn test_collect(
140        &mut self,
141        filter: &dyn TestFilter,
142    ) -> Result<BTreeMap<String, SuiteResult>> {
143        Ok(self.test_iter(filter)?.collect())
144    }
145
146    /// Executes _all_ tests that match the given `filter`.
147    ///
148    /// The same as [`test`](Self::test), but returns the results instead of streaming them.
149    ///
150    /// Note that this method returns only when all tests have been executed.
151    pub fn test_iter(
152        &mut self,
153        filter: &dyn TestFilter,
154    ) -> Result<impl Iterator<Item = (String, SuiteResult)>> {
155        let (tx, rx) = mpsc::channel();
156        self.test(filter, tx, false)?;
157        Ok(rx.into_iter())
158    }
159
160    /// Executes _all_ tests that match the given `filter`.
161    ///
162    /// This will create the runtime based on the configured `evm` ops and create the `Backend`
163    /// before executing all contracts and their tests in _parallel_.
164    ///
165    /// Each Executor gets its own instance of the `Backend`.
166    pub fn test(
167        &mut self,
168        filter: &dyn TestFilter,
169        tx: mpsc::Sender<(String, SuiteResult)>,
170        show_progress: bool,
171    ) -> Result<()> {
172        let tokio_handle = tokio::runtime::Handle::current();
173        trace!("running all tests");
174
175        // The DB backend that serves all the data.
176        let db = Backend::spawn(self.fork.take())?;
177
178        let find_timer = Instant::now();
179        let contracts = self.matching_contracts(filter).collect::<Vec<_>>();
180        let find_time = find_timer.elapsed();
181        debug!(
182            "Found {} test contracts out of {} in {:?}",
183            contracts.len(),
184            self.contracts.len(),
185            find_time,
186        );
187
188        if show_progress {
189            let tests_progress = TestsProgress::new(contracts.len(), rayon::current_num_threads());
190            // Collect test suite results to stream at the end of test run.
191            let results: Vec<(String, SuiteResult)> = contracts
192                .par_iter()
193                .map(|&(id, contract)| {
194                    let _guard = tokio_handle.enter();
195                    tests_progress.inner.lock().start_suite_progress(&id.identifier());
196
197                    let result = self.run_test_suite(
198                        id,
199                        contract,
200                        &db,
201                        filter,
202                        &tokio_handle,
203                        Some(&tests_progress),
204                    );
205
206                    tests_progress
207                        .inner
208                        .lock()
209                        .end_suite_progress(&id.identifier(), result.summary());
210
211                    (id.identifier(), result)
212                })
213                .collect();
214
215            tests_progress.inner.lock().clear();
216
217            results.iter().for_each(|result| {
218                let _ = tx.send(result.to_owned());
219            });
220        } else {
221            contracts.par_iter().for_each(|&(id, contract)| {
222                let _guard = tokio_handle.enter();
223                let result = self.run_test_suite(id, contract, &db, filter, &tokio_handle, None);
224                let _ = tx.send((id.identifier(), result));
225            })
226        }
227
228        Ok(())
229    }
230
231    fn run_test_suite(
232        &self,
233        artifact_id: &ArtifactId,
234        contract: &TestContract,
235        db: &Backend,
236        filter: &dyn TestFilter,
237        tokio_handle: &tokio::runtime::Handle,
238        progress: Option<&TestsProgress>,
239    ) -> SuiteResult {
240        let identifier = artifact_id.identifier();
241        let mut span_name = identifier.as_str();
242
243        if !enabled!(tracing::Level::TRACE) {
244            span_name = get_contract_name(&identifier);
245        }
246        let span = debug_span!("suite", name = %span_name);
247        let span_local = span.clone();
248        let _guard = span_local.enter();
249
250        debug!("start executing all tests in contract");
251
252        let runner = ContractRunner::new(
253            &identifier,
254            contract,
255            self.tcfg.executor(self.known_contracts.clone(), artifact_id, db.clone()),
256            progress,
257            tokio_handle,
258            span,
259            self,
260        );
261        let r = runner.run_tests(filter);
262
263        debug!(duration=?r.duration, "executed all tests in contract");
264
265        r
266    }
267}
268
269/// Configuration for the test runner.
270///
271/// This is modified after instantiation through inline config.
272#[derive(Clone)]
273pub struct TestRunnerConfig {
274    /// Project config.
275    pub config: Arc<Config>,
276    /// Inline configuration.
277    pub inline_config: Arc<InlineConfig>,
278
279    /// EVM configuration.
280    pub evm_opts: EvmOpts,
281    /// EVM environment.
282    pub env: revm::primitives::Env,
283    /// EVM version.
284    pub spec_id: SpecId,
285    /// The address which will be used to deploy the initial contracts and send all transactions.
286    pub sender: Address,
287
288    /// Whether to collect coverage info
289    pub coverage: bool,
290    /// Whether to collect debug info
291    pub debug: bool,
292    /// Whether to enable steps tracking in the tracer.
293    pub decode_internal: InternalTraceMode,
294    /// Whether to enable call isolation.
295    pub isolation: bool,
296    /// Whether to enable Odyssey features.
297    pub odyssey: bool,
298}
299
300impl TestRunnerConfig {
301    /// Reconfigures all fields using the given `config`.
302    /// This is for example used to override the configuration with inline config.
303    pub fn reconfigure_with(&mut self, config: Arc<Config>) {
304        debug_assert!(!Arc::ptr_eq(&self.config, &config));
305
306        self.spec_id = config.evm_spec_id();
307        self.sender = config.sender;
308        self.odyssey = config.odyssey;
309        self.isolation = config.isolate;
310
311        // Specific to Forge, not present in config.
312        // TODO: self.evm_opts
313        // TODO: self.env
314        // self.coverage = N/A;
315        // self.debug = N/A;
316        // self.decode_internal = N/A;
317
318        self.config = config;
319    }
320
321    /// Configures the given executor with this configuration.
322    pub fn configure_executor(&self, executor: &mut Executor) {
323        // TODO: See above
324
325        let inspector = executor.inspector_mut();
326        // inspector.set_env(&self.env);
327        if let Some(cheatcodes) = inspector.cheatcodes.as_mut() {
328            cheatcodes.config =
329                Arc::new(cheatcodes.config.clone_with(&self.config, self.evm_opts.clone()));
330        }
331        inspector.tracing(self.trace_mode());
332        inspector.collect_coverage(self.coverage);
333        inspector.enable_isolation(self.isolation);
334        inspector.odyssey(self.odyssey);
335        // inspector.set_create2_deployer(self.evm_opts.create2_deployer);
336
337        // executor.env_mut().clone_from(&self.env);
338        executor.set_spec_id(self.spec_id);
339        // executor.set_gas_limit(self.evm_opts.gas_limit());
340        executor.set_legacy_assertions(self.config.legacy_assertions);
341    }
342
343    /// Creates a new executor with this configuration.
344    pub fn executor(
345        &self,
346        known_contracts: ContractsByArtifact,
347        artifact_id: &ArtifactId,
348        db: Backend,
349    ) -> Executor {
350        let cheats_config = Arc::new(CheatsConfig::new(
351            &self.config,
352            self.evm_opts.clone(),
353            Some(known_contracts),
354            Some(artifact_id.clone()),
355        ));
356        ExecutorBuilder::new()
357            .inspectors(|stack| {
358                stack
359                    .cheatcodes(cheats_config)
360                    .trace_mode(self.trace_mode())
361                    .coverage(self.coverage)
362                    .enable_isolation(self.isolation)
363                    .odyssey(self.odyssey)
364                    .create2_deployer(self.evm_opts.create2_deployer)
365            })
366            .spec_id(self.spec_id)
367            .gas_limit(self.evm_opts.gas_limit())
368            .legacy_assertions(self.config.legacy_assertions)
369            .build(self.env.clone(), db)
370    }
371
372    fn trace_mode(&self) -> TraceMode {
373        TraceMode::default()
374            .with_debug(self.debug)
375            .with_decode_internal(self.decode_internal)
376            .with_verbosity(self.evm_opts.verbosity)
377            .with_state_changes(verbosity() > 4)
378    }
379}
380
381/// Builder used for instantiating the multi-contract runner
382#[derive(Clone, Debug)]
383#[must_use = "builders do nothing unless you call `build` on them"]
384pub struct MultiContractRunnerBuilder {
385    /// The address which will be used to deploy the initial contracts and send all
386    /// transactions
387    pub sender: Option<Address>,
388    /// The initial balance for each one of the deployed smart contracts
389    pub initial_balance: U256,
390    /// The EVM spec to use
391    pub evm_spec: Option<SpecId>,
392    /// The fork to use at launch
393    pub fork: Option<CreateFork>,
394    /// Project config.
395    pub config: Arc<Config>,
396    /// Whether or not to collect coverage info
397    pub coverage: bool,
398    /// Whether or not to collect debug info
399    pub debug: bool,
400    /// Whether to enable steps tracking in the tracer.
401    pub decode_internal: InternalTraceMode,
402    /// Whether to enable call isolation
403    pub isolation: bool,
404    /// Whether to enable Odyssey features.
405    pub odyssey: bool,
406}
407
408impl MultiContractRunnerBuilder {
409    pub fn new(config: Arc<Config>) -> Self {
410        Self {
411            config,
412            sender: Default::default(),
413            initial_balance: Default::default(),
414            evm_spec: Default::default(),
415            fork: Default::default(),
416            coverage: Default::default(),
417            debug: Default::default(),
418            isolation: Default::default(),
419            decode_internal: Default::default(),
420            odyssey: Default::default(),
421        }
422    }
423
424    pub fn sender(mut self, sender: Address) -> Self {
425        self.sender = Some(sender);
426        self
427    }
428
429    pub fn initial_balance(mut self, initial_balance: U256) -> Self {
430        self.initial_balance = initial_balance;
431        self
432    }
433
434    pub fn evm_spec(mut self, spec: SpecId) -> Self {
435        self.evm_spec = Some(spec);
436        self
437    }
438
439    pub fn with_fork(mut self, fork: Option<CreateFork>) -> Self {
440        self.fork = fork;
441        self
442    }
443
444    pub fn set_coverage(mut self, enable: bool) -> Self {
445        self.coverage = enable;
446        self
447    }
448
449    pub fn set_debug(mut self, enable: bool) -> Self {
450        self.debug = enable;
451        self
452    }
453
454    pub fn set_decode_internal(mut self, mode: InternalTraceMode) -> Self {
455        self.decode_internal = mode;
456        self
457    }
458
459    pub fn enable_isolation(mut self, enable: bool) -> Self {
460        self.isolation = enable;
461        self
462    }
463
464    pub fn odyssey(mut self, enable: bool) -> Self {
465        self.odyssey = enable;
466        self
467    }
468
469    /// Given an EVM, proceeds to return a runner which is able to execute all tests
470    /// against that evm
471    pub fn build<C: Compiler<CompilerContract = Contract>>(
472        self,
473        root: &Path,
474        output: &ProjectCompileOutput,
475        env: revm::primitives::Env,
476        evm_opts: EvmOpts,
477    ) -> Result<MultiContractRunner> {
478        let contracts = output
479            .artifact_ids()
480            .map(|(id, v)| (id.with_stripped_file_prefixes(root), v))
481            .collect();
482        let linker = Linker::new(root, contracts);
483
484        // Build revert decoder from ABIs of all artifacts.
485        let abis = linker
486            .contracts
487            .iter()
488            .filter_map(|(_, contract)| contract.abi.as_ref().map(|abi| abi.borrow()));
489        let revert_decoder = RevertDecoder::new().with_abis(abis);
490
491        let LinkOutput { libraries, libs_to_deploy } = linker.link_with_nonce_or_address(
492            Default::default(),
493            LIBRARY_DEPLOYER,
494            0,
495            linker.contracts.keys(),
496        )?;
497
498        let linked_contracts = linker.get_linked_artifacts(&libraries)?;
499
500        // Create a mapping of name => (abi, deployment code, Vec<library deployment code>)
501        let mut deployable_contracts = DeployableContracts::default();
502
503        for (id, contract) in linked_contracts.iter() {
504            let Some(abi) = &contract.abi else { continue };
505
506            // if it's a test, link it and add to deployable contracts
507            if abi.constructor.as_ref().map(|c| c.inputs.is_empty()).unwrap_or(true) &&
508                abi.functions().any(|func| func.name.is_any_test())
509            {
510                let Some(bytecode) =
511                    contract.get_bytecode_bytes().map(|b| b.into_owned()).filter(|b| !b.is_empty())
512                else {
513                    continue;
514                };
515
516                deployable_contracts
517                    .insert(id.clone(), TestContract { abi: abi.clone(), bytecode });
518            }
519        }
520
521        let known_contracts = ContractsByArtifact::new(linked_contracts);
522
523        Ok(MultiContractRunner {
524            contracts: deployable_contracts,
525            revert_decoder,
526            known_contracts,
527            libs_to_deploy,
528            libraries,
529
530            fork: self.fork,
531
532            tcfg: TestRunnerConfig {
533                evm_opts,
534                env,
535                spec_id: self.evm_spec.unwrap_or_else(|| self.config.evm_spec_id()),
536                sender: self.sender.unwrap_or(self.config.sender),
537
538                coverage: self.coverage,
539                debug: self.debug,
540                decode_internal: self.decode_internal,
541                inline_config: Arc::new(InlineConfig::new_parsed(output, &self.config)?),
542                isolation: self.isolation,
543                odyssey: self.odyssey,
544
545                config: self.config,
546            },
547        })
548    }
549}
550
551pub fn matches_contract(id: &ArtifactId, abi: &JsonAbi, filter: &dyn TestFilter) -> bool {
552    (filter.matches_path(&id.source) && filter.matches_contract(&id.name)) &&
553        abi.functions().any(|func| is_matching_test(func, filter))
554}
555
556/// Returns `true` if the function is a test function that matches the given filter.
557pub(crate) fn is_matching_test(func: &Function, filter: &dyn TestFilter) -> bool {
558    func.is_any_test() && filter.matches_test(&func.signature())
559}