Skip to main content

foundry_cli/utils/
mod.rs

1use alloy_json_abi::JsonAbi;
2use alloy_primitives::{Address, U256, map::HashMap};
3use alloy_provider::{Provider, network::AnyNetwork};
4use eyre::{ContextCompat, Result};
5use foundry_common::{
6    provider::{ProviderBuilder, RetryProvider},
7    shell,
8};
9use foundry_config::{Chain, Config};
10use itertools::Itertools;
11use path_slash::PathExt;
12use regex::Regex;
13use serde::de::DeserializeOwned;
14use std::{
15    ffi::OsStr,
16    path::{Path, PathBuf},
17    process::{Command, Output, Stdio},
18    str::FromStr,
19    sync::LazyLock,
20    time::{Duration, SystemTime, UNIX_EPOCH},
21};
22use tracing_subscriber::prelude::*;
23
24mod cmd;
25pub use cmd::*;
26
27mod suggestions;
28pub use suggestions::*;
29
30mod abi;
31pub use abi::*;
32
33mod allocator;
34pub use allocator::*;
35
36// reexport all `foundry_config::utils`
37#[doc(hidden)]
38pub use foundry_config::utils::*;
39
40/// Deterministic fuzzer seed used for gas snapshots and coverage reports.
41///
42/// The keccak256 hash of "foundry rulez"
43pub const STATIC_FUZZ_SEED: [u8; 32] = [
44    0x01, 0x00, 0xfa, 0x69, 0xa5, 0xf1, 0x71, 0x0a, 0x95, 0xcd, 0xef, 0x94, 0x88, 0x9b, 0x02, 0x84,
45    0x5d, 0x64, 0x0b, 0x19, 0xad, 0xf0, 0xe3, 0x57, 0xb8, 0xd4, 0xbe, 0x7d, 0x49, 0xee, 0x70, 0xe6,
46];
47
48/// Regex used to parse `.gitmodules` file and capture the submodule path and branch.
49pub static SUBMODULE_BRANCH_REGEX: LazyLock<Regex> =
50    LazyLock::new(|| Regex::new(r#"\[submodule "([^"]+)"\](?:[^\[]*?branch = ([^\s]+))"#).unwrap());
51/// Regex used to parse `git submodule status` output.
52pub static SUBMODULE_STATUS_REGEX: LazyLock<Regex> =
53    LazyLock::new(|| Regex::new(r"^[\s+-]?([a-f0-9]+)\s+([^\s]+)(?:\s+\([^)]+\))?$").unwrap());
54
55/// Useful extensions to [`std::path::Path`].
56pub trait FoundryPathExt {
57    /// Returns true if the [`Path`] ends with `.t.sol`
58    fn is_sol_test(&self) -> bool;
59
60    /// Returns true if the  [`Path`] has a `sol` extension
61    fn is_sol(&self) -> bool;
62
63    /// Returns true if the  [`Path`] has a `yul` extension
64    fn is_yul(&self) -> bool;
65}
66
67impl<T: AsRef<Path>> FoundryPathExt for T {
68    fn is_sol_test(&self) -> bool {
69        self.as_ref()
70            .file_name()
71            .and_then(|s| s.to_str())
72            .map(|s| s.ends_with(".t.sol"))
73            .unwrap_or_default()
74    }
75
76    fn is_sol(&self) -> bool {
77        self.as_ref().extension() == Some(std::ffi::OsStr::new("sol"))
78    }
79
80    fn is_yul(&self) -> bool {
81        self.as_ref().extension() == Some(std::ffi::OsStr::new("yul"))
82    }
83}
84
85/// Initializes a tracing Subscriber for logging
86pub fn subscriber() {
87    let registry = tracing_subscriber::Registry::default().with(env_filter());
88    #[cfg(feature = "tracy")]
89    let registry = registry.with(tracing_tracy::TracyLayer::default());
90    registry.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr)).init()
91}
92
93fn env_filter() -> tracing_subscriber::EnvFilter {
94    const DEFAULT_DIRECTIVES: &[&str] = &include!("./default_directives.txt");
95    let mut filter = tracing_subscriber::EnvFilter::from_default_env();
96    for &directive in DEFAULT_DIRECTIVES {
97        filter = filter.add_directive(directive.parse().unwrap());
98    }
99    filter
100}
101
102/// Returns a [RetryProvider] instantiated using [Config]'s RPC settings.
103pub fn get_provider(config: &Config) -> Result<RetryProvider> {
104    get_provider_builder(config, false)?.build()
105}
106
107/// Returns a [RetryProvider] with curl mode option.
108///
109/// When `curl_mode` is true, the provider will print equivalent curl commands
110/// to stdout instead of executing RPC requests.
111pub fn get_provider_with_curl(config: &Config, curl_mode: bool) -> Result<RetryProvider> {
112    get_provider_builder(config, curl_mode)?.build()
113}
114
115/// Returns a [ProviderBuilder] instantiated using [Config] values.
116///
117/// Defaults to `http://localhost:8545` and `Mainnet`.
118pub fn get_provider_builder(config: &Config, curl_mode: bool) -> Result<ProviderBuilder> {
119    ProviderBuilder::from_config(config).map(|builder| builder.curl_mode(curl_mode))
120}
121
122pub async fn get_chain<P>(chain: Option<Chain>, provider: P) -> Result<Chain>
123where
124    P: Provider<AnyNetwork>,
125{
126    match chain {
127        Some(chain) => Ok(chain),
128        None => Ok(Chain::from_id(provider.get_chain_id().await?)),
129    }
130}
131
132/// Parses an ether value from a string.
133///
134/// The amount can be tagged with a unit, e.g. "1ether".
135///
136/// If the string represents an untagged amount (e.g. "100") then
137/// it is interpreted as wei.
138pub fn parse_ether_value(value: &str) -> Result<U256> {
139    Ok(if value.starts_with("0x") {
140        U256::from_str_radix(value, 16)?
141    } else {
142        alloy_dyn_abi::DynSolType::coerce_str(&alloy_dyn_abi::DynSolType::Uint(256), value)?
143            .as_uint()
144            .wrap_err("Could not parse ether value from string")?
145            .0
146    })
147}
148
149/// Parses a `T` from a string using [`serde_json::from_str`].
150pub fn parse_json<T: DeserializeOwned>(value: &str) -> serde_json::Result<T> {
151    serde_json::from_str(value)
152}
153
154/// Parses a `Duration` from a &str
155pub fn parse_delay(delay: &str) -> Result<Duration> {
156    let delay = if delay.ends_with("ms") {
157        let d: u64 = delay.trim_end_matches("ms").parse()?;
158        Duration::from_millis(d)
159    } else {
160        let d: f64 = delay.parse()?;
161        let delay = (d * 1000.0).round();
162        if delay.is_infinite() || delay.is_nan() || delay.is_sign_negative() {
163            eyre::bail!("delay must be finite and non-negative");
164        }
165
166        Duration::from_millis(delay as u64)
167    };
168    Ok(delay)
169}
170
171/// Returns the current time as a [`Duration`] since the Unix epoch.
172pub fn now() -> Duration {
173    SystemTime::now().duration_since(UNIX_EPOCH).expect("time went backwards")
174}
175
176/// Common setup for all CLI tools. Does not include [tracing subscriber](subscriber).
177pub fn common_setup() {
178    install_crypto_provider();
179    crate::handler::install();
180    load_dotenv();
181    enable_paint();
182}
183
184/// Loads a dotenv file, from the cwd and the project root, ignoring potential failure.
185///
186/// We could use `warn!` here, but that would imply that the dotenv file can't configure
187/// the logging behavior of Foundry.
188///
189/// Similarly, we could just use `eprintln!`, but colors are off limits otherwise dotenv is implied
190/// to not be able to configure the colors. It would also mess up the JSON output.
191pub fn load_dotenv() {
192    let load = |p: &Path| {
193        dotenvy::from_path(p.join(".env")).ok();
194    };
195
196    // we only want the .env file of the cwd and project root
197    // `find_project_root` calls `current_dir` internally so both paths are either both `Ok` or
198    // both `Err`
199    if let (Ok(cwd), Ok(prj_root)) = (std::env::current_dir(), find_project_root(None)) {
200        load(&prj_root);
201        if cwd != prj_root {
202            // prj root and cwd can be identical
203            load(&cwd);
204        }
205    };
206}
207
208/// Sets the default [`yansi`] color output condition.
209pub fn enable_paint() {
210    let enable = yansi::Condition::os_support() && yansi::Condition::tty_and_color_live();
211    yansi::whenever(yansi::Condition::cached(enable));
212}
213
214/// This force installs the default crypto provider.
215///
216/// This is necessary in case there are more than one available backends enabled in rustls (ring,
217/// aws-lc-rs).
218///
219/// This should be called high in the main fn.
220///
221/// See also:
222///   <https://github.com/snapview/tokio-tungstenite/issues/353#issuecomment-2455100010>
223///   <https://github.com/awslabs/aws-sdk-rust/discussions/1257>
224pub fn install_crypto_provider() {
225    // https://github.com/snapview/tokio-tungstenite/issues/353
226    rustls::crypto::ring::default_provider()
227        .install_default()
228        .expect("Failed to install default rustls crypto provider");
229}
230
231/// Fetches the ABI of a contract from Etherscan.
232pub async fn fetch_abi_from_etherscan(
233    address: Address,
234    config: &foundry_config::Config,
235) -> Result<Vec<(JsonAbi, String)>> {
236    let chain = config.chain.unwrap_or_default();
237    let api_key = config.get_etherscan_api_key(Some(chain)).unwrap_or_default();
238    let client = foundry_block_explorers::Client::new(chain, api_key)?;
239    let source = client.contract_source_code(address).await?;
240    source.items.into_iter().map(|item| Ok((item.abi()?, item.contract_name))).collect()
241}
242
243/// Useful extensions to [`std::process::Command`].
244pub trait CommandUtils {
245    /// Returns the command's output if execution is successful, otherwise, throws an error.
246    fn exec(&mut self) -> Result<Output>;
247
248    /// Returns the command's stdout if execution is successful, otherwise, throws an error.
249    fn get_stdout_lossy(&mut self) -> Result<String>;
250}
251
252impl CommandUtils for Command {
253    #[track_caller]
254    fn exec(&mut self) -> Result<Output> {
255        trace!(command=?self, "executing");
256
257        let output = self.output()?;
258
259        trace!(code=?output.status.code(), ?output);
260
261        if output.status.success() {
262            Ok(output)
263        } else {
264            let stdout = String::from_utf8_lossy(&output.stdout);
265            let stdout = stdout.trim();
266            let stderr = String::from_utf8_lossy(&output.stderr);
267            let stderr = stderr.trim();
268            let msg = if stdout.is_empty() {
269                stderr.to_string()
270            } else if stderr.is_empty() {
271                stdout.to_string()
272            } else {
273                format!("stdout:\n{stdout}\n\nstderr:\n{stderr}")
274            };
275
276            let mut name = self.get_program().to_string_lossy();
277            if let Some(arg) = self.get_args().next() {
278                let arg = arg.to_string_lossy();
279                if !arg.starts_with('-') {
280                    let name = name.to_mut();
281                    name.push(' ');
282                    name.push_str(&arg);
283                }
284            }
285
286            let mut err = match output.status.code() {
287                Some(code) => format!("{name} exited with code {code}"),
288                None => format!("{name} terminated by a signal"),
289            };
290            if !msg.is_empty() {
291                err.push(':');
292                err.push(if msg.lines().count() == 0 { ' ' } else { '\n' });
293                err.push_str(&msg);
294            }
295            Err(eyre::eyre!(err))
296        }
297    }
298
299    #[track_caller]
300    fn get_stdout_lossy(&mut self) -> Result<String> {
301        let output = self.exec()?;
302        let stdout = String::from_utf8_lossy(&output.stdout);
303        Ok(stdout.trim().into())
304    }
305}
306
307#[derive(Clone, Copy, Debug)]
308pub struct Git<'a> {
309    pub root: &'a Path,
310    pub quiet: bool,
311    pub shallow: bool,
312}
313
314impl<'a> Git<'a> {
315    pub fn new(root: &'a Path) -> Self {
316        Self { root, quiet: shell::is_quiet(), shallow: false }
317    }
318
319    pub fn from_config(config: &'a Config) -> Self {
320        Self::new(config.root.as_path())
321    }
322
323    pub fn root_of(relative_to: &Path) -> Result<PathBuf> {
324        let output = Self::cmd_no_root()
325            .current_dir(relative_to)
326            .args(["rev-parse", "--show-toplevel"])
327            .get_stdout_lossy()?;
328        Ok(PathBuf::from(output))
329    }
330
331    pub fn clone_with_branch(
332        shallow: bool,
333        from: impl AsRef<OsStr>,
334        branch: impl AsRef<OsStr>,
335        to: Option<impl AsRef<OsStr>>,
336    ) -> Result<()> {
337        Self::cmd_no_root()
338            .stderr(Stdio::inherit())
339            .args(["clone", "--recurse-submodules"])
340            .args(shallow.then_some("--depth=1"))
341            .args(shallow.then_some("--shallow-submodules"))
342            .arg("-b")
343            .arg(branch)
344            .arg(from)
345            .args(to)
346            .exec()
347            .map(drop)
348    }
349
350    pub fn clone(
351        shallow: bool,
352        from: impl AsRef<OsStr>,
353        to: Option<impl AsRef<OsStr>>,
354    ) -> Result<()> {
355        Self::cmd_no_root()
356            .stderr(Stdio::inherit())
357            .args(["clone", "--recurse-submodules"])
358            .args(shallow.then_some("--depth=1"))
359            .args(shallow.then_some("--shallow-submodules"))
360            .arg(from)
361            .args(to)
362            .exec()
363            .map(drop)
364    }
365
366    pub fn fetch(
367        self,
368        shallow: bool,
369        remote: impl AsRef<OsStr>,
370        branch: Option<impl AsRef<OsStr>>,
371    ) -> Result<()> {
372        self.cmd()
373            .stderr(Stdio::inherit())
374            .arg("fetch")
375            .args(shallow.then_some("--no-tags"))
376            .args(shallow.then_some("--depth=1"))
377            .arg(remote)
378            .args(branch)
379            .exec()
380            .map(drop)
381    }
382
383    pub fn root(self, root: &Path) -> Git<'_> {
384        Git { root, ..self }
385    }
386
387    pub fn quiet(self, quiet: bool) -> Self {
388        Self { quiet, ..self }
389    }
390
391    /// True to perform shallow clones
392    pub fn shallow(self, shallow: bool) -> Self {
393        Self { shallow, ..self }
394    }
395
396    pub fn checkout(self, recursive: bool, tag: impl AsRef<OsStr>) -> Result<()> {
397        self.cmd()
398            .arg("checkout")
399            .args(recursive.then_some("--recurse-submodules"))
400            .arg(tag)
401            .exec()
402            .map(drop)
403    }
404
405    /// Returns the current HEAD commit hash of the current branch.
406    pub fn head(self) -> Result<String> {
407        self.cmd().args(["rev-parse", "HEAD"]).get_stdout_lossy()
408    }
409
410    pub fn checkout_at(self, tag: impl AsRef<OsStr>, at: &Path) -> Result<()> {
411        self.cmd_at(at).arg("checkout").arg(tag).exec().map(drop)
412    }
413
414    pub fn init(self) -> Result<()> {
415        self.cmd().arg("init").exec().map(drop)
416    }
417
418    pub fn current_rev_branch(self, at: &Path) -> Result<(String, String)> {
419        let rev = self.cmd_at(at).args(["rev-parse", "HEAD"]).get_stdout_lossy()?;
420        let branch =
421            self.cmd_at(at).args(["rev-parse", "--abbrev-ref", "HEAD"]).get_stdout_lossy()?;
422        Ok((rev, branch))
423    }
424
425    #[expect(clippy::should_implement_trait)] // this is not std::ops::Add clippy
426    pub fn add<I, S>(self, paths: I) -> Result<()>
427    where
428        I: IntoIterator<Item = S>,
429        S: AsRef<OsStr>,
430    {
431        self.cmd().arg("add").args(paths).exec().map(drop)
432    }
433
434    pub fn reset(self, hard: bool, tree: impl AsRef<OsStr>) -> Result<()> {
435        self.cmd().arg("reset").args(hard.then_some("--hard")).arg(tree).exec().map(drop)
436    }
437
438    pub fn commit_tree(
439        self,
440        tree: impl AsRef<OsStr>,
441        msg: Option<impl AsRef<OsStr>>,
442    ) -> Result<String> {
443        self.cmd()
444            .arg("commit-tree")
445            .arg(tree)
446            .args(msg.as_ref().is_some().then_some("-m"))
447            .args(msg)
448            .get_stdout_lossy()
449    }
450
451    pub fn rm<I, S>(self, force: bool, paths: I) -> Result<()>
452    where
453        I: IntoIterator<Item = S>,
454        S: AsRef<OsStr>,
455    {
456        self.cmd().arg("rm").args(force.then_some("--force")).args(paths).exec().map(drop)
457    }
458
459    pub fn commit(self, msg: &str) -> Result<()> {
460        let output = self
461            .cmd()
462            .args(["commit", "-m", msg])
463            .args(cfg!(any(test, debug_assertions)).then_some("--no-gpg-sign"))
464            .output()?;
465        if !output.status.success() {
466            let stdout = String::from_utf8_lossy(&output.stdout);
467            let stderr = String::from_utf8_lossy(&output.stderr);
468            // ignore "nothing to commit" error
469            let msg = "nothing to commit, working tree clean";
470            if !(stdout.contains(msg) || stderr.contains(msg)) {
471                return Err(eyre::eyre!(
472                    "failed to commit (code={:?}, stdout={:?}, stderr={:?})",
473                    output.status.code(),
474                    stdout.trim(),
475                    stderr.trim()
476                ));
477            }
478        }
479        Ok(())
480    }
481
482    pub fn is_in_repo(self) -> std::io::Result<bool> {
483        self.cmd().args(["rev-parse", "--is-inside-work-tree"]).status().map(|s| s.success())
484    }
485
486    pub fn is_repo_root(self) -> Result<bool> {
487        self.cmd().args(["rev-parse", "--show-cdup"]).exec().map(|out| out.stdout.is_empty())
488    }
489
490    pub fn is_clean(self) -> Result<bool> {
491        self.cmd().args(["status", "--porcelain"]).exec().map(|out| out.stdout.is_empty())
492    }
493
494    pub fn has_branch(self, branch: impl AsRef<OsStr>, at: &Path) -> Result<bool> {
495        self.cmd_at(at)
496            .args(["branch", "--list", "--no-color"])
497            .arg(branch)
498            .get_stdout_lossy()
499            .map(|stdout| !stdout.is_empty())
500    }
501
502    pub fn has_tag(self, tag: impl AsRef<OsStr>, at: &Path) -> Result<bool> {
503        self.cmd_at(at)
504            .args(["tag", "--list"])
505            .arg(tag)
506            .get_stdout_lossy()
507            .map(|stdout| !stdout.is_empty())
508    }
509
510    pub fn has_rev(self, rev: impl AsRef<OsStr>, at: &Path) -> Result<bool> {
511        self.cmd_at(at)
512            .args(["cat-file", "-t"])
513            .arg(rev)
514            .get_stdout_lossy()
515            .map(|stdout| &stdout == "commit")
516    }
517
518    pub fn get_rev(self, tag_or_branch: impl AsRef<OsStr>, at: &Path) -> Result<String> {
519        self.cmd_at(at).args(["rev-list", "-n", "1"]).arg(tag_or_branch).get_stdout_lossy()
520    }
521
522    pub fn ensure_clean(self) -> Result<()> {
523        if self.is_clean()? {
524            Ok(())
525        } else {
526            Err(eyre::eyre!(
527                "\
528The target directory is a part of or on its own an already initialized git repository,
529and it requires clean working and staging areas, including no untracked files.
530
531Check the current git repository's status with `git status`.
532Then, you can track files with `git add ...` and then commit them with `git commit`,
533ignore them in the `.gitignore` file."
534            ))
535        }
536    }
537
538    pub fn commit_hash(self, short: bool, revision: &str) -> Result<String> {
539        self.cmd()
540            .arg("rev-parse")
541            .args(short.then_some("--short"))
542            .arg(revision)
543            .get_stdout_lossy()
544    }
545
546    pub fn tag(self) -> Result<String> {
547        self.cmd().arg("tag").get_stdout_lossy()
548    }
549
550    /// Returns the tag the commit first appeared in.
551    ///
552    /// E.g Take rev = `abc1234`. This commit can be found in multiple releases (tags).
553    /// Consider releases: `v0.1.0`, `v0.2.0`, `v0.3.0` in chronological order, `rev` first appeared
554    /// in `v0.2.0`.
555    ///
556    /// Hence, `tag_for_commit("abc1234")` will return `v0.2.0`.
557    pub fn tag_for_commit(self, rev: &str, at: &Path) -> Result<Option<String>> {
558        self.cmd_at(at)
559            .args(["tag", "--contains"])
560            .arg(rev)
561            .get_stdout_lossy()
562            .map(|stdout| stdout.lines().next().map(str::to_string))
563    }
564
565    /// Returns a list of tuples of submodule paths and their respective branches.
566    ///
567    /// This function reads the `.gitmodules` file and returns the paths of all submodules that have
568    /// a branch. The paths are relative to the Git::root_of(git.root) and not lib/ directory.
569    ///
570    /// `at` is the dir in which the `.gitmodules` file is located, this is the git root.
571    /// `lib` is name of the directory where the submodules are located.
572    pub fn read_submodules_with_branch(
573        self,
574        at: &Path,
575        lib: &OsStr,
576    ) -> Result<HashMap<PathBuf, String>> {
577        // Read the .gitmodules file
578        let gitmodules = foundry_common::fs::read_to_string(at.join(".gitmodules"))?;
579
580        let paths = SUBMODULE_BRANCH_REGEX
581            .captures_iter(&gitmodules)
582            .map(|cap| {
583                let path_str = cap.get(1).unwrap().as_str();
584                let path = PathBuf::from_str(path_str).unwrap();
585                trace!(path = %path.display(), "unstripped path");
586
587                // Keep only the components that come after the lib directory.
588                // This needs to be done because the lockfile uses paths relative foundry project
589                // root whereas .gitmodules use paths relative to the git root which may not be the
590                // project root. e.g monorepo.
591                // Hence, if path is lib/solady, then `lib/solady` is kept. if path is
592                // packages/contract-bedrock/lib/solady, then `lib/solady` is kept.
593                let lib_pos = path.components().find_position(|c| c.as_os_str() == lib);
594                let path = path
595                    .components()
596                    .skip(lib_pos.map(|(i, _)| i).unwrap_or(0))
597                    .collect::<PathBuf>();
598
599                let branch = cap.get(2).unwrap().as_str().to_string();
600                (path, branch)
601            })
602            .collect::<HashMap<_, _>>();
603
604        Ok(paths)
605    }
606
607    pub fn has_missing_dependencies<I, S>(self, paths: I) -> Result<bool>
608    where
609        I: IntoIterator<Item = S>,
610        S: AsRef<OsStr>,
611    {
612        self.cmd()
613            .args(["submodule", "status"])
614            .args(paths)
615            .get_stdout_lossy()
616            .map(|stdout| stdout.lines().any(|line| line.starts_with('-')))
617    }
618
619    /// Returns true if the given path has submodules by checking `git submodule status`
620    pub fn has_submodules<I, S>(self, paths: I) -> Result<bool>
621    where
622        I: IntoIterator<Item = S>,
623        S: AsRef<OsStr>,
624    {
625        self.cmd()
626            .args(["submodule", "status"])
627            .args(paths)
628            .get_stdout_lossy()
629            .map(|stdout| stdout.trim().lines().next().is_some())
630    }
631
632    pub fn submodule_add(
633        self,
634        force: bool,
635        url: impl AsRef<OsStr>,
636        path: impl AsRef<OsStr>,
637    ) -> Result<()> {
638        self.cmd()
639            .stderr(self.stderr())
640            .args(["submodule", "add"])
641            .args(self.shallow.then_some("--depth=1"))
642            .args(force.then_some("--force"))
643            .arg(url)
644            .arg(path)
645            .exec()
646            .map(drop)
647    }
648
649    pub fn submodule_update<I, S>(
650        self,
651        force: bool,
652        remote: bool,
653        no_fetch: bool,
654        recursive: bool,
655        paths: I,
656    ) -> Result<()>
657    where
658        I: IntoIterator<Item = S>,
659        S: AsRef<OsStr>,
660    {
661        self.cmd()
662            .stderr(self.stderr())
663            .args(["submodule", "update", "--progress", "--init"])
664            .args(self.shallow.then_some("--depth=1"))
665            .args(force.then_some("--force"))
666            .args(remote.then_some("--remote"))
667            .args(no_fetch.then_some("--no-fetch"))
668            .args(recursive.then_some("--recursive"))
669            .args(paths)
670            .exec()
671            .map(drop)
672    }
673
674    pub fn submodule_foreach(self, recursive: bool, cmd: impl AsRef<OsStr>) -> Result<()> {
675        self.cmd()
676            .stderr(self.stderr())
677            .args(["submodule", "foreach"])
678            .args(recursive.then_some("--recursive"))
679            .arg(cmd)
680            .exec()
681            .map(drop)
682    }
683
684    /// If the status is prefix with `-`, the submodule is not initialized.
685    ///
686    /// Ref: <https://git-scm.com/docs/git-submodule#Documentation/git-submodule.txt-status--cached--recursive--ltpathgt82308203>
687    pub fn submodules_uninitialized(self) -> Result<bool> {
688        // keep behavior consistent with `has_missing_dependencies`, but avoid duplicating the
689        // "submodule status has '-' prefix" logic.
690        self.has_missing_dependencies(std::iter::empty::<&OsStr>())
691    }
692
693    /// Initializes the git submodules.
694    pub fn submodule_init(self) -> Result<()> {
695        self.cmd().stderr(self.stderr()).args(["submodule", "init"]).exec().map(drop)
696    }
697
698    pub fn submodules(&self) -> Result<Submodules> {
699        self.cmd().args(["submodule", "status"]).get_stdout_lossy().map(|stdout| stdout.parse())?
700    }
701
702    pub fn submodule_sync(self) -> Result<()> {
703        self.cmd().stderr(self.stderr()).args(["submodule", "sync"]).exec().map(drop)
704    }
705
706    /// Get the URL of a submodule from git config
707    pub fn submodule_url(self, path: &Path) -> Result<Option<String>> {
708        self.cmd()
709            .args(["config", "--get", &format!("submodule.{}.url", path.to_slash_lossy())])
710            .get_stdout_lossy()
711            .map(|url| Some(url.trim().to_string()))
712    }
713
714    pub fn cmd(self) -> Command {
715        let mut cmd = Self::cmd_no_root();
716        cmd.current_dir(self.root);
717        cmd
718    }
719
720    pub fn cmd_at(self, path: &Path) -> Command {
721        let mut cmd = Self::cmd_no_root();
722        cmd.current_dir(path);
723        cmd
724    }
725
726    pub fn cmd_no_root() -> Command {
727        let mut cmd = Command::new("git");
728        cmd.stdout(Stdio::piped()).stderr(Stdio::piped());
729        cmd
730    }
731
732    // don't set this in cmd() because it's not wanted for all commands
733    fn stderr(self) -> Stdio {
734        if self.quiet { Stdio::piped() } else { Stdio::inherit() }
735    }
736}
737
738/// Deserialized `git submodule status lib/dep` output.
739#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq, Eq)]
740pub struct Submodule {
741    /// Current commit hash the submodule is checked out at.
742    rev: String,
743    /// Relative path to the submodule.
744    path: PathBuf,
745}
746
747impl Submodule {
748    pub fn new(rev: String, path: PathBuf) -> Self {
749        Self { rev, path }
750    }
751
752    pub fn rev(&self) -> &str {
753        &self.rev
754    }
755
756    pub fn path(&self) -> &PathBuf {
757        &self.path
758    }
759}
760
761impl FromStr for Submodule {
762    type Err = eyre::Report;
763
764    fn from_str(s: &str) -> Result<Self> {
765        let caps = SUBMODULE_STATUS_REGEX
766            .captures(s)
767            .ok_or_else(|| eyre::eyre!("Invalid submodule status format"))?;
768
769        Ok(Self {
770            rev: caps.get(1).unwrap().as_str().to_string(),
771            path: PathBuf::from(caps.get(2).unwrap().as_str()),
772        })
773    }
774}
775
776/// Deserialized `git submodule status` output.
777#[derive(Debug, Clone, PartialEq, Eq)]
778pub struct Submodules(pub Vec<Submodule>);
779
780impl Submodules {
781    pub fn len(&self) -> usize {
782        self.0.len()
783    }
784
785    pub fn is_empty(&self) -> bool {
786        self.0.is_empty()
787    }
788}
789
790impl FromStr for Submodules {
791    type Err = eyre::Report;
792
793    fn from_str(s: &str) -> Result<Self> {
794        let subs = s.lines().map(str::parse).collect::<Result<Vec<Submodule>>>()?;
795        Ok(Self(subs))
796    }
797}
798
799impl<'a> IntoIterator for &'a Submodules {
800    type Item = &'a Submodule;
801    type IntoIter = std::slice::Iter<'a, Submodule>;
802
803    fn into_iter(self) -> Self::IntoIter {
804        self.0.iter()
805    }
806}
807#[cfg(test)]
808mod tests {
809    use super::*;
810    use foundry_common::fs;
811    use std::{env, fs::File, io::Write};
812    use tempfile::tempdir;
813
814    #[test]
815    fn parse_submodule_status() {
816        let s = "+8829465a08cac423dcf59852f21e448449c1a1a8 lib/openzeppelin-contracts (v4.8.0-791-g8829465a)";
817        let sub = Submodule::from_str(s).unwrap();
818        assert_eq!(sub.rev(), "8829465a08cac423dcf59852f21e448449c1a1a8");
819        assert_eq!(sub.path(), Path::new("lib/openzeppelin-contracts"));
820
821        let s = "-8829465a08cac423dcf59852f21e448449c1a1a8 lib/openzeppelin-contracts";
822        let sub = Submodule::from_str(s).unwrap();
823        assert_eq!(sub.rev(), "8829465a08cac423dcf59852f21e448449c1a1a8");
824        assert_eq!(sub.path(), Path::new("lib/openzeppelin-contracts"));
825
826        let s = "8829465a08cac423dcf59852f21e448449c1a1a8 lib/openzeppelin-contracts";
827        let sub = Submodule::from_str(s).unwrap();
828        assert_eq!(sub.rev(), "8829465a08cac423dcf59852f21e448449c1a1a8");
829        assert_eq!(sub.path(), Path::new("lib/openzeppelin-contracts"));
830    }
831
832    #[test]
833    fn parse_multiline_submodule_status() {
834        let s = r#"+d3db4ef90a72b7d24aa5a2e5c649593eaef7801d lib/forge-std (v1.9.4-6-gd3db4ef)
835+8829465a08cac423dcf59852f21e448449c1a1a8 lib/openzeppelin-contracts (v4.8.0-791-g8829465a)
836"#;
837        let subs = Submodules::from_str(s).unwrap().0;
838        assert_eq!(subs.len(), 2);
839        assert_eq!(subs[0].rev(), "d3db4ef90a72b7d24aa5a2e5c649593eaef7801d");
840        assert_eq!(subs[0].path(), Path::new("lib/forge-std"));
841        assert_eq!(subs[1].rev(), "8829465a08cac423dcf59852f21e448449c1a1a8");
842        assert_eq!(subs[1].path(), Path::new("lib/openzeppelin-contracts"));
843    }
844
845    #[test]
846    fn foundry_path_ext_works() {
847        let p = Path::new("contracts/MyTest.t.sol");
848        assert!(p.is_sol_test());
849        assert!(p.is_sol());
850        let p = Path::new("contracts/Greeter.sol");
851        assert!(!p.is_sol_test());
852    }
853
854    // loads .env from cwd and project dir, See [`find_project_root()`]
855    #[test]
856    fn can_load_dotenv() {
857        let temp = tempdir().unwrap();
858        Git::new(temp.path()).init().unwrap();
859        let cwd_env = temp.path().join(".env");
860        fs::create_file(temp.path().join("foundry.toml")).unwrap();
861        let nested = temp.path().join("nested");
862        fs::create_dir(&nested).unwrap();
863
864        let mut cwd_file = File::create(cwd_env).unwrap();
865        let mut prj_file = File::create(nested.join(".env")).unwrap();
866
867        cwd_file.write_all("TESTCWDKEY=cwd_val".as_bytes()).unwrap();
868        cwd_file.sync_all().unwrap();
869
870        prj_file.write_all("TESTPRJKEY=prj_val".as_bytes()).unwrap();
871        prj_file.sync_all().unwrap();
872
873        let cwd = env::current_dir().unwrap();
874        env::set_current_dir(nested).unwrap();
875        load_dotenv();
876        env::set_current_dir(cwd).unwrap();
877
878        assert_eq!(env::var("TESTCWDKEY").unwrap(), "cwd_val");
879        assert_eq!(env::var("TESTPRJKEY").unwrap(), "prj_val");
880    }
881
882    #[test]
883    fn test_read_gitmodules_regex() {
884        let gitmodules = r#"
885        [submodule "lib/solady"]
886        path = lib/solady
887        url = ""
888        branch = v0.1.0
889        [submodule "lib/openzeppelin-contracts"]
890        path = lib/openzeppelin-contracts
891        url = ""
892        branch = v4.8.0-791-g8829465a
893        [submodule "lib/forge-std"]
894        path = lib/forge-std
895        url = ""
896"#;
897
898        let paths = SUBMODULE_BRANCH_REGEX
899            .captures_iter(gitmodules)
900            .map(|cap| {
901                (
902                    PathBuf::from_str(cap.get(1).unwrap().as_str()).unwrap(),
903                    String::from(cap.get(2).unwrap().as_str()),
904                )
905            })
906            .collect::<HashMap<_, _>>();
907
908        assert_eq!(paths.get(Path::new("lib/solady")).unwrap(), "v0.1.0");
909        assert_eq!(
910            paths.get(Path::new("lib/openzeppelin-contracts")).unwrap(),
911            "v4.8.0-791-g8829465a"
912        );
913
914        let no_branch_gitmodules = r#"
915        [submodule "lib/solady"]
916        path = lib/solady
917        url = ""
918        [submodule "lib/openzeppelin-contracts"]
919        path = lib/openzeppelin-contracts
920        url = ""
921        [submodule "lib/forge-std"]
922        path = lib/forge-std
923        url = ""
924"#;
925        let paths = SUBMODULE_BRANCH_REGEX
926            .captures_iter(no_branch_gitmodules)
927            .map(|cap| {
928                (
929                    PathBuf::from_str(cap.get(1).unwrap().as_str()).unwrap(),
930                    String::from(cap.get(2).unwrap().as_str()),
931                )
932            })
933            .collect::<HashMap<_, _>>();
934
935        assert!(paths.is_empty());
936
937        let branch_in_between = r#"
938        [submodule "lib/solady"]
939        path = lib/solady
940        url = ""
941        [submodule "lib/openzeppelin-contracts"]
942        path = lib/openzeppelin-contracts
943        url = ""
944        branch = v4.8.0-791-g8829465a
945        [submodule "lib/forge-std"]
946        path = lib/forge-std
947        url = ""
948        "#;
949
950        let paths = SUBMODULE_BRANCH_REGEX
951            .captures_iter(branch_in_between)
952            .map(|cap| {
953                (
954                    PathBuf::from_str(cap.get(1).unwrap().as_str()).unwrap(),
955                    String::from(cap.get(2).unwrap().as_str()),
956                )
957            })
958            .collect::<HashMap<_, _>>();
959
960        assert_eq!(paths.len(), 1);
961        assert_eq!(
962            paths.get(Path::new("lib/openzeppelin-contracts")).unwrap(),
963            "v4.8.0-791-g8829465a"
964        );
965    }
966}