foundry_common/preprocessor/
deps.rs1use super::{
2 data::{ContractData, PreprocessorData},
3 span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar::sema::{
8 Gcx, Hir,
9 hir::{CallArgs, ContractId, Expr, ExprKind, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10 interface::{SourceMap, data_structures::Never, source_map::FileName},
11};
12use std::{
13 collections::{BTreeMap, BTreeSet, HashSet},
14 ops::{ControlFlow, Range},
15 path::{Path, PathBuf},
16};
17
18pub(crate) struct PreprocessorDependencies {
20 pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22 pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27 pub fn new(
28 gcx: Gcx<'_>,
29 paths: &[PathBuf],
30 src_dir: &Path,
31 root_dir: &Path,
32 mocks: &mut HashSet<PathBuf>,
33 ) -> Self {
34 let mut preprocessed_contracts = BTreeMap::new();
35 let mut referenced_contracts = HashSet::new();
36 for contract_id in gcx.hir.contract_ids() {
37 let contract = gcx.hir.contract(contract_id);
38 let source = gcx.hir.source(contract.source);
39
40 let FileName::Real(path) = &source.file.name else {
41 continue;
42 };
43
44 if !paths.contains(path) {
46 let path = path.display();
47 trace!("{path} is not test or script");
48 continue;
49 }
50
51 if contract.linearized_bases.iter().any(|base_contract_id| {
54 let base_contract = gcx.hir.contract(*base_contract_id);
55 let FileName::Real(path) = &gcx.hir.source(base_contract.source).file.name else {
56 return false;
57 };
58 path.starts_with(src_dir)
59 }) {
60 mocks.insert(root_dir.join(path));
62 let path = path.display();
63 trace!("found mock contract {path}");
64 continue;
65 } else {
66 mocks.remove(&root_dir.join(path));
69 }
70
71 let mut deps_collector =
72 BytecodeDependencyCollector::new(gcx, source.file.src.as_str(), src_dir);
73 let _ = deps_collector.walk_contract(contract);
75 if !deps_collector.dependencies.is_empty() {
77 preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
78 }
79 referenced_contracts.extend(deps_collector.referenced_contracts);
81 }
82 Self { preprocessed_contracts, referenced_contracts }
83 }
84}
85
86#[derive(Debug)]
88enum BytecodeDependencyKind {
89 CreationCode,
91 New {
93 name: String,
95 args_length: usize,
97 call_args_offset: usize,
99 value: Option<String>,
101 salt: Option<String>,
103 try_stmt: bool,
105 },
106}
107
108#[derive(Debug)]
110pub(crate) struct BytecodeDependency {
111 kind: BytecodeDependencyKind,
113 loc: Range<usize>,
115 referenced_contract: ContractId,
117}
118
119struct BytecodeDependencyCollector<'gcx, 'src> {
121 gcx: Gcx<'gcx>,
123 src: &'src str,
125 src_dir: &'src Path,
127 dependencies: Vec<BytecodeDependency>,
129 referenced_contracts: HashSet<ContractId>,
131}
132
133impl<'gcx, 'src> BytecodeDependencyCollector<'gcx, 'src> {
134 fn new(gcx: Gcx<'gcx>, src: &'src str, src_dir: &'src Path) -> Self {
135 Self { gcx, src, src_dir, dependencies: vec![], referenced_contracts: HashSet::default() }
136 }
137
138 fn collect_dependency(&mut self, dependency: BytecodeDependency) {
142 let contract = self.gcx.hir.contract(dependency.referenced_contract);
143 let source = self.gcx.hir.source(contract.source);
144 let FileName::Real(path) = &source.file.name else {
145 return;
146 };
147
148 if !path.starts_with(self.src_dir) {
149 let path = path.display();
150 trace!("ignore dependency {path}");
151 return;
152 }
153
154 self.referenced_contracts.insert(dependency.referenced_contract);
155 self.dependencies.push(dependency);
156 }
157}
158
159impl<'gcx> Visit<'gcx> for BytecodeDependencyCollector<'gcx, '_> {
160 type BreakValue = Never;
161
162 fn hir(&self) -> &'gcx Hir<'gcx> {
163 &self.gcx.hir
164 }
165
166 fn visit_expr(&mut self, expr: &'gcx Expr<'gcx>) -> ControlFlow<Self::BreakValue> {
167 match &expr.kind {
168 ExprKind::Call(call_expr, call_args, named_args) => {
169 if let Some(dependency) = handle_call_expr(
170 self.src,
171 self.gcx.sess.source_map(),
172 expr,
173 call_expr,
174 call_args,
175 named_args,
176 false,
177 ) {
178 self.collect_dependency(dependency);
179 }
180 }
181 ExprKind::Member(member_expr, ident) => {
182 if let ExprKind::TypeCall(ty) = &member_expr.kind
183 && let TypeKind::Custom(contract_id) = &ty.kind
184 && ident.name.as_str() == "creationCode"
185 && let Some(contract_id) = contract_id.as_contract()
186 {
187 self.collect_dependency(BytecodeDependency {
188 kind: BytecodeDependencyKind::CreationCode,
189 loc: span_to_range(self.gcx.sess.source_map(), expr.span),
190 referenced_contract: contract_id,
191 });
192 }
193 }
194 _ => {}
195 }
196 self.walk_expr(expr)
197 }
198
199 fn visit_stmt(&mut self, stmt: &'gcx Stmt<'gcx>) -> ControlFlow<Self::BreakValue> {
200 if let StmtKind::Try(stmt_try) = stmt.kind
201 && let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
202 && let Some(dependency) = handle_call_expr(
203 self.src,
204 self.gcx.sess.source_map(),
205 &stmt_try.expr,
206 call_expr,
207 call_args,
208 named_args,
209 true,
210 )
211 {
212 self.collect_dependency(dependency);
213 for clause in stmt_try.clauses {
214 for &var in clause.args {
215 self.visit_nested_var(var)?;
216 }
217 for stmt in clause.block.stmts {
218 self.visit_stmt(stmt)?;
219 }
220 }
221 return ControlFlow::Continue(());
222 }
223 self.walk_stmt(stmt)
224 }
225}
226
227fn handle_call_expr(
229 src: &str,
230 source_map: &SourceMap,
231 parent_expr: &Expr<'_>,
232 call_expr: &Expr<'_>,
233 call_args: &CallArgs<'_>,
234 named_args: &Option<&[NamedArg<'_>]>,
235 try_stmt: bool,
236) -> Option<BytecodeDependency> {
237 if let ExprKind::New(ty_new) = &call_expr.kind
238 && let TypeKind::Custom(item_id) = ty_new.kind
239 && let Some(contract_id) = item_id.as_contract()
240 {
241 let name_loc = span_to_range(source_map, ty_new.span);
242 let name = &src[name_loc];
243
244 let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
248 (call_args.span.lo() - ty_new.span.hi()).to_usize()
249 } else {
250 0
251 };
252
253 let args_len = parent_expr.span.hi() - ty_new.span.hi();
254 return Some(BytecodeDependency {
255 kind: BytecodeDependencyKind::New {
256 name: name.to_string(),
257 args_length: args_len.to_usize(),
258 call_args_offset,
259 value: named_arg(src, named_args, "value", source_map),
260 salt: named_arg(src, named_args, "salt", source_map),
261 try_stmt,
262 },
263 loc: span_to_range(source_map, call_expr.span),
264 referenced_contract: contract_id,
265 });
266 }
267 None
268}
269
270fn named_arg(
272 src: &str,
273 named_args: &Option<&[NamedArg<'_>]>,
274 arg: &str,
275 source_map: &SourceMap,
276) -> Option<String> {
277 named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
278 |named_arg| {
279 let named_arg_loc = span_to_range(source_map, named_arg.value.span);
280 src[named_arg_loc].to_string()
281 },
282 )
283}
284
285pub(crate) fn remove_bytecode_dependencies(
288 gcx: Gcx<'_>,
289 deps: &PreprocessorDependencies,
290 data: &PreprocessorData,
291) -> Updates {
292 let mut updates = Updates::default();
293 for (contract_id, deps) in &deps.preprocessed_contracts {
294 let contract = gcx.hir.contract(*contract_id);
295 let source = gcx.hir.source(contract.source);
296 let FileName::Real(path) = &source.file.name else {
297 continue;
298 };
299
300 let updates = updates.entry(path.clone()).or_default();
301 let mut used_helpers = BTreeSet::new();
302
303 let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
304 let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
306
307 for dep in deps {
308 let Some(ContractData { artifact, constructor_data, .. }) =
309 data.get(&dep.referenced_contract)
310 else {
311 continue;
312 };
313
314 match &dep.kind {
315 BytecodeDependencyKind::CreationCode => {
316 updates.insert((
318 dep.loc.start,
319 dep.loc.end,
320 format!("{vm}.getCode(\"{artifact}\")"),
321 ));
322 }
323 BytecodeDependencyKind::New {
324 name,
325 args_length,
326 call_args_offset,
327 value,
328 salt,
329 try_stmt,
330 } => {
331 let (mut update, closing_seq) = if *try_stmt {
332 (String::new(), "})")
333 } else {
334 (format!("{name}(payable("), "})))")
335 };
336 update.push_str(&format!("{vm}.deployCode({{"));
337 update.push_str(&format!("_artifact: \"{artifact}\""));
338
339 if let Some(value) = value {
340 update.push_str(", ");
341 update.push_str(&format!("_value: {value}"));
342 }
343
344 if let Some(salt) = salt {
345 update.push_str(", ");
346 update.push_str(&format!("_salt: {salt}"));
347 }
348
349 if constructor_data.is_some() {
350 used_helpers.insert(dep.referenced_contract);
352
353 update.push_str(", ");
354 update.push_str(&format!(
355 "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
356 id = dep.referenced_contract.get()
357 ));
358 updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
359
360 updates.insert((
361 dep.loc.end + args_length,
362 dep.loc.end + args_length,
363 format!("){closing_seq}"),
364 ));
365 } else {
366 update.push_str(closing_seq);
367 updates.insert((dep.loc.start, dep.loc.end + args_length, update));
368 }
369 }
370 };
371 }
372 let helper_imports = used_helpers.into_iter().map(|id| {
373 let id = id.get();
374 format!(
375 "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
376 )
377 }).join("\n");
378 updates.insert((
379 source.file.src.len(),
380 source.file.src.len(),
381 format!(
382 r#"
383{helper_imports}
384
385interface {vm_interface_name} {{
386 function deployCode(string memory _artifact) external returns (address);
387 function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
388 function deployCode(string memory _artifact, bytes memory _args) external returns (address);
389 function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
390 function deployCode(string memory _artifact, uint256 _value) external returns (address);
391 function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
392 function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
393 function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
394 function getCode(string memory _artifact) external returns (bytes memory);
395}}"#
396 ),
397 ));
398 }
399 updates
400}