chisel/
session_source.rs

1//! Session Source
2//!
3//! This module contains the `SessionSource` struct, which is a minimal wrapper around
4//! the REPL contract's source code. It provides simple compilation, parsing, and
5//! execution helpers.
6
7use alloy_primitives::map::HashMap;
8use eyre::Result;
9use forge_fmt::solang_ext::SafeUnwrap;
10use foundry_compilers::{
11    artifacts::{CompilerOutput, Settings, SolcInput, Source, Sources},
12    compilers::solc::Solc,
13};
14use foundry_config::{Config, SolcReq};
15use foundry_evm::{backend::Backend, opts::EvmOpts};
16use semver::Version;
17use serde::{Deserialize, Serialize};
18use solang_parser::{diagnostics::Diagnostic, pt};
19use std::{fs, path::PathBuf};
20use walkdir::WalkDir;
21use yansi::Paint;
22
23/// The minimum Solidity version of the `Vm` interface.
24pub const MIN_VM_VERSION: Version = Version::new(0, 6, 2);
25
26/// Solidity source for the `Vm` interface in [forge-std](https://github.com/foundry-rs/forge-std)
27static VM_SOURCE: &str = include_str!("../../../testdata/cheats/Vm.sol");
28
29/// Intermediate output for the compiled [SessionSource]
30#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
31pub struct IntermediateOutput {
32    /// All expressions within the REPL contract's run function and top level scope.
33    #[serde(skip)]
34    pub repl_contract_expressions: HashMap<String, pt::Expression>,
35    /// Intermediate contracts
36    #[serde(skip)]
37    pub intermediate_contracts: IntermediateContracts,
38}
39
40/// A refined intermediate parse tree for a contract that enables easy lookups
41/// of definitions.
42#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
43pub struct IntermediateContract {
44    /// All function definitions within the contract
45    #[serde(skip)]
46    pub function_definitions: HashMap<String, Box<pt::FunctionDefinition>>,
47    /// All event definitions within the contract
48    #[serde(skip)]
49    pub event_definitions: HashMap<String, Box<pt::EventDefinition>>,
50    /// All struct definitions within the contract
51    #[serde(skip)]
52    pub struct_definitions: HashMap<String, Box<pt::StructDefinition>>,
53    /// All variable definitions within the top level scope of the contract
54    #[serde(skip)]
55    pub variable_definitions: HashMap<String, Box<pt::VariableDefinition>>,
56}
57
58/// A defined type for a map of contract names to [IntermediateContract]s
59type IntermediateContracts = HashMap<String, IntermediateContract>;
60
61/// Full compilation output for the [SessionSource]
62#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
63pub struct GeneratedOutput {
64    /// The [IntermediateOutput] component
65    pub intermediate: IntermediateOutput,
66    /// The [CompilerOutput] component
67    pub compiler_output: CompilerOutput,
68}
69
70/// Configuration for the [SessionSource]
71#[derive(Clone, Debug, Default, Serialize, Deserialize)]
72pub struct SessionSourceConfig {
73    /// Foundry configuration
74    pub foundry_config: Config,
75    /// EVM Options
76    pub evm_opts: EvmOpts,
77    /// Disable the default `Vm` import.
78    pub no_vm: bool,
79    #[serde(skip)]
80    /// In-memory REVM db for the session's runner.
81    pub backend: Option<Backend>,
82    /// Optionally enable traces for the REPL contract execution
83    pub traces: bool,
84    /// Optionally set calldata for the REPL contract execution
85    pub calldata: Option<Vec<u8>>,
86}
87
88impl SessionSourceConfig {
89    /// Returns the solc version to use
90    ///
91    /// Solc version precedence
92    /// - Foundry configuration / `--use` flag
93    /// - Latest installed version via SVM
94    /// - Default: Latest 0.8.19
95    pub(crate) fn solc(&self) -> Result<Solc> {
96        let solc_req = if let Some(solc_req) = self.foundry_config.solc.clone() {
97            solc_req
98        } else if let Some(version) = Solc::installed_versions().into_iter().max() {
99            SolcReq::Version(version)
100        } else {
101            if !self.foundry_config.offline {
102                sh_print!("{}", "No solidity versions installed! ".green())?;
103            }
104            // use default
105            SolcReq::Version(Version::new(0, 8, 19))
106        };
107
108        match solc_req {
109            SolcReq::Version(version) => {
110                let solc = if let Some(solc) = Solc::find_svm_installed_version(&version)? {
111                    solc
112                } else {
113                    if self.foundry_config.offline {
114                        eyre::bail!("can't install missing solc {version} in offline mode")
115                    }
116                    sh_println!("{}", format!("Installing solidity version {version}...").green())?;
117                    Solc::blocking_install(&version)?
118                };
119                Ok(solc)
120            }
121            SolcReq::Local(solc) => {
122                if !solc.is_file() {
123                    eyre::bail!("`solc` {} does not exist", solc.display());
124                }
125                Ok(Solc::new(solc)?)
126            }
127        }
128    }
129}
130
131/// REPL Session Source wrapper
132///
133/// Heavily based on soli's [`ConstructedSource`](https://github.com/jpopesculian/soli/blob/master/src/main.rs#L166)
134#[derive(Clone, Debug, Serialize, Deserialize)]
135pub struct SessionSource {
136    /// The file name
137    pub file_name: PathBuf,
138    /// The contract name
139    pub contract_name: String,
140    /// The solidity compiler version
141    pub solc: Solc,
142    /// Global level solidity code
143    ///
144    /// Typically, global-level code is present between the contract definition and the first
145    /// function (usually constructor)
146    pub global_code: String,
147    /// Top level solidity code
148    ///
149    /// Typically, this is code seen above the constructor
150    pub top_level_code: String,
151    /// Code existing within the "run()" function's scope
152    pub run_code: String,
153    /// The generated output
154    pub generated_output: Option<GeneratedOutput>,
155    /// Session Source configuration
156    pub config: SessionSourceConfig,
157}
158
159impl SessionSource {
160    /// Creates a new source given a solidity compiler version
161    ///
162    /// # Panics
163    ///
164    /// If no Solc binary is set, cannot be found or the `--version` command fails
165    ///
166    /// ### Takes
167    ///
168    /// - An instance of [Solc]
169    /// - An instance of [SessionSourceConfig]
170    ///
171    /// ### Returns
172    ///
173    /// A new instance of [SessionSource]
174    #[track_caller]
175    pub fn new(solc: Solc, mut config: SessionSourceConfig) -> Self {
176        if solc.version < MIN_VM_VERSION && !config.no_vm {
177            tracing::info!(version=%solc.version, minimum=%MIN_VM_VERSION, "Disabling VM injection");
178            config.no_vm = true;
179        }
180
181        Self {
182            file_name: PathBuf::from("ReplContract.sol".to_string()),
183            contract_name: "REPL".to_string(),
184            solc,
185            config,
186            global_code: Default::default(),
187            top_level_code: Default::default(),
188            run_code: Default::default(),
189            generated_output: None,
190        }
191    }
192
193    /// Clones a [SessionSource] without copying the [GeneratedOutput], as it will
194    /// need to be regenerated as soon as new code is added.
195    ///
196    /// ### Returns
197    ///
198    /// A shallow-cloned [SessionSource]
199    pub fn shallow_clone(&self) -> Self {
200        Self {
201            file_name: self.file_name.clone(),
202            contract_name: self.contract_name.clone(),
203            solc: self.solc.clone(),
204            global_code: self.global_code.clone(),
205            top_level_code: self.top_level_code.clone(),
206            run_code: self.run_code.clone(),
207            generated_output: None,
208            config: self.config.clone(),
209        }
210    }
211
212    /// Clones the [SessionSource] and appends a new line of code. Will return
213    /// an error result if the new line fails to be parsed.
214    ///
215    /// ### Returns
216    ///
217    /// Optionally, a shallow-cloned [SessionSource] with the passed content appended to the
218    /// source code.
219    pub fn clone_with_new_line(&self, mut content: String) -> Result<(Self, bool)> {
220        let new_source = self.shallow_clone();
221        if let Some(parsed) = parse_fragment(new_source.solc, new_source.config, &content)
222            .or_else(|| {
223                let new_source = self.shallow_clone();
224                content.push(';');
225                parse_fragment(new_source.solc, new_source.config, &content)
226            })
227            .or_else(|| {
228                let new_source = self.shallow_clone();
229                content = content.trim_end().trim_end_matches(';').to_string();
230                parse_fragment(new_source.solc, new_source.config, &content)
231            })
232        {
233            let mut new_source = self.shallow_clone();
234            // Flag that tells the dispatcher whether to build or execute the session
235            // source based on the scope of the new code.
236            match parsed {
237                ParseTreeFragment::Function => new_source.with_run_code(&content),
238                ParseTreeFragment::Contract => new_source.with_top_level_code(&content),
239                ParseTreeFragment::Source => new_source.with_global_code(&content),
240            };
241
242            Ok((new_source, matches!(parsed, ParseTreeFragment::Function)))
243        } else {
244            eyre::bail!("\"{}\"", content.trim().to_owned());
245        }
246    }
247
248    // Fillers
249
250    /// Appends global-level code to the source
251    pub fn with_global_code(&mut self, content: &str) -> &mut Self {
252        self.global_code.push_str(content.trim());
253        self.global_code.push('\n');
254        self.generated_output = None;
255        self
256    }
257
258    /// Appends top-level code to the source
259    pub fn with_top_level_code(&mut self, content: &str) -> &mut Self {
260        self.top_level_code.push_str(content.trim());
261        self.top_level_code.push('\n');
262        self.generated_output = None;
263        self
264    }
265
266    /// Appends code to the "run()" function
267    pub fn with_run_code(&mut self, content: &str) -> &mut Self {
268        self.run_code.push_str(content.trim());
269        self.run_code.push('\n');
270        self.generated_output = None;
271        self
272    }
273
274    // Drains
275
276    /// Clears global code from the source
277    pub fn drain_global_code(&mut self) -> &mut Self {
278        String::clear(&mut self.global_code);
279        self.generated_output = None;
280        self
281    }
282
283    /// Clears top-level code from the source
284    pub fn drain_top_level_code(&mut self) -> &mut Self {
285        String::clear(&mut self.top_level_code);
286        self.generated_output = None;
287        self
288    }
289
290    /// Clears the "run()" function's code
291    pub fn drain_run(&mut self) -> &mut Self {
292        String::clear(&mut self.run_code);
293        self.generated_output = None;
294        self
295    }
296
297    /// Generates and [`SolcInput`] from the source.
298    ///
299    /// ### Returns
300    ///
301    /// A [`SolcInput`] object containing forge-std's `Vm` interface as well as the REPL contract
302    /// source.
303    pub fn compiler_input(&self) -> SolcInput {
304        let mut sources = Sources::new();
305        sources.insert(self.file_name.clone(), Source::new(self.to_repl_source()));
306
307        let remappings = self.config.foundry_config.get_all_remappings().collect::<Vec<_>>();
308
309        // Include Vm.sol if forge-std remapping is not available
310        if !self.config.no_vm && !remappings.iter().any(|r| r.name.starts_with("forge-std")) {
311            sources.insert(PathBuf::from("forge-std/Vm.sol"), Source::new(VM_SOURCE));
312        }
313
314        let settings = Settings {
315            remappings,
316            evm_version: self
317                .config
318                .foundry_config
319                .evm_version
320                .normalize_version_solc(&self.solc.version),
321            ..Default::default()
322        };
323
324        // we only care about the solidity source, so we can safely unwrap
325        SolcInput::resolve_and_build(sources, settings)
326            .into_iter()
327            .next()
328            .map(|i| i.sanitized(&self.solc.version))
329            .expect("Solidity source not found")
330    }
331
332    /// Compiles the source using [solang_parser]
333    ///
334    /// ### Returns
335    ///
336    /// A [pt::SourceUnit] if successful.
337    /// A vec of [solang_parser::diagnostics::Diagnostic]s if unsuccessful.
338    pub fn parse(&self) -> Result<pt::SourceUnit, Vec<solang_parser::diagnostics::Diagnostic>> {
339        solang_parser::parse(&self.to_repl_source(), 0).map(|(pt, _)| pt)
340    }
341
342    /// Generate intermediate contracts for all contract definitions in the compilation source.
343    ///
344    /// ### Returns
345    ///
346    /// Optionally, a map of contract names to a vec of [IntermediateContract]s.
347    pub fn generate_intermediate_contracts(&self) -> Result<HashMap<String, IntermediateContract>> {
348        let mut res_map = HashMap::default();
349        let parsed_map = self.compiler_input().sources;
350        for source in parsed_map.values() {
351            Self::get_intermediate_contract(&source.content, &mut res_map);
352        }
353        Ok(res_map)
354    }
355
356    /// Generate intermediate output for the REPL contract
357    pub fn generate_intermediate_output(&self) -> Result<IntermediateOutput> {
358        // Parse generate intermediate contracts
359        let intermediate_contracts = self.generate_intermediate_contracts()?;
360
361        // Construct variable definitions
362        let variable_definitions = intermediate_contracts
363            .get("REPL")
364            .ok_or_else(|| eyre::eyre!("Could not find intermediate REPL contract!"))?
365            .variable_definitions
366            .clone()
367            .into_iter()
368            .map(|(k, v)| (k, v.ty))
369            .collect::<HashMap<String, pt::Expression>>();
370        // Construct intermediate output
371        let mut intermediate_output = IntermediateOutput {
372            repl_contract_expressions: variable_definitions,
373            intermediate_contracts,
374        };
375
376        // Add all statements within the run function to the repl_contract_expressions map
377        for (key, val) in intermediate_output
378            .run_func_body()?
379            .clone()
380            .iter()
381            .flat_map(Self::get_statement_definitions)
382        {
383            intermediate_output.repl_contract_expressions.insert(key, val);
384        }
385
386        Ok(intermediate_output)
387    }
388
389    /// Compile the contract
390    ///
391    /// ### Returns
392    ///
393    /// Optionally, a [CompilerOutput] object that contains compilation artifacts.
394    pub fn compile(&self) -> Result<CompilerOutput> {
395        // Compile the contract
396        let compiled = self.solc.compile_exact(&self.compiler_input())?;
397
398        // Extract compiler errors
399        let errors =
400            compiled.errors.iter().filter(|error| error.severity.is_error()).collect::<Vec<_>>();
401        if !errors.is_empty() {
402            eyre::bail!(
403                "Compiler errors:\n{}",
404                errors.into_iter().map(|err| err.to_string()).collect::<String>()
405            );
406        }
407
408        Ok(compiled)
409    }
410
411    /// Builds the SessionSource from input into the complete CompiledOutput
412    ///
413    /// ### Returns
414    ///
415    /// Optionally, a [GeneratedOutput] object containing both the [CompilerOutput] and the
416    /// [IntermediateOutput].
417    pub fn build(&mut self) -> Result<GeneratedOutput> {
418        // Compile
419        let compiler_output = self.compile()?;
420
421        // Generate intermediate output
422        let intermediate_output = self.generate_intermediate_output()?;
423
424        // Construct generated output
425        let generated_output =
426            GeneratedOutput { intermediate: intermediate_output, compiler_output };
427        self.generated_output = Some(generated_output.clone()); // ehhh, need to not clone this.
428        Ok(generated_output)
429    }
430
431    /// Convert the [SessionSource] to a valid Script contract
432    ///
433    /// ### Returns
434    ///
435    /// The [SessionSource] represented as a Forge Script contract.
436    pub fn to_script_source(&self) -> String {
437        let Version { major, minor, patch, .. } = self.solc.version;
438        let Self { contract_name, global_code, top_level_code, run_code, config, .. } = self;
439
440        let script_import =
441            if !config.no_vm { "import {Script} from \"forge-std/Script.sol\";\n" } else { "" };
442
443        format!(
444            r#"
445// SPDX-License-Identifier: UNLICENSED
446pragma solidity ^{major}.{minor}.{patch};
447
448{script_import}
449{global_code}
450
451contract {contract_name} is Script {{
452    {top_level_code}
453  
454    /// @notice Script entry point
455    function run() public {{
456        {run_code}
457    }}
458}}"#,
459        )
460    }
461
462    /// Convert the [SessionSource] to a valid REPL contract
463    ///
464    /// ### Returns
465    ///
466    /// The [SessionSource] represented as a REPL contract.
467    pub fn to_repl_source(&self) -> String {
468        let Version { major, minor, patch, .. } = self.solc.version;
469        let Self { contract_name, global_code, top_level_code, run_code, config, .. } = self;
470        let (mut vm_import, mut vm_constant) = (String::new(), String::new());
471        if !config.no_vm {
472            // Check if there's any `forge-std` remapping and determine proper path to it by
473            // searching remapping path.
474            if let Some(remapping) = config
475                .foundry_config
476                .remappings
477                .iter()
478                .find(|remapping| remapping.name == "forge-std/")
479            {
480                if let Some(vm_path) = WalkDir::new(&remapping.path.path)
481                    .into_iter()
482                    .filter_map(|e| e.ok())
483                    .find(|e| e.file_name() == "Vm.sol")
484                {
485                    vm_import = format!("import {{Vm}} from \"{}\";\n", vm_path.path().display());
486                    vm_constant = "Vm internal constant vm = Vm(address(uint160(uint256(keccak256(\"hevm cheat code\")))));\n".to_string();
487                }
488            }
489        }
490
491        format!(
492            r#"
493// SPDX-License-Identifier: UNLICENSED
494pragma solidity ^{major}.{minor}.{patch};
495
496{vm_import}
497{global_code}
498
499contract {contract_name} {{
500    {vm_constant}
501    {top_level_code}
502  
503    /// @notice REPL contract entry point
504    function run() public {{
505        {run_code}
506    }}
507}}"#,
508        )
509    }
510
511    /// Gets the [IntermediateContract] for a Solidity source string and inserts it into the
512    /// passed `res_map`. In addition, recurses on any imported files as well.
513    ///
514    /// ### Takes
515    /// - `content` - A Solidity source string
516    /// - `res_map` - A mutable reference to a map of contract names to [IntermediateContract]s
517    pub fn get_intermediate_contract(
518        content: &str,
519        res_map: &mut HashMap<String, IntermediateContract>,
520    ) {
521        if let Ok((pt::SourceUnit(source_unit_parts), _)) = solang_parser::parse(content, 0) {
522            let func_defs = source_unit_parts
523                .into_iter()
524                .filter_map(|sup| match sup {
525                    pt::SourceUnitPart::ImportDirective(i) => match i {
526                        pt::Import::Plain(s, _) |
527                        pt::Import::Rename(s, _, _) |
528                        pt::Import::GlobalSymbol(s, _, _) => {
529                            let s = match s {
530                                pt::ImportPath::Filename(s) => s.string,
531                                pt::ImportPath::Path(p) => p.to_string(),
532                            };
533                            let path = PathBuf::from(s);
534
535                            match fs::read_to_string(path) {
536                                Ok(source) => {
537                                    Self::get_intermediate_contract(&source, res_map);
538                                    None
539                                }
540                                Err(_) => None,
541                            }
542                        }
543                    },
544                    pt::SourceUnitPart::ContractDefinition(cd) => {
545                        let mut intermediate = IntermediateContract::default();
546
547                        cd.parts.into_iter().for_each(|part| match part {
548                            pt::ContractPart::FunctionDefinition(def) => {
549                                // Only match normal function definitions here.
550                                if matches!(def.ty, pt::FunctionTy::Function) {
551                                    intermediate
552                                        .function_definitions
553                                        .insert(def.name.clone().unwrap().name, def);
554                                }
555                            }
556                            pt::ContractPart::EventDefinition(def) => {
557                                let event_name = def.name.safe_unwrap().name.clone();
558                                intermediate.event_definitions.insert(event_name, def);
559                            }
560                            pt::ContractPart::StructDefinition(def) => {
561                                let struct_name = def.name.safe_unwrap().name.clone();
562                                intermediate.struct_definitions.insert(struct_name, def);
563                            }
564                            pt::ContractPart::VariableDefinition(def) => {
565                                let var_name = def.name.safe_unwrap().name.clone();
566                                intermediate.variable_definitions.insert(var_name, def);
567                            }
568                            _ => {}
569                        });
570                        Some((cd.name.safe_unwrap().name.clone(), intermediate))
571                    }
572                    _ => None,
573                })
574                .collect::<HashMap<String, IntermediateContract>>();
575            res_map.extend(func_defs);
576        }
577    }
578
579    /// Helper to deconstruct a statement
580    ///
581    /// ### Takes
582    ///
583    /// A reference to a [pt::Statement]
584    ///
585    /// ### Returns
586    ///
587    /// A vector containing tuples of the inner expressions' names, types, and storage locations.
588    pub fn get_statement_definitions(statement: &pt::Statement) -> Vec<(String, pt::Expression)> {
589        match statement {
590            pt::Statement::VariableDefinition(_, def, _) => {
591                vec![(def.name.safe_unwrap().name.clone(), def.ty.clone())]
592            }
593            pt::Statement::Expression(_, pt::Expression::Assign(_, left, _)) => {
594                if let pt::Expression::List(_, list) = left.as_ref() {
595                    list.iter()
596                        .filter_map(|(_, param)| {
597                            param.as_ref().and_then(|param| {
598                                param
599                                    .name
600                                    .as_ref()
601                                    .map(|name| (name.name.clone(), param.ty.clone()))
602                            })
603                        })
604                        .collect()
605                } else {
606                    Vec::default()
607                }
608            }
609            _ => Vec::default(),
610        }
611    }
612}
613
614impl IntermediateOutput {
615    /// Helper function that returns the body of the REPL contract's "run" function.
616    ///
617    /// ### Returns
618    ///
619    /// Optionally, the last statement within the "run" function of the REPL contract.
620    pub fn run_func_body(&self) -> Result<&Vec<pt::Statement>> {
621        match self
622            .intermediate_contracts
623            .get("REPL")
624            .ok_or_else(|| eyre::eyre!("Could not find REPL intermediate contract!"))?
625            .function_definitions
626            .get("run")
627            .ok_or_else(|| eyre::eyre!("Could not find run function definition in REPL contract!"))?
628            .body
629            .as_ref()
630            .ok_or_else(|| eyre::eyre!("Could not find run function body!"))?
631        {
632            pt::Statement::Block { statements, .. } => Ok(statements),
633            _ => eyre::bail!("Could not find statements within run function body!"),
634        }
635    }
636}
637
638/// A Parse Tree Fragment
639///
640/// Used to determine whether an input will go to the "run()" function,
641/// the top level of the contract, or in global scope.
642#[derive(Debug)]
643pub enum ParseTreeFragment {
644    /// Code for the global scope
645    Source,
646    /// Code for the top level of the contract
647    Contract,
648    /// Code for the "run()" function
649    Function,
650}
651
652/// Parses a fragment of solidity code with solang_parser and assigns
653/// it a scope within the [SessionSource].
654pub fn parse_fragment(
655    solc: Solc,
656    config: SessionSourceConfig,
657    buffer: &str,
658) -> Option<ParseTreeFragment> {
659    let mut base = SessionSource::new(solc, config);
660
661    match base.clone().with_run_code(buffer).parse() {
662        Ok(_) => return Some(ParseTreeFragment::Function),
663        Err(e) => debug_errors(&e),
664    }
665    match base.clone().with_top_level_code(buffer).parse() {
666        Ok(_) => return Some(ParseTreeFragment::Contract),
667        Err(e) => debug_errors(&e),
668    }
669    match base.with_global_code(buffer).parse() {
670        Ok(_) => return Some(ParseTreeFragment::Source),
671        Err(e) => debug_errors(&e),
672    }
673
674    None
675}
676
677fn debug_errors(errors: &[Diagnostic]) {
678    if !tracing::enabled!(tracing::Level::DEBUG) {
679        return;
680    }
681
682    for error in errors {
683        tracing::debug!("error: {}", error.message);
684    }
685}