foundry_common/preprocessor/
deps.rs1use super::{
2 data::{ContractData, PreprocessorData},
3 span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar::sema::{
8 Gcx, Hir,
9 hir::{CallArgs, ContractId, Expr, ExprKind, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10 interface::{SourceMap, data_structures::Never, source_map::FileName},
11};
12use std::{
13 collections::{BTreeMap, BTreeSet, HashSet},
14 ops::{ControlFlow, Range},
15 path::{Path, PathBuf},
16};
17
18pub(crate) struct PreprocessorDependencies {
20 pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22 pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27 pub fn new(
28 gcx: Gcx<'_>,
29 paths: &[PathBuf],
30 script_paths: &HashSet<PathBuf>,
31 src_dir: &Path,
32 root_dir: &Path,
33 mocks: &mut HashSet<PathBuf>,
34 ) -> Self {
35 let mut preprocessed_contracts = BTreeMap::new();
36 let mut referenced_contracts = HashSet::new();
37 let mut current_mocks = HashSet::new();
38
39 let candidate_contracts = || {
41 gcx.hir.contract_ids().filter_map(|id| {
42 let contract = gcx.hir.contract(id);
43 let source = gcx.hir.source(contract.source);
44 let FileName::Real(path) = &source.file.name else {
45 return None;
46 };
47
48 if !paths.contains(path) {
49 trace!("{} is not test or script", path.display());
50 return None;
51 }
52
53 Some((id, contract, source, path))
54 })
55 };
56
57 for (_, contract, _, path) in candidate_contracts() {
59 if contract.linearized_bases.iter().any(|base_id| {
60 let base = gcx.hir.contract(*base_id);
61 matches!(
62 &gcx.hir.source(base.source).file.name,
63 FileName::Real(base_path) if base_path.starts_with(src_dir)
64 )
65 }) {
66 let mock_path = root_dir.join(path);
67 trace!("found mock contract {}", mock_path.display());
68 current_mocks.insert(mock_path);
69 }
70 }
71
72 for (contract_id, contract, source, path) in candidate_contracts() {
74 let full_path = root_dir.join(path);
75
76 if current_mocks.contains(&full_path) {
77 trace!("{} is a mock, skipping", path.display());
78 continue;
79 }
80
81 mocks.remove(&full_path);
84
85 let is_script = script_paths.contains(path)
89 || contract
90 .linearized_bases
91 .iter()
92 .skip(1)
93 .any(|base_id| gcx.hir.contract(*base_id).name.as_str() == "Script");
94 let mut deps_collector =
95 BytecodeDependencyCollector::new(gcx, source.file.src.as_str(), src_dir, is_script);
96 let _ = deps_collector.walk_contract(contract);
98 if !deps_collector.dependencies.is_empty() {
100 preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
101 }
102
103 referenced_contracts.extend(deps_collector.referenced_contracts);
105 }
106
107 mocks.extend(current_mocks);
109
110 Self { preprocessed_contracts, referenced_contracts }
111 }
112}
113
114#[derive(Debug)]
116enum BytecodeDependencyKind {
117 CreationCode,
119 New {
121 name: String,
123 args_length: usize,
125 call_args_offset: usize,
127 value: Option<String>,
129 salt: Option<String>,
131 try_stmt: Option<bool>,
133 },
134}
135
136#[derive(Debug)]
138pub(crate) struct BytecodeDependency {
139 kind: BytecodeDependencyKind,
141 loc: Range<usize>,
143 referenced_contract: ContractId,
145}
146
147struct BytecodeDependencyCollector<'gcx, 'src> {
149 gcx: Gcx<'gcx>,
151 src: &'src str,
153 src_dir: &'src Path,
155 is_script: bool,
161 dependencies: Vec<BytecodeDependency>,
163 referenced_contracts: HashSet<ContractId>,
165}
166
167impl<'gcx, 'src> BytecodeDependencyCollector<'gcx, 'src> {
168 fn new(gcx: Gcx<'gcx>, src: &'src str, src_dir: &'src Path, is_script: bool) -> Self {
169 Self {
170 gcx,
171 src,
172 src_dir,
173 is_script,
174 dependencies: vec![],
175 referenced_contracts: HashSet::default(),
176 }
177 }
178
179 fn collect_dependency(&mut self, dependency: BytecodeDependency) {
183 if self.is_script
185 && let BytecodeDependencyKind::New { salt: Some(_), .. } = &dependency.kind
186 {
187 trace!("skip salted new-expression in script");
188 return;
189 }
190
191 let contract = self.gcx.hir.contract(dependency.referenced_contract);
192 let source = self.gcx.hir.source(contract.source);
193 let FileName::Real(path) = &source.file.name else {
194 return;
195 };
196
197 if !path.starts_with(self.src_dir) {
198 let path = path.display();
199 trace!("ignore dependency {path}");
200 return;
201 }
202
203 self.referenced_contracts.insert(dependency.referenced_contract);
204 self.dependencies.push(dependency);
205 }
206}
207
208impl<'gcx> Visit<'gcx> for BytecodeDependencyCollector<'gcx, '_> {
209 type BreakValue = Never;
210
211 fn hir(&self) -> &'gcx Hir<'gcx> {
212 &self.gcx.hir
213 }
214
215 fn visit_expr(&mut self, expr: &'gcx Expr<'gcx>) -> ControlFlow<Self::BreakValue> {
216 #[allow(clippy::collapsible_match)]
217 match &expr.kind {
218 ExprKind::Call(call_expr, call_args, named_args) => {
219 if let Some(dependency) = handle_call_expr(
220 self.src,
221 self.gcx.sess.source_map(),
222 expr,
223 call_expr,
224 call_args,
225 named_args,
226 ) {
227 self.collect_dependency(dependency);
228 }
229 }
230 ExprKind::Member(member_expr, ident) => {
231 if let ExprKind::TypeCall(ty) = &member_expr.kind
232 && let TypeKind::Custom(contract_id) = &ty.kind
233 && ident.name.as_str() == "creationCode"
234 && let Some(contract_id) = contract_id.as_contract()
235 {
236 self.collect_dependency(BytecodeDependency {
237 kind: BytecodeDependencyKind::CreationCode,
238 loc: span_to_range(self.gcx.sess.source_map(), expr.span),
239 referenced_contract: contract_id,
240 });
241 }
242 }
243 _ => {}
244 }
245 self.walk_expr(expr)
246 }
247
248 fn visit_stmt(&mut self, stmt: &'gcx Stmt<'gcx>) -> ControlFlow<Self::BreakValue> {
249 if let StmtKind::Try(stmt_try) = stmt.kind
250 && let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
251 && let Some(mut dependency) = handle_call_expr(
252 self.src,
253 self.gcx.sess.source_map(),
254 &stmt_try.expr,
255 call_expr,
256 call_args,
257 named_args,
258 )
259 {
260 let has_custom_return = if let Some(clause) = stmt_try.clauses.first()
261 && clause.args.len() == 1
262 && let Some(ret_var) = clause.args.first()
263 && let TypeKind::Custom(_) = self.hir().variable(*ret_var).ty.kind
264 {
265 true
266 } else {
267 false
268 };
269
270 if let BytecodeDependencyKind::New { try_stmt, .. } = &mut dependency.kind {
271 *try_stmt = Some(has_custom_return);
272 }
273 self.collect_dependency(dependency);
274
275 for clause in stmt_try.clauses {
276 for &var in clause.args {
277 self.visit_nested_var(var)?;
278 }
279 for stmt in clause.block.stmts {
280 self.visit_stmt(stmt)?;
281 }
282 }
283 return ControlFlow::Continue(());
284 }
285 self.walk_stmt(stmt)
286 }
287}
288
289fn handle_call_expr(
291 src: &str,
292 source_map: &SourceMap,
293 parent_expr: &Expr<'_>,
294 call_expr: &Expr<'_>,
295 call_args: &CallArgs<'_>,
296 named_args: &Option<&[NamedArg<'_>]>,
297) -> Option<BytecodeDependency> {
298 if let ExprKind::New(ty_new) = &call_expr.kind
299 && let TypeKind::Custom(item_id) = ty_new.kind
300 && let Some(contract_id) = item_id.as_contract()
301 {
302 let name_loc = span_to_range(source_map, ty_new.span);
303 let name = &src[name_loc];
304
305 let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
309 (call_args.span.lo() - ty_new.span.hi()).to_usize()
310 } else {
311 0
312 };
313
314 let args_len = parent_expr.span.hi() - ty_new.span.hi();
315 return Some(BytecodeDependency {
316 kind: BytecodeDependencyKind::New {
317 name: name.to_string(),
318 args_length: args_len.to_usize(),
319 call_args_offset,
320 value: named_arg(src, named_args, "value", source_map),
321 salt: named_arg(src, named_args, "salt", source_map),
322 try_stmt: None,
323 },
324 loc: span_to_range(source_map, call_expr.span),
325 referenced_contract: contract_id,
326 });
327 }
328 None
329}
330
331fn named_arg(
333 src: &str,
334 named_args: &Option<&[NamedArg<'_>]>,
335 arg: &str,
336 source_map: &SourceMap,
337) -> Option<String> {
338 named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
339 |named_arg| {
340 let named_arg_loc = span_to_range(source_map, named_arg.value.span);
341 src[named_arg_loc].to_string()
342 },
343 )
344}
345
346pub(crate) fn remove_bytecode_dependencies(
360 gcx: Gcx<'_>,
361 deps: &PreprocessorDependencies,
362 data: &PreprocessorData,
363) -> Updates {
364 let mut updates = Updates::default();
365 for (contract_id, deps) in &deps.preprocessed_contracts {
366 let contract = gcx.hir.contract(*contract_id);
367 let source = gcx.hir.source(contract.source);
368 let FileName::Real(path) = &source.file.name else {
369 continue;
370 };
371
372 let updates = updates.entry(path.clone()).or_default();
373 let mut used_helpers = BTreeSet::new();
374
375 let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
376 let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
378 let mut try_catch_helpers: HashSet<&str> = HashSet::default();
379
380 for dep in deps {
381 let Some(ContractData { artifact, constructor_data, .. }) =
382 data.get(&dep.referenced_contract)
383 else {
384 continue;
385 };
386
387 match &dep.kind {
388 BytecodeDependencyKind::CreationCode => {
389 updates.insert((
391 dep.loc.start,
392 dep.loc.end,
393 format!("{vm}.getCode(\"{artifact}\")"),
394 ));
395 }
396 BytecodeDependencyKind::New {
397 name,
398 args_length,
399 call_args_offset,
400 value,
401 salt,
402 try_stmt,
403 } => {
404 let (mut update, closing_seq) = if let Some(has_ret) = try_stmt {
405 if *has_ret {
406 try_catch_helpers.insert(name);
408 (format!("this.addressTo{name}{id}(", id = contract_id.get()), "}))")
409 } else {
410 (String::new(), "})")
411 }
412 } else {
413 (format!("{name}(payable("), "})))")
414 };
415 update.push_str(&format!("{vm}.deployCode({{"));
416 update.push_str(&format!("_artifact: \"{artifact}\""));
417
418 if let Some(value) = value {
419 update.push_str(", ");
420 update.push_str(&format!("_value: {value}"));
421 }
422
423 if let Some(salt) = salt {
424 update.push_str(", ");
425 update.push_str(&format!("_salt: {salt}"));
426 }
427
428 if constructor_data.is_some() {
429 used_helpers.insert(dep.referenced_contract);
431
432 update.push_str(", ");
433 update.push_str(&format!(
434 "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
435 id = dep.referenced_contract.get()
436 ));
437 updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
438
439 updates.insert((
440 dep.loc.end + args_length,
441 dep.loc.end + args_length,
442 format!("){closing_seq}"),
443 ));
444 } else {
445 update.push_str(closing_seq);
446 updates.insert((dep.loc.start, dep.loc.end + args_length, update));
447 }
448 }
449 };
450 }
451
452 if !try_catch_helpers.is_empty()
454 && let Some(last_fn_id) = contract.functions().last()
455 {
456 let last_fn_range =
457 span_to_range(gcx.sess.source_map(), gcx.hir.function(last_fn_id).span);
458 let to_address_fns = try_catch_helpers
459 .iter()
460 .map(|ty| {
461 format!(
462 r#"
463 function addressTo{ty}{id}(address addr) public pure returns ({ty}) {{
464 return {ty}(addr);
465 }}
466 "#,
467 id = contract_id.get()
468 )
469 })
470 .collect::<String>();
471
472 updates.insert((last_fn_range.end, last_fn_range.end, to_address_fns));
473 }
474
475 let helper_imports = used_helpers.into_iter().map(|id| {
476 let id = id.get();
477 format!(
478 "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
479 )
480 }).join("\n");
481 updates.insert((
482 source.file.src.len(),
483 source.file.src.len(),
484 format!(
485 r#"
486{helper_imports}
487
488interface {vm_interface_name} {{
489 function deployCode(string memory _artifact) external returns (address);
490 function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
491 function deployCode(string memory _artifact, bytes memory _args) external returns (address);
492 function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
493 function deployCode(string memory _artifact, uint256 _value) external returns (address);
494 function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
495 function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
496 function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
497 function getCode(string memory _artifact) external view returns (bytes memory);
498}}"#
499 ),
500 ));
501 }
502 updates
503}