1use eyre::Result;
8use foundry_compilers::{
9 Artifact, ProjectCompileOutput,
10 artifacts::{ConfigurableContractArtifact, Source, Sources},
11 project::ProjectCompiler,
12 solc::Solc,
13};
14use foundry_config::{Config, SolcReq};
15use foundry_evm::{backend::Backend, core::bytecode::InstIter, opts::EvmOpts};
16use semver::Version;
17use serde::{Deserialize, Serialize};
18use solar::{
19 ast::{ItemKind, StmtKind as AstStmtKind, yul},
20 interface::{Span, diagnostics::EmittedDiagnostics},
21 sema::{
22 CompilerRef,
23 hir::{Block, Contract, EventId, ItemId, Stmt, StmtKind as HirStmtKind},
24 ty::Gcx,
25 },
26};
27use std::{cell::OnceCell, fmt};
28use walkdir::WalkDir;
29
30pub const MIN_VM_VERSION: Version = Version::new(0, 6, 2);
32
33static VM_SOURCE: &str = include_str!("../../../testdata/utils/Vm.sol");
35
36pub struct GeneratedOutput {
38 output: ProjectCompileOutput,
39}
40
41impl fmt::Debug for GeneratedOutput {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 f.debug_struct("GeneratedOutput").finish_non_exhaustive()
44 }
45}
46
47impl GeneratedOutput {
48 pub fn enter<R: Send>(
50 &self,
51 f: impl for<'a, 'b, 'gcx> FnOnce(GeneratedOutputRef<'a, 'b, 'gcx>) -> R + Send,
52 ) -> R {
53 self.output
54 .parser()
55 .solc()
56 .compiler()
57 .enter(|c| f(GeneratedOutputRef { output: &self.output, compiler: c }))
58 }
59}
60
61pub struct GeneratedOutputRef<'a, 'b, 'gcx> {
63 output: &'a ProjectCompileOutput,
64 pub(crate) compiler: &'b CompilerRef<'gcx>,
65}
66
67impl<'gcx> GeneratedOutputRef<'_, '_, 'gcx> {
68 pub fn gcx(&self) -> Gcx<'gcx> {
69 self.compiler.gcx()
70 }
71
72 pub fn repl_contract(&self) -> Option<&ConfigurableContractArtifact> {
73 self.output.find_first("REPL")
74 }
75
76 pub fn repl_contract_hir(&self) -> Option<&'gcx Contract<'gcx>> {
78 self.gcx().hir.contracts().find(|c| c.name.as_str() == "REPL")
79 }
80
81 pub fn run_func_body(&self) -> Block<'gcx> {
83 let hir = &self.gcx().hir;
84 let c = self.repl_contract_hir().expect("REPL contract not found in HIR");
85 let f = c
86 .functions()
87 .find(|&f| hir.function(f).name.as_ref().map(|n| n.as_str()) == Some("run"))
88 .expect("`run()` function not found in REPL contract");
89 hir.function(f).body.expect("`run()` function does not have a body")
90 }
91
92 pub fn get_event(&self, input: &str) -> Option<EventId> {
94 let hir = &self.gcx().hir;
95 let c = self.repl_contract_hir()?;
96 c.items.iter().find_map(|id| {
97 if let ItemId::Event(eid) = id
98 && hir.event(*eid).name.as_str() == input
99 {
100 Some(*eid)
101 } else {
102 None
103 }
104 })
105 }
106
107 pub fn final_pc(&self, contract: &ConfigurableContractArtifact) -> Result<Option<usize>> {
108 let deployed_bytecode = contract
109 .get_deployed_bytecode()
110 .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
111 let deployed_bytecode_bytes = deployed_bytecode
112 .bytes()
113 .ok_or_else(|| eyre::eyre!("No deployed bytecode found for `REPL` contract"))?;
114
115 let run_body = self.run_func_body();
117
118 let last_yul_return_span: Option<Span> = self.first_yul_return_span();
126
127 let Some(last_stmt) = run_body.last() else { return Ok(None) };
130
131 let source_stmt = match &last_stmt.kind {
139 HirStmtKind::UncheckedBlock(stmts) | HirStmtKind::Block(stmts) => {
140 if let Some(stmt) = stmts.last() {
141 stmt
142 } else {
143 &run_body[run_body.len().saturating_sub(2)]
147 }
148 }
149 _ => last_stmt,
150 };
151 let mut source_span = if matches!(last_stmt.kind, HirStmtKind::Err(_))
162 && let Some(span) = self.trailing_assembly_last_stmt_span()
163 {
164 span
165 } else {
166 self.stmt_span_without_semicolon(source_stmt)
167 };
168
169 if let Some(yul_return_span) = last_yul_return_span
171 && yul_return_span.hi() < source_span.lo()
172 {
173 source_span = yul_return_span;
174 }
175
176 let result = self
179 .compiler
180 .sess()
181 .source_map()
182 .span_to_source(source_span)
183 .map_err(|e| eyre::eyre!("failed to resolve span: {e:?}"))?;
184 let range = result.data;
185 let offset = range.start as u32;
186 let length = range.len() as u32;
187 trace!(%offset, %length, "find pc");
188 let final_pc = contract
189 .get_source_map_deployed()
190 .ok_or_else(|| eyre::eyre!("No source map found for `REPL` contract"))??
191 .into_iter()
192 .zip(InstIter::new(deployed_bytecode_bytes).with_pc().map(|(pc, _)| pc))
193 .filter(|(s, _)| s.offset() == offset && s.length() == length)
194 .map(|(_, pc)| pc)
195 .max();
196 trace!(?final_pc);
197 Ok(final_pc)
198 }
199
200 fn stmt_span_without_semicolon(&self, stmt: &Stmt<'_>) -> Span {
202 match stmt.kind {
203 HirStmtKind::DeclSingle(id) => {
204 let decl = self.gcx().hir.variable(id);
205 if let Some(expr) = decl.initializer {
206 stmt.span.with_hi(expr.span.hi())
207 } else {
208 stmt.span
209 }
210 }
211 HirStmtKind::DeclMulti(_, expr) => stmt.span.with_hi(expr.span.hi()),
212 HirStmtKind::Expr(expr) => expr.span,
213 _ => stmt.span,
214 }
215 }
216
217 fn repl_run_ast_body(&self) -> Option<&'gcx solar::ast::Block<'gcx>> {
222 let contract = self.repl_contract_hir()?;
223 let source = self.gcx().sources.get(contract.source)?;
224 let ast = source.ast.as_ref()?;
225
226 let contract_ast = ast.items.iter().find_map(|i| match &i.kind {
227 ItemKind::Contract(c) if c.name.as_str() == "REPL" => Some(c),
228 _ => None,
229 })?;
230 contract_ast.body.iter().find_map(|i| match &i.kind {
231 ItemKind::Function(f) if f.header.name.is_some_and(|n| n.as_str() == "run") => {
232 f.body.as_ref()
233 }
234 _ => None,
235 })
236 }
237
238 fn first_yul_return_span(&self) -> Option<Span> {
241 let run_body = self.repl_run_ast_body()?;
242 for stmt in run_body.stmts.iter() {
243 let AstStmtKind::Assembly(asm) = &stmt.kind else { continue };
244 for ystmt in asm.block.stmts.iter() {
245 if let yul::StmtKind::Expr(e) = &ystmt.kind
246 && let yul::ExprKind::Call(call) = &e.kind
247 && call.name.as_str() == "return"
248 {
249 return Some(ystmt.span);
250 }
251 }
252 }
253 None
254 }
255
256 fn trailing_assembly_last_stmt_span(&self) -> Option<Span> {
262 let run_body = self.repl_run_ast_body()?;
263 let AstStmtKind::Assembly(asm) = &run_body.stmts.last()?.kind else { return None };
264 asm.block
265 .stmts
266 .iter()
267 .rev()
268 .find(|s| !matches!(s.kind, yul::StmtKind::VarDecl(_, _)))
269 .map(|s| s.span)
270 }
271}
272
273#[derive(Clone, Debug, Default, Serialize, Deserialize)]
275pub struct SessionSourceConfig {
276 pub foundry_config: Config,
278 pub evm_opts: EvmOpts,
280 pub no_vm: bool,
282 #[serde(skip)]
284 pub backend: Option<Backend>,
285 pub traces: bool,
287 pub calldata: Option<Vec<u8>>,
289 pub ir_minimum: bool,
294}
295
296impl SessionSourceConfig {
297 pub fn detect_solc(&mut self) -> Result<()> {
299 if self.foundry_config.solc.is_none() {
300 let version = Solc::ensure_installed(&"*".parse().unwrap())?;
301 self.foundry_config.solc = Some(SolcReq::Version(version));
302 }
303 if !self.no_vm
304 && let Some(version) = self.foundry_config.solc_version()
305 && version < MIN_VM_VERSION
306 {
307 info!(%version, minimum=%MIN_VM_VERSION, "Disabling VM injection");
308 self.no_vm = true;
309 }
310 Ok(())
311 }
312}
313
314#[derive(Debug, Serialize, Deserialize)]
318pub struct SessionSource {
319 pub file_name: String,
321 pub contract_name: String,
323
324 pub config: SessionSourceConfig,
326
327 pub global_code: String,
331 pub contract_code: String,
335 pub run_code: String,
337
338 #[serde(skip, default = "vm_source")]
340 vm_source: Source,
341 #[serde(skip)]
343 output: OnceCell<GeneratedOutput>,
344}
345
346fn vm_source() -> Source {
347 Source::new(VM_SOURCE)
348}
349
350impl Clone for SessionSource {
351 fn clone(&self) -> Self {
352 Self {
353 file_name: self.file_name.clone(),
354 contract_name: self.contract_name.clone(),
355 global_code: self.global_code.clone(),
356 contract_code: self.contract_code.clone(),
357 run_code: self.run_code.clone(),
358 config: self.config.clone(),
359 vm_source: self.vm_source.clone(),
360 output: Default::default(),
361 }
362 }
363}
364
365impl SessionSource {
366 pub fn new(mut config: SessionSourceConfig) -> Result<Self> {
381 config.detect_solc()?;
382 Ok(Self {
383 file_name: "ReplContract.sol".to_string(),
384 contract_name: "REPL".to_string(),
385 config,
386 global_code: Default::default(),
387 contract_code: Default::default(),
388 run_code: Default::default(),
389 vm_source: vm_source(),
390 output: Default::default(),
391 })
392 }
393
394 pub fn clone_with_new_line(&self, mut content: String) -> Result<(Self, bool)> {
398 if let Some((new_source, fragment)) = self
399 .parse_fragment(&content)
400 .or_else(|| {
401 content.push(';');
402 self.parse_fragment(&content)
403 })
404 .or_else(|| {
405 content = content.trim_end().trim_end_matches(';').to_string();
406 self.parse_fragment(&content)
407 })
408 {
409 Ok((new_source, matches!(fragment, ParseTreeFragment::Function)))
410 } else {
411 eyre::bail!("\"{}\"", content.trim());
412 }
413 }
414
415 fn parse_fragment(&self, buffer: &str) -> Option<(Self, ParseTreeFragment)> {
418 #[track_caller]
419 fn debug_errors(errors: &EmittedDiagnostics) {
420 debug!("{errors}");
421 }
422
423 let mut this = self.clone();
424 match this.add_run_code(buffer).parse() {
425 Ok(()) => return Some((this, ParseTreeFragment::Function)),
426 Err(e) => debug_errors(&e),
427 }
428 this = self.clone();
429 match this.add_contract_code(buffer).parse() {
430 Ok(()) => return Some((this, ParseTreeFragment::Contract)),
431 Err(e) => debug_errors(&e),
432 }
433 this = self.clone();
434 match this.add_global_code(buffer).parse() {
435 Ok(()) => return Some((this, ParseTreeFragment::Source)),
436 Err(e) => debug_errors(&e),
437 }
438 None
439 }
440
441 pub fn add_global_code(&mut self, content: &str) -> &mut Self {
443 self.global_code.push_str(content.trim());
444 self.global_code.push('\n');
445 self.clear_output();
446 self
447 }
448
449 pub fn add_contract_code(&mut self, content: &str) -> &mut Self {
451 self.contract_code.push_str(content.trim());
452 self.contract_code.push('\n');
453 self.clear_output();
454 self
455 }
456
457 pub fn add_run_code(&mut self, content: &str) -> &mut Self {
459 self.run_code.push_str(content.trim());
460 self.run_code.push('\n');
461 self.clear_output();
462 self
463 }
464
465 pub fn clear(&mut self) {
467 String::clear(&mut self.global_code);
468 String::clear(&mut self.contract_code);
469 String::clear(&mut self.run_code);
470 self.clear_output();
471 }
472
473 pub fn clear_run(&mut self) -> &mut Self {
475 String::clear(&mut self.run_code);
476 self.clear_output();
477 self
478 }
479
480 fn clear_output(&mut self) {
481 self.output.take();
482 }
483
484 pub fn build(&self) -> Result<&GeneratedOutput> {
486 if let Some(output) = self.output.get() {
488 return Ok(output);
489 }
490 let output = self.compile()?;
491 let output = GeneratedOutput { output };
492 Ok(self.output.get_or_init(|| output))
493 }
494
495 #[cold]
497 fn compile(&self) -> Result<ProjectCompileOutput> {
498 let sources = self.get_sources();
499
500 let mut project = self.config.foundry_config.ephemeral_project()?;
501 self.config.foundry_config.disable_optimizations(&mut project, self.config.ir_minimum);
502 let mut output = ProjectCompiler::with_sources(&project, sources)?.compile()?;
503
504 if output.has_compiler_errors() {
505 eyre::bail!("{output}");
506 }
507
508 output.parser_mut().solc_mut().compiler_mut().enter_mut(|c| {
510 let _ = c.lower_asts();
511 let _ = c.analysis();
512 });
513
514 Ok(output)
515 }
516
517 fn get_sources(&self) -> Sources {
518 let mut sources = Sources::new();
519
520 let src = self.to_repl_source();
521 sources.insert(self.file_name.clone().into(), Source::new(src));
522
523 if !self.config.no_vm
525 && !self
526 .config
527 .foundry_config
528 .get_all_remappings()
529 .any(|r| r.name.starts_with("forge-std"))
530 {
531 sources.insert("forge-std/Vm.sol".into(), self.vm_source.clone());
532 }
533
534 sources
535 }
536
537 pub fn to_repl_source(&self) -> String {
539 let Self {
540 contract_name,
541 global_code,
542 contract_code: top_level_code,
543 run_code,
544 config,
545 ..
546 } = self;
547 let (mut vm_import, mut vm_constant) = (String::new(), String::new());
548 if !config.no_vm
551 && let Some(remapping) = config
552 .foundry_config
553 .remappings
554 .iter()
555 .find(|remapping| remapping.name == "forge-std/")
556 && let Some(vm_path) = WalkDir::new(&remapping.path.path)
557 .into_iter()
558 .filter_map(|e| e.ok())
559 .find(|e| e.file_name() == "Vm.sol")
560 {
561 vm_import = format!(
562 "import {{Vm}} from \"{}\";\n",
563 vm_path.path().to_string_lossy().replace('\\', "/")
564 );
565 vm_constant = "Vm internal constant vm = Vm(address(uint160(uint256(keccak256(\"hevm cheat code\")))));\n".to_string();
566 }
567
568 format!(
569 r#"
570// SPDX-License-Identifier: UNLICENSED
571pragma solidity 0;
572
573{vm_import}
574{global_code}
575
576contract {contract_name} {{
577 {vm_constant}
578 {top_level_code}
579
580 /// @notice REPL contract entry point
581 function run() public {{
582 {run_code}
583 }}
584}}"#,
585 )
586 }
587
588 pub(crate) fn parse(&self) -> Result<(), EmittedDiagnostics> {
590 let sess =
591 solar::interface::Session::builder().with_buffer_emitter(Default::default()).build();
592 let _ = sess.enter_sequential(|| -> solar::interface::Result<()> {
593 let arena = solar::ast::Arena::new();
594 let filename = self.file_name.clone().into();
595 let src = self.to_repl_source();
596 let mut parser = solar::parse::Parser::from_source_code(&sess, &arena, filename, src)?;
597 let _ast = parser.parse_file().map_err(|e| e.emit())?;
598 Ok(())
599 });
600 sess.dcx.emitted_errors().unwrap()
601 }
602}
603
604#[derive(Debug)]
609enum ParseTreeFragment {
610 Source,
612 Contract,
614 Function,
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621 use foundry_compilers::artifacts::remappings::{RelativeRemapping, RelativeRemappingPathBuf};
622 use std::fs;
623
624 #[test]
629 fn test_vm_import_path_uses_forward_slashes() {
630 let tmp = tempfile::tempdir().unwrap();
631 let vm_sol = tmp.path().join("Vm.sol");
632 fs::write(&vm_sol, "// dummy").unwrap();
633
634 let remapping = RelativeRemapping {
635 context: None,
636 name: "forge-std/".to_string(),
637 path: RelativeRemappingPathBuf { parent: None, path: tmp.path().to_path_buf() },
638 };
639
640 let mut config = SessionSourceConfig {
641 foundry_config: Config {
642 solc: Some(SolcReq::Version(Version::new(0, 8, 29))),
643 remappings: vec![remapping],
644 ..Default::default()
645 },
646 ..Default::default()
647 };
648 config.detect_solc().unwrap();
650
651 let source = SessionSource {
652 file_name: "ReplContract.sol".to_string(),
653 contract_name: "REPL".to_string(),
654 config,
655 global_code: Default::default(),
656 contract_code: Default::default(),
657 run_code: Default::default(),
658 vm_source: vm_source(),
659 output: Default::default(),
660 };
661
662 let repl = source.to_repl_source();
663 let import_line = repl.lines().find(|l| l.contains("import {Vm}")).unwrap();
664 assert!(
665 !import_line.contains('\\'),
666 "Vm import path must not contain backslashes, got: {import_line}"
667 );
668 assert!(import_line.contains('/'), "Vm import path must use forward slashes");
669 }
670}