foundry_common/preprocessor/
deps.rs

1use super::{
2    data::{ContractData, PreprocessorData},
3    span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar_parse::interface::Session;
8use solar_sema::{
9    hir::{CallArgs, ContractId, Expr, ExprKind, Hir, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10    interface::{SourceMap, data_structures::Never, source_map::FileName},
11};
12use std::{
13    collections::{BTreeMap, BTreeSet, HashSet},
14    ops::{ControlFlow, Range},
15    path::{Path, PathBuf},
16};
17
18/// Holds data about referenced source contracts and bytecode dependencies.
19pub(crate) struct PreprocessorDependencies {
20    // Mapping contract id to preprocess -> contract bytecode dependencies.
21    pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22    // Referenced contract ids.
23    pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27    pub fn new(
28        sess: &Session,
29        hir: &Hir<'_>,
30        paths: &[PathBuf],
31        src_dir: &Path,
32        root_dir: &Path,
33        mocks: &mut HashSet<PathBuf>,
34    ) -> Self {
35        let mut preprocessed_contracts = BTreeMap::new();
36        let mut referenced_contracts = HashSet::new();
37        for contract_id in hir.contract_ids() {
38            let contract = hir.contract(contract_id);
39            let source = hir.source(contract.source);
40
41            let FileName::Real(path) = &source.file.name else {
42                continue;
43            };
44
45            // Collect dependencies only for tests and scripts.
46            if !paths.contains(path) {
47                let path = path.display();
48                trace!("{path} is not test or script");
49                continue;
50            }
51
52            // Do not collect dependencies for mock contracts. Walk through base contracts and
53            // check if they're from src dir.
54            if contract.linearized_bases.iter().any(|base_contract_id| {
55                let base_contract = hir.contract(*base_contract_id);
56                let FileName::Real(path) = &hir.source(base_contract.source).file.name else {
57                    return false;
58                };
59                path.starts_with(src_dir)
60            }) {
61                // Record mock contracts to be evicted from preprocessed cache.
62                mocks.insert(root_dir.join(path));
63                let path = path.display();
64                trace!("found mock contract {path}");
65                continue;
66            } else {
67                // Make sure current contract is not in list of mocks (could happen when a contract
68                // which used to be a mock is refactored to a non-mock implementation).
69                mocks.remove(&root_dir.join(path));
70            }
71
72            let mut deps_collector = BytecodeDependencyCollector::new(
73                sess.source_map(),
74                hir,
75                source.file.src.as_str(),
76                src_dir,
77            );
78            // Analyze current contract.
79            let _ = deps_collector.walk_contract(contract);
80            // Ignore empty test contracts declared in source files with other contracts.
81            if !deps_collector.dependencies.is_empty() {
82                preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
83            }
84            // Record collected referenced contract ids.
85            referenced_contracts.extend(deps_collector.referenced_contracts);
86        }
87        Self { preprocessed_contracts, referenced_contracts }
88    }
89}
90
91/// Represents a bytecode dependency kind.
92#[derive(Debug)]
93enum BytecodeDependencyKind {
94    /// `type(Contract).creationCode`
95    CreationCode,
96    /// `new Contract`.
97    New {
98        /// Contract name.
99        name: String,
100        /// Constructor args length.
101        args_length: usize,
102        /// Constructor call args offset.
103        call_args_offset: usize,
104        /// `msg.value` (if any) used when creating contract.
105        value: Option<String>,
106        /// `salt` (if any) used when creating contract.
107        salt: Option<String>,
108        /// Whether it's a try contract creation statement.
109        try_stmt: bool,
110    },
111}
112
113/// Represents a single bytecode dependency.
114#[derive(Debug)]
115pub(crate) struct BytecodeDependency {
116    /// Dependency kind.
117    kind: BytecodeDependencyKind,
118    /// Source map location of this dependency.
119    loc: Range<usize>,
120    /// HIR id of referenced contract.
121    referenced_contract: ContractId,
122}
123
124/// Walks over contract HIR and collects [`BytecodeDependency`]s and referenced contracts.
125struct BytecodeDependencyCollector<'hir> {
126    /// Source map, used for determining contract item locations.
127    source_map: &'hir SourceMap,
128    /// Parsed HIR.
129    hir: &'hir Hir<'hir>,
130    /// Source content of current contract.
131    src: &'hir str,
132    /// Project source dir, used to determine if referenced contract is a source contract.
133    src_dir: &'hir Path,
134    /// Dependencies collected for current contract.
135    dependencies: Vec<BytecodeDependency>,
136    /// Unique HIR ids of contracts referenced from current contract.
137    referenced_contracts: HashSet<ContractId>,
138}
139
140impl<'hir> BytecodeDependencyCollector<'hir> {
141    fn new(
142        source_map: &'hir SourceMap,
143        hir: &'hir Hir<'hir>,
144        src: &'hir str,
145        src_dir: &'hir Path,
146    ) -> Self {
147        Self {
148            source_map,
149            hir,
150            src,
151            src_dir,
152            dependencies: vec![],
153            referenced_contracts: HashSet::default(),
154        }
155    }
156
157    /// Collects reference identified as bytecode dependency of analyzed contract.
158    /// Discards any reference that is not in project src directory (e.g. external
159    /// libraries or mock contracts that extend source contracts).
160    fn collect_dependency(&mut self, dependency: BytecodeDependency) {
161        let contract = self.hir.contract(dependency.referenced_contract);
162        let source = self.hir.source(contract.source);
163        let FileName::Real(path) = &source.file.name else {
164            return;
165        };
166
167        if !path.starts_with(self.src_dir) {
168            let path = path.display();
169            trace!("ignore dependency {path}");
170            return;
171        }
172
173        self.referenced_contracts.insert(dependency.referenced_contract);
174        self.dependencies.push(dependency);
175    }
176}
177
178impl<'hir> Visit<'hir> for BytecodeDependencyCollector<'hir> {
179    type BreakValue = Never;
180
181    fn hir(&self) -> &'hir Hir<'hir> {
182        self.hir
183    }
184
185    fn visit_expr(&mut self, expr: &'hir Expr<'hir>) -> ControlFlow<Self::BreakValue> {
186        match &expr.kind {
187            ExprKind::Call(call_expr, call_args, named_args) => {
188                if let Some(dependency) = handle_call_expr(
189                    self.src,
190                    self.source_map,
191                    expr,
192                    call_expr,
193                    call_args,
194                    named_args,
195                    false,
196                ) {
197                    self.collect_dependency(dependency);
198                }
199            }
200            ExprKind::Member(member_expr, ident) => {
201                if let ExprKind::TypeCall(ty) = &member_expr.kind
202                    && let TypeKind::Custom(contract_id) = &ty.kind
203                    && ident.name.as_str() == "creationCode"
204                    && let Some(contract_id) = contract_id.as_contract()
205                {
206                    self.collect_dependency(BytecodeDependency {
207                        kind: BytecodeDependencyKind::CreationCode,
208                        loc: span_to_range(self.source_map, expr.span),
209                        referenced_contract: contract_id,
210                    });
211                }
212            }
213            _ => {}
214        }
215        self.walk_expr(expr)
216    }
217
218    fn visit_stmt(&mut self, stmt: &'hir Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
219        if let StmtKind::Try(stmt_try) = stmt.kind
220            && let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
221            && let Some(dependency) = handle_call_expr(
222                self.src,
223                self.source_map,
224                &stmt_try.expr,
225                call_expr,
226                call_args,
227                named_args,
228                true,
229            )
230        {
231            self.collect_dependency(dependency);
232            for clause in stmt_try.clauses {
233                for &var in clause.args {
234                    self.visit_nested_var(var)?;
235                }
236                for stmt in clause.block.stmts {
237                    self.visit_stmt(stmt)?;
238                }
239            }
240            return ControlFlow::Continue(());
241        }
242        self.walk_stmt(stmt)
243    }
244}
245
246/// Helper function to analyze and extract bytecode dependency from a given call expression.
247fn handle_call_expr(
248    src: &str,
249    source_map: &SourceMap,
250    parent_expr: &Expr<'_>,
251    call_expr: &Expr<'_>,
252    call_args: &CallArgs<'_>,
253    named_args: &Option<&[NamedArg<'_>]>,
254    try_stmt: bool,
255) -> Option<BytecodeDependency> {
256    if let ExprKind::New(ty_new) = &call_expr.kind
257        && let TypeKind::Custom(item_id) = ty_new.kind
258        && let Some(contract_id) = item_id.as_contract()
259    {
260        let name_loc = span_to_range(source_map, ty_new.span);
261        let name = &src[name_loc];
262
263        // Calculate offset to remove named args, e.g. for an expression like
264        // `new Counter {value: 333} (  address(this))`
265        // the offset will be used to replace `{value: 333} (  ` with `(`
266        let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
267            (call_args.span.lo() - ty_new.span.hi()).to_usize()
268        } else {
269            0
270        };
271
272        let args_len = parent_expr.span.hi() - ty_new.span.hi();
273        return Some(BytecodeDependency {
274            kind: BytecodeDependencyKind::New {
275                name: name.to_string(),
276                args_length: args_len.to_usize(),
277                call_args_offset,
278                value: named_arg(src, named_args, "value", source_map),
279                salt: named_arg(src, named_args, "salt", source_map),
280                try_stmt,
281            },
282            loc: span_to_range(source_map, call_expr.span),
283            referenced_contract: contract_id,
284        });
285    }
286    None
287}
288
289/// Helper function to extract value of a given named arg.
290fn named_arg(
291    src: &str,
292    named_args: &Option<&[NamedArg<'_>]>,
293    arg: &str,
294    source_map: &SourceMap,
295) -> Option<String> {
296    named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
297        |named_arg| {
298            let named_arg_loc = span_to_range(source_map, named_arg.value.span);
299            src[named_arg_loc].to_string()
300        },
301    )
302}
303
304/// Goes over all test/script files and replaces bytecode dependencies with cheatcode
305/// invocations.
306pub(crate) fn remove_bytecode_dependencies(
307    hir: &Hir<'_>,
308    deps: &PreprocessorDependencies,
309    data: &PreprocessorData,
310) -> Updates {
311    let mut updates = Updates::default();
312    for (contract_id, deps) in &deps.preprocessed_contracts {
313        let contract = hir.contract(*contract_id);
314        let source = hir.source(contract.source);
315        let FileName::Real(path) = &source.file.name else {
316            continue;
317        };
318
319        let updates = updates.entry(path.clone()).or_default();
320        let mut used_helpers = BTreeSet::new();
321
322        let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
323        // `address(uint160(uint256(keccak256("hevm cheat code"))))`
324        let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
325
326        for dep in deps {
327            let Some(ContractData { artifact, constructor_data, .. }) =
328                data.get(&dep.referenced_contract)
329            else {
330                continue;
331            };
332
333            match &dep.kind {
334                BytecodeDependencyKind::CreationCode => {
335                    // for creation code we need to just call getCode
336                    updates.insert((
337                        dep.loc.start,
338                        dep.loc.end,
339                        format!("{vm}.getCode(\"{artifact}\")"),
340                    ));
341                }
342                BytecodeDependencyKind::New {
343                    name,
344                    args_length,
345                    call_args_offset,
346                    value,
347                    salt,
348                    try_stmt,
349                } => {
350                    let (mut update, closing_seq) = if *try_stmt {
351                        (String::new(), "})")
352                    } else {
353                        (format!("{name}(payable("), "})))")
354                    };
355                    update.push_str(&format!("{vm}.deployCode({{"));
356                    update.push_str(&format!("_artifact: \"{artifact}\""));
357
358                    if let Some(value) = value {
359                        update.push_str(", ");
360                        update.push_str(&format!("_value: {value}"));
361                    }
362
363                    if let Some(salt) = salt {
364                        update.push_str(", ");
365                        update.push_str(&format!("_salt: {salt}"));
366                    }
367
368                    if constructor_data.is_some() {
369                        // Insert our helper
370                        used_helpers.insert(dep.referenced_contract);
371
372                        update.push_str(", ");
373                        update.push_str(&format!(
374                            "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
375                            id = dep.referenced_contract.get()
376                        ));
377                        updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
378
379                        updates.insert((
380                            dep.loc.end + args_length,
381                            dep.loc.end + args_length,
382                            format!("){closing_seq}"),
383                        ));
384                    } else {
385                        update.push_str(closing_seq);
386                        updates.insert((dep.loc.start, dep.loc.end + args_length, update));
387                    }
388                }
389            };
390        }
391        let helper_imports = used_helpers.into_iter().map(|id| {
392            let id = id.get();
393            format!(
394                "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
395            )
396        }).join("\n");
397        updates.insert((
398            source.file.src.len(),
399            source.file.src.len(),
400            format!(
401                r#"
402{helper_imports}
403
404interface {vm_interface_name} {{
405    function deployCode(string memory _artifact) external returns (address);
406    function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
407    function deployCode(string memory _artifact, bytes memory _args) external returns (address);
408    function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
409    function deployCode(string memory _artifact, uint256 _value) external returns (address);
410    function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
411    function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
412    function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
413    function getCode(string memory _artifact) external returns (bytes memory);
414}}"#
415            ),
416        ));
417    }
418    updates
419}