foundry_common/preprocessor/
deps.rs
1use super::{
2 data::{ContractData, PreprocessorData},
3 span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar_parse::interface::Session;
8use solar_sema::{
9 hir::{CallArgs, ContractId, Expr, ExprKind, Hir, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10 interface::{data_structures::Never, source_map::FileName, SourceMap},
11};
12use std::{
13 collections::{BTreeMap, BTreeSet, HashSet},
14 ops::{ControlFlow, Range},
15 path::{Path, PathBuf},
16};
17
18pub(crate) struct PreprocessorDependencies {
20 pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22 pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27 pub fn new(
28 sess: &Session,
29 hir: &Hir<'_>,
30 paths: &[PathBuf],
31 src_dir: &Path,
32 root_dir: &Path,
33 mocks: &mut HashSet<PathBuf>,
34 ) -> Self {
35 let mut preprocessed_contracts = BTreeMap::new();
36 let mut referenced_contracts = HashSet::new();
37 for contract_id in hir.contract_ids() {
38 let contract = hir.contract(contract_id);
39 let source = hir.source(contract.source);
40
41 let FileName::Real(path) = &source.file.name else {
42 continue;
43 };
44
45 if !paths.contains(path) {
47 let path = path.display();
48 trace!("{path} is not test or script");
49 continue;
50 }
51
52 if contract.linearized_bases.iter().any(|base_contract_id| {
55 let base_contract = hir.contract(*base_contract_id);
56 let FileName::Real(path) = &hir.source(base_contract.source).file.name else {
57 return false;
58 };
59 path.starts_with(src_dir)
60 }) {
61 mocks.insert(root_dir.join(path));
63 let path = path.display();
64 trace!("found mock contract {path}");
65 continue;
66 } else {
67 mocks.remove(&root_dir.join(path));
70 }
71
72 let mut deps_collector = BytecodeDependencyCollector::new(
73 sess.source_map(),
74 hir,
75 source.file.src.as_str(),
76 src_dir,
77 );
78 let _ = deps_collector.walk_contract(contract);
80 if !deps_collector.dependencies.is_empty() {
82 preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
83 }
84 referenced_contracts.extend(deps_collector.referenced_contracts);
86 }
87 Self { preprocessed_contracts, referenced_contracts }
88 }
89}
90
91#[derive(Debug)]
93enum BytecodeDependencyKind {
94 CreationCode,
96 New {
98 name: String,
100 args_length: usize,
102 call_args_offset: usize,
104 value: Option<String>,
106 salt: Option<String>,
108 try_stmt: bool,
110 },
111}
112
113#[derive(Debug)]
115pub(crate) struct BytecodeDependency {
116 kind: BytecodeDependencyKind,
118 loc: Range<usize>,
120 referenced_contract: ContractId,
122}
123
124struct BytecodeDependencyCollector<'hir> {
126 source_map: &'hir SourceMap,
128 hir: &'hir Hir<'hir>,
130 src: &'hir str,
132 src_dir: &'hir Path,
134 dependencies: Vec<BytecodeDependency>,
136 referenced_contracts: HashSet<ContractId>,
138}
139
140impl<'hir> BytecodeDependencyCollector<'hir> {
141 fn new(
142 source_map: &'hir SourceMap,
143 hir: &'hir Hir<'hir>,
144 src: &'hir str,
145 src_dir: &'hir Path,
146 ) -> Self {
147 Self {
148 source_map,
149 hir,
150 src,
151 src_dir,
152 dependencies: vec![],
153 referenced_contracts: HashSet::default(),
154 }
155 }
156
157 fn collect_dependency(&mut self, dependency: BytecodeDependency) {
161 let contract = self.hir.contract(dependency.referenced_contract);
162 let source = self.hir.source(contract.source);
163 let FileName::Real(path) = &source.file.name else {
164 return;
165 };
166
167 if !path.starts_with(self.src_dir) {
168 let path = path.display();
169 trace!("ignore dependency {path}");
170 return;
171 }
172
173 self.referenced_contracts.insert(dependency.referenced_contract);
174 self.dependencies.push(dependency);
175 }
176}
177
178impl<'hir> Visit<'hir> for BytecodeDependencyCollector<'hir> {
179 type BreakValue = Never;
180
181 fn hir(&self) -> &'hir Hir<'hir> {
182 self.hir
183 }
184
185 fn visit_expr(&mut self, expr: &'hir Expr<'hir>) -> ControlFlow<Self::BreakValue> {
186 match &expr.kind {
187 ExprKind::Call(call_expr, call_args, named_args) => {
188 if let Some(dependency) = handle_call_expr(
189 self.src,
190 self.source_map,
191 expr,
192 call_expr,
193 call_args,
194 named_args,
195 false,
196 ) {
197 self.collect_dependency(dependency);
198 }
199 }
200 ExprKind::Member(member_expr, ident) => {
201 if let ExprKind::TypeCall(ty) = &member_expr.kind {
202 if let TypeKind::Custom(contract_id) = &ty.kind {
203 if ident.name.as_str() == "creationCode" {
204 if let Some(contract_id) = contract_id.as_contract() {
205 self.collect_dependency(BytecodeDependency {
206 kind: BytecodeDependencyKind::CreationCode,
207 loc: span_to_range(self.source_map, expr.span),
208 referenced_contract: contract_id,
209 });
210 }
211 }
212 }
213 }
214 }
215 _ => {}
216 }
217 self.walk_expr(expr)
218 }
219
220 fn visit_stmt(&mut self, stmt: &'hir Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
221 if let StmtKind::Try(stmt_try) = stmt.kind {
222 if let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind {
223 if let Some(dependency) = handle_call_expr(
224 self.src,
225 self.source_map,
226 &stmt_try.expr,
227 call_expr,
228 call_args,
229 named_args,
230 true,
231 ) {
232 self.collect_dependency(dependency);
233 for clause in stmt_try.clauses {
234 for &var in clause.args {
235 self.visit_nested_var(var)?;
236 }
237 for stmt in clause.block {
238 self.visit_stmt(stmt)?;
239 }
240 }
241 return ControlFlow::Continue(());
242 }
243 }
244 }
245 self.walk_stmt(stmt)
246 }
247}
248
249fn handle_call_expr(
251 src: &str,
252 source_map: &SourceMap,
253 parent_expr: &Expr<'_>,
254 call_expr: &Expr<'_>,
255 call_args: &CallArgs<'_>,
256 named_args: &Option<&[NamedArg<'_>]>,
257 try_stmt: bool,
258) -> Option<BytecodeDependency> {
259 if let ExprKind::New(ty_new) = &call_expr.kind {
260 if let TypeKind::Custom(item_id) = ty_new.kind {
261 if let Some(contract_id) = item_id.as_contract() {
262 let name_loc = span_to_range(source_map, ty_new.span);
263 let name = &src[name_loc];
264
265 let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
269 (call_args.span.lo() - ty_new.span.hi()).to_usize()
270 } else {
271 0
272 };
273
274 let args_len = parent_expr.span.hi() - ty_new.span.hi();
275 return Some(BytecodeDependency {
276 kind: BytecodeDependencyKind::New {
277 name: name.to_string(),
278 args_length: args_len.to_usize(),
279 call_args_offset,
280 value: named_arg(src, named_args, "value", source_map),
281 salt: named_arg(src, named_args, "salt", source_map),
282 try_stmt,
283 },
284 loc: span_to_range(source_map, call_expr.span),
285 referenced_contract: contract_id,
286 })
287 }
288 }
289 }
290 None
291}
292
293fn named_arg(
295 src: &str,
296 named_args: &Option<&[NamedArg<'_>]>,
297 arg: &str,
298 source_map: &SourceMap,
299) -> Option<String> {
300 named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
301 |named_arg| {
302 let named_arg_loc = span_to_range(source_map, named_arg.value.span);
303 src[named_arg_loc].to_string()
304 },
305 )
306}
307
308pub(crate) fn remove_bytecode_dependencies(
311 hir: &Hir<'_>,
312 deps: &PreprocessorDependencies,
313 data: &PreprocessorData,
314) -> Updates {
315 let mut updates = Updates::default();
316 for (contract_id, deps) in &deps.preprocessed_contracts {
317 let contract = hir.contract(*contract_id);
318 let source = hir.source(contract.source);
319 let FileName::Real(path) = &source.file.name else {
320 continue;
321 };
322
323 let updates = updates.entry(path.clone()).or_default();
324 let mut used_helpers = BTreeSet::new();
325
326 let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
327 let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
329
330 for dep in deps {
331 let Some(ContractData { artifact, constructor_data, .. }) =
332 data.get(&dep.referenced_contract)
333 else {
334 continue;
335 };
336
337 match &dep.kind {
338 BytecodeDependencyKind::CreationCode => {
339 updates.insert((
341 dep.loc.start,
342 dep.loc.end,
343 format!("{vm}.getCode(\"{artifact}\")"),
344 ));
345 }
346 BytecodeDependencyKind::New {
347 name,
348 args_length,
349 call_args_offset,
350 value,
351 salt,
352 try_stmt,
353 } => {
354 let (mut update, closing_seq) = if *try_stmt {
355 (String::new(), "})")
356 } else {
357 (format!("{name}(payable("), "})))")
358 };
359 update.push_str(&format!("{vm}.deployCode({{"));
360 update.push_str(&format!("_artifact: \"{artifact}\""));
361
362 if let Some(value) = value {
363 update.push_str(", ");
364 update.push_str(&format!("_value: {value}"));
365 }
366
367 if let Some(salt) = salt {
368 update.push_str(", ");
369 update.push_str(&format!("_salt: {salt}"));
370 }
371
372 if constructor_data.is_some() {
373 used_helpers.insert(dep.referenced_contract);
375
376 update.push_str(", ");
377 update.push_str(&format!(
378 "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
379 id = dep.referenced_contract.get()
380 ));
381 updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
382
383 updates.insert((
384 dep.loc.end + args_length,
385 dep.loc.end + args_length,
386 format!("){closing_seq}"),
387 ));
388 } else {
389 update.push_str(closing_seq);
390 updates.insert((dep.loc.start, dep.loc.end + args_length, update));
391 }
392 }
393 };
394 }
395 let helper_imports = used_helpers.into_iter().map(|id| {
396 let id = id.get();
397 format!(
398 "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
399 )
400 }).join("\n");
401 updates.insert((
402 source.file.src.len(),
403 source.file.src.len(),
404 format!(
405 r#"
406{helper_imports}
407
408interface {vm_interface_name} {{
409 function deployCode(string memory _artifact) external returns (address);
410 function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
411 function deployCode(string memory _artifact, bytes memory _args) external returns (address);
412 function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
413 function deployCode(string memory _artifact, uint256 _value) external returns (address);
414 function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
415 function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
416 function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
417 function getCode(string memory _artifact) external returns (bytes memory);
418}}"#
419 ),
420 ));
421 }
422 updates
423}