Skip to main content

forge_sol_macro_gen/
sol_macro_gen.rs

1//! SolMacroGen and MultiSolMacroGen
2//!
3//! This type encapsulates the logic for expansion of a Rust TokenStream from Solidity tokens. It
4//! uses the `expand` method from `alloy_sol_macro_expander` underneath.
5//!
6//! It holds info such as `path` to the ABI file, `name` of the file and the rust binding being
7//! generated, and lastly the `expansion` itself, i.e the Rust binding for the provided ABI.
8//!
9//! It contains methods to read the json abi, generate rust bindings from the abi and ultimately
10//! write the bindings to a crate or modules.
11
12use alloy_sol_macro_expander::expand::expand;
13use alloy_sol_macro_input::{SolInput, SolInputKind};
14use eyre::{Context, OptionExt, Result};
15use foundry_common::fs;
16use proc_macro2::{Span, TokenStream};
17use std::{
18    fmt::Write,
19    path::{Path, PathBuf},
20    str::FromStr,
21};
22
23use heck::ToSnakeCase;
24
25pub struct SolMacroGen {
26    pub path: PathBuf,
27    pub name: String,
28    pub expansion: Option<TokenStream>,
29}
30
31impl SolMacroGen {
32    pub const fn new(path: PathBuf, name: String) -> Self {
33        Self { path, name, expansion: None }
34    }
35
36    pub fn get_sol_input(&self) -> Result<SolInput> {
37        let path = self.path.to_string_lossy().into_owned();
38        let name = proc_macro2::Ident::new(&self.name, Span::call_site());
39        let tokens = quote::quote! {
40            #[sol(ignore_unlinked)]
41            #name,
42            #path
43        };
44
45        let sol_input: SolInput = syn::parse2(tokens).wrap_err("failed to parse input")?;
46
47        Ok(sol_input)
48    }
49}
50
51pub struct MultiSolMacroGen {
52    pub instances: Vec<SolMacroGen>,
53}
54
55impl MultiSolMacroGen {
56    pub const fn new(instances: Vec<SolMacroGen>) -> Self {
57        Self { instances }
58    }
59
60    pub fn populate_expansion(&mut self, bindings_path: &Path) -> Result<()> {
61        for instance in &mut self.instances {
62            let path = bindings_path.join(format!("{}.rs", instance.name.to_snake_case()));
63            let expansion = fs::read_to_string(path).wrap_err("Failed to read file")?;
64
65            let tokens = TokenStream::from_str(&expansion)
66                .map_err(|e| eyre::eyre!("Failed to parse TokenStream: {e}"))?;
67            instance.expansion = Some(tokens);
68        }
69        Ok(())
70    }
71
72    pub fn generate_bindings(&mut self, all_derives: bool) -> Result<()> {
73        for instance in &mut self.instances {
74            Self::generate_binding(instance, all_derives).wrap_err_with(|| {
75                format!(
76                    "failed to generate bindings for {}:{}",
77                    instance.path.display(),
78                    instance.name
79                )
80            })?;
81        }
82
83        Ok(())
84    }
85
86    fn generate_binding(instance: &mut SolMacroGen, all_derives: bool) -> Result<()> {
87        let input = instance.get_sol_input()?.normalize_json()?;
88        let SolInput { attrs: _, path: _, kind } = input;
89
90        let tokens = match kind {
91            SolInputKind::Sol(mut file) => {
92                let sol_attr: syn::Attribute = if all_derives {
93                    syn::parse_quote! {
94                            #[sol(rpc, alloy_sol_types = alloy::sol_types, alloy_contract =
95                    alloy::contract, all_derives = true, extra_derives(serde::Serialize,
96                    serde::Deserialize))]     }
97                } else {
98                    syn::parse_quote! {
99                            #[sol(rpc, alloy_sol_types = alloy::sol_types, alloy_contract =
100                    alloy::contract)]     }
101                };
102                file.attrs.push(sol_attr);
103                expand(file).wrap_err("failed to expand")?
104            }
105            _ => unreachable!(),
106        };
107
108        instance.expansion = Some(tokens);
109        Ok(())
110    }
111
112    #[allow(clippy::too_many_arguments)]
113    pub fn write_to_crate(
114        &mut self,
115        name: &str,
116        version: &str,
117        description: &str,
118        license: &str,
119        bindings_path: &Path,
120        single_file: bool,
121        alloy_version: Option<String>,
122        alloy_rev: Option<String>,
123        all_derives: bool,
124    ) -> Result<()> {
125        self.generate_bindings(all_derives)?;
126
127        let src = bindings_path.join("src");
128        fs::create_dir_all(&src)?;
129
130        // Write Cargo.toml
131        let cargo_toml_path = bindings_path.join("Cargo.toml");
132        let mut toml_contents = format!(
133            r#"[package]
134name = "{name}"
135version = "{version}"
136edition = "2021"
137"#
138        );
139
140        if !description.is_empty() {
141            toml_contents.push_str(&format!("description = \"{description}\"\n"));
142        }
143
144        if !license.is_empty() {
145            let formatted_licenses: Vec<String> =
146                license.split(',').map(Self::parse_license_alias).collect();
147
148            let formatted_license = formatted_licenses.join(" OR ");
149            toml_contents.push_str(&format!("license = \"{formatted_license}\"\n"));
150        }
151
152        toml_contents.push_str("\n[dependencies]\n");
153
154        let alloy_dep = Self::get_alloy_dep(alloy_version, alloy_rev);
155        write!(toml_contents, "{alloy_dep}")?;
156
157        if all_derives {
158            let serde_dep = r#"serde = { version = "1.0", features = ["derive"] }"#;
159            write!(toml_contents, "\n{serde_dep}")?;
160        }
161
162        fs::write(cargo_toml_path, toml_contents).wrap_err("Failed to write Cargo.toml")?;
163
164        let mut lib_contents = String::new();
165        write!(
166            &mut lib_contents,
167            r#"#![allow(unused_imports, unused_attributes, clippy::all, rustdoc::all)]
168        //! This module contains the sol! generated bindings for solidity contracts.
169        //! This is autogenerated code.
170        //! Do not manually edit these files.
171        //! These files may be overwritten by the codegen system at any time.
172        "#
173        )?;
174
175        // Write src
176        let parse_error = |name: &str| {
177            format!("failed to parse generated tokens as an AST for {name};\nthis is likely a bug")
178        };
179        for instance in &self.instances {
180            let contents = instance.expansion.as_ref().unwrap();
181
182            let name = instance.name.to_snake_case();
183            let path = src.join(format!("{name}.rs"));
184            let file = syn::parse2(contents.clone())
185                .wrap_err_with(|| parse_error(&format!("{}:{}", path.display(), name)))?;
186            let contents = qualify_shadowed_sibling_module_paths(prettyplease::unparse(&file));
187            if single_file {
188                write!(&mut lib_contents, "{contents}")?;
189            } else {
190                fs::write(path, contents).wrap_err("failed to write to file")?;
191                write_mod_name(&mut lib_contents, &name)?;
192            }
193        }
194
195        let lib_path = src.join("lib.rs");
196        let lib_file = syn::parse_file(&lib_contents).wrap_err_with(|| parse_error("lib.rs"))?;
197        let lib_contents = prettyplease::unparse(&lib_file);
198        fs::write(lib_path, lib_contents).wrap_err("Failed to write lib.rs")?;
199
200        Ok(())
201    }
202
203    /// Attempts to detect the appropriate license.
204    pub fn parse_license_alias(license: &str) -> String {
205        match license.trim().to_lowercase().as_str() {
206            "mit" => "MIT".to_string(),
207            "apache" | "apache2" | "apache20" | "apache2.0" => "Apache-2.0".to_string(),
208            "gpl" | "gpl3" => "GPL-3.0".to_string(),
209            "lgpl" | "lgpl3" => "LGPL-3.0".to_string(),
210            "agpl" | "agpl3" => "AGPL-3.0".to_string(),
211            "bsd" | "bsd3" => "BSD-3-Clause".to_string(),
212            "bsd2" => "BSD-2-Clause".to_string(),
213            "mpl" | "mpl2" => "MPL-2.0".to_string(),
214            "isc" => "ISC".to_string(),
215            "unlicense" => "Unlicense".to_string(),
216            _ => license.trim().to_string(),
217        }
218    }
219
220    pub fn write_to_module(
221        &mut self,
222        bindings_path: &Path,
223        single_file: bool,
224        all_derives: bool,
225    ) -> Result<()> {
226        self.generate_bindings(all_derives)?;
227
228        fs::create_dir_all(bindings_path)?;
229
230        let mut mod_contents =
231            r#"#![allow(unused_imports, unused_attributes, clippy::all, rustdoc::all)]
232        //! This module contains the sol! generated bindings for solidity contracts.
233        //! This is autogenerated code.
234        //! Do not manually edit these files.
235        //! These files may be overwritten by the codegen system at any time.
236        "#
237            .to_string();
238
239        for instance in &self.instances {
240            let name = instance.name.to_snake_case();
241            if single_file {
242                // Single File
243                let mut contents = String::new();
244                write!(contents, "{}\n\n", instance.expansion.as_ref().unwrap())?;
245                write!(mod_contents, "{contents}")?;
246            } else {
247                // Module
248                write_mod_name(&mut mod_contents, &name)?;
249                let mut contents = String::new();
250
251                write!(contents, "{}", instance.expansion.as_ref().unwrap())?;
252                let file = syn::parse_file(&contents)?;
253
254                let contents = qualify_shadowed_sibling_module_paths(prettyplease::unparse(&file));
255                fs::write(bindings_path.join(format!("{name}.rs")), contents)
256                    .wrap_err("Failed to write file")?;
257            }
258        }
259
260        let mod_path = bindings_path.join("mod.rs");
261        let mod_file = syn::parse_file(&mod_contents)?;
262        let mod_contents = qualify_shadowed_sibling_module_paths(prettyplease::unparse(&mod_file));
263
264        fs::write(mod_path, mod_contents).wrap_err("Failed to write mod.rs")?;
265
266        Ok(())
267    }
268
269    /// Checks that the generated bindings are up to date with the latest version of
270    /// `sol!`.
271    ///
272    /// Returns `Ok(())` if the generated bindings are up to date, otherwise it returns
273    /// `Err(_)`.
274    #[expect(clippy::too_many_arguments)]
275    pub fn check_consistency(
276        &self,
277        name: &str,
278        version: &str,
279        crate_path: &Path,
280        single_file: bool,
281        check_cargo_toml: bool,
282        is_mod: bool,
283        alloy_version: Option<String>,
284        alloy_rev: Option<String>,
285    ) -> Result<()> {
286        if check_cargo_toml && !is_mod {
287            self.check_cargo_toml(name, version, crate_path, alloy_version, alloy_rev)?;
288        }
289
290        let mut super_contents = String::new();
291        write!(
292            &mut super_contents,
293            r#"#![allow(unused_imports, unused_attributes, clippy::all, rustdoc::all)]
294            //! This module contains the sol! generated bindings for solidity contracts.
295            //! This is autogenerated code.
296            //! Do not manually edit these files.
297            //! These files may be overwritten by the codegen system at any time.
298            "#
299        )?;
300        if !single_file {
301            for instance in &self.instances {
302                let name = instance.name.to_snake_case();
303                let path = if is_mod {
304                    crate_path.join(format!("{name}.rs"))
305                } else {
306                    crate_path.join(format!("src/{name}.rs"))
307                };
308                let tokens = instance
309                    .expansion
310                    .as_ref()
311                    .ok_or_eyre(format!("TokenStream for {path:?} does not exist"))?
312                    .to_string();
313
314                self.check_file_contents(&path, &tokens)?;
315                write_mod_name(&mut super_contents, &name)?;
316            }
317
318            let super_path =
319                if is_mod { crate_path.join("mod.rs") } else { crate_path.join("src/lib.rs") };
320            self.check_file_contents(&super_path, &super_contents)?;
321        }
322
323        Ok(())
324    }
325
326    fn check_file_contents(&self, file_path: &Path, expected_contents: &str) -> Result<()> {
327        eyre::ensure!(file_path.is_file(), "{} is not a file", file_path.display());
328        let file_contents = &fs::read_to_string(file_path).wrap_err("Failed to read file")?;
329
330        // Format both
331        let file_contents = syn::parse_file(file_contents)?;
332        let formatted_file = prettyplease::unparse(&file_contents);
333
334        let expected_contents = syn::parse_file(expected_contents)?;
335        let formatted_exp =
336            qualify_shadowed_sibling_module_paths(prettyplease::unparse(&expected_contents));
337
338        eyre::ensure!(
339            formatted_file == formatted_exp,
340            "File contents do not match expected contents for {file_path:?}"
341        );
342        Ok(())
343    }
344
345    fn check_cargo_toml(
346        &self,
347        name: &str,
348        version: &str,
349        crate_path: &Path,
350        alloy_version: Option<String>,
351        alloy_rev: Option<String>,
352    ) -> Result<()> {
353        eyre::ensure!(crate_path.is_dir(), "Crate path must be a directory");
354
355        let cargo_toml_path = crate_path.join("Cargo.toml");
356
357        eyre::ensure!(cargo_toml_path.is_file(), "Cargo.toml must exist");
358        let cargo_toml_contents =
359            fs::read_to_string(cargo_toml_path).wrap_err("Failed to read Cargo.toml")?;
360
361        let name_check = format!("name = \"{name}\"");
362        let version_check = format!("version = \"{version}\"");
363        let alloy_dep_check = Self::get_alloy_dep(alloy_version, alloy_rev);
364        let toml_consistent = cargo_toml_contents.contains(&name_check)
365            && cargo_toml_contents.contains(&version_check)
366            && cargo_toml_contents.contains(&alloy_dep_check);
367        eyre::ensure!(
368            toml_consistent,
369            r#"The contents of Cargo.toml do not match the expected output of the latest `sol!` version.
370                This indicates that the existing bindings are outdated and need to be generated again."#
371        );
372
373        Ok(())
374    }
375
376    /// Returns the `alloy` dependency string for the Cargo.toml file.
377    /// If `alloy_version` is provided, it will use that version from crates.io.
378    /// If `alloy_rev` is provided, it will use that revision from the GitHub repository.
379    fn get_alloy_dep(alloy_version: Option<String>, alloy_rev: Option<String>) -> String {
380        if let Some(alloy_version) = alloy_version {
381            format!(
382                r#"alloy = {{ version = "{alloy_version}", features = ["sol-types", "contract"] }}"#,
383            )
384        } else if let Some(alloy_rev) = alloy_rev {
385            format!(
386                r#"alloy = {{ git = "https://github.com/alloy-rs/alloy", rev = "{alloy_rev}", features = ["sol-types", "contract"] }}"#,
387            )
388        } else {
389            r#"alloy = { version = "1.0", features = ["sol-types", "contract"] }"#.to_string()
390        }
391    }
392}
393
394fn write_mod_name(contents: &mut String, name: &str) -> Result<()> {
395    if syn::parse_str::<syn::Ident>(name).is_ok() {
396        write!(contents, "pub mod {name};")?;
397    } else {
398        write!(contents, "pub mod r#{name};")?;
399    }
400    Ok(())
401}
402
403/// Qualifies paths to sibling binding modules when a generated item in the current module shadows
404/// that module name.
405///
406/// Alloy names the event enum for a contract module by appending `Events` to the contract name. If
407/// the ABI also contains a sibling contract/interface with that exact name, inherited event
408/// parameter types such as `IExampleContractEvents::SomeEventData` resolve to the local event enum
409/// instead of the sibling module that owns `SomeEventData`. Qualifying those paths with `super::`
410/// keeps the generated binding compiling without changing the upstream `sol!` expansion.
411fn qualify_shadowed_sibling_module_paths(mut contents: String) -> String {
412    let module_names = top_level_module_names(&contents);
413    let enum_names = public_enum_names(&contents);
414
415    for module_name in module_names {
416        if enum_names.iter().any(|enum_name| enum_name == &module_name) {
417            contents = qualify_unqualified_module_paths(&contents, &module_name);
418        }
419    }
420
421    contents
422}
423
424fn qualify_unqualified_module_paths(contents: &str, module_name: &str) -> String {
425    let needle = format!("{module_name}::");
426    let replacement = format!("super::{module_name}::");
427    let mut qualified = String::with_capacity(contents.len());
428    let mut rest = contents;
429
430    while let Some(index) = rest.find(&needle) {
431        let (before, after) = rest.split_at(index);
432        qualified.push_str(before);
433
434        let boundary = before
435            .chars()
436            .next_back()
437            .is_none_or(|c| !matches!(c, '_' | '0'..='9' | 'a'..='z' | 'A'..='Z' | ':'));
438
439        if boundary {
440            qualified.push_str(&replacement);
441        } else {
442            qualified.push_str(&needle);
443        }
444
445        rest = &after[needle.len()..];
446    }
447
448    qualified.push_str(rest);
449    qualified
450}
451
452fn top_level_module_names(contents: &str) -> Vec<String> {
453    contents
454        .split("pub mod ")
455        .skip(1)
456        .filter_map(|rest| rest.split_whitespace().next())
457        .map(|name| name.trim_start_matches("r#").to_string())
458        .collect()
459}
460
461fn public_enum_names(contents: &str) -> Vec<String> {
462    contents
463        .split("pub enum ")
464        .skip(1)
465        .filter_map(|rest| rest.split_whitespace().next())
466        .map(|name| name.trim_start_matches("r#").to_string())
467        .collect()
468}