foundry_common/preprocessor/
deps.rs1use super::{
2 data::{ContractData, PreprocessorData},
3 span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar::sema::{
8 Gcx, Hir,
9 hir::{CallArgs, ContractId, Expr, ExprKind, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10 interface::{SourceMap, data_structures::Never, source_map::FileName},
11};
12use std::{
13 collections::{BTreeMap, BTreeSet, HashSet},
14 ops::{ControlFlow, Range},
15 path::{Path, PathBuf},
16};
17
18pub(crate) struct PreprocessorDependencies {
20 pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22 pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27 pub fn new(
28 gcx: Gcx<'_>,
29 paths: &[PathBuf],
30 src_dir: &Path,
31 root_dir: &Path,
32 mocks: &mut HashSet<PathBuf>,
33 ) -> Self {
34 let mut preprocessed_contracts = BTreeMap::new();
35 let mut referenced_contracts = HashSet::new();
36 let mut current_mocks = HashSet::new();
37
38 let candidate_contracts = || {
40 gcx.hir.contract_ids().filter_map(|id| {
41 let contract = gcx.hir.contract(id);
42 let source = gcx.hir.source(contract.source);
43 let FileName::Real(path) = &source.file.name else {
44 return None;
45 };
46
47 if !paths.contains(path) {
48 trace!("{} is not test or script", path.display());
49 return None;
50 }
51
52 Some((id, contract, source, path))
53 })
54 };
55
56 for (_, contract, _, path) in candidate_contracts() {
58 if contract.linearized_bases.iter().any(|base_id| {
59 let base = gcx.hir.contract(*base_id);
60 matches!(
61 &gcx.hir.source(base.source).file.name,
62 FileName::Real(base_path) if base_path.starts_with(src_dir)
63 )
64 }) {
65 let mock_path = root_dir.join(path);
66 trace!("found mock contract {}", mock_path.display());
67 current_mocks.insert(mock_path);
68 }
69 }
70
71 for (contract_id, contract, source, path) in candidate_contracts() {
73 let full_path = root_dir.join(path);
74
75 if current_mocks.contains(&full_path) {
76 trace!("{} is a mock, skipping", path.display());
77 continue;
78 }
79
80 mocks.remove(&full_path);
83
84 let mut deps_collector =
85 BytecodeDependencyCollector::new(gcx, source.file.src.as_str(), src_dir);
86 let _ = deps_collector.walk_contract(contract);
88 if !deps_collector.dependencies.is_empty() {
90 preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
91 }
92
93 referenced_contracts.extend(deps_collector.referenced_contracts);
95 }
96
97 mocks.extend(current_mocks);
99
100 Self { preprocessed_contracts, referenced_contracts }
101 }
102}
103
104#[derive(Debug)]
106enum BytecodeDependencyKind {
107 CreationCode,
109 New {
111 name: String,
113 args_length: usize,
115 call_args_offset: usize,
117 value: Option<String>,
119 salt: Option<String>,
121 try_stmt: Option<bool>,
123 },
124}
125
126#[derive(Debug)]
128pub(crate) struct BytecodeDependency {
129 kind: BytecodeDependencyKind,
131 loc: Range<usize>,
133 referenced_contract: ContractId,
135}
136
137struct BytecodeDependencyCollector<'gcx, 'src> {
139 gcx: Gcx<'gcx>,
141 src: &'src str,
143 src_dir: &'src Path,
145 dependencies: Vec<BytecodeDependency>,
147 referenced_contracts: HashSet<ContractId>,
149}
150
151impl<'gcx, 'src> BytecodeDependencyCollector<'gcx, 'src> {
152 fn new(gcx: Gcx<'gcx>, src: &'src str, src_dir: &'src Path) -> Self {
153 Self { gcx, src, src_dir, dependencies: vec![], referenced_contracts: HashSet::default() }
154 }
155
156 fn collect_dependency(&mut self, dependency: BytecodeDependency) {
160 let contract = self.gcx.hir.contract(dependency.referenced_contract);
161 let source = self.gcx.hir.source(contract.source);
162 let FileName::Real(path) = &source.file.name else {
163 return;
164 };
165
166 if !path.starts_with(self.src_dir) {
167 let path = path.display();
168 trace!("ignore dependency {path}");
169 return;
170 }
171
172 self.referenced_contracts.insert(dependency.referenced_contract);
173 self.dependencies.push(dependency);
174 }
175}
176
177impl<'gcx> Visit<'gcx> for BytecodeDependencyCollector<'gcx, '_> {
178 type BreakValue = Never;
179
180 fn hir(&self) -> &'gcx Hir<'gcx> {
181 &self.gcx.hir
182 }
183
184 fn visit_expr(&mut self, expr: &'gcx Expr<'gcx>) -> ControlFlow<Self::BreakValue> {
185 #[allow(clippy::collapsible_match)]
186 match &expr.kind {
187 ExprKind::Call(call_expr, call_args, named_args) => {
188 if let Some(dependency) = handle_call_expr(
189 self.src,
190 self.gcx.sess.source_map(),
191 expr,
192 call_expr,
193 call_args,
194 named_args,
195 ) {
196 self.collect_dependency(dependency);
197 }
198 }
199 ExprKind::Member(member_expr, ident) => {
200 if let ExprKind::TypeCall(ty) = &member_expr.kind
201 && let TypeKind::Custom(contract_id) = &ty.kind
202 && ident.name.as_str() == "creationCode"
203 && let Some(contract_id) = contract_id.as_contract()
204 {
205 self.collect_dependency(BytecodeDependency {
206 kind: BytecodeDependencyKind::CreationCode,
207 loc: span_to_range(self.gcx.sess.source_map(), expr.span),
208 referenced_contract: contract_id,
209 });
210 }
211 }
212 _ => {}
213 }
214 self.walk_expr(expr)
215 }
216
217 fn visit_stmt(&mut self, stmt: &'gcx Stmt<'gcx>) -> ControlFlow<Self::BreakValue> {
218 if let StmtKind::Try(stmt_try) = stmt.kind
219 && let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
220 && let Some(mut dependency) = handle_call_expr(
221 self.src,
222 self.gcx.sess.source_map(),
223 &stmt_try.expr,
224 call_expr,
225 call_args,
226 named_args,
227 )
228 {
229 let has_custom_return = if let Some(clause) = stmt_try.clauses.first()
230 && clause.args.len() == 1
231 && let Some(ret_var) = clause.args.first()
232 && let TypeKind::Custom(_) = self.hir().variable(*ret_var).ty.kind
233 {
234 true
235 } else {
236 false
237 };
238
239 if let BytecodeDependencyKind::New { try_stmt, .. } = &mut dependency.kind {
240 *try_stmt = Some(has_custom_return);
241 }
242 self.collect_dependency(dependency);
243
244 for clause in stmt_try.clauses {
245 for &var in clause.args {
246 self.visit_nested_var(var)?;
247 }
248 for stmt in clause.block.stmts {
249 self.visit_stmt(stmt)?;
250 }
251 }
252 return ControlFlow::Continue(());
253 }
254 self.walk_stmt(stmt)
255 }
256}
257
258fn handle_call_expr(
260 src: &str,
261 source_map: &SourceMap,
262 parent_expr: &Expr<'_>,
263 call_expr: &Expr<'_>,
264 call_args: &CallArgs<'_>,
265 named_args: &Option<&[NamedArg<'_>]>,
266) -> Option<BytecodeDependency> {
267 if let ExprKind::New(ty_new) = &call_expr.kind
268 && let TypeKind::Custom(item_id) = ty_new.kind
269 && let Some(contract_id) = item_id.as_contract()
270 {
271 let name_loc = span_to_range(source_map, ty_new.span);
272 let name = &src[name_loc];
273
274 let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
278 (call_args.span.lo() - ty_new.span.hi()).to_usize()
279 } else {
280 0
281 };
282
283 let args_len = parent_expr.span.hi() - ty_new.span.hi();
284 return Some(BytecodeDependency {
285 kind: BytecodeDependencyKind::New {
286 name: name.to_string(),
287 args_length: args_len.to_usize(),
288 call_args_offset,
289 value: named_arg(src, named_args, "value", source_map),
290 salt: named_arg(src, named_args, "salt", source_map),
291 try_stmt: None,
292 },
293 loc: span_to_range(source_map, call_expr.span),
294 referenced_contract: contract_id,
295 });
296 }
297 None
298}
299
300fn named_arg(
302 src: &str,
303 named_args: &Option<&[NamedArg<'_>]>,
304 arg: &str,
305 source_map: &SourceMap,
306) -> Option<String> {
307 named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
308 |named_arg| {
309 let named_arg_loc = span_to_range(source_map, named_arg.value.span);
310 src[named_arg_loc].to_string()
311 },
312 )
313}
314
315pub(crate) fn remove_bytecode_dependencies(
329 gcx: Gcx<'_>,
330 deps: &PreprocessorDependencies,
331 data: &PreprocessorData,
332) -> Updates {
333 let mut updates = Updates::default();
334 for (contract_id, deps) in &deps.preprocessed_contracts {
335 let contract = gcx.hir.contract(*contract_id);
336 let source = gcx.hir.source(contract.source);
337 let FileName::Real(path) = &source.file.name else {
338 continue;
339 };
340
341 let updates = updates.entry(path.clone()).or_default();
342 let mut used_helpers = BTreeSet::new();
343
344 let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
345 let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
347 let mut try_catch_helpers: HashSet<&str> = HashSet::default();
348
349 for dep in deps {
350 let Some(ContractData { artifact, constructor_data, .. }) =
351 data.get(&dep.referenced_contract)
352 else {
353 continue;
354 };
355
356 match &dep.kind {
357 BytecodeDependencyKind::CreationCode => {
358 updates.insert((
360 dep.loc.start,
361 dep.loc.end,
362 format!("{vm}.getCode(\"{artifact}\")"),
363 ));
364 }
365 BytecodeDependencyKind::New {
366 name,
367 args_length,
368 call_args_offset,
369 value,
370 salt,
371 try_stmt,
372 } => {
373 let (mut update, closing_seq) = if let Some(has_ret) = try_stmt {
374 if *has_ret {
375 try_catch_helpers.insert(name);
377 (format!("this.addressTo{name}{id}(", id = contract_id.get()), "}))")
378 } else {
379 (String::new(), "})")
380 }
381 } else {
382 (format!("{name}(payable("), "})))")
383 };
384 update.push_str(&format!("{vm}.deployCode({{"));
385 update.push_str(&format!("_artifact: \"{artifact}\""));
386
387 if let Some(value) = value {
388 update.push_str(", ");
389 update.push_str(&format!("_value: {value}"));
390 }
391
392 if let Some(salt) = salt {
393 update.push_str(", ");
394 update.push_str(&format!("_salt: {salt}"));
395 }
396
397 if constructor_data.is_some() {
398 used_helpers.insert(dep.referenced_contract);
400
401 update.push_str(", ");
402 update.push_str(&format!(
403 "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
404 id = dep.referenced_contract.get()
405 ));
406 updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
407
408 updates.insert((
409 dep.loc.end + args_length,
410 dep.loc.end + args_length,
411 format!("){closing_seq}"),
412 ));
413 } else {
414 update.push_str(closing_seq);
415 updates.insert((dep.loc.start, dep.loc.end + args_length, update));
416 }
417 }
418 };
419 }
420
421 if !try_catch_helpers.is_empty()
423 && let Some(last_fn_id) = contract.functions().last()
424 {
425 let last_fn_range =
426 span_to_range(gcx.sess.source_map(), gcx.hir.function(last_fn_id).span);
427 let to_address_fns = try_catch_helpers
428 .iter()
429 .map(|ty| {
430 format!(
431 r#"
432 function addressTo{ty}{id}(address addr) public pure returns ({ty}) {{
433 return {ty}(addr);
434 }}
435 "#,
436 id = contract_id.get()
437 )
438 })
439 .collect::<String>();
440
441 updates.insert((last_fn_range.end, last_fn_range.end, to_address_fns));
442 }
443
444 let helper_imports = used_helpers.into_iter().map(|id| {
445 let id = id.get();
446 format!(
447 "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
448 )
449 }).join("\n");
450 updates.insert((
451 source.file.src.len(),
452 source.file.src.len(),
453 format!(
454 r#"
455{helper_imports}
456
457interface {vm_interface_name} {{
458 function deployCode(string memory _artifact) external returns (address);
459 function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
460 function deployCode(string memory _artifact, bytes memory _args) external returns (address);
461 function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
462 function deployCode(string memory _artifact, uint256 _value) external returns (address);
463 function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
464 function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
465 function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
466 function getCode(string memory _artifact) external view returns (bytes memory);
467}}"#
468 ),
469 ));
470 }
471 updates
472}