foundry_common/preprocessor/
deps.rs1use super::{
2 data::{ContractData, PreprocessorData},
3 span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar::sema::{
8 Gcx, Hir,
9 hir::{CallArgs, ContractId, Expr, ExprKind, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10 interface::{SourceMap, data_structures::Never, source_map::FileName},
11};
12use std::{
13 collections::{BTreeMap, BTreeSet, HashSet},
14 ops::{ControlFlow, Range},
15 path::{Path, PathBuf},
16};
17
18pub(crate) struct PreprocessorDependencies {
20 pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22 pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27 pub fn new(
28 gcx: Gcx<'_>,
29 paths: &[PathBuf],
30 src_dir: &Path,
31 root_dir: &Path,
32 mocks: &mut HashSet<PathBuf>,
33 ) -> Self {
34 let mut preprocessed_contracts = BTreeMap::new();
35 let mut referenced_contracts = HashSet::new();
36 for contract_id in gcx.hir.contract_ids() {
37 let contract = gcx.hir.contract(contract_id);
38 let source = gcx.hir.source(contract.source);
39
40 let FileName::Real(path) = &source.file.name else {
41 continue;
42 };
43
44 if !paths.contains(path) {
46 let path = path.display();
47 trace!("{path} is not test or script");
48 continue;
49 }
50
51 if contract.linearized_bases.iter().any(|base_contract_id| {
54 let base_contract = gcx.hir.contract(*base_contract_id);
55 let FileName::Real(path) = &gcx.hir.source(base_contract.source).file.name else {
56 return false;
57 };
58 path.starts_with(src_dir)
59 }) {
60 mocks.insert(root_dir.join(path));
62 let path = path.display();
63 trace!("found mock contract {path}");
64 continue;
65 } else {
66 mocks.remove(&root_dir.join(path));
69 }
70
71 let mut deps_collector =
72 BytecodeDependencyCollector::new(gcx, source.file.src.as_str(), src_dir);
73 let _ = deps_collector.walk_contract(contract);
75 if !deps_collector.dependencies.is_empty() {
77 preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
78 }
79 referenced_contracts.extend(deps_collector.referenced_contracts);
81 }
82 Self { preprocessed_contracts, referenced_contracts }
83 }
84}
85
86#[derive(Debug)]
88enum BytecodeDependencyKind {
89 CreationCode,
91 New {
93 name: String,
95 args_length: usize,
97 call_args_offset: usize,
99 value: Option<String>,
101 salt: Option<String>,
103 try_stmt: Option<bool>,
105 },
106}
107
108#[derive(Debug)]
110pub(crate) struct BytecodeDependency {
111 kind: BytecodeDependencyKind,
113 loc: Range<usize>,
115 referenced_contract: ContractId,
117}
118
119struct BytecodeDependencyCollector<'gcx, 'src> {
121 gcx: Gcx<'gcx>,
123 src: &'src str,
125 src_dir: &'src Path,
127 dependencies: Vec<BytecodeDependency>,
129 referenced_contracts: HashSet<ContractId>,
131}
132
133impl<'gcx, 'src> BytecodeDependencyCollector<'gcx, 'src> {
134 fn new(gcx: Gcx<'gcx>, src: &'src str, src_dir: &'src Path) -> Self {
135 Self { gcx, src, src_dir, dependencies: vec![], referenced_contracts: HashSet::default() }
136 }
137
138 fn collect_dependency(&mut self, dependency: BytecodeDependency) {
142 let contract = self.gcx.hir.contract(dependency.referenced_contract);
143 let source = self.gcx.hir.source(contract.source);
144 let FileName::Real(path) = &source.file.name else {
145 return;
146 };
147
148 if !path.starts_with(self.src_dir) {
149 let path = path.display();
150 trace!("ignore dependency {path}");
151 return;
152 }
153
154 self.referenced_contracts.insert(dependency.referenced_contract);
155 self.dependencies.push(dependency);
156 }
157}
158
159impl<'gcx> Visit<'gcx> for BytecodeDependencyCollector<'gcx, '_> {
160 type BreakValue = Never;
161
162 fn hir(&self) -> &'gcx Hir<'gcx> {
163 &self.gcx.hir
164 }
165
166 fn visit_expr(&mut self, expr: &'gcx Expr<'gcx>) -> ControlFlow<Self::BreakValue> {
167 match &expr.kind {
168 ExprKind::Call(call_expr, call_args, named_args) => {
169 if let Some(dependency) = handle_call_expr(
170 self.src,
171 self.gcx.sess.source_map(),
172 expr,
173 call_expr,
174 call_args,
175 named_args,
176 ) {
177 self.collect_dependency(dependency);
178 }
179 }
180 ExprKind::Member(member_expr, ident) => {
181 if let ExprKind::TypeCall(ty) = &member_expr.kind
182 && let TypeKind::Custom(contract_id) = &ty.kind
183 && ident.name.as_str() == "creationCode"
184 && let Some(contract_id) = contract_id.as_contract()
185 {
186 self.collect_dependency(BytecodeDependency {
187 kind: BytecodeDependencyKind::CreationCode,
188 loc: span_to_range(self.gcx.sess.source_map(), expr.span),
189 referenced_contract: contract_id,
190 });
191 }
192 }
193 _ => {}
194 }
195 self.walk_expr(expr)
196 }
197
198 fn visit_stmt(&mut self, stmt: &'gcx Stmt<'gcx>) -> ControlFlow<Self::BreakValue> {
199 if let StmtKind::Try(stmt_try) = stmt.kind
200 && let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
201 && let Some(mut dependency) = handle_call_expr(
202 self.src,
203 self.gcx.sess.source_map(),
204 &stmt_try.expr,
205 call_expr,
206 call_args,
207 named_args,
208 )
209 {
210 let has_custom_return = if let Some(clause) = stmt_try.clauses.first()
211 && clause.args.len() == 1
212 && let Some(ret_var) = clause.args.first()
213 && let TypeKind::Custom(_) = self.hir().variable(*ret_var).ty.kind
214 {
215 true
216 } else {
217 false
218 };
219
220 if let BytecodeDependencyKind::New { try_stmt, .. } = &mut dependency.kind {
221 *try_stmt = Some(has_custom_return);
222 }
223 self.collect_dependency(dependency);
224
225 for clause in stmt_try.clauses {
226 for &var in clause.args {
227 self.visit_nested_var(var)?;
228 }
229 for stmt in clause.block.stmts {
230 self.visit_stmt(stmt)?;
231 }
232 }
233 return ControlFlow::Continue(());
234 }
235 self.walk_stmt(stmt)
236 }
237}
238
239fn handle_call_expr(
241 src: &str,
242 source_map: &SourceMap,
243 parent_expr: &Expr<'_>,
244 call_expr: &Expr<'_>,
245 call_args: &CallArgs<'_>,
246 named_args: &Option<&[NamedArg<'_>]>,
247) -> Option<BytecodeDependency> {
248 if let ExprKind::New(ty_new) = &call_expr.kind
249 && let TypeKind::Custom(item_id) = ty_new.kind
250 && let Some(contract_id) = item_id.as_contract()
251 {
252 let name_loc = span_to_range(source_map, ty_new.span);
253 let name = &src[name_loc];
254
255 let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
259 (call_args.span.lo() - ty_new.span.hi()).to_usize()
260 } else {
261 0
262 };
263
264 let args_len = parent_expr.span.hi() - ty_new.span.hi();
265 return Some(BytecodeDependency {
266 kind: BytecodeDependencyKind::New {
267 name: name.to_string(),
268 args_length: args_len.to_usize(),
269 call_args_offset,
270 value: named_arg(src, named_args, "value", source_map),
271 salt: named_arg(src, named_args, "salt", source_map),
272 try_stmt: None,
273 },
274 loc: span_to_range(source_map, call_expr.span),
275 referenced_contract: contract_id,
276 });
277 }
278 None
279}
280
281fn named_arg(
283 src: &str,
284 named_args: &Option<&[NamedArg<'_>]>,
285 arg: &str,
286 source_map: &SourceMap,
287) -> Option<String> {
288 named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
289 |named_arg| {
290 let named_arg_loc = span_to_range(source_map, named_arg.value.span);
291 src[named_arg_loc].to_string()
292 },
293 )
294}
295
296pub(crate) fn remove_bytecode_dependencies(
310 gcx: Gcx<'_>,
311 deps: &PreprocessorDependencies,
312 data: &PreprocessorData,
313) -> Updates {
314 let mut updates = Updates::default();
315 for (contract_id, deps) in &deps.preprocessed_contracts {
316 let contract = gcx.hir.contract(*contract_id);
317 let source = gcx.hir.source(contract.source);
318 let FileName::Real(path) = &source.file.name else {
319 continue;
320 };
321
322 let updates = updates.entry(path.clone()).or_default();
323 let mut used_helpers = BTreeSet::new();
324
325 let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
326 let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
328 let mut try_catch_helpers: HashSet<&str> = HashSet::default();
329
330 for dep in deps {
331 let Some(ContractData { artifact, constructor_data, .. }) =
332 data.get(&dep.referenced_contract)
333 else {
334 continue;
335 };
336
337 match &dep.kind {
338 BytecodeDependencyKind::CreationCode => {
339 updates.insert((
341 dep.loc.start,
342 dep.loc.end,
343 format!("{vm}.getCode(\"{artifact}\")"),
344 ));
345 }
346 BytecodeDependencyKind::New {
347 name,
348 args_length,
349 call_args_offset,
350 value,
351 salt,
352 try_stmt,
353 } => {
354 let (mut update, closing_seq) = if let Some(has_ret) = try_stmt {
355 if *has_ret {
356 try_catch_helpers.insert(name);
358 (format!("this.addressTo{name}{id}(", id = contract_id.get()), "}))")
359 } else {
360 (String::new(), "})")
361 }
362 } else {
363 (format!("{name}(payable("), "})))")
364 };
365 update.push_str(&format!("{vm}.deployCode({{"));
366 update.push_str(&format!("_artifact: \"{artifact}\""));
367
368 if let Some(value) = value {
369 update.push_str(", ");
370 update.push_str(&format!("_value: {value}"));
371 }
372
373 if let Some(salt) = salt {
374 update.push_str(", ");
375 update.push_str(&format!("_salt: {salt}"));
376 }
377
378 if constructor_data.is_some() {
379 used_helpers.insert(dep.referenced_contract);
381
382 update.push_str(", ");
383 update.push_str(&format!(
384 "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
385 id = dep.referenced_contract.get()
386 ));
387 updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
388
389 updates.insert((
390 dep.loc.end + args_length,
391 dep.loc.end + args_length,
392 format!("){closing_seq}"),
393 ));
394 } else {
395 update.push_str(closing_seq);
396 updates.insert((dep.loc.start, dep.loc.end + args_length, update));
397 }
398 }
399 };
400 }
401
402 if !try_catch_helpers.is_empty()
404 && let Some(last_fn_id) = contract.functions().last()
405 {
406 let last_fn_range =
407 span_to_range(gcx.sess.source_map(), gcx.hir.function(last_fn_id).span);
408 let to_address_fns = try_catch_helpers
409 .iter()
410 .map(|ty| {
411 format!(
412 r#"
413 function addressTo{ty}{id}(address addr) public pure returns ({ty}) {{
414 return {ty}(addr);
415 }}
416 "#,
417 id = contract_id.get()
418 )
419 })
420 .collect::<String>();
421
422 updates.insert((last_fn_range.end, last_fn_range.end, to_address_fns));
423 }
424
425 let helper_imports = used_helpers.into_iter().map(|id| {
426 let id = id.get();
427 format!(
428 "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
429 )
430 }).join("\n");
431 updates.insert((
432 source.file.src.len(),
433 source.file.src.len(),
434 format!(
435 r#"
436{helper_imports}
437
438interface {vm_interface_name} {{
439 function deployCode(string memory _artifact) external returns (address);
440 function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
441 function deployCode(string memory _artifact, bytes memory _args) external returns (address);
442 function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
443 function deployCode(string memory _artifact, uint256 _value) external returns (address);
444 function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
445 function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
446 function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
447 function getCode(string memory _artifact) external returns (bytes memory);
448}}"#
449 ),
450 ));
451 }
452 updates
453}