1use eyre::Result;
8use forge_doc::solang_ext::{CodeLocationExt, SafeUnwrap};
9use foundry_common::fs;
10use foundry_compilers::{
11 Artifact, ProjectCompileOutput,
12 artifacts::{ConfigurableContractArtifact, Source, Sources},
13 project::ProjectCompiler,
14 solc::Solc,
15};
16use foundry_config::{Config, SolcReq};
17use foundry_evm::{backend::Backend, core::bytecode::InstIter, opts::EvmOpts};
18use semver::Version;
19use serde::{Deserialize, Serialize};
20use solang_parser::pt;
21use solar::interface::diagnostics::EmittedDiagnostics;
22use std::{cell::OnceCell, collections::HashMap, fmt, path::PathBuf};
23use walkdir::WalkDir;
24
25pub const MIN_VM_VERSION: Version = Version::new(0, 6, 2);
27
28static VM_SOURCE: &str = include_str!("../../../testdata/cheats/Vm.sol");
30
31pub struct GeneratedOutput {
33 output: ProjectCompileOutput,
34 pub(crate) intermediate: IntermediateOutput,
35}
36
37pub struct GeneratedOutputRef<'a> {
38 output: &'a ProjectCompileOutput,
39 pub(crate) intermediate: &'a IntermediateOutput,
41}
42
43#[derive(Clone, Debug, PartialEq, Eq)]
45pub struct IntermediateOutput {
46 pub repl_contract_expressions: HashMap<String, pt::Expression>,
48 pub intermediate_contracts: IntermediateContracts,
50}
51
52#[derive(Clone, Debug, Default, PartialEq, Eq)]
55pub struct IntermediateContract {
56 pub function_definitions: HashMap<String, Box<pt::FunctionDefinition>>,
58 pub event_definitions: HashMap<String, Box<pt::EventDefinition>>,
60 pub struct_definitions: HashMap<String, Box<pt::StructDefinition>>,
62 pub variable_definitions: HashMap<String, Box<pt::VariableDefinition>>,
64}
65
66type IntermediateContracts = HashMap<String, IntermediateContract>;
68
69impl fmt::Debug for GeneratedOutput {
70 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71 f.debug_struct("GeneratedOutput").finish_non_exhaustive()
72 }
73}
74
75impl GeneratedOutput {
76 pub fn enter<T: Send>(&self, f: impl FnOnce(GeneratedOutputRef<'_>) -> T + Send) -> T {
77 f(GeneratedOutputRef { output: &self.output, intermediate: &self.intermediate })
84 }
85}
86
87impl GeneratedOutputRef<'_> {
88 pub fn repl_contract(&self) -> Option<&ConfigurableContractArtifact> {
89 self.output.find_first("REPL")
90 }
91}
92
93impl std::ops::Deref for GeneratedOutput {
94 type Target = IntermediateOutput;
95 fn deref(&self) -> &Self::Target {
96 &self.intermediate
97 }
98}
99impl std::ops::Deref for GeneratedOutputRef<'_> {
100 type Target = IntermediateOutput;
101 fn deref(&self) -> &Self::Target {
102 self.intermediate
103 }
104}
105
106impl IntermediateOutput {
107 pub fn get_event(&self, input: &str) -> Option<&pt::EventDefinition> {
108 self.intermediate_contracts
109 .get("REPL")
110 .and_then(|contract| contract.event_definitions.get(input).map(std::ops::Deref::deref))
111 }
112
113 pub fn final_pc(&self, contract: &ConfigurableContractArtifact) -> Result<Option<usize>> {
114 let deployed_bytecode = contract
115 .get_deployed_bytecode()
116 .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
117 let deployed_bytecode_bytes = deployed_bytecode
118 .bytes()
119 .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
120
121 let run_func_statements = self.run_func_body()?;
122
123 let last_yul_return = run_func_statements.iter().find_map(|statement| {
127 if let pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } = statement
128 && let Some(statement) = block.statements.last()
129 && let pt::YulStatement::FunctionCall(yul_call) = statement
130 && yul_call.id.name == "return"
131 {
132 return Some(statement.loc());
133 }
134 None
135 });
136
137 let Some(final_statement) = run_func_statements.last() else { return Ok(None) };
140
141 let mut source_loc = match final_statement {
148 pt::Statement::Assembly { loc: _, dialect: _, flags: _, block } => {
149 let last_statement = block.statements.iter().rev().find(|statement| {
151 !matches!(statement, pt::YulStatement::VariableDeclaration(_, _, _))
152 });
153 if let Some(statement) = last_statement {
154 statement.loc()
155 } else {
156 run_func_statements
160 .get(run_func_statements.len().saturating_sub(2))
161 .unwrap()
162 .loc()
163 }
164 }
165 pt::Statement::Block { loc: _, unchecked: _, statements } => {
166 if let Some(statement) = statements.last() {
167 statement.loc()
168 } else {
169 run_func_statements
173 .get(run_func_statements.len().saturating_sub(2))
174 .unwrap()
175 .loc()
176 }
177 }
178 _ => final_statement.loc(),
179 };
180
181 if let Some(yul_return) = last_yul_return
183 && yul_return.end() < source_loc.start()
184 {
185 source_loc = yul_return;
186 }
187
188 let final_pc = {
191 let offset = source_loc.start() as u32;
192 let length = (source_loc.end() - source_loc.start()) as u32;
193 trace!(%offset, %length, "find pc");
194 contract
195 .get_source_map_deployed()
196 .unwrap()
197 .unwrap()
198 .into_iter()
199 .zip(InstIter::new(deployed_bytecode_bytes).with_pc().map(|(pc, _)| pc))
200 .filter(|(s, _)| s.offset() == offset && s.length() == length)
201 .map(|(_, pc)| pc)
202 .max()
203 };
204 trace!(?final_pc);
205 Ok(final_pc)
206 }
207
208 pub fn run_func_body(&self) -> Result<&Vec<pt::Statement>> {
209 match self
210 .intermediate_contracts
211 .get("REPL")
212 .ok_or_else(|| eyre::eyre!("Could not find REPL intermediate contract!"))?
213 .function_definitions
214 .get("run")
215 .ok_or_else(|| eyre::eyre!("Could not find run function definition in REPL contract!"))?
216 .body
217 .as_ref()
218 .ok_or_else(|| eyre::eyre!("Could not find run function body!"))?
219 {
220 pt::Statement::Block { statements, .. } => Ok(statements),
221 _ => eyre::bail!("Could not find statements within run function body!"),
222 }
223 }
224}
225
226#[cfg(false)]
228impl<'gcx> GeneratedOutputRef<'_, '_, 'gcx> {
229 pub fn gcx(&self) -> Gcx<'gcx> {
230 self.compiler.gcx()
231 }
232
233 pub fn repl_contract(&self) -> Option<&ConfigurableContractArtifact> {
234 self.output.find_first("REPL")
235 }
236
237 pub fn get_event(&self, input: &str) -> Option<hir::EventId> {
238 self.gcx().hir.events_enumerated().find(|(_, e)| e.name.as_str() == input).map(|(id, _)| id)
239 }
240
241 pub fn final_pc(&self, contract: &ConfigurableContractArtifact) -> Result<Option<usize>> {
242 let deployed_bytecode = contract
243 .get_deployed_bytecode()
244 .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
245 let deployed_bytecode_bytes = deployed_bytecode
246 .bytes()
247 .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
248
249 let run_body = self.run_func_body();
251
252 let last_yul_return_span: Option<Span> = run_body.iter().find_map(|stmt| {
256 let _ = stmt;
258 None
270 });
271
272 let Some(last_stmt) = run_body.last() else { return Ok(None) };
275
276 let source_stmt = match &last_stmt.kind {
283 hir::StmtKind::UncheckedBlock(stmts) | hir::StmtKind::Block(stmts) => {
301 if let Some(stmt) = stmts.last() {
302 stmt
303 } else {
304 &run_body[run_body.len().saturating_sub(2)]
308 }
309 }
310 _ => last_stmt,
311 };
312 let mut source_span = self.stmt_span_without_semicolon(source_stmt);
313
314 if let Some(yul_return_span) = last_yul_return_span
316 && yul_return_span.hi() < source_span.lo()
317 {
318 source_span = yul_return_span;
319 }
320
321 let (_sf, range) = self.compiler.sess().source_map().span_to_source(source_span).unwrap();
324 dbg!(source_span, &range, &_sf.src[range.clone()]);
325 let offset = range.start as u32;
326 let length = range.len() as u32;
327 let final_pc = deployed_bytecode
328 .source_map()
329 .ok_or_else(|| eyre::eyre!("No source map found for `REPL` contract"))??
330 .into_iter()
331 .zip(InstructionIter::new(deployed_bytecode_bytes))
332 .filter(|(s, _)| s.offset() == offset && s.length() == length)
333 .map(|(_, i)| i.pc)
334 .max()
335 .unwrap_or_default();
336 Ok(Some(final_pc))
337 }
338
339 fn stmt_span_without_semicolon(&self, stmt: &hir::Stmt<'_>) -> Span {
341 match stmt.kind {
342 hir::StmtKind::DeclSingle(id) => {
343 let decl = self.gcx().hir.variable(id);
344 if let Some(expr) = decl.initializer {
345 stmt.span.with_hi(expr.span.hi())
346 } else {
347 stmt.span
348 }
349 }
350 hir::StmtKind::DeclMulti(_, expr) => stmt.span.with_hi(expr.span.hi()),
351 hir::StmtKind::Expr(expr) => expr.span,
352 _ => stmt.span,
353 }
354 }
355
356 fn run_func_body(&self) -> hir::Block<'_> {
357 let c = self.repl_contract_hir().expect("REPL contract not found in HIR");
358 let f = c
359 .functions()
360 .find(|&f| self.gcx().hir.function(f).name.as_ref().map(|n| n.as_str()) == Some("run"))
361 .expect("`run()` function not found in REPL contract");
362 self.gcx().hir.function(f).body.expect("`run()` function does not have a body")
363 }
364
365 fn repl_contract_hir(&self) -> Option<&hir::Contract<'_>> {
366 self.gcx().hir.contracts().find(|c| c.name.as_str() == "REPL")
367 }
368}
369
370#[derive(Clone, Debug, Default, Serialize, Deserialize)]
372pub struct SessionSourceConfig {
373 pub foundry_config: Config,
375 pub evm_opts: EvmOpts,
377 pub no_vm: bool,
379 #[serde(skip)]
381 pub backend: Option<Backend>,
382 pub traces: bool,
384 pub calldata: Option<Vec<u8>>,
386 pub ir_minimum: bool,
391}
392
393impl SessionSourceConfig {
394 pub fn detect_solc(&mut self) -> Result<()> {
396 if self.foundry_config.solc.is_none() {
397 let version = Solc::ensure_installed(&"*".parse().unwrap())?;
398 self.foundry_config.solc = Some(SolcReq::Version(version));
399 }
400 if !self.no_vm
401 && let Some(version) = self.foundry_config.solc_version()
402 && version < MIN_VM_VERSION
403 {
404 info!(%version, minimum=%MIN_VM_VERSION, "Disabling VM injection");
405 self.no_vm = true;
406 }
407 Ok(())
408 }
409}
410
411#[derive(Debug, Serialize, Deserialize)]
415pub struct SessionSource {
416 pub file_name: String,
418 pub contract_name: String,
420
421 pub config: SessionSourceConfig,
423
424 pub global_code: String,
428 pub contract_code: String,
432 pub run_code: String,
434
435 #[serde(skip, default = "vm_source")]
437 vm_source: Source,
438 #[serde(skip)]
440 output: OnceCell<GeneratedOutput>,
441}
442
443fn vm_source() -> Source {
444 Source::new(VM_SOURCE)
445}
446
447impl Clone for SessionSource {
448 fn clone(&self) -> Self {
449 Self {
450 file_name: self.file_name.clone(),
451 contract_name: self.contract_name.clone(),
452 global_code: self.global_code.clone(),
453 contract_code: self.contract_code.clone(),
454 run_code: self.run_code.clone(),
455 config: self.config.clone(),
456 vm_source: self.vm_source.clone(),
457 output: Default::default(),
458 }
459 }
460}
461
462impl SessionSource {
463 pub fn new(mut config: SessionSourceConfig) -> Result<Self> {
478 config.detect_solc()?;
479 Ok(Self {
480 file_name: "ReplContract.sol".to_string(),
481 contract_name: "REPL".to_string(),
482 config,
483 global_code: Default::default(),
484 contract_code: Default::default(),
485 run_code: Default::default(),
486 vm_source: vm_source(),
487 output: Default::default(),
488 })
489 }
490
491 pub fn clone_with_new_line(&self, mut content: String) -> Result<(Self, bool)> {
495 if let Some((new_source, fragment)) = self
496 .parse_fragment(&content)
497 .or_else(|| {
498 content.push(';');
499 self.parse_fragment(&content)
500 })
501 .or_else(|| {
502 content = content.trim_end().trim_end_matches(';').to_string();
503 self.parse_fragment(&content)
504 })
505 {
506 Ok((new_source, matches!(fragment, ParseTreeFragment::Function)))
507 } else {
508 eyre::bail!("\"{}\"", content.trim());
509 }
510 }
511
512 fn parse_fragment(&self, buffer: &str) -> Option<(Self, ParseTreeFragment)> {
515 #[track_caller]
516 fn debug_errors(errors: &EmittedDiagnostics) {
517 debug!("{errors}");
518 }
519
520 let mut this = self.clone();
521 match this.add_run_code(buffer).parse() {
522 Ok(()) => return Some((this, ParseTreeFragment::Function)),
523 Err(e) => debug_errors(&e),
524 }
525 this = self.clone();
526 match this.add_contract_code(buffer).parse() {
527 Ok(()) => return Some((this, ParseTreeFragment::Contract)),
528 Err(e) => debug_errors(&e),
529 }
530 this = self.clone();
531 match this.add_global_code(buffer).parse() {
532 Ok(()) => return Some((this, ParseTreeFragment::Source)),
533 Err(e) => debug_errors(&e),
534 }
535 None
536 }
537
538 pub fn add_global_code(&mut self, content: &str) -> &mut Self {
540 self.global_code.push_str(content.trim());
541 self.global_code.push('\n');
542 self.clear_output();
543 self
544 }
545
546 pub fn add_contract_code(&mut self, content: &str) -> &mut Self {
548 self.contract_code.push_str(content.trim());
549 self.contract_code.push('\n');
550 self.clear_output();
551 self
552 }
553
554 pub fn add_run_code(&mut self, content: &str) -> &mut Self {
556 self.run_code.push_str(content.trim());
557 self.run_code.push('\n');
558 self.clear_output();
559 self
560 }
561
562 pub fn clear(&mut self) {
564 String::clear(&mut self.global_code);
565 String::clear(&mut self.contract_code);
566 String::clear(&mut self.run_code);
567 self.clear_output();
568 }
569
570 pub fn clear_global(&mut self) -> &mut Self {
572 String::clear(&mut self.global_code);
573 self.clear_output();
574 self
575 }
576
577 pub fn clear_contract(&mut self) -> &mut Self {
579 String::clear(&mut self.contract_code);
580 self.clear_output();
581 self
582 }
583
584 pub fn clear_run(&mut self) -> &mut Self {
586 String::clear(&mut self.run_code);
587 self.clear_output();
588 self
589 }
590
591 fn clear_output(&mut self) {
592 self.output.take();
593 }
594
595 pub fn build(&self) -> Result<&GeneratedOutput> {
597 if let Some(output) = self.output.get() {
599 return Ok(output);
600 }
601 let output = self.compile()?;
602 let intermediate = self.generate_intermediate_output()?;
603 let output = GeneratedOutput { output, intermediate };
604 Ok(self.output.get_or_init(|| output))
605 }
606
607 #[cold]
609 fn compile(&self) -> Result<ProjectCompileOutput> {
610 let sources = self.get_sources();
611
612 let mut project = self.config.foundry_config.ephemeral_project()?;
613 self.config.foundry_config.disable_optimizations(&mut project, self.config.ir_minimum);
614 let mut output = ProjectCompiler::with_sources(&project, sources)?.compile()?;
615
616 if output.has_compiler_errors() {
617 eyre::bail!("{output}");
618 }
619
620 if cfg!(false) {
622 output.parser_mut().solc_mut().compiler_mut().enter_mut(|c| {
623 let _ = c.lower_asts();
624 });
625 }
626
627 Ok(output)
628 }
629
630 fn get_sources(&self) -> Sources {
631 let mut sources = Sources::new();
632
633 let src = self.to_repl_source();
634 sources.insert(self.file_name.clone().into(), Source::new(src));
635
636 if !self.config.no_vm
638 && !self
639 .config
640 .foundry_config
641 .get_all_remappings()
642 .any(|r| r.name.starts_with("forge-std"))
643 {
644 sources.insert("forge-std/Vm.sol".into(), self.vm_source.clone());
645 }
646
647 sources
648 }
649
650 pub fn generate_intermediate_contracts(&self) -> Result<HashMap<String, IntermediateContract>> {
656 let mut res_map = HashMap::default();
657 let parsed_map = self.get_sources();
658 for source in parsed_map.values() {
659 Self::get_intermediate_contract(&source.content, &mut res_map);
660 }
661 Ok(res_map)
662 }
663
664 pub fn generate_intermediate_output(&self) -> Result<IntermediateOutput> {
666 let intermediate_contracts = self.generate_intermediate_contracts()?;
668
669 let variable_definitions = intermediate_contracts
671 .get("REPL")
672 .ok_or_else(|| eyre::eyre!("Could not find intermediate REPL contract!"))?
673 .variable_definitions
674 .clone()
675 .into_iter()
676 .map(|(k, v)| (k, v.ty))
677 .collect::<HashMap<String, pt::Expression>>();
678 let mut intermediate_output = IntermediateOutput {
680 repl_contract_expressions: variable_definitions,
681 intermediate_contracts,
682 };
683
684 for (key, val) in intermediate_output
686 .run_func_body()?
687 .clone()
688 .iter()
689 .flat_map(Self::get_statement_definitions)
690 {
691 intermediate_output.repl_contract_expressions.insert(key, val);
692 }
693
694 Ok(intermediate_output)
695 }
696
697 pub fn to_script_source(&self) -> String {
699 let Self {
700 contract_name,
701 global_code,
702 contract_code: top_level_code,
703 run_code,
704 config,
705 ..
706 } = self;
707
708 let script_import =
709 if !config.no_vm { "import {Script} from \"forge-std/Script.sol\";\n" } else { "" };
710
711 format!(
712 r#"
713// SPDX-License-Identifier: UNLICENSED
714pragma solidity 0;
715
716{script_import}
717{global_code}
718
719contract {contract_name} is Script {{
720 {top_level_code}
721
722 /// @notice Script entry point
723 function run() public {{
724 {run_code}
725 }}
726}}"#,
727 )
728 }
729
730 pub fn to_repl_source(&self) -> String {
732 let Self {
733 contract_name,
734 global_code,
735 contract_code: top_level_code,
736 run_code,
737 config,
738 ..
739 } = self;
740 let (mut vm_import, mut vm_constant) = (String::new(), String::new());
741 if !config.no_vm
744 && let Some(remapping) = config
745 .foundry_config
746 .remappings
747 .iter()
748 .find(|remapping| remapping.name == "forge-std/")
749 && let Some(vm_path) = WalkDir::new(&remapping.path.path)
750 .into_iter()
751 .filter_map(|e| e.ok())
752 .find(|e| e.file_name() == "Vm.sol")
753 {
754 vm_import = format!("import {{Vm}} from \"{}\";\n", vm_path.path().display());
755 vm_constant = "Vm internal constant vm = Vm(address(uint160(uint256(keccak256(\"hevm cheat code\")))));\n".to_string();
756 }
757
758 format!(
759 r#"
760// SPDX-License-Identifier: UNLICENSED
761pragma solidity 0;
762
763{vm_import}
764{global_code}
765
766contract {contract_name} {{
767 {vm_constant}
768 {top_level_code}
769
770 /// @notice REPL contract entry point
771 function run() public {{
772 {run_code}
773 }}
774}}"#,
775 )
776 }
777
778 pub(crate) fn parse(&self) -> Result<(), EmittedDiagnostics> {
780 let sess =
781 solar::interface::Session::builder().with_buffer_emitter(Default::default()).build();
782 let _ = sess.enter_sequential(|| -> solar::interface::Result<()> {
783 let arena = solar::ast::Arena::new();
784 let filename = self.file_name.clone().into();
785 let src = self.to_repl_source();
786 let mut parser = solar::parse::Parser::from_source_code(&sess, &arena, filename, src)?;
787 let _ast = parser.parse_file().map_err(|e| e.emit())?;
788 Ok(())
789 });
790 sess.dcx.emitted_errors().unwrap()
791 }
792
793 pub fn get_intermediate_contract(
800 content: &str,
801 res_map: &mut HashMap<String, IntermediateContract>,
802 ) {
803 if let Ok((pt::SourceUnit(source_unit_parts), _)) = solang_parser::parse(content, 0) {
804 let func_defs = source_unit_parts
805 .into_iter()
806 .filter_map(|sup| match sup {
807 pt::SourceUnitPart::ImportDirective(i) => match i {
808 pt::Import::Plain(s, _)
809 | pt::Import::Rename(s, _, _)
810 | pt::Import::GlobalSymbol(s, _, _) => {
811 let s = match s {
812 pt::ImportPath::Filename(s) => s.string,
813 pt::ImportPath::Path(p) => p.to_string(),
814 };
815 let path = PathBuf::from(s);
816
817 match fs::read_to_string(path) {
818 Ok(source) => {
819 Self::get_intermediate_contract(&source, res_map);
820 None
821 }
822 Err(_) => None,
823 }
824 }
825 },
826 pt::SourceUnitPart::ContractDefinition(cd) => {
827 let mut intermediate = IntermediateContract::default();
828
829 cd.parts.into_iter().for_each(|part| match part {
830 pt::ContractPart::FunctionDefinition(def) => {
831 if matches!(def.ty, pt::FunctionTy::Function) {
833 intermediate
834 .function_definitions
835 .insert(def.name.clone().unwrap().name, def);
836 }
837 }
838 pt::ContractPart::EventDefinition(def) => {
839 let event_name = def.name.safe_unwrap().name.clone();
840 intermediate.event_definitions.insert(event_name, def);
841 }
842 pt::ContractPart::StructDefinition(def) => {
843 let struct_name = def.name.safe_unwrap().name.clone();
844 intermediate.struct_definitions.insert(struct_name, def);
845 }
846 pt::ContractPart::VariableDefinition(def) => {
847 let var_name = def.name.safe_unwrap().name.clone();
848 intermediate.variable_definitions.insert(var_name, def);
849 }
850 _ => {}
851 });
852 Some((cd.name.safe_unwrap().name.clone(), intermediate))
853 }
854 _ => None,
855 })
856 .collect::<HashMap<String, IntermediateContract>>();
857 res_map.extend(func_defs);
858 }
859 }
860
861 pub fn get_statement_definitions(statement: &pt::Statement) -> Vec<(String, pt::Expression)> {
871 match statement {
872 pt::Statement::VariableDefinition(_, def, _) => {
873 vec![(def.name.safe_unwrap().name.clone(), def.ty.clone())]
874 }
875 pt::Statement::Expression(_, pt::Expression::Assign(_, left, _)) => {
876 if let pt::Expression::List(_, list) = left.as_ref() {
877 list.iter()
878 .filter_map(|(_, param)| {
879 param.as_ref().and_then(|param| {
880 param
881 .name
882 .as_ref()
883 .map(|name| (name.name.clone(), param.ty.clone()))
884 })
885 })
886 .collect()
887 } else {
888 Vec::default()
889 }
890 }
891 _ => Vec::default(),
892 }
893 }
894}
895
896#[derive(Debug)]
901enum ParseTreeFragment {
902 Source,
904 Contract,
906 Function,
908}