foundry_common/preprocessor/
deps.rs
1use super::{
2 data::{ContractData, PreprocessorData},
3 span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar_parse::interface::Session;
8use solar_sema::{
9 hir::{CallArgs, ContractId, Expr, ExprKind, Hir, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10 interface::{SourceMap, data_structures::Never, source_map::FileName},
11};
12use std::{
13 collections::{BTreeMap, BTreeSet, HashSet},
14 ops::{ControlFlow, Range},
15 path::{Path, PathBuf},
16};
17
18pub(crate) struct PreprocessorDependencies {
20 pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22 pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27 pub fn new(
28 sess: &Session,
29 hir: &Hir<'_>,
30 paths: &[PathBuf],
31 src_dir: &Path,
32 root_dir: &Path,
33 mocks: &mut HashSet<PathBuf>,
34 ) -> Self {
35 let mut preprocessed_contracts = BTreeMap::new();
36 let mut referenced_contracts = HashSet::new();
37 for contract_id in hir.contract_ids() {
38 let contract = hir.contract(contract_id);
39 let source = hir.source(contract.source);
40
41 let FileName::Real(path) = &source.file.name else {
42 continue;
43 };
44
45 if !paths.contains(path) {
47 let path = path.display();
48 trace!("{path} is not test or script");
49 continue;
50 }
51
52 if contract.linearized_bases.iter().any(|base_contract_id| {
55 let base_contract = hir.contract(*base_contract_id);
56 let FileName::Real(path) = &hir.source(base_contract.source).file.name else {
57 return false;
58 };
59 path.starts_with(src_dir)
60 }) {
61 mocks.insert(root_dir.join(path));
63 let path = path.display();
64 trace!("found mock contract {path}");
65 continue;
66 } else {
67 mocks.remove(&root_dir.join(path));
70 }
71
72 let mut deps_collector = BytecodeDependencyCollector::new(
73 sess.source_map(),
74 hir,
75 source.file.src.as_str(),
76 src_dir,
77 );
78 let _ = deps_collector.walk_contract(contract);
80 if !deps_collector.dependencies.is_empty() {
82 preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
83 }
84 referenced_contracts.extend(deps_collector.referenced_contracts);
86 }
87 Self { preprocessed_contracts, referenced_contracts }
88 }
89}
90
91#[derive(Debug)]
93enum BytecodeDependencyKind {
94 CreationCode,
96 New {
98 name: String,
100 args_length: usize,
102 call_args_offset: usize,
104 value: Option<String>,
106 salt: Option<String>,
108 try_stmt: bool,
110 },
111}
112
113#[derive(Debug)]
115pub(crate) struct BytecodeDependency {
116 kind: BytecodeDependencyKind,
118 loc: Range<usize>,
120 referenced_contract: ContractId,
122}
123
124struct BytecodeDependencyCollector<'hir> {
126 source_map: &'hir SourceMap,
128 hir: &'hir Hir<'hir>,
130 src: &'hir str,
132 src_dir: &'hir Path,
134 dependencies: Vec<BytecodeDependency>,
136 referenced_contracts: HashSet<ContractId>,
138}
139
140impl<'hir> BytecodeDependencyCollector<'hir> {
141 fn new(
142 source_map: &'hir SourceMap,
143 hir: &'hir Hir<'hir>,
144 src: &'hir str,
145 src_dir: &'hir Path,
146 ) -> Self {
147 Self {
148 source_map,
149 hir,
150 src,
151 src_dir,
152 dependencies: vec![],
153 referenced_contracts: HashSet::default(),
154 }
155 }
156
157 fn collect_dependency(&mut self, dependency: BytecodeDependency) {
161 let contract = self.hir.contract(dependency.referenced_contract);
162 let source = self.hir.source(contract.source);
163 let FileName::Real(path) = &source.file.name else {
164 return;
165 };
166
167 if !path.starts_with(self.src_dir) {
168 let path = path.display();
169 trace!("ignore dependency {path}");
170 return;
171 }
172
173 self.referenced_contracts.insert(dependency.referenced_contract);
174 self.dependencies.push(dependency);
175 }
176}
177
178impl<'hir> Visit<'hir> for BytecodeDependencyCollector<'hir> {
179 type BreakValue = Never;
180
181 fn hir(&self) -> &'hir Hir<'hir> {
182 self.hir
183 }
184
185 fn visit_expr(&mut self, expr: &'hir Expr<'hir>) -> ControlFlow<Self::BreakValue> {
186 match &expr.kind {
187 ExprKind::Call(call_expr, call_args, named_args) => {
188 if let Some(dependency) = handle_call_expr(
189 self.src,
190 self.source_map,
191 expr,
192 call_expr,
193 call_args,
194 named_args,
195 false,
196 ) {
197 self.collect_dependency(dependency);
198 }
199 }
200 ExprKind::Member(member_expr, ident) => {
201 if let ExprKind::TypeCall(ty) = &member_expr.kind
202 && let TypeKind::Custom(contract_id) = &ty.kind
203 && ident.name.as_str() == "creationCode"
204 && let Some(contract_id) = contract_id.as_contract()
205 {
206 self.collect_dependency(BytecodeDependency {
207 kind: BytecodeDependencyKind::CreationCode,
208 loc: span_to_range(self.source_map, expr.span),
209 referenced_contract: contract_id,
210 });
211 }
212 }
213 _ => {}
214 }
215 self.walk_expr(expr)
216 }
217
218 fn visit_stmt(&mut self, stmt: &'hir Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
219 if let StmtKind::Try(stmt_try) = stmt.kind
220 && let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
221 && let Some(dependency) = handle_call_expr(
222 self.src,
223 self.source_map,
224 &stmt_try.expr,
225 call_expr,
226 call_args,
227 named_args,
228 true,
229 )
230 {
231 self.collect_dependency(dependency);
232 for clause in stmt_try.clauses {
233 for &var in clause.args {
234 self.visit_nested_var(var)?;
235 }
236 for stmt in clause.block.stmts {
237 self.visit_stmt(stmt)?;
238 }
239 }
240 return ControlFlow::Continue(());
241 }
242 self.walk_stmt(stmt)
243 }
244}
245
246fn handle_call_expr(
248 src: &str,
249 source_map: &SourceMap,
250 parent_expr: &Expr<'_>,
251 call_expr: &Expr<'_>,
252 call_args: &CallArgs<'_>,
253 named_args: &Option<&[NamedArg<'_>]>,
254 try_stmt: bool,
255) -> Option<BytecodeDependency> {
256 if let ExprKind::New(ty_new) = &call_expr.kind
257 && let TypeKind::Custom(item_id) = ty_new.kind
258 && let Some(contract_id) = item_id.as_contract()
259 {
260 let name_loc = span_to_range(source_map, ty_new.span);
261 let name = &src[name_loc];
262
263 let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
267 (call_args.span.lo() - ty_new.span.hi()).to_usize()
268 } else {
269 0
270 };
271
272 let args_len = parent_expr.span.hi() - ty_new.span.hi();
273 return Some(BytecodeDependency {
274 kind: BytecodeDependencyKind::New {
275 name: name.to_string(),
276 args_length: args_len.to_usize(),
277 call_args_offset,
278 value: named_arg(src, named_args, "value", source_map),
279 salt: named_arg(src, named_args, "salt", source_map),
280 try_stmt,
281 },
282 loc: span_to_range(source_map, call_expr.span),
283 referenced_contract: contract_id,
284 });
285 }
286 None
287}
288
289fn named_arg(
291 src: &str,
292 named_args: &Option<&[NamedArg<'_>]>,
293 arg: &str,
294 source_map: &SourceMap,
295) -> Option<String> {
296 named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
297 |named_arg| {
298 let named_arg_loc = span_to_range(source_map, named_arg.value.span);
299 src[named_arg_loc].to_string()
300 },
301 )
302}
303
304pub(crate) fn remove_bytecode_dependencies(
307 hir: &Hir<'_>,
308 deps: &PreprocessorDependencies,
309 data: &PreprocessorData,
310) -> Updates {
311 let mut updates = Updates::default();
312 for (contract_id, deps) in &deps.preprocessed_contracts {
313 let contract = hir.contract(*contract_id);
314 let source = hir.source(contract.source);
315 let FileName::Real(path) = &source.file.name else {
316 continue;
317 };
318
319 let updates = updates.entry(path.clone()).or_default();
320 let mut used_helpers = BTreeSet::new();
321
322 let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
323 let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
325
326 for dep in deps {
327 let Some(ContractData { artifact, constructor_data, .. }) =
328 data.get(&dep.referenced_contract)
329 else {
330 continue;
331 };
332
333 match &dep.kind {
334 BytecodeDependencyKind::CreationCode => {
335 updates.insert((
337 dep.loc.start,
338 dep.loc.end,
339 format!("{vm}.getCode(\"{artifact}\")"),
340 ));
341 }
342 BytecodeDependencyKind::New {
343 name,
344 args_length,
345 call_args_offset,
346 value,
347 salt,
348 try_stmt,
349 } => {
350 let (mut update, closing_seq) = if *try_stmt {
351 (String::new(), "})")
352 } else {
353 (format!("{name}(payable("), "})))")
354 };
355 update.push_str(&format!("{vm}.deployCode({{"));
356 update.push_str(&format!("_artifact: \"{artifact}\""));
357
358 if let Some(value) = value {
359 update.push_str(", ");
360 update.push_str(&format!("_value: {value}"));
361 }
362
363 if let Some(salt) = salt {
364 update.push_str(", ");
365 update.push_str(&format!("_salt: {salt}"));
366 }
367
368 if constructor_data.is_some() {
369 used_helpers.insert(dep.referenced_contract);
371
372 update.push_str(", ");
373 update.push_str(&format!(
374 "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
375 id = dep.referenced_contract.get()
376 ));
377 updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
378
379 updates.insert((
380 dep.loc.end + args_length,
381 dep.loc.end + args_length,
382 format!("){closing_seq}"),
383 ));
384 } else {
385 update.push_str(closing_seq);
386 updates.insert((dep.loc.start, dep.loc.end + args_length, update));
387 }
388 }
389 };
390 }
391 let helper_imports = used_helpers.into_iter().map(|id| {
392 let id = id.get();
393 format!(
394 "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
395 )
396 }).join("\n");
397 updates.insert((
398 source.file.src.len(),
399 source.file.src.len(),
400 format!(
401 r#"
402{helper_imports}
403
404interface {vm_interface_name} {{
405 function deployCode(string memory _artifact) external returns (address);
406 function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
407 function deployCode(string memory _artifact, bytes memory _args) external returns (address);
408 function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
409 function deployCode(string memory _artifact, uint256 _value) external returns (address);
410 function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
411 function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
412 function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
413 function getCode(string memory _artifact) external returns (bytes memory);
414}}"#
415 ),
416 ));
417 }
418 updates
419}