chisel/
source.rs

1//! Session Source
2//!
3//! This module contains the `SessionSource` struct, which is a minimal wrapper around
4//! the REPL contract's source code. It provides simple compilation, parsing, and
5//! execution helpers.
6
7use eyre::Result;
8use forge_doc::solang_ext::{CodeLocationExt, SafeUnwrap};
9use foundry_common::fs;
10use foundry_compilers::{
11    Artifact, ProjectCompileOutput,
12    artifacts::{ConfigurableContractArtifact, Source, Sources},
13    project::ProjectCompiler,
14    solc::Solc,
15};
16use foundry_config::{Config, SolcReq};
17use foundry_evm::{backend::Backend, core::bytecode::InstIter, opts::EvmOpts};
18use semver::Version;
19use serde::{Deserialize, Serialize};
20use solang_parser::pt;
21use solar::interface::diagnostics::EmittedDiagnostics;
22use std::{cell::OnceCell, collections::HashMap, fmt, path::PathBuf};
23use walkdir::WalkDir;
24
25/// The minimum Solidity version of the `Vm` interface.
26pub const MIN_VM_VERSION: Version = Version::new(0, 6, 2);
27
28/// Solidity source for the `Vm` interface in [forge-std](https://github.com/foundry-rs/forge-std)
29static VM_SOURCE: &str = include_str!("../../../testdata/cheats/Vm.sol");
30
31/// [`SessionSource`] build output.
32pub struct GeneratedOutput {
33    output: ProjectCompileOutput,
34    pub(crate) intermediate: IntermediateOutput,
35}
36
37pub struct GeneratedOutputRef<'a> {
38    output: &'a ProjectCompileOutput,
39    // compiler: &'b solar::sema::CompilerRef<'c>,
40    pub(crate) intermediate: &'a IntermediateOutput,
41}
42
43/// Intermediate output for the compiled [SessionSource]
44#[derive(Clone, Debug, PartialEq, Eq)]
45pub struct IntermediateOutput {
46    /// All expressions within the REPL contract's run function and top level scope.
47    pub repl_contract_expressions: HashMap<String, pt::Expression>,
48    /// Intermediate contracts
49    pub intermediate_contracts: IntermediateContracts,
50}
51
52/// A refined intermediate parse tree for a contract that enables easy lookups
53/// of definitions.
54#[derive(Clone, Debug, Default, PartialEq, Eq)]
55pub struct IntermediateContract {
56    /// All function definitions within the contract
57    pub function_definitions: HashMap<String, Box<pt::FunctionDefinition>>,
58    /// All event definitions within the contract
59    pub event_definitions: HashMap<String, Box<pt::EventDefinition>>,
60    /// All struct definitions within the contract
61    pub struct_definitions: HashMap<String, Box<pt::StructDefinition>>,
62    /// All variable definitions within the top level scope of the contract
63    pub variable_definitions: HashMap<String, Box<pt::VariableDefinition>>,
64}
65
66/// A defined type for a map of contract names to [IntermediateContract]s
67type IntermediateContracts = HashMap<String, IntermediateContract>;
68
69impl fmt::Debug for GeneratedOutput {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("GeneratedOutput").finish_non_exhaustive()
72    }
73}
74
75impl GeneratedOutput {
76    pub fn enter<T: Send>(&self, f: impl FnOnce(GeneratedOutputRef<'_>) -> T + Send) -> T {
77        // TODO(dani): once intermediate is removed
78        // self.output
79        //     .parser()
80        //     .solc()
81        //     .compiler()
82        //     .enter(|compiler| f(GeneratedOutputRef { output: &self.output, compiler }))
83        f(GeneratedOutputRef { output: &self.output, intermediate: &self.intermediate })
84    }
85}
86
87impl GeneratedOutputRef<'_> {
88    pub fn repl_contract(&self) -> Option<&ConfigurableContractArtifact> {
89        self.output.find_first("REPL")
90    }
91}
92
93impl std::ops::Deref for GeneratedOutput {
94    type Target = IntermediateOutput;
95    fn deref(&self) -> &Self::Target {
96        &self.intermediate
97    }
98}
99impl std::ops::Deref for GeneratedOutputRef<'_> {
100    type Target = IntermediateOutput;
101    fn deref(&self) -> &Self::Target {
102        self.intermediate
103    }
104}
105
106impl IntermediateOutput {
107    pub fn get_event(&self, input: &str) -> Option<&pt::EventDefinition> {
108        self.intermediate_contracts
109            .get("REPL")
110            .and_then(|contract| contract.event_definitions.get(input).map(std::ops::Deref::deref))
111    }
112
113    pub fn final_pc(&self, contract: &ConfigurableContractArtifact) -> Result<Option<usize>> {
114        let deployed_bytecode = contract
115            .get_deployed_bytecode()
116            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
117        let deployed_bytecode_bytes = deployed_bytecode
118            .bytes()
119            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
120
121        let run_func_statements = self.run_func_body()?;
122
123        // Record loc of first yul block return statement (if any).
124        // This is used to decide which is the final statement within the `run()` method.
125        // see <https://github.com/foundry-rs/foundry/issues/4617>.
126        let last_yul_return = run_func_statements.iter().find_map(|statement| {
127            if let pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } = statement
128                && let Some(statement) = block.statements.last()
129                && let pt::YulStatement::FunctionCall(yul_call) = statement
130                && yul_call.id.name == "return"
131            {
132                return Some(statement.loc());
133            }
134            None
135        });
136
137        // Find the last statement within the "run()" method and get the program
138        // counter via the source map.
139        let Some(final_statement) = run_func_statements.last() else { return Ok(None) };
140
141        // If the final statement is some type of block (assembly, unchecked, or regular),
142        // we need to find the final statement within that block. Otherwise, default to
143        // the source loc of the final statement of the `run()` function's block.
144        //
145        // There is some code duplication within the arms due to the difference between
146        // the [pt::Statement] type and the [pt::YulStatement] types.
147        let mut source_loc = match final_statement {
148            pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } => {
149                // Select last non variable declaration statement, see <https://github.com/foundry-rs/foundry/issues/4938>.
150                let last_statement = block.statements.iter().rev().find(|statement| {
151                    !matches!(statement, pt::YulStatement::VariableDeclaration(_, _, _))
152                });
153                if let Some(statement) = last_statement {
154                    statement.loc()
155                } else {
156                    // In the case where the block is empty, attempt to grab the statement
157                    // before the asm block. Because we use saturating sub to get the second
158                    // to last index, this can always be safely unwrapped.
159                    run_func_statements
160                        .get(run_func_statements.len().saturating_sub(2))
161                        .unwrap()
162                        .loc()
163                }
164            }
165            pt::Statement::Block { loc: _, unchecked: _, statements } => {
166                if let Some(statement) = statements.last() {
167                    statement.loc()
168                } else {
169                    // In the case where the block is empty, attempt to grab the statement
170                    // before the block. Because we use saturating sub to get the second to
171                    // last index, this can always be safely unwrapped.
172                    run_func_statements
173                        .get(run_func_statements.len().saturating_sub(2))
174                        .unwrap()
175                        .loc()
176                }
177            }
178            _ => final_statement.loc(),
179        };
180
181        // Consider yul return statement as final statement (if it's loc is lower) .
182        if let Some(yul_return) = last_yul_return
183            && yul_return.end() < source_loc.start()
184        {
185            source_loc = yul_return;
186        }
187
188        // Map the source location of the final statement of the `run()` function to its
189        // corresponding runtime program counter
190        let final_pc = {
191            let offset = source_loc.start() as u32;
192            let length = (source_loc.end() - source_loc.start()) as u32;
193            trace!(%offset, %length, "find pc");
194            contract
195                .get_source_map_deployed()
196                .unwrap()
197                .unwrap()
198                .into_iter()
199                .zip(InstIter::new(deployed_bytecode_bytes).with_pc().map(|(pc, _)| pc))
200                .filter(|(s, _)| s.offset() == offset && s.length() == length)
201                .map(|(_, pc)| pc)
202                .max()
203        };
204        trace!(?final_pc);
205        Ok(final_pc)
206    }
207
208    pub fn run_func_body(&self) -> Result<&Vec<pt::Statement>> {
209        match self
210            .intermediate_contracts
211            .get("REPL")
212            .ok_or_else(|| eyre::eyre!("Could not find REPL intermediate contract!"))?
213            .function_definitions
214            .get("run")
215            .ok_or_else(|| eyre::eyre!("Could not find run function definition in REPL contract!"))?
216            .body
217            .as_ref()
218            .ok_or_else(|| eyre::eyre!("Could not find run function body!"))?
219        {
220            pt::Statement::Block { statements, .. } => Ok(statements),
221            _ => eyre::bail!("Could not find statements within run function body!"),
222        }
223    }
224}
225
226// TODO(dani): further migration blocked on upstream work
227#[cfg(false)]
228impl<'gcx> GeneratedOutputRef<'_, '_, 'gcx> {
229    pub fn gcx(&self) -> Gcx<'gcx> {
230        self.compiler.gcx()
231    }
232
233    pub fn repl_contract(&self) -> Option<&ConfigurableContractArtifact> {
234        self.output.find_first("REPL")
235    }
236
237    pub fn get_event(&self, input: &str) -> Option<hir::EventId> {
238        self.gcx().hir.events_enumerated().find(|(_, e)| e.name.as_str() == input).map(|(id, _)| id)
239    }
240
241    pub fn final_pc(&self, contract: &ConfigurableContractArtifact) -> Result<Option<usize>> {
242        let deployed_bytecode = contract
243            .get_deployed_bytecode()
244            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
245        let deployed_bytecode_bytes = deployed_bytecode
246            .bytes()
247            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
248
249        // Fetch the run function's body statement
250        let run_body = self.run_func_body();
251
252        // Record loc of first yul block return statement (if any).
253        // This is used to decide which is the final statement within the `run()` method.
254        // see <https://github.com/foundry-rs/foundry/issues/4617>.
255        let last_yul_return_span: Option<Span> = run_body.iter().find_map(|stmt| {
256            // TODO(dani): Yul is not yet lowered to HIR.
257            let _ = stmt;
258            /*
259            if let hir::StmtKind::Assembly { block, .. } = stmt {
260                if let Some(stmt) = block.last() {
261                    if let pt::YulStatement::FunctionCall(yul_call) = stmt {
262                        if yul_call.id.name == "return" {
263                            return Some(stmt.loc())
264                        }
265                    }
266                }
267            }
268            */
269            None
270        });
271
272        // Find the last statement within the "run()" method and get the program
273        // counter via the source map.
274        let Some(last_stmt) = run_body.last() else { return Ok(None) };
275
276        // If the final statement is some type of block (assembly, unchecked, or regular),
277        // we need to find the final statement within that block. Otherwise, default to
278        // the source loc of the final statement of the `run()` function's block.
279        //
280        // There is some code duplication within the arms due to the difference between
281        // the [pt::Statement] type and the [pt::YulStatement] types.
282        let source_stmt = match &last_stmt.kind {
283            // TODO(dani): Yul is not yet lowered to HIR.
284            /*
285            pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } => {
286                // Select last non variable declaration statement, see <https://github.com/foundry-rs/foundry/issues/4938>.
287                let last_statement = block.statements.iter().rev().find(|statement| {
288                    !matches!(statement, pt::YulStatement::VariableDeclaration(_, _, _))
289                });
290                if let Some(stmt) = last_statement {
291                    stmt
292                } else {
293                    // In the case where the block is empty, attempt to grab the statement
294                    // before the block. Because we use saturating sub to get the second to
295                    // last index, this can always be safely unwrapped.
296                    &run_body[run_body.len().saturating_sub(2)]
297                }
298            }
299            */
300            hir::StmtKind::UncheckedBlock(stmts) | hir::StmtKind::Block(stmts) => {
301                if let Some(stmt) = stmts.last() {
302                    stmt
303                } else {
304                    // In the case where the block is empty, attempt to grab the statement
305                    // before the block. Because we use saturating sub to get the second to
306                    // last index, this can always be safely unwrapped.
307                    &run_body[run_body.len().saturating_sub(2)]
308                }
309            }
310            _ => last_stmt,
311        };
312        let mut source_span = self.stmt_span_without_semicolon(source_stmt);
313
314        // Consider yul return statement as final statement (if it's loc is lower) .
315        if let Some(yul_return_span) = last_yul_return_span
316            && yul_return_span.hi() < source_span.lo()
317        {
318            source_span = yul_return_span;
319        }
320
321        // Map the source location of the final statement of the `run()` function to its
322        // corresponding runtime program counter
323        let (_sf, range) = self.compiler.sess().source_map().span_to_source(source_span).unwrap();
324        dbg!(source_span, &range, &_sf.src[range.clone()]);
325        let offset = range.start as u32;
326        let length = range.len() as u32;
327        let final_pc = deployed_bytecode
328            .source_map()
329            .ok_or_else(|| eyre::eyre!("No source map found for `REPL` contract"))??
330            .into_iter()
331            .zip(InstructionIter::new(deployed_bytecode_bytes))
332            .filter(|(s, _)| s.offset() == offset && s.length() == length)
333            .map(|(_, i)| i.pc)
334            .max()
335            .unwrap_or_default();
336        Ok(Some(final_pc))
337    }
338
339    /// Statements' ranges in the solc source map do not include the semicolon.
340    fn stmt_span_without_semicolon(&self, stmt: &hir::Stmt<'_>) -> Span {
341        match stmt.kind {
342            hir::StmtKind::DeclSingle(id) => {
343                let decl = self.gcx().hir.variable(id);
344                if let Some(expr) = decl.initializer {
345                    stmt.span.with_hi(expr.span.hi())
346                } else {
347                    stmt.span
348                }
349            }
350            hir::StmtKind::DeclMulti(_, expr) => stmt.span.with_hi(expr.span.hi()),
351            hir::StmtKind::Expr(expr) => expr.span,
352            _ => stmt.span,
353        }
354    }
355
356    fn run_func_body(&self) -> hir::Block<'_> {
357        let c = self.repl_contract_hir().expect("REPL contract not found in HIR");
358        let f = c
359            .functions()
360            .find(|&f| self.gcx().hir.function(f).name.as_ref().map(|n| n.as_str()) == Some("run"))
361            .expect("`run()` function not found in REPL contract");
362        self.gcx().hir.function(f).body.expect("`run()` function does not have a body")
363    }
364
365    fn repl_contract_hir(&self) -> Option<&hir::Contract<'_>> {
366        self.gcx().hir.contracts().find(|c| c.name.as_str() == "REPL")
367    }
368}
369
370/// Configuration for the [SessionSource]
371#[derive(Clone, Debug, Default, Serialize, Deserialize)]
372pub struct SessionSourceConfig {
373    /// Foundry configuration
374    pub foundry_config: Config,
375    /// EVM Options
376    pub evm_opts: EvmOpts,
377    /// Disable the default `Vm` import.
378    pub no_vm: bool,
379    /// In-memory REVM db for the session's runner.
380    #[serde(skip)]
381    pub backend: Option<Backend>,
382    /// Optionally enable traces for the REPL contract execution
383    pub traces: bool,
384    /// Optionally set calldata for the REPL contract execution
385    pub calldata: Option<Vec<u8>>,
386    /// Enable viaIR with minimum optimization
387    ///
388    /// This can fix most of the "stack too deep" errors while resulting a
389    /// relatively accurate source map.
390    pub ir_minimum: bool,
391}
392
393impl SessionSourceConfig {
394    /// Detect the solc version to know if VM can be injected.
395    pub fn detect_solc(&mut self) -> Result<()> {
396        if self.foundry_config.solc.is_none() {
397            let version = Solc::ensure_installed(&"*".parse().unwrap())?;
398            self.foundry_config.solc = Some(SolcReq::Version(version));
399        }
400        if !self.no_vm
401            && let Some(version) = self.foundry_config.solc_version()
402            && version < MIN_VM_VERSION
403        {
404            info!(%version, minimum=%MIN_VM_VERSION, "Disabling VM injection");
405            self.no_vm = true;
406        }
407        Ok(())
408    }
409}
410
411/// REPL Session Source wrapper
412///
413/// Heavily based on soli's [`ConstructedSource`](https://github.com/jpopesculian/soli/blob/master/src/main.rs#L166)
414#[derive(Debug, Serialize, Deserialize)]
415pub struct SessionSource {
416    /// The file name
417    pub file_name: String,
418    /// The contract name
419    pub contract_name: String,
420
421    /// Session Source configuration
422    pub config: SessionSourceConfig,
423
424    /// Global level Solidity code.
425    ///
426    /// Above and outside all contract declarations, in the global context.
427    pub global_code: String,
428    /// Top level Solidity code.
429    ///
430    /// Within the contract declaration, but outside of the `run()` function.
431    pub contract_code: String,
432    /// The code to be executed in the `run()` function.
433    pub run_code: String,
434
435    /// Cached VM source code.
436    #[serde(skip, default = "vm_source")]
437    vm_source: Source,
438    /// The generated output
439    #[serde(skip)]
440    output: OnceCell<GeneratedOutput>,
441}
442
443fn vm_source() -> Source {
444    Source::new(VM_SOURCE)
445}
446
447impl Clone for SessionSource {
448    fn clone(&self) -> Self {
449        Self {
450            file_name: self.file_name.clone(),
451            contract_name: self.contract_name.clone(),
452            global_code: self.global_code.clone(),
453            contract_code: self.contract_code.clone(),
454            run_code: self.run_code.clone(),
455            config: self.config.clone(),
456            vm_source: self.vm_source.clone(),
457            output: Default::default(),
458        }
459    }
460}
461
462impl SessionSource {
463    /// Creates a new source given a solidity compiler version
464    ///
465    /// # Panics
466    ///
467    /// If no Solc binary is set, cannot be found or the `--version` command fails
468    ///
469    /// ### Takes
470    ///
471    /// - An instance of [Solc]
472    /// - An instance of [SessionSourceConfig]
473    ///
474    /// ### Returns
475    ///
476    /// A new instance of [SessionSource]
477    pub fn new(mut config: SessionSourceConfig) -> Result<Self> {
478        config.detect_solc()?;
479        Ok(Self {
480            file_name: "ReplContract.sol".to_string(),
481            contract_name: "REPL".to_string(),
482            config,
483            global_code: Default::default(),
484            contract_code: Default::default(),
485            run_code: Default::default(),
486            vm_source: vm_source(),
487            output: Default::default(),
488        })
489    }
490
491    /// Clones the [SessionSource] and appends a new line of code.
492    ///
493    /// Returns `true` if the new line was added to `run()`.
494    pub fn clone_with_new_line(&self, mut content: String) -> Result<(Self, bool)> {
495        if let Some((new_source, fragment)) = self
496            .parse_fragment(&content)
497            .or_else(|| {
498                content.push(';');
499                self.parse_fragment(&content)
500            })
501            .or_else(|| {
502                content = content.trim_end().trim_end_matches(';').to_string();
503                self.parse_fragment(&content)
504            })
505        {
506            Ok((new_source, matches!(fragment, ParseTreeFragment::Function)))
507        } else {
508            eyre::bail!("\"{}\"", content.trim());
509        }
510    }
511
512    /// Parses a fragment of Solidity code in memory and assigns it a scope within the
513    /// [`SessionSource`].
514    fn parse_fragment(&self, buffer: &str) -> Option<(Self, ParseTreeFragment)> {
515        #[track_caller]
516        fn debug_errors(errors: &EmittedDiagnostics) {
517            debug!("{errors}");
518        }
519
520        let mut this = self.clone();
521        match this.add_run_code(buffer).parse() {
522            Ok(()) => return Some((this, ParseTreeFragment::Function)),
523            Err(e) => debug_errors(&e),
524        }
525        this = self.clone();
526        match this.add_contract_code(buffer).parse() {
527            Ok(()) => return Some((this, ParseTreeFragment::Contract)),
528            Err(e) => debug_errors(&e),
529        }
530        this = self.clone();
531        match this.add_global_code(buffer).parse() {
532            Ok(()) => return Some((this, ParseTreeFragment::Source)),
533            Err(e) => debug_errors(&e),
534        }
535        None
536    }
537
538    /// Append global-level code to the source.
539    pub fn add_global_code(&mut self, content: &str) -> &mut Self {
540        self.global_code.push_str(content.trim());
541        self.global_code.push('\n');
542        self.clear_output();
543        self
544    }
545
546    /// Append contract-level code to the source.
547    pub fn add_contract_code(&mut self, content: &str) -> &mut Self {
548        self.contract_code.push_str(content.trim());
549        self.contract_code.push('\n');
550        self.clear_output();
551        self
552    }
553
554    /// Append code to the `run()` function of the REPL contract.
555    pub fn add_run_code(&mut self, content: &str) -> &mut Self {
556        self.run_code.push_str(content.trim());
557        self.run_code.push('\n');
558        self.clear_output();
559        self
560    }
561
562    /// Clears all source code.
563    pub fn clear(&mut self) {
564        String::clear(&mut self.global_code);
565        String::clear(&mut self.contract_code);
566        String::clear(&mut self.run_code);
567        self.clear_output();
568    }
569
570    /// Clear the global-level code .
571    pub fn clear_global(&mut self) -> &mut Self {
572        String::clear(&mut self.global_code);
573        self.clear_output();
574        self
575    }
576
577    /// Clear the contract-level code .
578    pub fn clear_contract(&mut self) -> &mut Self {
579        String::clear(&mut self.contract_code);
580        self.clear_output();
581        self
582    }
583
584    /// Clear the `run()` function code.
585    pub fn clear_run(&mut self) -> &mut Self {
586        String::clear(&mut self.run_code);
587        self.clear_output();
588        self
589    }
590
591    fn clear_output(&mut self) {
592        self.output.take();
593    }
594
595    /// Compiles the source if necessary.
596    pub fn build(&self) -> Result<&GeneratedOutput> {
597        // TODO: mimics `get_or_try_init`
598        if let Some(output) = self.output.get() {
599            return Ok(output);
600        }
601        let output = self.compile()?;
602        let intermediate = self.generate_intermediate_output()?;
603        let output = GeneratedOutput { output, intermediate };
604        Ok(self.output.get_or_init(|| output))
605    }
606
607    /// Compiles the source.
608    #[cold]
609    fn compile(&self) -> Result<ProjectCompileOutput> {
610        let sources = self.get_sources();
611
612        let mut project = self.config.foundry_config.ephemeral_project()?;
613        self.config.foundry_config.disable_optimizations(&mut project, self.config.ir_minimum);
614        let mut output = ProjectCompiler::with_sources(&project, sources)?.compile()?;
615
616        if output.has_compiler_errors() {
617            eyre::bail!("{output}");
618        }
619
620        // TODO(dani): re-enable
621        if cfg!(false) {
622            output.parser_mut().solc_mut().compiler_mut().enter_mut(|c| {
623                let _ = c.lower_asts();
624            });
625        }
626
627        Ok(output)
628    }
629
630    fn get_sources(&self) -> Sources {
631        let mut sources = Sources::new();
632
633        let src = self.to_repl_source();
634        sources.insert(self.file_name.clone().into(), Source::new(src));
635
636        // Include Vm.sol if forge-std remapping is not available.
637        if !self.config.no_vm
638            && !self
639                .config
640                .foundry_config
641                .get_all_remappings()
642                .any(|r| r.name.starts_with("forge-std"))
643        {
644            sources.insert("forge-std/Vm.sol".into(), self.vm_source.clone());
645        }
646
647        sources
648    }
649
650    /// Generate intermediate contracts for all contract definitions in the compilation source.
651    ///
652    /// ### Returns
653    ///
654    /// Optionally, a map of contract names to a vec of [IntermediateContract]s.
655    pub fn generate_intermediate_contracts(&self) -> Result<HashMap<String, IntermediateContract>> {
656        let mut res_map = HashMap::default();
657        let parsed_map = self.get_sources();
658        for source in parsed_map.values() {
659            Self::get_intermediate_contract(&source.content, &mut res_map);
660        }
661        Ok(res_map)
662    }
663
664    /// Generate intermediate output for the REPL contract
665    pub fn generate_intermediate_output(&self) -> Result<IntermediateOutput> {
666        // Parse generate intermediate contracts
667        let intermediate_contracts = self.generate_intermediate_contracts()?;
668
669        // Construct variable definitions
670        let variable_definitions = intermediate_contracts
671            .get("REPL")
672            .ok_or_else(|| eyre::eyre!("Could not find intermediate REPL contract!"))?
673            .variable_definitions
674            .clone()
675            .into_iter()
676            .map(|(k, v)| (k, v.ty))
677            .collect::<HashMap<String, pt::Expression>>();
678        // Construct intermediate output
679        let mut intermediate_output = IntermediateOutput {
680            repl_contract_expressions: variable_definitions,
681            intermediate_contracts,
682        };
683
684        // Add all statements within the run function to the repl_contract_expressions map
685        for (key, val) in intermediate_output
686            .run_func_body()?
687            .clone()
688            .iter()
689            .flat_map(Self::get_statement_definitions)
690        {
691            intermediate_output.repl_contract_expressions.insert(key, val);
692        }
693
694        Ok(intermediate_output)
695    }
696
697    /// Construct the source as a valid Forge script.
698    pub fn to_script_source(&self) -> String {
699        let Self {
700            contract_name,
701            global_code,
702            contract_code: top_level_code,
703            run_code,
704            config,
705            ..
706        } = self;
707
708        let script_import =
709            if !config.no_vm { "import {Script} from \"forge-std/Script.sol\";\n" } else { "" };
710
711        format!(
712            r#"
713// SPDX-License-Identifier: UNLICENSED
714pragma solidity 0;
715
716{script_import}
717{global_code}
718
719contract {contract_name} is Script {{
720    {top_level_code}
721
722    /// @notice Script entry point
723    function run() public {{
724        {run_code}
725    }}
726}}"#,
727        )
728    }
729
730    /// Construct the REPL source.
731    pub fn to_repl_source(&self) -> String {
732        let Self {
733            contract_name,
734            global_code,
735            contract_code: top_level_code,
736            run_code,
737            config,
738            ..
739        } = self;
740        let (mut vm_import, mut vm_constant) = (String::new(), String::new());
741        // Check if there's any `forge-std` remapping and determine proper path to it by
742        // searching remapping path.
743        if !config.no_vm
744            && let Some(remapping) = config
745                .foundry_config
746                .remappings
747                .iter()
748                .find(|remapping| remapping.name == "forge-std/")
749            && let Some(vm_path) = WalkDir::new(&remapping.path.path)
750                .into_iter()
751                .filter_map(|e| e.ok())
752                .find(|e| e.file_name() == "Vm.sol")
753        {
754            vm_import = format!("import {{Vm}} from \"{}\";\n", vm_path.path().display());
755            vm_constant = "Vm internal constant vm = Vm(address(uint160(uint256(keccak256(\"hevm cheat code\")))));\n".to_string();
756        }
757
758        format!(
759            r#"
760// SPDX-License-Identifier: UNLICENSED
761pragma solidity 0;
762
763{vm_import}
764{global_code}
765
766contract {contract_name} {{
767    {vm_constant}
768    {top_level_code}
769
770    /// @notice REPL contract entry point
771    function run() public {{
772        {run_code}
773    }}
774}}"#,
775        )
776    }
777
778    /// Parse the current source in memory using Solar.
779    pub(crate) fn parse(&self) -> Result<(), EmittedDiagnostics> {
780        let sess =
781            solar::interface::Session::builder().with_buffer_emitter(Default::default()).build();
782        let _ = sess.enter_sequential(|| -> solar::interface::Result<()> {
783            let arena = solar::ast::Arena::new();
784            let filename = self.file_name.clone().into();
785            let src = self.to_repl_source();
786            let mut parser = solar::parse::Parser::from_source_code(&sess, &arena, filename, src)?;
787            let _ast = parser.parse_file().map_err(|e| e.emit())?;
788            Ok(())
789        });
790        sess.dcx.emitted_errors().unwrap()
791    }
792
793    /// Gets the [IntermediateContract] for a Solidity source string and inserts it into the
794    /// passed `res_map`. In addition, recurses on any imported files as well.
795    ///
796    /// ### Takes
797    /// - `content` - A Solidity source string
798    /// - `res_map` - A mutable reference to a map of contract names to [IntermediateContract]s
799    pub fn get_intermediate_contract(
800        content: &str,
801        res_map: &mut HashMap<String, IntermediateContract>,
802    ) {
803        if let Ok((pt::SourceUnit(source_unit_parts), _)) = solang_parser::parse(content, 0) {
804            let func_defs = source_unit_parts
805                .into_iter()
806                .filter_map(|sup| match sup {
807                    pt::SourceUnitPart::ImportDirective(i) => match i {
808                        pt::Import::Plain(s, _)
809                        | pt::Import::Rename(s, _, _)
810                        | pt::Import::GlobalSymbol(s, _, _) => {
811                            let s = match s {
812                                pt::ImportPath::Filename(s) => s.string,
813                                pt::ImportPath::Path(p) => p.to_string(),
814                            };
815                            let path = PathBuf::from(s);
816
817                            match fs::read_to_string(path) {
818                                Ok(source) => {
819                                    Self::get_intermediate_contract(&source, res_map);
820                                    None
821                                }
822                                Err(_) => None,
823                            }
824                        }
825                    },
826                    pt::SourceUnitPart::ContractDefinition(cd) => {
827                        let mut intermediate = IntermediateContract::default();
828
829                        cd.parts.into_iter().for_each(|part| match part {
830                            pt::ContractPart::FunctionDefinition(def) => {
831                                // Only match normal function definitions here.
832                                if matches!(def.ty, pt::FunctionTy::Function) {
833                                    intermediate
834                                        .function_definitions
835                                        .insert(def.name.clone().unwrap().name, def);
836                                }
837                            }
838                            pt::ContractPart::EventDefinition(def) => {
839                                let event_name = def.name.safe_unwrap().name.clone();
840                                intermediate.event_definitions.insert(event_name, def);
841                            }
842                            pt::ContractPart::StructDefinition(def) => {
843                                let struct_name = def.name.safe_unwrap().name.clone();
844                                intermediate.struct_definitions.insert(struct_name, def);
845                            }
846                            pt::ContractPart::VariableDefinition(def) => {
847                                let var_name = def.name.safe_unwrap().name.clone();
848                                intermediate.variable_definitions.insert(var_name, def);
849                            }
850                            _ => {}
851                        });
852                        Some((cd.name.safe_unwrap().name.clone(), intermediate))
853                    }
854                    _ => None,
855                })
856                .collect::<HashMap<String, IntermediateContract>>();
857            res_map.extend(func_defs);
858        }
859    }
860
861    /// Helper to deconstruct a statement
862    ///
863    /// ### Takes
864    ///
865    /// A reference to a [pt::Statement]
866    ///
867    /// ### Returns
868    ///
869    /// A vector containing tuples of the inner expressions' names, types, and storage locations.
870    pub fn get_statement_definitions(statement: &pt::Statement) -> Vec<(String, pt::Expression)> {
871        match statement {
872            pt::Statement::VariableDefinition(_, def, _) => {
873                vec![(def.name.safe_unwrap().name.clone(), def.ty.clone())]
874            }
875            pt::Statement::Expression(_, pt::Expression::Assign(_, left, _)) => {
876                if let pt::Expression::List(_, list) = left.as_ref() {
877                    list.iter()
878                        .filter_map(|(_, param)| {
879                            param.as_ref().and_then(|param| {
880                                param
881                                    .name
882                                    .as_ref()
883                                    .map(|name| (name.name.clone(), param.ty.clone()))
884                            })
885                        })
886                        .collect()
887                } else {
888                    Vec::default()
889                }
890            }
891            _ => Vec::default(),
892        }
893    }
894}
895
896/// A Parse Tree Fragment
897///
898/// Used to determine whether an input will go to the "run()" function,
899/// the top level of the contract, or in global scope.
900#[derive(Debug)]
901enum ParseTreeFragment {
902    /// Code for the global scope
903    Source,
904    /// Code for the top level of the contract
905    Contract,
906    /// Code for the "run()" function
907    Function,
908}