1use alloy_primitives::map::HashMap;
8use eyre::Result;
9use forge_fmt::solang_ext::SafeUnwrap;
10use foundry_compilers::{
11 artifacts::{CompilerOutput, Settings, SolcInput, Source, Sources},
12 compilers::solc::Solc,
13};
14use foundry_config::{Config, SolcReq};
15use foundry_evm::{backend::Backend, opts::EvmOpts};
16use semver::Version;
17use serde::{Deserialize, Serialize};
18use solang_parser::{diagnostics::Diagnostic, pt};
19use std::{fs, path::PathBuf};
20use walkdir::WalkDir;
21use yansi::Paint;
22
23pub const MIN_VM_VERSION: Version = Version::new(0, 6, 2);
25
26static VM_SOURCE: &str = include_str!("../../../testdata/cheats/Vm.sol");
28
29#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
31pub struct IntermediateOutput {
32 #[serde(skip)]
34 pub repl_contract_expressions: HashMap<String, pt::Expression>,
35 #[serde(skip)]
37 pub intermediate_contracts: IntermediateContracts,
38}
39
40#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
43pub struct IntermediateContract {
44 #[serde(skip)]
46 pub function_definitions: HashMap<String, Box<pt::FunctionDefinition>>,
47 #[serde(skip)]
49 pub event_definitions: HashMap<String, Box<pt::EventDefinition>>,
50 #[serde(skip)]
52 pub struct_definitions: HashMap<String, Box<pt::StructDefinition>>,
53 #[serde(skip)]
55 pub variable_definitions: HashMap<String, Box<pt::VariableDefinition>>,
56}
57
58type IntermediateContracts = HashMap<String, IntermediateContract>;
60
61#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
63pub struct GeneratedOutput {
64 pub intermediate: IntermediateOutput,
66 pub compiler_output: CompilerOutput,
68}
69
70#[derive(Clone, Debug, Default, Serialize, Deserialize)]
72pub struct SessionSourceConfig {
73 pub foundry_config: Config,
75 pub evm_opts: EvmOpts,
77 pub no_vm: bool,
79 #[serde(skip)]
80 pub backend: Option<Backend>,
82 pub traces: bool,
84 pub calldata: Option<Vec<u8>>,
86}
87
88impl SessionSourceConfig {
89 pub(crate) fn solc(&self) -> Result<Solc> {
96 let solc_req = if let Some(solc_req) = self.foundry_config.solc.clone() {
97 solc_req
98 } else if let Some(version) = Solc::installed_versions().into_iter().max() {
99 SolcReq::Version(version)
100 } else {
101 if !self.foundry_config.offline {
102 sh_print!("{}", "No solidity versions installed! ".green())?;
103 }
104 SolcReq::Version(Version::new(0, 8, 19))
106 };
107
108 match solc_req {
109 SolcReq::Version(version) => {
110 let solc = if let Some(solc) = Solc::find_svm_installed_version(&version)? {
111 solc
112 } else {
113 if self.foundry_config.offline {
114 eyre::bail!("can't install missing solc {version} in offline mode")
115 }
116 sh_println!("{}", format!("Installing solidity version {version}...").green())?;
117 Solc::blocking_install(&version)?
118 };
119 Ok(solc)
120 }
121 SolcReq::Local(solc) => {
122 if !solc.is_file() {
123 eyre::bail!("`solc` {} does not exist", solc.display());
124 }
125 Ok(Solc::new(solc)?)
126 }
127 }
128 }
129}
130
131#[derive(Clone, Debug, Serialize, Deserialize)]
135pub struct SessionSource {
136 pub file_name: PathBuf,
138 pub contract_name: String,
140 pub solc: Solc,
142 pub global_code: String,
147 pub top_level_code: String,
151 pub run_code: String,
153 pub generated_output: Option<GeneratedOutput>,
155 pub config: SessionSourceConfig,
157}
158
159impl SessionSource {
160 #[track_caller]
175 pub fn new(solc: Solc, mut config: SessionSourceConfig) -> Self {
176 if solc.version < MIN_VM_VERSION && !config.no_vm {
177 tracing::info!(version=%solc.version, minimum=%MIN_VM_VERSION, "Disabling VM injection");
178 config.no_vm = true;
179 }
180
181 Self {
182 file_name: PathBuf::from("ReplContract.sol".to_string()),
183 contract_name: "REPL".to_string(),
184 solc,
185 config,
186 global_code: Default::default(),
187 top_level_code: Default::default(),
188 run_code: Default::default(),
189 generated_output: None,
190 }
191 }
192
193 pub fn shallow_clone(&self) -> Self {
200 Self {
201 file_name: self.file_name.clone(),
202 contract_name: self.contract_name.clone(),
203 solc: self.solc.clone(),
204 global_code: self.global_code.clone(),
205 top_level_code: self.top_level_code.clone(),
206 run_code: self.run_code.clone(),
207 generated_output: None,
208 config: self.config.clone(),
209 }
210 }
211
212 pub fn clone_with_new_line(&self, mut content: String) -> Result<(Self, bool)> {
220 let new_source = self.shallow_clone();
221 if let Some(parsed) = parse_fragment(new_source.solc, new_source.config, &content)
222 .or_else(|| {
223 let new_source = self.shallow_clone();
224 content.push(';');
225 parse_fragment(new_source.solc, new_source.config, &content)
226 })
227 .or_else(|| {
228 let new_source = self.shallow_clone();
229 content = content.trim_end().trim_end_matches(';').to_string();
230 parse_fragment(new_source.solc, new_source.config, &content)
231 })
232 {
233 let mut new_source = self.shallow_clone();
234 match parsed {
237 ParseTreeFragment::Function => new_source.with_run_code(&content),
238 ParseTreeFragment::Contract => new_source.with_top_level_code(&content),
239 ParseTreeFragment::Source => new_source.with_global_code(&content),
240 };
241
242 Ok((new_source, matches!(parsed, ParseTreeFragment::Function)))
243 } else {
244 eyre::bail!("\"{}\"", content.trim().to_owned());
245 }
246 }
247
248 pub fn with_global_code(&mut self, content: &str) -> &mut Self {
252 self.global_code.push_str(content.trim());
253 self.global_code.push('\n');
254 self.generated_output = None;
255 self
256 }
257
258 pub fn with_top_level_code(&mut self, content: &str) -> &mut Self {
260 self.top_level_code.push_str(content.trim());
261 self.top_level_code.push('\n');
262 self.generated_output = None;
263 self
264 }
265
266 pub fn with_run_code(&mut self, content: &str) -> &mut Self {
268 self.run_code.push_str(content.trim());
269 self.run_code.push('\n');
270 self.generated_output = None;
271 self
272 }
273
274 pub fn drain_global_code(&mut self) -> &mut Self {
278 String::clear(&mut self.global_code);
279 self.generated_output = None;
280 self
281 }
282
283 pub fn drain_top_level_code(&mut self) -> &mut Self {
285 String::clear(&mut self.top_level_code);
286 self.generated_output = None;
287 self
288 }
289
290 pub fn drain_run(&mut self) -> &mut Self {
292 String::clear(&mut self.run_code);
293 self.generated_output = None;
294 self
295 }
296
297 pub fn compiler_input(&self) -> SolcInput {
304 let mut sources = Sources::new();
305 sources.insert(self.file_name.clone(), Source::new(self.to_repl_source()));
306
307 let remappings = self.config.foundry_config.get_all_remappings().collect::<Vec<_>>();
308
309 if !self.config.no_vm && !remappings.iter().any(|r| r.name.starts_with("forge-std")) {
311 sources.insert(PathBuf::from("forge-std/Vm.sol"), Source::new(VM_SOURCE));
312 }
313
314 let settings = Settings {
315 remappings,
316 evm_version: self
317 .config
318 .foundry_config
319 .evm_version
320 .normalize_version_solc(&self.solc.version),
321 ..Default::default()
322 };
323
324 SolcInput::resolve_and_build(sources, settings)
326 .into_iter()
327 .next()
328 .map(|i| i.sanitized(&self.solc.version))
329 .expect("Solidity source not found")
330 }
331
332 pub fn parse(&self) -> Result<pt::SourceUnit, Vec<solang_parser::diagnostics::Diagnostic>> {
339 solang_parser::parse(&self.to_repl_source(), 0).map(|(pt, _)| pt)
340 }
341
342 pub fn generate_intermediate_contracts(&self) -> Result<HashMap<String, IntermediateContract>> {
348 let mut res_map = HashMap::default();
349 let parsed_map = self.compiler_input().sources;
350 for source in parsed_map.values() {
351 Self::get_intermediate_contract(&source.content, &mut res_map);
352 }
353 Ok(res_map)
354 }
355
356 pub fn generate_intermediate_output(&self) -> Result<IntermediateOutput> {
358 let intermediate_contracts = self.generate_intermediate_contracts()?;
360
361 let variable_definitions = intermediate_contracts
363 .get("REPL")
364 .ok_or_else(|| eyre::eyre!("Could not find intermediate REPL contract!"))?
365 .variable_definitions
366 .clone()
367 .into_iter()
368 .map(|(k, v)| (k, v.ty))
369 .collect::<HashMap<String, pt::Expression>>();
370 let mut intermediate_output = IntermediateOutput {
372 repl_contract_expressions: variable_definitions,
373 intermediate_contracts,
374 };
375
376 for (key, val) in intermediate_output
378 .run_func_body()?
379 .clone()
380 .iter()
381 .flat_map(Self::get_statement_definitions)
382 {
383 intermediate_output.repl_contract_expressions.insert(key, val);
384 }
385
386 Ok(intermediate_output)
387 }
388
389 pub fn compile(&self) -> Result<CompilerOutput> {
395 let compiled = self.solc.compile_exact(&self.compiler_input())?;
397
398 let errors =
400 compiled.errors.iter().filter(|error| error.severity.is_error()).collect::<Vec<_>>();
401 if !errors.is_empty() {
402 eyre::bail!(
403 "Compiler errors:\n{}",
404 errors.into_iter().map(|err| err.to_string()).collect::<String>()
405 );
406 }
407
408 Ok(compiled)
409 }
410
411 pub fn build(&mut self) -> Result<GeneratedOutput> {
418 let compiler_output = self.compile()?;
420
421 let intermediate_output = self.generate_intermediate_output()?;
423
424 let generated_output =
426 GeneratedOutput { intermediate: intermediate_output, compiler_output };
427 self.generated_output = Some(generated_output.clone()); Ok(generated_output)
429 }
430
431 pub fn to_script_source(&self) -> String {
437 let Version { major, minor, patch, .. } = self.solc.version;
438 let Self { contract_name, global_code, top_level_code, run_code, config, .. } = self;
439
440 let script_import =
441 if !config.no_vm { "import {Script} from \"forge-std/Script.sol\";\n" } else { "" };
442
443 format!(
444 r#"
445// SPDX-License-Identifier: UNLICENSED
446pragma solidity ^{major}.{minor}.{patch};
447
448{script_import}
449{global_code}
450
451contract {contract_name} is Script {{
452 {top_level_code}
453
454 /// @notice Script entry point
455 function run() public {{
456 {run_code}
457 }}
458}}"#,
459 )
460 }
461
462 pub fn to_repl_source(&self) -> String {
468 let Version { major, minor, patch, .. } = self.solc.version;
469 let Self { contract_name, global_code, top_level_code, run_code, config, .. } = self;
470 let (mut vm_import, mut vm_constant) = (String::new(), String::new());
471 if !config.no_vm {
472 if let Some(remapping) = config
475 .foundry_config
476 .remappings
477 .iter()
478 .find(|remapping| remapping.name == "forge-std/")
479 {
480 if let Some(vm_path) = WalkDir::new(&remapping.path.path)
481 .into_iter()
482 .filter_map(|e| e.ok())
483 .find(|e| e.file_name() == "Vm.sol")
484 {
485 vm_import = format!("import {{Vm}} from \"{}\";\n", vm_path.path().display());
486 vm_constant = "Vm internal constant vm = Vm(address(uint160(uint256(keccak256(\"hevm cheat code\")))));\n".to_string();
487 }
488 }
489 }
490
491 format!(
492 r#"
493// SPDX-License-Identifier: UNLICENSED
494pragma solidity ^{major}.{minor}.{patch};
495
496{vm_import}
497{global_code}
498
499contract {contract_name} {{
500 {vm_constant}
501 {top_level_code}
502
503 /// @notice REPL contract entry point
504 function run() public {{
505 {run_code}
506 }}
507}}"#,
508 )
509 }
510
511 pub fn get_intermediate_contract(
518 content: &str,
519 res_map: &mut HashMap<String, IntermediateContract>,
520 ) {
521 if let Ok((pt::SourceUnit(source_unit_parts), _)) = solang_parser::parse(content, 0) {
522 let func_defs = source_unit_parts
523 .into_iter()
524 .filter_map(|sup| match sup {
525 pt::SourceUnitPart::ImportDirective(i) => match i {
526 pt::Import::Plain(s, _) |
527 pt::Import::Rename(s, _, _) |
528 pt::Import::GlobalSymbol(s, _, _) => {
529 let s = match s {
530 pt::ImportPath::Filename(s) => s.string,
531 pt::ImportPath::Path(p) => p.to_string(),
532 };
533 let path = PathBuf::from(s);
534
535 match fs::read_to_string(path) {
536 Ok(source) => {
537 Self::get_intermediate_contract(&source, res_map);
538 None
539 }
540 Err(_) => None,
541 }
542 }
543 },
544 pt::SourceUnitPart::ContractDefinition(cd) => {
545 let mut intermediate = IntermediateContract::default();
546
547 cd.parts.into_iter().for_each(|part| match part {
548 pt::ContractPart::FunctionDefinition(def) => {
549 if matches!(def.ty, pt::FunctionTy::Function) {
551 intermediate
552 .function_definitions
553 .insert(def.name.clone().unwrap().name, def);
554 }
555 }
556 pt::ContractPart::EventDefinition(def) => {
557 let event_name = def.name.safe_unwrap().name.clone();
558 intermediate.event_definitions.insert(event_name, def);
559 }
560 pt::ContractPart::StructDefinition(def) => {
561 let struct_name = def.name.safe_unwrap().name.clone();
562 intermediate.struct_definitions.insert(struct_name, def);
563 }
564 pt::ContractPart::VariableDefinition(def) => {
565 let var_name = def.name.safe_unwrap().name.clone();
566 intermediate.variable_definitions.insert(var_name, def);
567 }
568 _ => {}
569 });
570 Some((cd.name.safe_unwrap().name.clone(), intermediate))
571 }
572 _ => None,
573 })
574 .collect::<HashMap<String, IntermediateContract>>();
575 res_map.extend(func_defs);
576 }
577 }
578
579 pub fn get_statement_definitions(statement: &pt::Statement) -> Vec<(String, pt::Expression)> {
589 match statement {
590 pt::Statement::VariableDefinition(_, def, _) => {
591 vec![(def.name.safe_unwrap().name.clone(), def.ty.clone())]
592 }
593 pt::Statement::Expression(_, pt::Expression::Assign(_, left, _)) => {
594 if let pt::Expression::List(_, list) = left.as_ref() {
595 list.iter()
596 .filter_map(|(_, param)| {
597 param.as_ref().and_then(|param| {
598 param
599 .name
600 .as_ref()
601 .map(|name| (name.name.clone(), param.ty.clone()))
602 })
603 })
604 .collect()
605 } else {
606 Vec::default()
607 }
608 }
609 _ => Vec::default(),
610 }
611 }
612}
613
614impl IntermediateOutput {
615 pub fn run_func_body(&self) -> Result<&Vec<pt::Statement>> {
621 match self
622 .intermediate_contracts
623 .get("REPL")
624 .ok_or_else(|| eyre::eyre!("Could not find REPL intermediate contract!"))?
625 .function_definitions
626 .get("run")
627 .ok_or_else(|| eyre::eyre!("Could not find run function definition in REPL contract!"))?
628 .body
629 .as_ref()
630 .ok_or_else(|| eyre::eyre!("Could not find run function body!"))?
631 {
632 pt::Statement::Block { statements, .. } => Ok(statements),
633 _ => eyre::bail!("Could not find statements within run function body!"),
634 }
635 }
636}
637
638#[derive(Debug)]
643pub enum ParseTreeFragment {
644 Source,
646 Contract,
648 Function,
650}
651
652pub fn parse_fragment(
655 solc: Solc,
656 config: SessionSourceConfig,
657 buffer: &str,
658) -> Option<ParseTreeFragment> {
659 let mut base = SessionSource::new(solc, config);
660
661 match base.clone().with_run_code(buffer).parse() {
662 Ok(_) => return Some(ParseTreeFragment::Function),
663 Err(e) => debug_errors(&e),
664 }
665 match base.clone().with_top_level_code(buffer).parse() {
666 Ok(_) => return Some(ParseTreeFragment::Contract),
667 Err(e) => debug_errors(&e),
668 }
669 match base.with_global_code(buffer).parse() {
670 Ok(_) => return Some(ParseTreeFragment::Source),
671 Err(e) => debug_errors(&e),
672 }
673
674 None
675}
676
677fn debug_errors(errors: &[Diagnostic]) {
678 if !tracing::enabled!(tracing::Level::DEBUG) {
679 return;
680 }
681
682 for error in errors {
683 tracing::debug!("error: {}", error.message);
684 }
685}