foundry_common/preprocessor/
deps.rs

1use super::{
2    data::{ContractData, PreprocessorData},
3    span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar_parse::interface::Session;
8use solar_sema::{
9    hir::{CallArgs, ContractId, Expr, ExprKind, Hir, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10    interface::{data_structures::Never, source_map::FileName, SourceMap},
11};
12use std::{
13    collections::{BTreeMap, BTreeSet, HashSet},
14    ops::{ControlFlow, Range},
15    path::{Path, PathBuf},
16};
17
18/// Holds data about referenced source contracts and bytecode dependencies.
19pub(crate) struct PreprocessorDependencies {
20    // Mapping contract id to preprocess -> contract bytecode dependencies.
21    pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22    // Referenced contract ids.
23    pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27    pub fn new(
28        sess: &Session,
29        hir: &Hir<'_>,
30        paths: &[PathBuf],
31        src_dir: &Path,
32        root_dir: &Path,
33        mocks: &mut HashSet<PathBuf>,
34    ) -> Self {
35        let mut preprocessed_contracts = BTreeMap::new();
36        let mut referenced_contracts = HashSet::new();
37        for contract_id in hir.contract_ids() {
38            let contract = hir.contract(contract_id);
39            let source = hir.source(contract.source);
40
41            let FileName::Real(path) = &source.file.name else {
42                continue;
43            };
44
45            // Collect dependencies only for tests and scripts.
46            if !paths.contains(path) {
47                let path = path.display();
48                trace!("{path} is not test or script");
49                continue;
50            }
51
52            // Do not collect dependencies for mock contracts. Walk through base contracts and
53            // check if they're from src dir.
54            if contract.linearized_bases.iter().any(|base_contract_id| {
55                let base_contract = hir.contract(*base_contract_id);
56                let FileName::Real(path) = &hir.source(base_contract.source).file.name else {
57                    return false;
58                };
59                path.starts_with(src_dir)
60            }) {
61                // Record mock contracts to be evicted from preprocessed cache.
62                mocks.insert(root_dir.join(path));
63                let path = path.display();
64                trace!("found mock contract {path}");
65                continue;
66            } else {
67                // Make sure current contract is not in list of mocks (could happen when a contract
68                // which used to be a mock is refactored to a non-mock implementation).
69                mocks.remove(&root_dir.join(path));
70            }
71
72            let mut deps_collector = BytecodeDependencyCollector::new(
73                sess.source_map(),
74                hir,
75                source.file.src.as_str(),
76                src_dir,
77            );
78            // Analyze current contract.
79            let _ = deps_collector.walk_contract(contract);
80            // Ignore empty test contracts declared in source files with other contracts.
81            if !deps_collector.dependencies.is_empty() {
82                preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
83            }
84            // Record collected referenced contract ids.
85            referenced_contracts.extend(deps_collector.referenced_contracts);
86        }
87        Self { preprocessed_contracts, referenced_contracts }
88    }
89}
90
91/// Represents a bytecode dependency kind.
92#[derive(Debug)]
93enum BytecodeDependencyKind {
94    /// `type(Contract).creationCode`
95    CreationCode,
96    /// `new Contract`.
97    New {
98        /// Contract name.
99        name: String,
100        /// Constructor args length.
101        args_length: usize,
102        /// Constructor call args offset.
103        call_args_offset: usize,
104        /// `msg.value` (if any) used when creating contract.
105        value: Option<String>,
106        /// `salt` (if any) used when creating contract.
107        salt: Option<String>,
108        /// Whether it's a try contract creation statement.
109        try_stmt: bool,
110    },
111}
112
113/// Represents a single bytecode dependency.
114#[derive(Debug)]
115pub(crate) struct BytecodeDependency {
116    /// Dependency kind.
117    kind: BytecodeDependencyKind,
118    /// Source map location of this dependency.
119    loc: Range<usize>,
120    /// HIR id of referenced contract.
121    referenced_contract: ContractId,
122}
123
124/// Walks over contract HIR and collects [`BytecodeDependency`]s and referenced contracts.
125struct BytecodeDependencyCollector<'hir> {
126    /// Source map, used for determining contract item locations.
127    source_map: &'hir SourceMap,
128    /// Parsed HIR.
129    hir: &'hir Hir<'hir>,
130    /// Source content of current contract.
131    src: &'hir str,
132    /// Project source dir, used to determine if referenced contract is a source contract.
133    src_dir: &'hir Path,
134    /// Dependencies collected for current contract.
135    dependencies: Vec<BytecodeDependency>,
136    /// Unique HIR ids of contracts referenced from current contract.
137    referenced_contracts: HashSet<ContractId>,
138}
139
140impl<'hir> BytecodeDependencyCollector<'hir> {
141    fn new(
142        source_map: &'hir SourceMap,
143        hir: &'hir Hir<'hir>,
144        src: &'hir str,
145        src_dir: &'hir Path,
146    ) -> Self {
147        Self {
148            source_map,
149            hir,
150            src,
151            src_dir,
152            dependencies: vec![],
153            referenced_contracts: HashSet::default(),
154        }
155    }
156
157    /// Collects reference identified as bytecode dependency of analyzed contract.
158    /// Discards any reference that is not in project src directory (e.g. external
159    /// libraries or mock contracts that extend source contracts).
160    fn collect_dependency(&mut self, dependency: BytecodeDependency) {
161        let contract = self.hir.contract(dependency.referenced_contract);
162        let source = self.hir.source(contract.source);
163        let FileName::Real(path) = &source.file.name else {
164            return;
165        };
166
167        if !path.starts_with(self.src_dir) {
168            let path = path.display();
169            trace!("ignore dependency {path}");
170            return;
171        }
172
173        self.referenced_contracts.insert(dependency.referenced_contract);
174        self.dependencies.push(dependency);
175    }
176}
177
178impl<'hir> Visit<'hir> for BytecodeDependencyCollector<'hir> {
179    type BreakValue = Never;
180
181    fn hir(&self) -> &'hir Hir<'hir> {
182        self.hir
183    }
184
185    fn visit_expr(&mut self, expr: &'hir Expr<'hir>) -> ControlFlow<Self::BreakValue> {
186        match &expr.kind {
187            ExprKind::Call(call_expr, call_args, named_args) => {
188                if let Some(dependency) = handle_call_expr(
189                    self.src,
190                    self.source_map,
191                    expr,
192                    call_expr,
193                    call_args,
194                    named_args,
195                    false,
196                ) {
197                    self.collect_dependency(dependency);
198                }
199            }
200            ExprKind::Member(member_expr, ident) => {
201                if let ExprKind::TypeCall(ty) = &member_expr.kind {
202                    if let TypeKind::Custom(contract_id) = &ty.kind {
203                        if ident.name.as_str() == "creationCode" {
204                            if let Some(contract_id) = contract_id.as_contract() {
205                                self.collect_dependency(BytecodeDependency {
206                                    kind: BytecodeDependencyKind::CreationCode,
207                                    loc: span_to_range(self.source_map, expr.span),
208                                    referenced_contract: contract_id,
209                                });
210                            }
211                        }
212                    }
213                }
214            }
215            _ => {}
216        }
217        self.walk_expr(expr)
218    }
219
220    fn visit_stmt(&mut self, stmt: &'hir Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
221        if let StmtKind::Try(stmt_try) = stmt.kind {
222            if let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind {
223                if let Some(dependency) = handle_call_expr(
224                    self.src,
225                    self.source_map,
226                    &stmt_try.expr,
227                    call_expr,
228                    call_args,
229                    named_args,
230                    true,
231                ) {
232                    self.collect_dependency(dependency);
233                    for clause in stmt_try.clauses {
234                        for &var in clause.args {
235                            self.visit_nested_var(var)?;
236                        }
237                        for stmt in clause.block {
238                            self.visit_stmt(stmt)?;
239                        }
240                    }
241                    return ControlFlow::Continue(());
242                }
243            }
244        }
245        self.walk_stmt(stmt)
246    }
247}
248
249/// Helper function to analyze and extract bytecode dependency from a given call expression.
250fn handle_call_expr(
251    src: &str,
252    source_map: &SourceMap,
253    parent_expr: &Expr<'_>,
254    call_expr: &Expr<'_>,
255    call_args: &CallArgs<'_>,
256    named_args: &Option<&[NamedArg<'_>]>,
257    try_stmt: bool,
258) -> Option<BytecodeDependency> {
259    if let ExprKind::New(ty_new) = &call_expr.kind {
260        if let TypeKind::Custom(item_id) = ty_new.kind {
261            if let Some(contract_id) = item_id.as_contract() {
262                let name_loc = span_to_range(source_map, ty_new.span);
263                let name = &src[name_loc];
264
265                // Calculate offset to remove named args, e.g. for an expression like
266                // `new Counter {value: 333} (  address(this))`
267                // the offset will be used to replace `{value: 333} (  ` with `(`
268                let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
269                    (call_args.span.lo() - ty_new.span.hi()).to_usize()
270                } else {
271                    0
272                };
273
274                let args_len = parent_expr.span.hi() - ty_new.span.hi();
275                return Some(BytecodeDependency {
276                    kind: BytecodeDependencyKind::New {
277                        name: name.to_string(),
278                        args_length: args_len.to_usize(),
279                        call_args_offset,
280                        value: named_arg(src, named_args, "value", source_map),
281                        salt: named_arg(src, named_args, "salt", source_map),
282                        try_stmt,
283                    },
284                    loc: span_to_range(source_map, call_expr.span),
285                    referenced_contract: contract_id,
286                })
287            }
288        }
289    }
290    None
291}
292
293/// Helper function to extract value of a given named arg.
294fn named_arg(
295    src: &str,
296    named_args: &Option<&[NamedArg<'_>]>,
297    arg: &str,
298    source_map: &SourceMap,
299) -> Option<String> {
300    named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
301        |named_arg| {
302            let named_arg_loc = span_to_range(source_map, named_arg.value.span);
303            src[named_arg_loc].to_string()
304        },
305    )
306}
307
308/// Goes over all test/script files and replaces bytecode dependencies with cheatcode
309/// invocations.
310pub(crate) fn remove_bytecode_dependencies(
311    hir: &Hir<'_>,
312    deps: &PreprocessorDependencies,
313    data: &PreprocessorData,
314) -> Updates {
315    let mut updates = Updates::default();
316    for (contract_id, deps) in &deps.preprocessed_contracts {
317        let contract = hir.contract(*contract_id);
318        let source = hir.source(contract.source);
319        let FileName::Real(path) = &source.file.name else {
320            continue;
321        };
322
323        let updates = updates.entry(path.clone()).or_default();
324        let mut used_helpers = BTreeSet::new();
325
326        let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
327        // `address(uint160(uint256(keccak256("hevm cheat code"))))`
328        let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
329
330        for dep in deps {
331            let Some(ContractData { artifact, constructor_data, .. }) =
332                data.get(&dep.referenced_contract)
333            else {
334                continue;
335            };
336
337            match &dep.kind {
338                BytecodeDependencyKind::CreationCode => {
339                    // for creation code we need to just call getCode
340                    updates.insert((
341                        dep.loc.start,
342                        dep.loc.end,
343                        format!("{vm}.getCode(\"{artifact}\")"),
344                    ));
345                }
346                BytecodeDependencyKind::New {
347                    name,
348                    args_length,
349                    call_args_offset,
350                    value,
351                    salt,
352                    try_stmt,
353                } => {
354                    let (mut update, closing_seq) = if *try_stmt {
355                        (String::new(), "})")
356                    } else {
357                        (format!("{name}(payable("), "})))")
358                    };
359                    update.push_str(&format!("{vm}.deployCode({{"));
360                    update.push_str(&format!("_artifact: \"{artifact}\""));
361
362                    if let Some(value) = value {
363                        update.push_str(", ");
364                        update.push_str(&format!("_value: {value}"));
365                    }
366
367                    if let Some(salt) = salt {
368                        update.push_str(", ");
369                        update.push_str(&format!("_salt: {salt}"));
370                    }
371
372                    if constructor_data.is_some() {
373                        // Insert our helper
374                        used_helpers.insert(dep.referenced_contract);
375
376                        update.push_str(", ");
377                        update.push_str(&format!(
378                            "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
379                            id = dep.referenced_contract.get()
380                        ));
381                        updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
382
383                        updates.insert((
384                            dep.loc.end + args_length,
385                            dep.loc.end + args_length,
386                            format!("){closing_seq}"),
387                        ));
388                    } else {
389                        update.push_str(closing_seq);
390                        updates.insert((dep.loc.start, dep.loc.end + args_length, update));
391                    }
392                }
393            };
394        }
395        let helper_imports = used_helpers.into_iter().map(|id| {
396            let id = id.get();
397            format!(
398                "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
399            )
400        }).join("\n");
401        updates.insert((
402            source.file.src.len(),
403            source.file.src.len(),
404            format!(
405                r#"
406{helper_imports}
407
408interface {vm_interface_name} {{
409    function deployCode(string memory _artifact) external returns (address);
410    function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
411    function deployCode(string memory _artifact, bytes memory _args) external returns (address);
412    function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
413    function deployCode(string memory _artifact, uint256 _value) external returns (address);
414    function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
415    function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
416    function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
417    function getCode(string memory _artifact) external returns (bytes memory);
418}}"#
419            ),
420        ));
421    }
422    updates
423}