forge_sol_macro_gen/
sol_macro_gen.rs

1//! SolMacroGen and MultiSolMacroGen
2//!
3//! This type encapsulates the logic for expansion of a Rust TokenStream from Solidity tokens. It
4//! uses the `expand` method from `alloy_sol_macro_expander` underneath.
5//!
6//! It holds info such as `path` to the ABI file, `name` of the file and the rust binding being
7//! generated, and lastly the `expansion` itself, i.e the Rust binding for the provided ABI.
8//!
9//! It contains methods to read the json abi, generate rust bindings from the abi and ultimately
10//! write the bindings to a crate or modules.
11
12use alloy_sol_macro_expander::expand::expand;
13use alloy_sol_macro_input::{SolInput, SolInputKind};
14use eyre::{Context, OptionExt, Result};
15use foundry_common::fs;
16use proc_macro2::{Span, TokenStream};
17use std::{
18    env::temp_dir,
19    fmt::Write,
20    path::{Path, PathBuf},
21    str::FromStr,
22};
23
24use heck::ToSnakeCase;
25
26pub struct SolMacroGen {
27    pub path: PathBuf,
28    pub name: String,
29    pub expansion: Option<TokenStream>,
30}
31
32impl SolMacroGen {
33    pub fn new(path: PathBuf, name: String) -> Self {
34        Self { path, name, expansion: None }
35    }
36
37    pub fn get_sol_input(&self) -> Result<SolInput> {
38        let path = self.path.to_string_lossy().into_owned();
39        let name = proc_macro2::Ident::new(&self.name, Span::call_site());
40        let tokens = quote::quote! {
41            #name,
42            #path
43        };
44
45        let sol_input: SolInput = syn::parse2(tokens).wrap_err("failed to parse input")?;
46
47        Ok(sol_input)
48    }
49}
50
51pub struct MultiSolMacroGen {
52    pub artifacts_path: PathBuf,
53    pub instances: Vec<SolMacroGen>,
54}
55
56impl MultiSolMacroGen {
57    pub fn new(artifacts_path: &Path, instances: Vec<SolMacroGen>) -> Self {
58        Self { artifacts_path: artifacts_path.to_path_buf(), instances }
59    }
60
61    pub fn populate_expansion(&mut self, bindings_path: &Path) -> Result<()> {
62        for instance in &mut self.instances {
63            let path = bindings_path.join(format!("{}.rs", instance.name.to_lowercase()));
64            let expansion = fs::read_to_string(path).wrap_err("Failed to read file")?;
65
66            let tokens = TokenStream::from_str(&expansion)
67                .map_err(|e| eyre::eyre!("Failed to parse TokenStream: {e}"))?;
68            instance.expansion = Some(tokens);
69        }
70        Ok(())
71    }
72
73    pub fn generate_bindings(&mut self, all_derives: bool) -> Result<()> {
74        for instance in &mut self.instances {
75            Self::generate_binding(instance, all_derives).wrap_err_with(|| {
76                format!(
77                    "failed to generate bindings for {}:{}",
78                    instance.path.display(),
79                    instance.name
80                )
81            })?;
82        }
83
84        Ok(())
85    }
86
87    fn generate_binding(instance: &mut SolMacroGen, all_derives: bool) -> Result<()> {
88        // TODO: in `get_sol_input` we currently can't handle unlinked bytecode: <https://github.com/alloy-rs/core/issues/926>
89        let input = match instance.get_sol_input() {
90            Ok(input) => input.normalize_json()?,
91            Err(error) => {
92                // TODO(mattsse): remove after <https://github.com/alloy-rs/core/issues/926>
93                if error.to_string().contains("expected bytecode, found unlinked bytecode") {
94                    // we attempt to do a little hack here until we have this properly supported by
95                    // removing the bytecode objects from the json file and using a tmpfile (very
96                    // hacky)
97                    let content = std::fs::read_to_string(&instance.path)?;
98                    let mut value = serde_json::from_str::<serde_json::Value>(&content)?;
99                    let obj = value.as_object_mut().expect("valid abi");
100
101                    // clear unlinked bytecode
102                    obj.remove("bytecode");
103                    obj.remove("deployedBytecode");
104
105                    let tmpdir = temp_dir();
106                    let mut tmp_file = tmpdir.join(instance.path.file_name().unwrap());
107                    std::fs::write(&tmp_file, serde_json::to_string(&value)?)?;
108
109                    // try again
110                    std::mem::swap(&mut tmp_file, &mut instance.path);
111                    let input = instance.get_sol_input()?.normalize_json()?;
112                    std::mem::swap(&mut tmp_file, &mut instance.path);
113                    input.normalize_json()?
114                } else {
115                    return Err(error)
116                }
117            }
118        };
119
120        let SolInput { attrs: _, path: _, kind } = input;
121
122        let tokens = match kind {
123            SolInputKind::Sol(mut file) => {
124                let sol_attr: syn::Attribute = if all_derives {
125                    syn::parse_quote! {
126                            #[sol(rpc, alloy_sol_types = alloy::sol_types, alloy_contract =
127                    alloy::contract, all_derives = true, extra_derives(serde::Serialize,
128                    serde::Deserialize))]     }
129                } else {
130                    syn::parse_quote! {
131                            #[sol(rpc, alloy_sol_types = alloy::sol_types, alloy_contract =
132                    alloy::contract)]     }
133                };
134                file.attrs.push(sol_attr);
135                expand(file).wrap_err("failed to expand")?
136            }
137            _ => unreachable!(),
138        };
139
140        instance.expansion = Some(tokens);
141        Ok(())
142    }
143
144    #[allow(clippy::too_many_arguments)]
145    pub fn write_to_crate(
146        &mut self,
147        name: &str,
148        version: &str,
149        description: &str,
150        license: &str,
151        bindings_path: &Path,
152        single_file: bool,
153        alloy_version: Option<String>,
154        alloy_rev: Option<String>,
155        all_derives: bool,
156    ) -> Result<()> {
157        self.generate_bindings(all_derives)?;
158
159        let src = bindings_path.join("src");
160        let _ = fs::create_dir_all(&src);
161
162        // Write Cargo.toml
163        let cargo_toml_path = bindings_path.join("Cargo.toml");
164        let mut toml_contents = format!(
165            r#"[package]
166name = "{name}"
167version = "{version}"
168edition = "2021"
169"#
170        );
171
172        if !description.is_empty() {
173            toml_contents.push_str(&format!("description = \"{description}\"\n"));
174        }
175
176        if !license.is_empty() {
177            let formatted_licenses: Vec<String> =
178                license.split(',').map(Self::parse_license_alias).collect();
179
180            let formatted_license = formatted_licenses.join(" OR ");
181            toml_contents.push_str(&format!("license = \"{formatted_license}\"\n"));
182        }
183
184        toml_contents.push_str("\n[dependencies]\n");
185
186        let alloy_dep = Self::get_alloy_dep(alloy_version, alloy_rev);
187        write!(toml_contents, "{alloy_dep}")?;
188
189        if all_derives {
190            let serde_dep = r#"serde = { version = "1.0", features = ["derive"] }"#;
191            write!(toml_contents, "\n{serde_dep}")?;
192        }
193
194        fs::write(cargo_toml_path, toml_contents).wrap_err("Failed to write Cargo.toml")?;
195
196        let mut lib_contents = String::new();
197        write!(
198            &mut lib_contents,
199            r#"#![allow(unused_imports, clippy::all, rustdoc::all)]
200        //! This module contains the sol! generated bindings for solidity contracts.
201        //! This is autogenerated code.
202        //! Do not manually edit these files.
203        //! These files may be overwritten by the codegen system at any time.
204        "#
205        )?;
206
207        // Write src
208        let parse_error = |name: &str| {
209            format!("failed to parse generated tokens as an AST for {name};\nthis is likely a bug")
210        };
211        for instance in &self.instances {
212            let contents = instance.expansion.as_ref().unwrap();
213
214            let name = instance.name.to_snake_case();
215            let path = src.join(format!("{name}.rs"));
216            let file = syn::parse2(contents.clone())
217                .wrap_err_with(|| parse_error(&format!("{}:{}", path.display(), name)))?;
218            let contents = prettyplease::unparse(&file);
219            if single_file {
220                write!(&mut lib_contents, "{contents}")?;
221            } else {
222                fs::write(path, contents).wrap_err("failed to write to file")?;
223                write_mod_name(&mut lib_contents, &name)?;
224            }
225        }
226
227        let lib_path = src.join("lib.rs");
228        let lib_file = syn::parse_file(&lib_contents).wrap_err_with(|| parse_error("lib.rs"))?;
229        let lib_contents = prettyplease::unparse(&lib_file);
230        fs::write(lib_path, lib_contents).wrap_err("Failed to write lib.rs")?;
231
232        Ok(())
233    }
234
235    /// Attempts to detect the appropriate license.
236    pub fn parse_license_alias(license: &str) -> String {
237        match license.trim().to_lowercase().as_str() {
238            "mit" => "MIT".to_string(),
239            "apache" | "apache2" | "apache20" | "apache2.0" => "Apache-2.0".to_string(),
240            "gpl" | "gpl3" => "GPL-3.0".to_string(),
241            "lgpl" | "lgpl3" => "LGPL-3.0".to_string(),
242            "agpl" | "agpl3" => "AGPL-3.0".to_string(),
243            "bsd" | "bsd3" => "BSD-3-Clause".to_string(),
244            "bsd2" => "BSD-2-Clause".to_string(),
245            "mpl" | "mpl2" => "MPL-2.0".to_string(),
246            "isc" => "ISC".to_string(),
247            "unlicense" => "Unlicense".to_string(),
248            _ => license.trim().to_string(),
249        }
250    }
251
252    pub fn write_to_module(
253        &mut self,
254        bindings_path: &Path,
255        single_file: bool,
256        all_derives: bool,
257    ) -> Result<()> {
258        self.generate_bindings(all_derives)?;
259
260        let _ = fs::create_dir_all(bindings_path);
261
262        let mut mod_contents = r#"#![allow(unused_imports, clippy::all, rustdoc::all)]
263        //! This module contains the sol! generated bindings for solidity contracts.
264        //! This is autogenerated code.
265        //! Do not manually edit these files.
266        //! These files may be overwritten by the codegen system at any time.
267        "#
268        .to_string();
269
270        for instance in &self.instances {
271            let name = instance.name.to_snake_case();
272            if !single_file {
273                // Module
274                write_mod_name(&mut mod_contents, &name)?;
275                let mut contents = String::new();
276
277                write!(contents, "{}", instance.expansion.as_ref().unwrap())?;
278                let file = syn::parse_file(&contents)?;
279
280                let contents = prettyplease::unparse(&file);
281                fs::write(bindings_path.join(format!("{name}.rs")), contents)
282                    .wrap_err("Failed to write file")?;
283            } else {
284                // Single File
285                let mut contents = String::new();
286                write!(contents, "{}\n\n", instance.expansion.as_ref().unwrap())?;
287                write!(mod_contents, "{contents}")?;
288            }
289        }
290
291        let mod_path = bindings_path.join("mod.rs");
292        let mod_file = syn::parse_file(&mod_contents)?;
293        let mod_contents = prettyplease::unparse(&mod_file);
294
295        fs::write(mod_path, mod_contents).wrap_err("Failed to write mod.rs")?;
296
297        Ok(())
298    }
299
300    /// Checks that the generated bindings are up to date with the latest version of
301    /// `sol!`.
302    ///
303    /// Returns `Ok(())` if the generated bindings are up to date, otherwise it returns
304    /// `Err(_)`.
305    #[expect(clippy::too_many_arguments)]
306    pub fn check_consistency(
307        &self,
308        name: &str,
309        version: &str,
310        crate_path: &Path,
311        single_file: bool,
312        check_cargo_toml: bool,
313        is_mod: bool,
314        alloy_version: Option<String>,
315        alloy_rev: Option<String>,
316    ) -> Result<()> {
317        if check_cargo_toml {
318            self.check_cargo_toml(name, version, crate_path, alloy_version, alloy_rev)?;
319        }
320
321        let mut super_contents = String::new();
322        write!(
323            &mut super_contents,
324            r#"#![allow(unused_imports, clippy::all, rustdoc::all)]
325            //! This module contains the sol! generated bindings for solidity contracts.
326            //! This is autogenerated code.
327            //! Do not manually edit these files.
328            //! These files may be overwritten by the codegen system at any time.
329            "#
330        )?;
331        if !single_file {
332            for instance in &self.instances {
333                let name = instance.name.to_snake_case();
334                let path = if is_mod {
335                    crate_path.join(format!("{name}.rs"))
336                } else {
337                    crate_path.join(format!("src/{name}.rs"))
338                };
339                let tokens = instance
340                    .expansion
341                    .as_ref()
342                    .ok_or_eyre(format!("TokenStream for {path:?} does not exist"))?
343                    .to_string();
344
345                self.check_file_contents(&path, &tokens)?;
346                write_mod_name(&mut super_contents, &name)?;
347            }
348
349            let super_path =
350                if is_mod { crate_path.join("mod.rs") } else { crate_path.join("src/lib.rs") };
351            self.check_file_contents(&super_path, &super_contents)?;
352        }
353
354        Ok(())
355    }
356
357    fn check_file_contents(&self, file_path: &Path, expected_contents: &str) -> Result<()> {
358        eyre::ensure!(
359            file_path.is_file() && file_path.exists(),
360            "{} is not a file",
361            file_path.display()
362        );
363        let file_contents = &fs::read_to_string(file_path).wrap_err("Failed to read file")?;
364
365        // Format both
366        let file_contents = syn::parse_file(file_contents)?;
367        let formatted_file = prettyplease::unparse(&file_contents);
368
369        let expected_contents = syn::parse_file(expected_contents)?;
370        let formatted_exp = prettyplease::unparse(&expected_contents);
371
372        eyre::ensure!(
373            formatted_file == formatted_exp,
374            "File contents do not match expected contents for {file_path:?}"
375        );
376        Ok(())
377    }
378
379    fn check_cargo_toml(
380        &self,
381        name: &str,
382        version: &str,
383        crate_path: &Path,
384        alloy_version: Option<String>,
385        alloy_rev: Option<String>,
386    ) -> Result<()> {
387        eyre::ensure!(crate_path.is_dir(), "Crate path must be a directory");
388
389        let cargo_toml_path = crate_path.join("Cargo.toml");
390
391        eyre::ensure!(cargo_toml_path.is_file(), "Cargo.toml must exist");
392        let cargo_toml_contents =
393            fs::read_to_string(cargo_toml_path).wrap_err("Failed to read Cargo.toml")?;
394
395        let name_check = format!("name = \"{name}\"");
396        let version_check = format!("version = \"{version}\"");
397        let alloy_dep_check = Self::get_alloy_dep(alloy_version, alloy_rev);
398        let toml_consistent = cargo_toml_contents.contains(&name_check) &&
399            cargo_toml_contents.contains(&version_check) &&
400            cargo_toml_contents.contains(&alloy_dep_check);
401        eyre::ensure!(
402            toml_consistent,
403            r#"The contents of Cargo.toml do not match the expected output of the latest `sol!` version.
404                This indicates that the existing bindings are outdated and need to be generated again."#
405        );
406
407        Ok(())
408    }
409
410    /// Returns the `alloy` dependency string for the Cargo.toml file.
411    /// If `alloy_version` is provided, it will use that version from crates.io.
412    /// If `alloy_rev` is provided, it will use that revision from the GitHub repository.
413    fn get_alloy_dep(alloy_version: Option<String>, alloy_rev: Option<String>) -> String {
414        if let Some(alloy_version) = alloy_version {
415            format!(
416                r#"alloy = {{ version = "{alloy_version}", features = ["sol-types", "contract"] }}"#,
417            )
418        } else if let Some(alloy_rev) = alloy_rev {
419            format!(
420                r#"alloy = {{ git = "https://github.com/alloy-rs/alloy", rev = "{alloy_rev}", features = ["sol-types", "contract"] }}"#,
421            )
422        } else {
423            r#"alloy = { git = "https://github.com/alloy-rs/alloy", features = ["sol-types", "contract"] }"#.to_string()
424        }
425    }
426}
427
428fn write_mod_name(contents: &mut String, name: &str) -> Result<()> {
429    if syn::parse_str::<syn::Ident>(&format!("pub mod {name};")).is_ok() {
430        write!(contents, "pub mod {name};")?;
431    } else {
432        write!(contents, "pub mod r#{name};")?;
433    }
434    Ok(())
435}