chisel/
source.rs

1//! Session Source
2//!
3//! This module contains the `SessionSource` struct, which is a minimal wrapper around
4//! the REPL contract's source code. It provides simple compilation, parsing, and
5//! execution helpers.
6
7use eyre::Result;
8use forge_fmt::solang_ext::{CodeLocationExt, SafeUnwrap};
9use foundry_common::fs;
10use foundry_compilers::{
11    Artifact, ProjectCompileOutput,
12    artifacts::{ConfigurableContractArtifact, Source, Sources},
13    project::ProjectCompiler,
14    solc::Solc,
15};
16use foundry_config::{Config, SolcReq};
17use foundry_evm::{backend::Backend, opts::EvmOpts};
18use semver::Version;
19use serde::{Deserialize, Serialize};
20use solang_parser::pt;
21use solar::interface::diagnostics::EmittedDiagnostics;
22use std::{cell::OnceCell, collections::HashMap, fmt, path::PathBuf};
23use walkdir::WalkDir;
24
25/// The minimum Solidity version of the `Vm` interface.
26pub const MIN_VM_VERSION: Version = Version::new(0, 6, 2);
27
28/// Solidity source for the `Vm` interface in [forge-std](https://github.com/foundry-rs/forge-std)
29static VM_SOURCE: &str = include_str!("../../../testdata/cheats/Vm.sol");
30
31/// [`SessionSource`] build output.
32pub struct GeneratedOutput {
33    output: ProjectCompileOutput,
34    pub(crate) intermediate: IntermediateOutput,
35}
36
37pub struct GeneratedOutputRef<'a> {
38    output: &'a ProjectCompileOutput,
39    // compiler: &'b solar::sema::CompilerRef<'c>,
40    pub(crate) intermediate: &'a IntermediateOutput,
41}
42
43/// Intermediate output for the compiled [SessionSource]
44#[derive(Clone, Debug, PartialEq, Eq)]
45pub struct IntermediateOutput {
46    /// All expressions within the REPL contract's run function and top level scope.
47    pub repl_contract_expressions: HashMap<String, pt::Expression>,
48    /// Intermediate contracts
49    pub intermediate_contracts: IntermediateContracts,
50}
51
52/// A refined intermediate parse tree for a contract that enables easy lookups
53/// of definitions.
54#[derive(Clone, Debug, Default, PartialEq, Eq)]
55pub struct IntermediateContract {
56    /// All function definitions within the contract
57    pub function_definitions: HashMap<String, Box<pt::FunctionDefinition>>,
58    /// All event definitions within the contract
59    pub event_definitions: HashMap<String, Box<pt::EventDefinition>>,
60    /// All struct definitions within the contract
61    pub struct_definitions: HashMap<String, Box<pt::StructDefinition>>,
62    /// All variable definitions within the top level scope of the contract
63    pub variable_definitions: HashMap<String, Box<pt::VariableDefinition>>,
64}
65
66/// A defined type for a map of contract names to [IntermediateContract]s
67type IntermediateContracts = HashMap<String, IntermediateContract>;
68
69impl fmt::Debug for GeneratedOutput {
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("GeneratedOutput").finish_non_exhaustive()
72    }
73}
74
75impl GeneratedOutput {
76    pub fn enter<T: Send>(&self, f: impl FnOnce(GeneratedOutputRef<'_>) -> T + Send) -> T {
77        // TODO(dani): once intermediate is removed
78        // self.output
79        //     .parser()
80        //     .solc()
81        //     .compiler()
82        //     .enter(|compiler| f(GeneratedOutputRef { output: &self.output, compiler }))
83        f(GeneratedOutputRef { output: &self.output, intermediate: &self.intermediate })
84    }
85}
86
87impl GeneratedOutputRef<'_> {
88    pub fn repl_contract(&self) -> Option<&ConfigurableContractArtifact> {
89        self.output.find_first("REPL")
90    }
91}
92
93impl std::ops::Deref for GeneratedOutput {
94    type Target = IntermediateOutput;
95    fn deref(&self) -> &Self::Target {
96        &self.intermediate
97    }
98}
99impl std::ops::Deref for GeneratedOutputRef<'_> {
100    type Target = IntermediateOutput;
101    fn deref(&self) -> &Self::Target {
102        self.intermediate
103    }
104}
105
106impl IntermediateOutput {
107    pub fn get_event(&self, input: &str) -> Option<&pt::EventDefinition> {
108        self.intermediate_contracts
109            .get("REPL")
110            .and_then(|contract| contract.event_definitions.get(input).map(std::ops::Deref::deref))
111    }
112
113    pub fn final_pc(&self, contract: &ConfigurableContractArtifact) -> Result<Option<usize>> {
114        let deployed_bytecode = contract
115            .get_deployed_bytecode()
116            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
117        let deployed_bytecode_bytes = deployed_bytecode
118            .bytes()
119            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
120
121        let run_func_statements = self.run_func_body()?;
122
123        // Record loc of first yul block return statement (if any).
124        // This is used to decide which is the final statement within the `run()` method.
125        // see <https://github.com/foundry-rs/foundry/issues/4617>.
126        let last_yul_return = run_func_statements.iter().find_map(|statement| {
127            if let pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } = statement
128                && let Some(statement) = block.statements.last()
129                && let pt::YulStatement::FunctionCall(yul_call) = statement
130                && yul_call.id.name == "return"
131            {
132                return Some(statement.loc());
133            }
134            None
135        });
136
137        // Find the last statement within the "run()" method and get the program
138        // counter via the source map.
139        let Some(final_statement) = run_func_statements.last() else { return Ok(None) };
140
141        // If the final statement is some type of block (assembly, unchecked, or regular),
142        // we need to find the final statement within that block. Otherwise, default to
143        // the source loc of the final statement of the `run()` function's block.
144        //
145        // There is some code duplication within the arms due to the difference between
146        // the [pt::Statement] type and the [pt::YulStatement] types.
147        let mut source_loc = match final_statement {
148            pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } => {
149                // Select last non variable declaration statement, see <https://github.com/foundry-rs/foundry/issues/4938>.
150                let last_statement = block.statements.iter().rev().find(|statement| {
151                    !matches!(statement, pt::YulStatement::VariableDeclaration(_, _, _))
152                });
153                if let Some(statement) = last_statement {
154                    statement.loc()
155                } else {
156                    // In the case where the block is empty, attempt to grab the statement
157                    // before the asm block. Because we use saturating sub to get the second
158                    // to last index, this can always be safely unwrapped.
159                    run_func_statements
160                        .get(run_func_statements.len().saturating_sub(2))
161                        .unwrap()
162                        .loc()
163                }
164            }
165            pt::Statement::Block { loc: _, unchecked: _, statements } => {
166                if let Some(statement) = statements.last() {
167                    statement.loc()
168                } else {
169                    // In the case where the block is empty, attempt to grab the statement
170                    // before the block. Because we use saturating sub to get the second to
171                    // last index, this can always be safely unwrapped.
172                    run_func_statements
173                        .get(run_func_statements.len().saturating_sub(2))
174                        .unwrap()
175                        .loc()
176                }
177            }
178            _ => final_statement.loc(),
179        };
180
181        // Consider yul return statement as final statement (if it's loc is lower) .
182        if let Some(yul_return) = last_yul_return
183            && yul_return.end() < source_loc.start()
184        {
185            source_loc = yul_return;
186        }
187
188        // Map the source location of the final statement of the `run()` function to its
189        // corresponding runtime program counter
190        let final_pc = {
191            let offset = source_loc.start() as u32;
192            let length = (source_loc.end() - source_loc.start()) as u32;
193            contract
194                .get_source_map_deployed()
195                .unwrap()
196                .unwrap()
197                .into_iter()
198                .zip(InstructionIter::new(deployed_bytecode_bytes))
199                .filter(|(s, _)| s.offset() == offset && s.length() == length)
200                .map(|(_, i)| i.pc)
201                .max()
202                .unwrap_or_default()
203        };
204        Ok(Some(final_pc))
205    }
206
207    pub fn run_func_body(&self) -> Result<&Vec<pt::Statement>> {
208        match self
209            .intermediate_contracts
210            .get("REPL")
211            .ok_or_else(|| eyre::eyre!("Could not find REPL intermediate contract!"))?
212            .function_definitions
213            .get("run")
214            .ok_or_else(|| eyre::eyre!("Could not find run function definition in REPL contract!"))?
215            .body
216            .as_ref()
217            .ok_or_else(|| eyre::eyre!("Could not find run function body!"))?
218        {
219            pt::Statement::Block { statements, .. } => Ok(statements),
220            _ => eyre::bail!("Could not find statements within run function body!"),
221        }
222    }
223}
224
225// TODO(dani): further migration blocked on upstream work
226#[cfg(false)]
227impl<'gcx> GeneratedOutputRef<'_, '_, 'gcx> {
228    pub fn gcx(&self) -> Gcx<'gcx> {
229        self.compiler.gcx()
230    }
231
232    pub fn repl_contract(&self) -> Option<&ConfigurableContractArtifact> {
233        self.output.find_first("REPL")
234    }
235
236    pub fn get_event(&self, input: &str) -> Option<hir::EventId> {
237        self.gcx().hir.events_enumerated().find(|(_, e)| e.name.as_str() == input).map(|(id, _)| id)
238    }
239
240    pub fn final_pc(&self, contract: &ConfigurableContractArtifact) -> Result<Option<usize>> {
241        let deployed_bytecode = contract
242            .get_deployed_bytecode()
243            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
244        let deployed_bytecode_bytes = deployed_bytecode
245            .bytes()
246            .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
247
248        // Fetch the run function's body statement
249        let run_body = self.run_func_body();
250
251        // Record loc of first yul block return statement (if any).
252        // This is used to decide which is the final statement within the `run()` method.
253        // see <https://github.com/foundry-rs/foundry/issues/4617>.
254        let last_yul_return_span: Option<Span> = run_body.iter().find_map(|stmt| {
255            // TODO(dani): Yul is not yet lowered to HIR.
256            let _ = stmt;
257            /*
258            if let hir::StmtKind::Assembly { block, .. } = stmt {
259                if let Some(stmt) = block.last() {
260                    if let pt::YulStatement::FunctionCall(yul_call) = stmt {
261                        if yul_call.id.name == "return" {
262                            return Some(stmt.loc())
263                        }
264                    }
265                }
266            }
267            */
268            None
269        });
270
271        // Find the last statement within the "run()" method and get the program
272        // counter via the source map.
273        let Some(last_stmt) = run_body.last() else { return Ok(None) };
274
275        // If the final statement is some type of block (assembly, unchecked, or regular),
276        // we need to find the final statement within that block. Otherwise, default to
277        // the source loc of the final statement of the `run()` function's block.
278        //
279        // There is some code duplication within the arms due to the difference between
280        // the [pt::Statement] type and the [pt::YulStatement] types.
281        let source_stmt = match &last_stmt.kind {
282            // TODO(dani): Yul is not yet lowered to HIR.
283            /*
284            pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } => {
285                // Select last non variable declaration statement, see <https://github.com/foundry-rs/foundry/issues/4938>.
286                let last_statement = block.statements.iter().rev().find(|statement| {
287                    !matches!(statement, pt::YulStatement::VariableDeclaration(_, _, _))
288                });
289                if let Some(stmt) = last_statement {
290                    stmt
291                } else {
292                    // In the case where the block is empty, attempt to grab the statement
293                    // before the block. Because we use saturating sub to get the second to
294                    // last index, this can always be safely unwrapped.
295                    &run_body[run_body.len().saturating_sub(2)]
296                }
297            }
298            */
299            hir::StmtKind::UncheckedBlock(stmts) | hir::StmtKind::Block(stmts) => {
300                if let Some(stmt) = stmts.last() {
301                    stmt
302                } else {
303                    // In the case where the block is empty, attempt to grab the statement
304                    // before the block. Because we use saturating sub to get the second to
305                    // last index, this can always be safely unwrapped.
306                    &run_body[run_body.len().saturating_sub(2)]
307                }
308            }
309            _ => last_stmt,
310        };
311        let mut source_span = self.stmt_span_without_semicolon(source_stmt);
312
313        // Consider yul return statement as final statement (if it's loc is lower) .
314        if let Some(yul_return_span) = last_yul_return_span
315            && yul_return_span.hi() < source_span.lo()
316        {
317            source_span = yul_return_span;
318        }
319
320        // Map the source location of the final statement of the `run()` function to its
321        // corresponding runtime program counter
322        let (_sf, range) = self.compiler.sess().source_map().span_to_source(source_span).unwrap();
323        dbg!(source_span, &range, &_sf.src[range.clone()]);
324        let offset = range.start as u32;
325        let length = range.len() as u32;
326        let final_pc = deployed_bytecode
327            .source_map()
328            .ok_or_else(|| eyre::eyre!("No source map found for `REPL` contract"))??
329            .into_iter()
330            .zip(InstructionIter::new(deployed_bytecode_bytes))
331            .filter(|(s, _)| s.offset() == offset && s.length() == length)
332            .map(|(_, i)| i.pc)
333            .max()
334            .unwrap_or_default();
335        Ok(Some(final_pc))
336    }
337
338    /// Statements' ranges in the solc source map do not include the semicolon.
339    fn stmt_span_without_semicolon(&self, stmt: &hir::Stmt<'_>) -> Span {
340        match stmt.kind {
341            hir::StmtKind::DeclSingle(id) => {
342                let decl = self.gcx().hir.variable(id);
343                if let Some(expr) = decl.initializer {
344                    stmt.span.with_hi(expr.span.hi())
345                } else {
346                    stmt.span
347                }
348            }
349            hir::StmtKind::DeclMulti(_, expr) => stmt.span.with_hi(expr.span.hi()),
350            hir::StmtKind::Expr(expr) => expr.span,
351            _ => stmt.span,
352        }
353    }
354
355    fn run_func_body(&self) -> hir::Block<'_> {
356        let c = self.repl_contract_hir().expect("REPL contract not found in HIR");
357        let f = c
358            .functions()
359            .find(|&f| self.gcx().hir.function(f).name.as_ref().map(|n| n.as_str()) == Some("run"))
360            .expect("`run()` function not found in REPL contract");
361        self.gcx().hir.function(f).body.expect("`run()` function does not have a body")
362    }
363
364    fn repl_contract_hir(&self) -> Option<&hir::Contract<'_>> {
365        self.gcx().hir.contracts().find(|c| c.name.as_str() == "REPL")
366    }
367}
368
369/// Configuration for the [SessionSource]
370#[derive(Clone, Debug, Default, Serialize, Deserialize)]
371pub struct SessionSourceConfig {
372    /// Foundry configuration
373    pub foundry_config: Config,
374    /// EVM Options
375    pub evm_opts: EvmOpts,
376    /// Disable the default `Vm` import.
377    pub no_vm: bool,
378    /// In-memory REVM db for the session's runner.
379    #[serde(skip)]
380    pub backend: Option<Backend>,
381    /// Optionally enable traces for the REPL contract execution
382    pub traces: bool,
383    /// Optionally set calldata for the REPL contract execution
384    pub calldata: Option<Vec<u8>>,
385}
386
387impl SessionSourceConfig {
388    /// Detect the solc version to know if VM can be injected.
389    pub fn detect_solc(&mut self) -> Result<()> {
390        if self.foundry_config.solc.is_none() {
391            let version = Solc::ensure_installed(&"*".parse().unwrap())?;
392            self.foundry_config.solc = Some(SolcReq::Version(version));
393        }
394        if !self.no_vm
395            && let Some(version) = self.foundry_config.solc_version()
396            && version < MIN_VM_VERSION
397        {
398            tracing::info!(%version, minimum=%MIN_VM_VERSION, "Disabling VM injection");
399            self.no_vm = true;
400        }
401        Ok(())
402    }
403}
404
405/// REPL Session Source wrapper
406///
407/// Heavily based on soli's [`ConstructedSource`](https://github.com/jpopesculian/soli/blob/master/src/main.rs#L166)
408#[derive(Debug, Serialize, Deserialize)]
409pub struct SessionSource {
410    /// The file name
411    pub file_name: String,
412    /// The contract name
413    pub contract_name: String,
414
415    /// Session Source configuration
416    pub config: SessionSourceConfig,
417
418    /// Global level Solidity code.
419    ///
420    /// Above and outside all contract declarations, in the global context.
421    pub global_code: String,
422    /// Top level Solidity code.
423    ///
424    /// Within the contract declaration, but outside of the `run()` function.
425    pub contract_code: String,
426    /// The code to be executed in the `run()` function.
427    pub run_code: String,
428
429    /// Cached VM source code.
430    #[serde(skip, default = "vm_source")]
431    vm_source: Source,
432    /// The generated output
433    #[serde(skip)]
434    output: OnceCell<GeneratedOutput>,
435}
436
437fn vm_source() -> Source {
438    Source::new(VM_SOURCE)
439}
440
441impl Clone for SessionSource {
442    fn clone(&self) -> Self {
443        Self {
444            file_name: self.file_name.clone(),
445            contract_name: self.contract_name.clone(),
446            global_code: self.global_code.clone(),
447            contract_code: self.contract_code.clone(),
448            run_code: self.run_code.clone(),
449            config: self.config.clone(),
450            vm_source: self.vm_source.clone(),
451            output: Default::default(),
452        }
453    }
454}
455
456impl SessionSource {
457    /// Creates a new source given a solidity compiler version
458    ///
459    /// # Panics
460    ///
461    /// If no Solc binary is set, cannot be found or the `--version` command fails
462    ///
463    /// ### Takes
464    ///
465    /// - An instance of [Solc]
466    /// - An instance of [SessionSourceConfig]
467    ///
468    /// ### Returns
469    ///
470    /// A new instance of [SessionSource]
471    pub fn new(mut config: SessionSourceConfig) -> Result<Self> {
472        config.detect_solc()?;
473        Ok(Self {
474            file_name: "ReplContract.sol".to_string(),
475            contract_name: "REPL".to_string(),
476            config,
477            global_code: Default::default(),
478            contract_code: Default::default(),
479            run_code: Default::default(),
480            vm_source: vm_source(),
481            output: Default::default(),
482        })
483    }
484
485    /// Clones the [SessionSource] and appends a new line of code.
486    ///
487    /// Returns `true` if the new line was added to `run()`.
488    pub fn clone_with_new_line(&self, mut content: String) -> Result<(Self, bool)> {
489        if let Some((new_source, fragment)) = self
490            .parse_fragment(&content)
491            .or_else(|| {
492                content.push(';');
493                self.parse_fragment(&content)
494            })
495            .or_else(|| {
496                content = content.trim_end().trim_end_matches(';').to_string();
497                self.parse_fragment(&content)
498            })
499        {
500            Ok((new_source, matches!(fragment, ParseTreeFragment::Function)))
501        } else {
502            eyre::bail!("\"{}\"", content.trim());
503        }
504    }
505
506    /// Parses a fragment of Solidity code in memory and assigns it a scope within the
507    /// [`SessionSource`].
508    fn parse_fragment(&self, buffer: &str) -> Option<(Self, ParseTreeFragment)> {
509        #[track_caller]
510        fn debug_errors(errors: &EmittedDiagnostics) {
511            tracing::debug!("{errors}");
512        }
513
514        let mut this = self.clone();
515        match this.add_run_code(buffer).parse() {
516            Ok(()) => return Some((this, ParseTreeFragment::Function)),
517            Err(e) => debug_errors(&e),
518        }
519        this = self.clone();
520        match this.add_contract_code(buffer).parse() {
521            Ok(()) => return Some((this, ParseTreeFragment::Contract)),
522            Err(e) => debug_errors(&e),
523        }
524        this = self.clone();
525        match this.add_global_code(buffer).parse() {
526            Ok(()) => return Some((this, ParseTreeFragment::Source)),
527            Err(e) => debug_errors(&e),
528        }
529        None
530    }
531
532    /// Append global-level code to the source.
533    pub fn add_global_code(&mut self, content: &str) -> &mut Self {
534        self.global_code.push_str(content.trim());
535        self.global_code.push('\n');
536        self.clear_output();
537        self
538    }
539
540    /// Append contract-level code to the source.
541    pub fn add_contract_code(&mut self, content: &str) -> &mut Self {
542        self.contract_code.push_str(content.trim());
543        self.contract_code.push('\n');
544        self.clear_output();
545        self
546    }
547
548    /// Append code to the `run()` function of the REPL contract.
549    pub fn add_run_code(&mut self, content: &str) -> &mut Self {
550        self.run_code.push_str(content.trim());
551        self.run_code.push('\n');
552        self.clear_output();
553        self
554    }
555
556    /// Clears all source code.
557    pub fn clear(&mut self) {
558        String::clear(&mut self.global_code);
559        String::clear(&mut self.contract_code);
560        String::clear(&mut self.run_code);
561        self.clear_output();
562    }
563
564    /// Clear the global-level code .
565    pub fn clear_global(&mut self) -> &mut Self {
566        String::clear(&mut self.global_code);
567        self.clear_output();
568        self
569    }
570
571    /// Clear the contract-level code .
572    pub fn clear_contract(&mut self) -> &mut Self {
573        String::clear(&mut self.contract_code);
574        self.clear_output();
575        self
576    }
577
578    /// Clear the `run()` function code.
579    pub fn clear_run(&mut self) -> &mut Self {
580        String::clear(&mut self.run_code);
581        self.clear_output();
582        self
583    }
584
585    fn clear_output(&mut self) {
586        self.output.take();
587    }
588
589    /// Compiles the source if necessary.
590    pub fn build(&self) -> Result<&GeneratedOutput> {
591        // TODO: mimics `get_or_try_init`
592        if let Some(output) = self.output.get() {
593            return Ok(output);
594        }
595        let output = self.compile()?;
596        let intermediate = self.generate_intermediate_output()?;
597        let output = GeneratedOutput { output, intermediate };
598        Ok(self.output.get_or_init(|| output))
599    }
600
601    /// Compiles the source.
602    #[cold]
603    fn compile(&self) -> Result<ProjectCompileOutput> {
604        let sources = self.get_sources();
605
606        let project = self.config.foundry_config.ephemeral_project()?;
607        let mut output = ProjectCompiler::with_sources(&project, sources)?.compile()?;
608
609        if output.has_compiler_errors() {
610            eyre::bail!("{output}");
611        }
612
613        // TODO(dani): re-enable
614        if cfg!(false) {
615            output.parser_mut().solc_mut().compiler_mut().enter_mut(|c| {
616                let _ = c.lower_asts();
617            });
618        }
619
620        Ok(output)
621    }
622
623    fn get_sources(&self) -> Sources {
624        let mut sources = Sources::new();
625
626        let src = self.to_repl_source();
627        sources.insert(self.file_name.clone().into(), Source::new(src));
628
629        // Include Vm.sol if forge-std remapping is not available.
630        if !self.config.no_vm
631            && !self
632                .config
633                .foundry_config
634                .get_all_remappings()
635                .any(|r| r.name.starts_with("forge-std"))
636        {
637            sources.insert("forge-std/Vm.sol".into(), self.vm_source.clone());
638        }
639
640        sources
641    }
642
643    /// Generate intermediate contracts for all contract definitions in the compilation source.
644    ///
645    /// ### Returns
646    ///
647    /// Optionally, a map of contract names to a vec of [IntermediateContract]s.
648    pub fn generate_intermediate_contracts(&self) -> Result<HashMap<String, IntermediateContract>> {
649        let mut res_map = HashMap::default();
650        let parsed_map = self.get_sources();
651        for source in parsed_map.values() {
652            Self::get_intermediate_contract(&source.content, &mut res_map);
653        }
654        Ok(res_map)
655    }
656
657    /// Generate intermediate output for the REPL contract
658    pub fn generate_intermediate_output(&self) -> Result<IntermediateOutput> {
659        // Parse generate intermediate contracts
660        let intermediate_contracts = self.generate_intermediate_contracts()?;
661
662        // Construct variable definitions
663        let variable_definitions = intermediate_contracts
664            .get("REPL")
665            .ok_or_else(|| eyre::eyre!("Could not find intermediate REPL contract!"))?
666            .variable_definitions
667            .clone()
668            .into_iter()
669            .map(|(k, v)| (k, v.ty))
670            .collect::<HashMap<String, pt::Expression>>();
671        // Construct intermediate output
672        let mut intermediate_output = IntermediateOutput {
673            repl_contract_expressions: variable_definitions,
674            intermediate_contracts,
675        };
676
677        // Add all statements within the run function to the repl_contract_expressions map
678        for (key, val) in intermediate_output
679            .run_func_body()?
680            .clone()
681            .iter()
682            .flat_map(Self::get_statement_definitions)
683        {
684            intermediate_output.repl_contract_expressions.insert(key, val);
685        }
686
687        Ok(intermediate_output)
688    }
689
690    /// Construct the source as a valid Forge script.
691    pub fn to_script_source(&self) -> String {
692        let Self {
693            contract_name,
694            global_code,
695            contract_code: top_level_code,
696            run_code,
697            config,
698            ..
699        } = self;
700
701        let script_import =
702            if !config.no_vm { "import {Script} from \"forge-std/Script.sol\";\n" } else { "" };
703
704        format!(
705            r#"
706// SPDX-License-Identifier: UNLICENSED
707pragma solidity 0;
708
709{script_import}
710{global_code}
711
712contract {contract_name} is Script {{
713    {top_level_code}
714
715    /// @notice Script entry point
716    function run() public {{
717        {run_code}
718    }}
719}}"#,
720        )
721    }
722
723    /// Construct the REPL source.
724    pub fn to_repl_source(&self) -> String {
725        let Self {
726            contract_name,
727            global_code,
728            contract_code: top_level_code,
729            run_code,
730            config,
731            ..
732        } = self;
733        let (mut vm_import, mut vm_constant) = (String::new(), String::new());
734        // Check if there's any `forge-std` remapping and determine proper path to it by
735        // searching remapping path.
736        if !config.no_vm
737            && let Some(remapping) = config
738                .foundry_config
739                .remappings
740                .iter()
741                .find(|remapping| remapping.name == "forge-std/")
742            && let Some(vm_path) = WalkDir::new(&remapping.path.path)
743                .into_iter()
744                .filter_map(|e| e.ok())
745                .find(|e| e.file_name() == "Vm.sol")
746        {
747            vm_import = format!("import {{Vm}} from \"{}\";\n", vm_path.path().display());
748            vm_constant = "Vm internal constant vm = Vm(address(uint160(uint256(keccak256(\"hevm cheat code\")))));\n".to_string();
749        }
750
751        format!(
752            r#"
753// SPDX-License-Identifier: UNLICENSED
754pragma solidity 0;
755
756{vm_import}
757{global_code}
758
759contract {contract_name} {{
760    {vm_constant}
761    {top_level_code}
762
763    /// @notice REPL contract entry point
764    function run() public {{
765        {run_code}
766    }}
767}}"#,
768        )
769    }
770
771    /// Parse the current source in memory using Solar.
772    pub(crate) fn parse(&self) -> Result<(), EmittedDiagnostics> {
773        let sess =
774            solar::interface::Session::builder().with_buffer_emitter(Default::default()).build();
775        let _ = sess.enter_sequential(|| -> solar::interface::Result<()> {
776            let arena = solar::ast::Arena::new();
777            let filename = self.file_name.clone().into();
778            let src = self.to_repl_source();
779            let mut parser = solar::parse::Parser::from_source_code(&sess, &arena, filename, src)?;
780            let _ast = parser.parse_file().map_err(|e| e.emit())?;
781            Ok(())
782        });
783        sess.dcx.emitted_errors().unwrap()
784    }
785
786    /// Gets the [IntermediateContract] for a Solidity source string and inserts it into the
787    /// passed `res_map`. In addition, recurses on any imported files as well.
788    ///
789    /// ### Takes
790    /// - `content` - A Solidity source string
791    /// - `res_map` - A mutable reference to a map of contract names to [IntermediateContract]s
792    pub fn get_intermediate_contract(
793        content: &str,
794        res_map: &mut HashMap<String, IntermediateContract>,
795    ) {
796        if let Ok((pt::SourceUnit(source_unit_parts), _)) = solang_parser::parse(content, 0) {
797            let func_defs = source_unit_parts
798                .into_iter()
799                .filter_map(|sup| match sup {
800                    pt::SourceUnitPart::ImportDirective(i) => match i {
801                        pt::Import::Plain(s, _)
802                        | pt::Import::Rename(s, _, _)
803                        | pt::Import::GlobalSymbol(s, _, _) => {
804                            let s = match s {
805                                pt::ImportPath::Filename(s) => s.string,
806                                pt::ImportPath::Path(p) => p.to_string(),
807                            };
808                            let path = PathBuf::from(s);
809
810                            match fs::read_to_string(path) {
811                                Ok(source) => {
812                                    Self::get_intermediate_contract(&source, res_map);
813                                    None
814                                }
815                                Err(_) => None,
816                            }
817                        }
818                    },
819                    pt::SourceUnitPart::ContractDefinition(cd) => {
820                        let mut intermediate = IntermediateContract::default();
821
822                        cd.parts.into_iter().for_each(|part| match part {
823                            pt::ContractPart::FunctionDefinition(def) => {
824                                // Only match normal function definitions here.
825                                if matches!(def.ty, pt::FunctionTy::Function) {
826                                    intermediate
827                                        .function_definitions
828                                        .insert(def.name.clone().unwrap().name, def);
829                                }
830                            }
831                            pt::ContractPart::EventDefinition(def) => {
832                                let event_name = def.name.safe_unwrap().name.clone();
833                                intermediate.event_definitions.insert(event_name, def);
834                            }
835                            pt::ContractPart::StructDefinition(def) => {
836                                let struct_name = def.name.safe_unwrap().name.clone();
837                                intermediate.struct_definitions.insert(struct_name, def);
838                            }
839                            pt::ContractPart::VariableDefinition(def) => {
840                                let var_name = def.name.safe_unwrap().name.clone();
841                                intermediate.variable_definitions.insert(var_name, def);
842                            }
843                            _ => {}
844                        });
845                        Some((cd.name.safe_unwrap().name.clone(), intermediate))
846                    }
847                    _ => None,
848                })
849                .collect::<HashMap<String, IntermediateContract>>();
850            res_map.extend(func_defs);
851        }
852    }
853
854    /// Helper to deconstruct a statement
855    ///
856    /// ### Takes
857    ///
858    /// A reference to a [pt::Statement]
859    ///
860    /// ### Returns
861    ///
862    /// A vector containing tuples of the inner expressions' names, types, and storage locations.
863    pub fn get_statement_definitions(statement: &pt::Statement) -> Vec<(String, pt::Expression)> {
864        match statement {
865            pt::Statement::VariableDefinition(_, def, _) => {
866                vec![(def.name.safe_unwrap().name.clone(), def.ty.clone())]
867            }
868            pt::Statement::Expression(_, pt::Expression::Assign(_, left, _)) => {
869                if let pt::Expression::List(_, list) = left.as_ref() {
870                    list.iter()
871                        .filter_map(|(_, param)| {
872                            param.as_ref().and_then(|param| {
873                                param
874                                    .name
875                                    .as_ref()
876                                    .map(|name| (name.name.clone(), param.ty.clone()))
877                            })
878                        })
879                        .collect()
880                } else {
881                    Vec::default()
882                }
883            }
884            _ => Vec::default(),
885        }
886    }
887}
888
889/// A Parse Tree Fragment
890///
891/// Used to determine whether an input will go to the "run()" function,
892/// the top level of the contract, or in global scope.
893#[derive(Debug)]
894enum ParseTreeFragment {
895    /// Code for the global scope
896    Source,
897    /// Code for the top level of the contract
898    Contract,
899    /// Code for the "run()" function
900    Function,
901}
902
903#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
904struct Instruction {
905    pub pc: usize,
906    pub opcode: u8,
907    pub data: [u8; 32],
908    pub data_len: u8,
909}
910
911struct InstructionIter<'a> {
912    bytes: &'a [u8],
913    offset: usize,
914}
915
916impl<'a> InstructionIter<'a> {
917    pub fn new(bytes: &'a [u8]) -> Self {
918        Self { bytes, offset: 0 }
919    }
920}
921
922impl Iterator for InstructionIter<'_> {
923    type Item = Instruction;
924
925    fn next(&mut self) -> Option<Self::Item> {
926        let pc = self.offset;
927        self.offset += 1;
928        let opcode = *self.bytes.get(pc)?;
929        let (data, data_len) = if matches!(opcode, 0x60..=0x7F) {
930            let mut data = [0; 32];
931            let data_len = (opcode - 0x60 + 1) as usize;
932            data[..data_len].copy_from_slice(&self.bytes[self.offset..self.offset + data_len]);
933            self.offset += data_len;
934            (data, data_len as u8)
935        } else {
936            ([0; 32], 0)
937        };
938        Some(Instruction { pc, opcode, data, data_len })
939    }
940}