foundry_test_utils/
script.rs

1use crate::{init_tracing, util::lossy_string, TestCommand};
2use alloy_primitives::{address, Address};
3use alloy_provider::Provider;
4use eyre::Result;
5use foundry_common::provider::{get_http_provider, RetryProvider};
6use std::{
7    collections::BTreeMap,
8    fs,
9    path::{Path, PathBuf},
10};
11
12const BROADCAST_TEST_PATH: &str = "src/Broadcast.t.sol";
13const TESTDATA: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata");
14
15fn init_script_cmd(
16    cmd: &mut TestCommand,
17    project_root: &Path,
18    target_contract: &str,
19    endpoint: Option<&str>,
20) {
21    cmd.forge_fuse();
22    cmd.set_current_dir(project_root);
23
24    cmd.args([
25        "script",
26        "-R",
27        "ds-test/=lib/",
28        "-R",
29        "cheats/=cheats/",
30        target_contract,
31        "--root",
32        project_root.to_str().unwrap(),
33        "-vvvvv",
34    ]);
35
36    if let Some(rpc_url) = endpoint {
37        cmd.args(["--fork-url", rpc_url]);
38    }
39}
40/// A helper struct to test forge script scenarios
41pub struct ScriptTester {
42    pub accounts_pub: Vec<Address>,
43    pub accounts_priv: Vec<String>,
44    pub provider: Option<RetryProvider>,
45    pub nonces: BTreeMap<u32, u64>,
46    pub address_nonces: BTreeMap<Address, u64>,
47    pub cmd: TestCommand,
48    pub project_root: PathBuf,
49    pub target_contract: String,
50    pub endpoint: Option<String>,
51}
52
53impl ScriptTester {
54    /// Creates a new instance of a Tester for the given contract
55    pub fn new(
56        mut cmd: TestCommand,
57        endpoint: Option<&str>,
58        project_root: &Path,
59        target_contract: &str,
60    ) -> Self {
61        init_tracing();
62        Self::copy_testdata(project_root).unwrap();
63        init_script_cmd(&mut cmd, project_root, target_contract, endpoint);
64
65        let mut provider = None;
66        if let Some(endpoint) = endpoint {
67            provider = Some(get_http_provider(endpoint))
68        }
69
70        Self {
71            accounts_pub: vec![
72                address!("0xf39Fd6e51aad88F6F4ce6aB8827279cffFb92266"),
73                address!("0x70997970C51812dc3A010C7d01b50e0d17dc79C8"),
74                address!("0x3C44CdDdB6a900fa2b585dd299e03d12FA4293BC"),
75            ],
76            accounts_priv: vec![
77                "ac0974bec39a17e36ba4a6b4d238ff944bacb478cbed5efcae784d7bf4f2ff80".to_string(),
78                "59c6995e998f97a5a0044966f0945389dc9e86dae88c7a8412f4603b6b78690d".to_string(),
79                "5de4111afa1a4b94908f83103eb1f1706367c2e68ca870fc3fb9a804cdab365a".to_string(),
80            ],
81            provider,
82            nonces: BTreeMap::default(),
83            address_nonces: BTreeMap::default(),
84            cmd,
85            project_root: project_root.to_path_buf(),
86            target_contract: target_contract.to_string(),
87            endpoint: endpoint.map(|s| s.to_string()),
88        }
89    }
90
91    /// Creates a new instance of a Tester for the `broadcast` test at the given `project_root` by
92    /// configuring the `TestCommand` with script
93    pub fn new_broadcast(cmd: TestCommand, endpoint: &str, project_root: &Path) -> Self {
94        let target_contract = project_root.join(BROADCAST_TEST_PATH).to_string_lossy().to_string();
95
96        // copy the broadcast test
97        fs::copy(
98            Self::testdata_path().join("default/cheats/Broadcast.t.sol"),
99            project_root.join(BROADCAST_TEST_PATH),
100        )
101        .expect("Failed to initialize broadcast contract");
102
103        Self::new(cmd, Some(endpoint), project_root, &target_contract)
104    }
105
106    /// Creates a new instance of a Tester for the `broadcast` test at the given `project_root` by
107    /// configuring the `TestCommand` with script without an endpoint
108    pub fn new_broadcast_without_endpoint(cmd: TestCommand, project_root: &Path) -> Self {
109        let target_contract = project_root.join(BROADCAST_TEST_PATH).to_string_lossy().to_string();
110
111        // copy the broadcast test
112        let testdata = Self::testdata_path();
113        fs::copy(
114            testdata.join("default/cheats/Broadcast.t.sol"),
115            project_root.join(BROADCAST_TEST_PATH),
116        )
117        .expect("Failed to initialize broadcast contract");
118
119        Self::new(cmd, None, project_root, &target_contract)
120    }
121
122    /// Returns the path to the dir that contains testdata
123    fn testdata_path() -> &'static Path {
124        Path::new(TESTDATA)
125    }
126
127    /// Initialises the test contracts by copying them into the workspace
128    fn copy_testdata(current_dir: &Path) -> Result<()> {
129        let testdata = Self::testdata_path();
130        fs::create_dir_all(current_dir.join("cheats"))?;
131        fs::copy(testdata.join("cheats/Vm.sol"), current_dir.join("cheats/Vm.sol"))?;
132        fs::copy(testdata.join("lib/ds-test/src/test.sol"), current_dir.join("lib/test.sol"))?;
133        Ok(())
134    }
135
136    pub async fn load_private_keys(&mut self, keys_indexes: &[u32]) -> &mut Self {
137        for &index in keys_indexes {
138            self.cmd.args(["--private-keys", &self.accounts_priv[index as usize]]);
139
140            if let Some(provider) = &self.provider {
141                let nonce = provider
142                    .get_transaction_count(self.accounts_pub[index as usize])
143                    .await
144                    .unwrap();
145                self.nonces.insert(index, nonce);
146            }
147        }
148        self
149    }
150
151    pub async fn load_addresses(&mut self, addresses: &[Address]) -> &mut Self {
152        for &address in addresses {
153            let nonce =
154                self.provider.as_ref().unwrap().get_transaction_count(address).await.unwrap();
155            self.address_nonces.insert(address, nonce);
156        }
157        self
158    }
159
160    pub fn add_deployer(&mut self, index: u32) -> &mut Self {
161        self.sender(self.accounts_pub[index as usize])
162    }
163
164    /// Adds given address as sender
165    pub fn sender(&mut self, addr: Address) -> &mut Self {
166        self.args(&["--sender", addr.to_string().as_str()])
167    }
168
169    pub fn add_sig(&mut self, contract_name: &str, sig: &str) -> &mut Self {
170        self.args(&["--tc", contract_name, "--sig", sig])
171    }
172
173    pub fn add_create2_deployer(&mut self, create2_deployer: Address) -> &mut Self {
174        self.args(&["--create2-deployer", create2_deployer.to_string().as_str()])
175    }
176
177    /// Adds the `--unlocked` flag
178    pub fn unlocked(&mut self) -> &mut Self {
179        self.arg("--unlocked")
180    }
181
182    pub fn simulate(&mut self, expected: ScriptOutcome) -> &mut Self {
183        self.run(expected)
184    }
185
186    pub fn broadcast(&mut self, expected: ScriptOutcome) -> &mut Self {
187        self.arg("--broadcast").run(expected)
188    }
189
190    pub fn resume(&mut self, expected: ScriptOutcome) -> &mut Self {
191        self.arg("--resume").run(expected)
192    }
193
194    /// `[(private_key_slot, expected increment)]`
195    pub async fn assert_nonce_increment(&mut self, keys_indexes: &[(u32, u32)]) -> &mut Self {
196        for &(private_key_slot, expected_increment) in keys_indexes {
197            let addr = self.accounts_pub[private_key_slot as usize];
198            let nonce = self.provider.as_ref().unwrap().get_transaction_count(addr).await.unwrap();
199            let prev_nonce = self.nonces.get(&private_key_slot).unwrap();
200
201            assert_eq!(
202                nonce,
203                (*prev_nonce + expected_increment as u64),
204                "nonce not incremented correctly for {addr}: \
205                 {prev_nonce} + {expected_increment} != {nonce}"
206            );
207        }
208        self
209    }
210
211    /// In Vec<(address, expected increment)>
212    pub async fn assert_nonce_increment_addresses(
213        &mut self,
214        address_indexes: &[(Address, u32)],
215    ) -> &mut Self {
216        for (address, expected_increment) in address_indexes {
217            let nonce =
218                self.provider.as_ref().unwrap().get_transaction_count(*address).await.unwrap();
219            let prev_nonce = self.address_nonces.get(address).unwrap();
220
221            assert_eq!(nonce, *prev_nonce + *expected_increment as u64);
222        }
223        self
224    }
225
226    pub fn run(&mut self, expected: ScriptOutcome) -> &mut Self {
227        let out = self.cmd.execute();
228        let (stdout, stderr) = (lossy_string(&out.stdout), lossy_string(&out.stderr));
229
230        trace!(target: "tests", "STDOUT\n{stdout}\n\nSTDERR\n{stderr}");
231
232        if !stdout.contains(expected.as_str()) && !stderr.contains(expected.as_str()) {
233            panic!(
234                "--STDOUT--\n{stdout}\n\n--STDERR--\n{stderr}\n\n--EXPECTED--\n{:?} not found in stdout or stderr",
235                expected.as_str()
236            );
237        }
238
239        self
240    }
241
242    pub fn slow(&mut self) -> &mut Self {
243        self.arg("--slow")
244    }
245
246    pub fn arg(&mut self, arg: &str) -> &mut Self {
247        self.cmd.arg(arg);
248        self
249    }
250
251    pub fn args(&mut self, args: &[&str]) -> &mut Self {
252        self.cmd.args(args);
253        self
254    }
255
256    pub fn clear(&mut self) {
257        init_script_cmd(
258            &mut self.cmd,
259            &self.project_root,
260            &self.target_contract,
261            self.endpoint.as_deref(),
262        );
263        self.nonces.clear();
264        self.address_nonces.clear();
265    }
266}
267
268/// Various `forge` script results
269#[derive(Debug)]
270pub enum ScriptOutcome {
271    OkNoEndpoint,
272    OkSimulation,
273    OkBroadcast,
274    WarnSpecifyDeployer,
275    MissingSender,
276    MissingWallet,
277    StaticCallNotAllowed,
278    ScriptFailed,
279    UnsupportedLibraries,
280    ErrorSelectForkOnBroadcast,
281    OkRun,
282}
283
284impl ScriptOutcome {
285    pub fn as_str(&self) -> &'static str {
286        match self {
287            Self::OkNoEndpoint => "If you wish to simulate on-chain transactions pass a RPC URL.",
288            Self::OkSimulation => "SIMULATION COMPLETE. To broadcast these",
289            Self::OkBroadcast => "ONCHAIN EXECUTION COMPLETE & SUCCESSFUL",
290            Self::WarnSpecifyDeployer => "Warning: You have more than one deployer who could predeploy libraries. Using `--sender` instead.",
291            Self::MissingSender => "You seem to be using Foundry's default sender. Be sure to set your own --sender",
292            Self::MissingWallet => "No associated wallet",
293            Self::StaticCallNotAllowed => "staticcall`s are not allowed after `broadcast`; use `startBroadcast` instead",
294            Self::ScriptFailed => "script failed: ",
295            Self::UnsupportedLibraries => "Multi chain deployment does not support library linking at the moment.",
296            Self::ErrorSelectForkOnBroadcast => "cannot select forks during a broadcast",
297            Self::OkRun => "Script ran successfully",
298        }
299    }
300
301    pub fn is_err(&self) -> bool {
302        match self {
303            Self::OkNoEndpoint |
304            Self::OkSimulation |
305            Self::OkBroadcast |
306            Self::WarnSpecifyDeployer |
307            Self::OkRun => false,
308            Self::MissingSender |
309            Self::MissingWallet |
310            Self::StaticCallNotAllowed |
311            Self::UnsupportedLibraries |
312            Self::ErrorSelectForkOnBroadcast |
313            Self::ScriptFailed => true,
314        }
315    }
316}