chisel/
source.rs

1//! Session Source
2//!
3//! This module contains the `SessionSource` struct, which is a minimal wrapper around
4//! the REPL contract's source code. It provides simple compilation, parsing, and
5//! execution helpers.
6
7use eyre::Result;
8use forge_doc::solang_ext::{CodeLocationExt, SafeUnwrap};
9use foundry_common::fs;
10use foundry_compilers::{
11    Artifact, ProjectCompileOutput,
12    artifacts::{ConfigurableContractArtifact, Source, Sources},
13    project::ProjectCompiler,
14    solc::Solc,
15};
16use foundry_config::{Config, SolcReq};
17use foundry_evm::{backend::Backend, core::bytecode::InstIter, opts::EvmOpts};
18use semver::Version;
19use serde::{Deserialize, Serialize};
20use solang_parser::pt;
21use solar::interface::diagnostics::EmittedDiagnostics;
22use std::{cell::OnceCell, collections::HashMap, fmt, path::PathBuf};
23use walkdir::WalkDir;
24
25/// The minimum Solidity version of the `Vm` interface.
26pub const MIN_VM_VERSION: Version = Version::new(0, 6, 2);
27
28/// Solidity source for the `Vm` interface in [forge-std](https://github.com/foundry-rs/forge-std)
29static VM_SOURCE: &str = include_str!("../../../testdata/utils/Vm.sol");
30
31/// [`SessionSource`] build output.
32pub struct GeneratedOutput {
33    output: ProjectCompileOutput,
34    pub(crate) intermediate: IntermediateOutput,
35}
36
37pub struct GeneratedOutputRef<'a> {
38    output: &'a ProjectCompileOutput,
39    // compiler: &'b solar::sema::CompilerRef<'c>,
40    pub(crate) intermediate: &'a IntermediateOutput,
41}
42
43/// Intermediate output for the compiled [SessionSource]
44#[derive(Clone, Debug, PartialEq, Eq)]
45pub struct IntermediateOutput {
46    /// All expressions within the REPL contract's run function and top level scope.
47    pub repl_contract_expressions: HashMap<String, pt::Expression>,
48    /// Intermediate contracts
49    pub intermediate_contracts: IntermediateContracts,
50}
51
52/// A refined intermediate parse tree for a contract that enables easy lookups
53/// of definitions.
54#[derive(Clone, Debug, Default, PartialEq, Eq)]
55pub struct IntermediateContract {
56    /// All function definitions within the contract
57    pub function_definitions: HashMap<String, Box<pt::FunctionDefinition>>,
58    /// All event definitions within the contract
59    pub event_definitions: HashMap<String, Box<pt::EventDefinition>>,
60    /// All struct definitions within the contract
61    pub struct_definitions: HashMap<String, Box<pt::StructDefinition>>,
62    /// All variable definitions within the top level scope of the contract
63    pub variable_definitions: HashMap<String, Box<pt::VariableDefinition>>,
64}
65
66/// A defined type for a map of contract names to [IntermediateContract]s
67type IntermediateContracts = HashMap<String, IntermediateContract>;
68
69impl fmt::Debug for GeneratedOutput {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("GeneratedOutput").finish_non_exhaustive()
72    }
73}
74
75impl GeneratedOutput {
76    pub fn enter<T: Send>(&self, f: impl FnOnce(GeneratedOutputRef<'_>) -> T + Send) -> T {
77        // TODO(dani): once intermediate is removed
78        // self.output
79        //     .parser()
80        //     .solc()
81        //     .compiler()
82        //     .enter(|compiler| f(GeneratedOutputRef { output: &self.output, compiler }))
83        f(GeneratedOutputRef { output: &self.output, intermediate: &self.intermediate })
84    }
85}
86
87impl GeneratedOutputRef<'_> {
88    pub fn repl_contract(&self) -> Option<&ConfigurableContractArtifact> {
89        self.output.find_first("REPL")
90    }
91}
92
93impl std::ops::Deref for GeneratedOutput {
94    type Target = IntermediateOutput;
95    fn deref(&self) -> &Self::Target {
96        &self.intermediate
97    }
98}
99impl std::ops::Deref for GeneratedOutputRef<'_> {
100    type Target = IntermediateOutput;
101    fn deref(&self) -> &Self::Target {
102        self.intermediate
103    }
104}
105
106impl IntermediateOutput {
107    pub fn get_event(&self, input: &str) -> Option<&pt::EventDefinition> {
108        self.intermediate_contracts
109            .get("REPL")
110            .and_then(|contract| contract.event_definitions.get(input).map(std::ops::Deref::deref))
111    }
112
113    pub fn final_pc(&self, contract: &ConfigurableContractArtifact) -> Result<Option<usize>> {
114        let deployed_bytecode = contract
115            .get_deployed_bytecode()
116            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
117        let deployed_bytecode_bytes = deployed_bytecode
118            .bytes()
119            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
120
121        let run_func_statements = self.run_func_body()?;
122
123        // Record loc of first yul block return statement (if any).
124        // This is used to decide which is the final statement within the `run()` method.
125        // see <https://github.com/foundry-rs/foundry/issues/4617>.
126        let last_yul_return = run_func_statements.iter().find_map(|statement| {
127            if let pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } = statement
128                && let Some(statement) = block.statements.last()
129                && let pt::YulStatement::FunctionCall(yul_call) = statement
130                && yul_call.id.name == "return"
131            {
132                return Some(statement.loc());
133            }
134            None
135        });
136
137        // Find the last statement within the "run()" method and get the program
138        // counter via the source map.
139        let Some(final_statement) = run_func_statements.last() else { return Ok(None) };
140
141        // If the final statement is some type of block (assembly, unchecked, or regular),
142        // we need to find the final statement within that block. Otherwise, default to
143        // the source loc of the final statement of the `run()` function's block.
144        //
145        // There is some code duplication within the arms due to the difference between
146        // the [pt::Statement] type and the [pt::YulStatement] types.
147        let mut source_loc = match final_statement {
148            pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } => {
149                // Select last non variable declaration statement, see <https://github.com/foundry-rs/foundry/issues/4938>.
150                let last_statement = block.statements.iter().rev().find(|statement| {
151                    !matches!(statement, pt::YulStatement::VariableDeclaration(_, _, _))
152                });
153                if let Some(statement) = last_statement {
154                    statement.loc()
155                } else {
156                    // In the case where the block is empty, attempt to grab the statement
157                    // before the asm block. Because we use saturating sub to get the second
158                    // to last index, this can always be safely unwrapped.
159                    run_func_statements
160                        .get(run_func_statements.len().saturating_sub(2))
161                        .unwrap()
162                        .loc()
163                }
164            }
165            pt::Statement::Block { loc: _, unchecked: _, statements } => {
166                if let Some(statement) = statements.last() {
167                    statement.loc()
168                } else {
169                    // In the case where the block is empty, attempt to grab the statement
170                    // before the block. Because we use saturating sub to get the second to
171                    // last index, this can always be safely unwrapped.
172                    run_func_statements
173                        .get(run_func_statements.len().saturating_sub(2))
174                        .unwrap()
175                        .loc()
176                }
177            }
178            _ => final_statement.loc(),
179        };
180
181        // Consider yul return statement as final statement (if it's loc is lower) .
182        if let Some(yul_return) = last_yul_return
183            && yul_return.end() < source_loc.start()
184        {
185            source_loc = yul_return;
186        }
187
188        // Map the source location of the final statement of the `run()` function to its
189        // corresponding runtime program counter
190        let final_pc = {
191            let offset = source_loc.start() as u32;
192            let length = (source_loc.end() - source_loc.start()) as u32;
193            trace!(%offset, %length, "find pc");
194            contract
195                .get_source_map_deployed()
196                .unwrap()
197                .unwrap()
198                .into_iter()
199                .zip(InstIter::new(deployed_bytecode_bytes).with_pc().map(|(pc, _)| pc))
200                .filter(|(s, _)| s.offset() == offset && s.length() == length)
201                .map(|(_, pc)| pc)
202                .max()
203        };
204        trace!(?final_pc);
205        Ok(final_pc)
206    }
207
208    pub fn run_func_body(&self) -> Result<&Vec<pt::Statement>> {
209        match self
210            .intermediate_contracts
211            .get("REPL")
212            .ok_or_else(|| eyre::eyre!("Could not find REPL intermediate contract!"))?
213            .function_definitions
214            .get("run")
215            .ok_or_else(|| eyre::eyre!("Could not find run function definition in REPL contract!"))?
216            .body
217            .as_ref()
218            .ok_or_else(|| eyre::eyre!("Could not find run function body!"))?
219        {
220            pt::Statement::Block { statements, .. } => Ok(statements),
221            _ => eyre::bail!("Could not find statements within run function body!"),
222        }
223    }
224}
225
226// TODO(dani): further migration blocked on upstream work
227#[cfg(false)]
228impl<'gcx> GeneratedOutputRef<'_, '_, 'gcx> {
229    pub fn gcx(&self) -> Gcx<'gcx> {
230        self.compiler.gcx()
231    }
232
233    pub fn repl_contract(&self) -> Option<&ConfigurableContractArtifact> {
234        self.output.find_first("REPL")
235    }
236
237    pub fn get_event(&self, input: &str) -> Option<hir::EventId> {
238        self.gcx().hir.events_enumerated().find(|(_, e)| e.name.as_str() == input).map(|(id, _)| id)
239    }
240
241    pub fn final_pc(&self, contract: &ConfigurableContractArtifact) -> Result<Option<usize>> {
242        let deployed_bytecode = contract
243            .get_deployed_bytecode()
244            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
245        let deployed_bytecode_bytes = deployed_bytecode
246            .bytes()
247            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
248
249        // Fetch the run function's body statement
250        let run_body = self.run_func_body();
251
252        // Record loc of first yul block return statement (if any).
253        // This is used to decide which is the final statement within the `run()` method.
254        // see <https://github.com/foundry-rs/foundry/issues/4617>.
255        let last_yul_return_span: Option<Span> = run_body.iter().find_map(|stmt| {
256            // TODO(dani): Yul is not yet lowered to HIR.
257            let _ = stmt;
258            /*
259            if let hir::StmtKind::Assembly { block, .. } = stmt {
260                if let Some(stmt) = block.last() {
261                    if let pt::YulStatement::FunctionCall(yul_call) = stmt {
262                        if yul_call.id.name == "return" {
263                            return Some(stmt.loc())
264                        }
265                    }
266                }
267            }
268            */
269            None
270        });
271
272        // Find the last statement within the "run()" method and get the program
273        // counter via the source map.
274        let Some(last_stmt) = run_body.last() else { return Ok(None) };
275
276        // If the final statement is some type of block (assembly, unchecked, or regular),
277        // we need to find the final statement within that block. Otherwise, default to
278        // the source loc of the final statement of the `run()` function's block.
279        //
280        // There is some code duplication within the arms due to the difference between
281        // the [pt::Statement] type and the [pt::YulStatement] types.
282        let source_stmt = match &last_stmt.kind {
283            // TODO(dani): Yul is not yet lowered to HIR.
284            /*
285            pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } => {
286                // Select last non variable declaration statement, see <https://github.com/foundry-rs/foundry/issues/4938>.
287                let last_statement = block.statements.iter().rev().find(|statement| {
288                    !matches!(statement, pt::YulStatement::VariableDeclaration(_, _, _))
289                });
290                if let Some(stmt) = last_statement {
291                    stmt
292                } else {
293                    // In the case where the block is empty, attempt to grab the statement
294                    // before the block. Because we use saturating sub to get the second to
295                    // last index, this can always be safely unwrapped.
296                    &run_body[run_body.len().saturating_sub(2)]
297                }
298            }
299            */
300            hir::StmtKind::UncheckedBlock(stmts) | hir::StmtKind::Block(stmts) => {
301                if let Some(stmt) = stmts.last() {
302                    stmt
303                } else {
304                    // In the case where the block is empty, attempt to grab the statement
305                    // before the block. Because we use saturating sub to get the second to
306                    // last index, this can always be safely unwrapped.
307                    &run_body[run_body.len().saturating_sub(2)]
308                }
309            }
310            _ => last_stmt,
311        };
312        let mut source_span = self.stmt_span_without_semicolon(source_stmt);
313
314        // Consider yul return statement as final statement (if it's loc is lower) .
315        if let Some(yul_return_span) = last_yul_return_span
316            && yul_return_span.hi() < source_span.lo()
317        {
318            source_span = yul_return_span;
319        }
320
321        // Map the source location of the final statement of the `run()` function to its
322        // corresponding runtime program counter
323        let (_sf, range) = self.compiler.sess().source_map().span_to_source(source_span).unwrap();
324        dbg!(source_span, &range, &_sf.src[range.clone()]);
325        let offset = range.start as u32;
326        let length = range.len() as u32;
327        let final_pc = deployed_bytecode
328            .source_map()
329            .ok_or_else(|| eyre::eyre!("No source map found for `REPL` contract"))??
330            .into_iter()
331            .zip(InstructionIter::new(deployed_bytecode_bytes))
332            .filter(|(s, _)| s.offset() == offset && s.length() == length)
333            .map(|(_, i)| i.pc)
334            .max()
335            .unwrap_or_default();
336        Ok(Some(final_pc))
337    }
338
339    /// Statements' ranges in the solc source map do not include the semicolon.
340    fn stmt_span_without_semicolon(&self, stmt: &hir::Stmt<'_>) -> Span {
341        match stmt.kind {
342            hir::StmtKind::DeclSingle(id) => {
343                let decl = self.gcx().hir.variable(id);
344                if let Some(expr) = decl.initializer {
345                    stmt.span.with_hi(expr.span.hi())
346                } else {
347                    stmt.span
348                }
349            }
350            hir::StmtKind::DeclMulti(_, expr) => stmt.span.with_hi(expr.span.hi()),
351            hir::StmtKind::Expr(expr) => expr.span,
352            _ => stmt.span,
353        }
354    }
355
356    fn run_func_body(&self) -> hir::Block<'_> {
357        let c = self.repl_contract_hir().expect("REPL contract not found in HIR");
358        let f = c
359            .functions()
360            .find(|&f| self.gcx().hir.function(f).name.as_ref().map(|n| n.as_str()) == Some("run"))
361            .expect("`run()` function not found in REPL contract");
362        self.gcx().hir.function(f).body.expect("`run()` function does not have a body")
363    }
364
365    fn repl_contract_hir(&self) -> Option<&hir::Contract<'_>> {
366        self.gcx().hir.contracts().find(|c| c.name.as_str() == "REPL")
367    }
368}
369
370/// Configuration for the [SessionSource]
371#[derive(Clone, Debug, Default, Serialize, Deserialize)]
372pub struct SessionSourceConfig {
373    /// Foundry configuration
374    pub foundry_config: Config,
375    /// EVM Options
376    pub evm_opts: EvmOpts,
377    /// Disable the default `Vm` import.
378    pub no_vm: bool,
379    /// In-memory REVM db for the session's runner.
380    #[serde(skip)]
381    pub backend: Option<Backend>,
382    /// Optionally enable traces for the REPL contract execution
383    pub traces: bool,
384    /// Optionally set calldata for the REPL contract execution
385    pub calldata: Option<Vec<u8>>,
386    /// Enable viaIR with minimum optimization
387    ///
388    /// This can fix most of the "stack too deep" errors while resulting a
389    /// relatively accurate source map.
390    pub ir_minimum: bool,
391}
392
393impl SessionSourceConfig {
394    /// Detect the solc version to know if VM can be injected.
395    pub fn detect_solc(&mut self) -> Result<()> {
396        if self.foundry_config.solc.is_none() {
397            let version = Solc::ensure_installed(&"*".parse().unwrap())?;
398            self.foundry_config.solc = Some(SolcReq::Version(version));
399        }
400        if !self.no_vm
401            && let Some(version) = self.foundry_config.solc_version()
402            && version < MIN_VM_VERSION
403        {
404            info!(%version, minimum=%MIN_VM_VERSION, "Disabling VM injection");
405            self.no_vm = true;
406        }
407        Ok(())
408    }
409}
410
411/// REPL Session Source wrapper
412///
413/// Heavily based on soli's [`ConstructedSource`](https://github.com/jpopesculian/soli/blob/master/src/main.rs#L166)
414#[derive(Debug, Serialize, Deserialize)]
415pub struct SessionSource {
416    /// The file name
417    pub file_name: String,
418    /// The contract name
419    pub contract_name: String,
420
421    /// Session Source configuration
422    pub config: SessionSourceConfig,
423
424    /// Global level Solidity code.
425    ///
426    /// Above and outside all contract declarations, in the global context.
427    pub global_code: String,
428    /// Top level Solidity code.
429    ///
430    /// Within the contract declaration, but outside of the `run()` function.
431    pub contract_code: String,
432    /// The code to be executed in the `run()` function.
433    pub run_code: String,
434
435    /// Cached VM source code.
436    #[serde(skip, default = "vm_source")]
437    vm_source: Source,
438    /// The generated output
439    #[serde(skip)]
440    output: OnceCell<GeneratedOutput>,
441}
442
443fn vm_source() -> Source {
444    Source::new(VM_SOURCE)
445}
446
447impl Clone for SessionSource {
448    fn clone(&self) -> Self {
449        Self {
450            file_name: self.file_name.clone(),
451            contract_name: self.contract_name.clone(),
452            global_code: self.global_code.clone(),
453            contract_code: self.contract_code.clone(),
454            run_code: self.run_code.clone(),
455            config: self.config.clone(),
456            vm_source: self.vm_source.clone(),
457            output: Default::default(),
458        }
459    }
460}
461
462impl SessionSource {
463    /// Creates a new source given a solidity compiler version
464    ///
465    /// # Panics
466    ///
467    /// If no Solc binary is set, cannot be found or the `--version` command fails
468    ///
469    /// ### Takes
470    ///
471    /// - An instance of [Solc]
472    /// - An instance of [SessionSourceConfig]
473    ///
474    /// ### Returns
475    ///
476    /// A new instance of [SessionSource]
477    pub fn new(mut config: SessionSourceConfig) -> Result<Self> {
478        config.detect_solc()?;
479        Ok(Self {
480            file_name: "ReplContract.sol".to_string(),
481            contract_name: "REPL".to_string(),
482            config,
483            global_code: Default::default(),
484            contract_code: Default::default(),
485            run_code: Default::default(),
486            vm_source: vm_source(),
487            output: Default::default(),
488        })
489    }
490
491    /// Clones the [SessionSource] and appends a new line of code.
492    ///
493    /// Returns `true` if the new line was added to `run()`.
494    pub fn clone_with_new_line(&self, mut content: String) -> Result<(Self, bool)> {
495        if let Some((new_source, fragment)) = self
496            .parse_fragment(&content)
497            .or_else(|| {
498                content.push(';');
499                self.parse_fragment(&content)
500            })
501            .or_else(|| {
502                content = content.trim_end().trim_end_matches(';').to_string();
503                self.parse_fragment(&content)
504            })
505        {
506            Ok((new_source, matches!(fragment, ParseTreeFragment::Function)))
507        } else {
508            eyre::bail!("\"{}\"", content.trim());
509        }
510    }
511
512    /// Parses a fragment of Solidity code in memory and assigns it a scope within the
513    /// [`SessionSource`].
514    fn parse_fragment(&self, buffer: &str) -> Option<(Self, ParseTreeFragment)> {
515        #[track_caller]
516        fn debug_errors(errors: &EmittedDiagnostics) {
517            debug!("{errors}");
518        }
519
520        let mut this = self.clone();
521        match this.add_run_code(buffer).parse() {
522            Ok(()) => return Some((this, ParseTreeFragment::Function)),
523            Err(e) => debug_errors(&e),
524        }
525        this = self.clone();
526        match this.add_contract_code(buffer).parse() {
527            Ok(()) => return Some((this, ParseTreeFragment::Contract)),
528            Err(e) => debug_errors(&e),
529        }
530        this = self.clone();
531        match this.add_global_code(buffer).parse() {
532            Ok(()) => return Some((this, ParseTreeFragment::Source)),
533            Err(e) => debug_errors(&e),
534        }
535        None
536    }
537
538    /// Append global-level code to the source.
539    pub fn add_global_code(&mut self, content: &str) -> &mut Self {
540        self.global_code.push_str(content.trim());
541        self.global_code.push('\n');
542        self.clear_output();
543        self
544    }
545
546    /// Append contract-level code to the source.
547    pub fn add_contract_code(&mut self, content: &str) -> &mut Self {
548        self.contract_code.push_str(content.trim());
549        self.contract_code.push('\n');
550        self.clear_output();
551        self
552    }
553
554    /// Append code to the `run()` function of the REPL contract.
555    pub fn add_run_code(&mut self, content: &str) -> &mut Self {
556        self.run_code.push_str(content.trim());
557        self.run_code.push('\n');
558        self.clear_output();
559        self
560    }
561
562    /// Clears all source code.
563    pub fn clear(&mut self) {
564        String::clear(&mut self.global_code);
565        String::clear(&mut self.contract_code);
566        String::clear(&mut self.run_code);
567        self.clear_output();
568    }
569
570    /// Clear the `run()` function code.
571    pub fn clear_run(&mut self) -> &mut Self {
572        String::clear(&mut self.run_code);
573        self.clear_output();
574        self
575    }
576
577    fn clear_output(&mut self) {
578        self.output.take();
579    }
580
581    /// Compiles the source if necessary.
582    pub fn build(&self) -> Result<&GeneratedOutput> {
583        // TODO: mimics `get_or_try_init`
584        if let Some(output) = self.output.get() {
585            return Ok(output);
586        }
587        let output = self.compile()?;
588        let intermediate = self.generate_intermediate_output()?;
589        let output = GeneratedOutput { output, intermediate };
590        Ok(self.output.get_or_init(|| output))
591    }
592
593    /// Compiles the source.
594    #[cold]
595    fn compile(&self) -> Result<ProjectCompileOutput> {
596        let sources = self.get_sources();
597
598        let mut project = self.config.foundry_config.ephemeral_project()?;
599        self.config.foundry_config.disable_optimizations(&mut project, self.config.ir_minimum);
600        let mut output = ProjectCompiler::with_sources(&project, sources)?.compile()?;
601
602        if output.has_compiler_errors() {
603            eyre::bail!("{output}");
604        }
605
606        // TODO(dani): re-enable
607        if cfg!(false) {
608            output.parser_mut().solc_mut().compiler_mut().enter_mut(|c| {
609                let _ = c.lower_asts();
610            });
611        }
612
613        Ok(output)
614    }
615
616    fn get_sources(&self) -> Sources {
617        let mut sources = Sources::new();
618
619        let src = self.to_repl_source();
620        sources.insert(self.file_name.clone().into(), Source::new(src));
621
622        // Include Vm.sol if forge-std remapping is not available.
623        if !self.config.no_vm
624            && !self
625                .config
626                .foundry_config
627                .get_all_remappings()
628                .any(|r| r.name.starts_with("forge-std"))
629        {
630            sources.insert("forge-std/Vm.sol".into(), self.vm_source.clone());
631        }
632
633        sources
634    }
635
636    /// Generate intermediate contracts for all contract definitions in the compilation source.
637    ///
638    /// ### Returns
639    ///
640    /// Optionally, a map of contract names to a vec of [IntermediateContract]s.
641    pub fn generate_intermediate_contracts(&self) -> Result<HashMap<String, IntermediateContract>> {
642        let mut res_map = HashMap::default();
643        let parsed_map = self.get_sources();
644        for source in parsed_map.values() {
645            Self::get_intermediate_contract(&source.content, &mut res_map);
646        }
647        Ok(res_map)
648    }
649
650    /// Generate intermediate output for the REPL contract
651    pub fn generate_intermediate_output(&self) -> Result<IntermediateOutput> {
652        // Parse generate intermediate contracts
653        let intermediate_contracts = self.generate_intermediate_contracts()?;
654
655        // Construct variable definitions
656        let variable_definitions = intermediate_contracts
657            .get("REPL")
658            .ok_or_else(|| eyre::eyre!("Could not find intermediate REPL contract!"))?
659            .variable_definitions
660            .clone()
661            .into_iter()
662            .map(|(k, v)| (k, v.ty))
663            .collect::<HashMap<String, pt::Expression>>();
664        // Construct intermediate output
665        let mut intermediate_output = IntermediateOutput {
666            repl_contract_expressions: variable_definitions,
667            intermediate_contracts,
668        };
669
670        // Add all statements within the run function to the repl_contract_expressions map
671        for (key, val) in intermediate_output
672            .run_func_body()?
673            .clone()
674            .iter()
675            .flat_map(Self::get_statement_definitions)
676        {
677            intermediate_output.repl_contract_expressions.insert(key, val);
678        }
679
680        Ok(intermediate_output)
681    }
682
683    /// Construct the REPL source.
684    pub fn to_repl_source(&self) -> String {
685        let Self {
686            contract_name,
687            global_code,
688            contract_code: top_level_code,
689            run_code,
690            config,
691            ..
692        } = self;
693        let (mut vm_import, mut vm_constant) = (String::new(), String::new());
694        // Check if there's any `forge-std` remapping and determine proper path to it by
695        // searching remapping path.
696        if !config.no_vm
697            && let Some(remapping) = config
698                .foundry_config
699                .remappings
700                .iter()
701                .find(|remapping| remapping.name == "forge-std/")
702            && let Some(vm_path) = WalkDir::new(&remapping.path.path)
703                .into_iter()
704                .filter_map(|e| e.ok())
705                .find(|e| e.file_name() == "Vm.sol")
706        {
707            vm_import = format!("import {{Vm}} from \"{}\";\n", vm_path.path().display());
708            vm_constant = "Vm internal constant vm = Vm(address(uint160(uint256(keccak256(\"hevm cheat code\")))));\n".to_string();
709        }
710
711        format!(
712            r#"
713// SPDX-License-Identifier: UNLICENSED
714pragma solidity 0;
715
716{vm_import}
717{global_code}
718
719contract {contract_name} {{
720    {vm_constant}
721    {top_level_code}
722
723    /// @notice REPL contract entry point
724    function run() public {{
725        {run_code}
726    }}
727}}"#,
728        )
729    }
730
731    /// Parse the current source in memory using Solar.
732    pub(crate) fn parse(&self) -> Result<(), EmittedDiagnostics> {
733        let sess =
734            solar::interface::Session::builder().with_buffer_emitter(Default::default()).build();
735        let _ = sess.enter_sequential(|| -> solar::interface::Result<()> {
736            let arena = solar::ast::Arena::new();
737            let filename = self.file_name.clone().into();
738            let src = self.to_repl_source();
739            let mut parser = solar::parse::Parser::from_source_code(&sess, &arena, filename, src)?;
740            let _ast = parser.parse_file().map_err(|e| e.emit())?;
741            Ok(())
742        });
743        sess.dcx.emitted_errors().unwrap()
744    }
745
746    /// Gets the [IntermediateContract] for a Solidity source string and inserts it into the
747    /// passed `res_map`. In addition, recurses on any imported files as well.
748    ///
749    /// ### Takes
750    /// - `content` - A Solidity source string
751    /// - `res_map` - A mutable reference to a map of contract names to [IntermediateContract]s
752    pub fn get_intermediate_contract(
753        content: &str,
754        res_map: &mut HashMap<String, IntermediateContract>,
755    ) {
756        if let Ok((pt::SourceUnit(source_unit_parts), _)) = solang_parser::parse(content, 0) {
757            let func_defs = source_unit_parts
758                .into_iter()
759                .filter_map(|sup| match sup {
760                    pt::SourceUnitPart::ImportDirective(i) => match i {
761                        pt::Import::Plain(s, _)
762                        | pt::Import::Rename(s, _, _)
763                        | pt::Import::GlobalSymbol(s, _, _) => {
764                            let s = match s {
765                                pt::ImportPath::Filename(s) => s.string,
766                                pt::ImportPath::Path(p) => p.to_string(),
767                            };
768                            let path = PathBuf::from(s);
769
770                            match fs::read_to_string(path) {
771                                Ok(source) => {
772                                    Self::get_intermediate_contract(&source, res_map);
773                                    None
774                                }
775                                Err(_) => None,
776                            }
777                        }
778                    },
779                    pt::SourceUnitPart::ContractDefinition(cd) => {
780                        let mut intermediate = IntermediateContract::default();
781
782                        cd.parts.into_iter().for_each(|part| match part {
783                            pt::ContractPart::FunctionDefinition(def) => {
784                                // Only match normal function definitions here.
785                                if matches!(def.ty, pt::FunctionTy::Function) {
786                                    intermediate
787                                        .function_definitions
788                                        .insert(def.name.clone().unwrap().name, def);
789                                }
790                            }
791                            pt::ContractPart::EventDefinition(def) => {
792                                let event_name = def.name.safe_unwrap().name.clone();
793                                intermediate.event_definitions.insert(event_name, def);
794                            }
795                            pt::ContractPart::StructDefinition(def) => {
796                                let struct_name = def.name.safe_unwrap().name.clone();
797                                intermediate.struct_definitions.insert(struct_name, def);
798                            }
799                            pt::ContractPart::VariableDefinition(def) => {
800                                let var_name = def.name.safe_unwrap().name.clone();
801                                intermediate.variable_definitions.insert(var_name, def);
802                            }
803                            _ => {}
804                        });
805                        Some((cd.name.safe_unwrap().name.clone(), intermediate))
806                    }
807                    _ => None,
808                })
809                .collect::<HashMap<String, IntermediateContract>>();
810            res_map.extend(func_defs);
811        }
812    }
813
814    /// Helper to deconstruct a statement
815    ///
816    /// ### Takes
817    ///
818    /// A reference to a [pt::Statement]
819    ///
820    /// ### Returns
821    ///
822    /// A vector containing tuples of the inner expressions' names, types, and storage locations.
823    pub fn get_statement_definitions(statement: &pt::Statement) -> Vec<(String, pt::Expression)> {
824        match statement {
825            pt::Statement::VariableDefinition(_, def, _) => {
826                vec![(def.name.safe_unwrap().name.clone(), def.ty.clone())]
827            }
828            pt::Statement::Expression(_, pt::Expression::Assign(_, left, _)) => {
829                if let pt::Expression::List(_, list) = left.as_ref() {
830                    list.iter()
831                        .filter_map(|(_, param)| {
832                            param.as_ref().and_then(|param| {
833                                param
834                                    .name
835                                    .as_ref()
836                                    .map(|name| (name.name.clone(), param.ty.clone()))
837                            })
838                        })
839                        .collect()
840                } else {
841                    Vec::default()
842                }
843            }
844            _ => Vec::default(),
845        }
846    }
847}
848
849/// A Parse Tree Fragment
850///
851/// Used to determine whether an input will go to the "run()" function,
852/// the top level of the contract, or in global scope.
853#[derive(Debug)]
854enum ParseTreeFragment {
855    /// Code for the global scope
856    Source,
857    /// Code for the top level of the contract
858    Contract,
859    /// Code for the "run()" function
860    Function,
861}