foundry_common/preprocessor/
deps.rs

1use super::{
2    data::{ContractData, PreprocessorData},
3    span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar::sema::{
8    Gcx, Hir,
9    hir::{CallArgs, ContractId, Expr, ExprKind, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10    interface::{SourceMap, data_structures::Never, source_map::FileName},
11};
12use std::{
13    collections::{BTreeMap, BTreeSet, HashSet},
14    ops::{ControlFlow, Range},
15    path::{Path, PathBuf},
16};
17
18/// Holds data about referenced source contracts and bytecode dependencies.
19pub(crate) struct PreprocessorDependencies {
20    // Mapping contract id to preprocess -> contract bytecode dependencies.
21    pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22    // Referenced contract ids.
23    pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27    pub fn new(
28        gcx: Gcx<'_>,
29        paths: &[PathBuf],
30        src_dir: &Path,
31        root_dir: &Path,
32        mocks: &mut HashSet<PathBuf>,
33    ) -> Self {
34        let mut preprocessed_contracts = BTreeMap::new();
35        let mut referenced_contracts = HashSet::new();
36        for contract_id in gcx.hir.contract_ids() {
37            let contract = gcx.hir.contract(contract_id);
38            let source = gcx.hir.source(contract.source);
39
40            let FileName::Real(path) = &source.file.name else {
41                continue;
42            };
43
44            // Collect dependencies only for tests and scripts.
45            if !paths.contains(path) {
46                let path = path.display();
47                trace!("{path} is not test or script");
48                continue;
49            }
50
51            // Do not collect dependencies for mock contracts. Walk through base contracts and
52            // check if they're from src dir.
53            if contract.linearized_bases.iter().any(|base_contract_id| {
54                let base_contract = gcx.hir.contract(*base_contract_id);
55                let FileName::Real(path) = &gcx.hir.source(base_contract.source).file.name else {
56                    return false;
57                };
58                path.starts_with(src_dir)
59            }) {
60                // Record mock contracts to be evicted from preprocessed cache.
61                mocks.insert(root_dir.join(path));
62                let path = path.display();
63                trace!("found mock contract {path}");
64                continue;
65            } else {
66                // Make sure current contract is not in list of mocks (could happen when a contract
67                // which used to be a mock is refactored to a non-mock implementation).
68                mocks.remove(&root_dir.join(path));
69            }
70
71            let mut deps_collector =
72                BytecodeDependencyCollector::new(gcx, source.file.src.as_str(), src_dir);
73            // Analyze current contract.
74            let _ = deps_collector.walk_contract(contract);
75            // Ignore empty test contracts declared in source files with other contracts.
76            if !deps_collector.dependencies.is_empty() {
77                preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
78            }
79            // Record collected referenced contract ids.
80            referenced_contracts.extend(deps_collector.referenced_contracts);
81        }
82        Self { preprocessed_contracts, referenced_contracts }
83    }
84}
85
86/// Represents a bytecode dependency kind.
87#[derive(Debug)]
88enum BytecodeDependencyKind {
89    /// `type(Contract).creationCode`
90    CreationCode,
91    /// `new Contract`.
92    New {
93        /// Contract name.
94        name: String,
95        /// Constructor args length.
96        args_length: usize,
97        /// Constructor call args offset.
98        call_args_offset: usize,
99        /// `msg.value` (if any) used when creating contract.
100        value: Option<String>,
101        /// `salt` (if any) used when creating contract.
102        salt: Option<String>,
103        /// Whether it's a try contract creation statement, with custom return.
104        try_stmt: Option<bool>,
105    },
106}
107
108/// Represents a single bytecode dependency.
109#[derive(Debug)]
110pub(crate) struct BytecodeDependency {
111    /// Dependency kind.
112    kind: BytecodeDependencyKind,
113    /// Source map location of this dependency.
114    loc: Range<usize>,
115    /// HIR id of referenced contract.
116    referenced_contract: ContractId,
117}
118
119/// Walks over contract HIR and collects [`BytecodeDependency`]s and referenced contracts.
120struct BytecodeDependencyCollector<'gcx, 'src> {
121    /// Source map, used for determining contract item locations.
122    gcx: Gcx<'gcx>,
123    /// Source content of current contract.
124    src: &'src str,
125    /// Project source dir, used to determine if referenced contract is a source contract.
126    src_dir: &'src Path,
127    /// Dependencies collected for current contract.
128    dependencies: Vec<BytecodeDependency>,
129    /// Unique HIR ids of contracts referenced from current contract.
130    referenced_contracts: HashSet<ContractId>,
131}
132
133impl<'gcx, 'src> BytecodeDependencyCollector<'gcx, 'src> {
134    fn new(gcx: Gcx<'gcx>, src: &'src str, src_dir: &'src Path) -> Self {
135        Self { gcx, src, src_dir, dependencies: vec![], referenced_contracts: HashSet::default() }
136    }
137
138    /// Collects reference identified as bytecode dependency of analyzed contract.
139    /// Discards any reference that is not in project src directory (e.g. external
140    /// libraries or mock contracts that extend source contracts).
141    fn collect_dependency(&mut self, dependency: BytecodeDependency) {
142        let contract = self.gcx.hir.contract(dependency.referenced_contract);
143        let source = self.gcx.hir.source(contract.source);
144        let FileName::Real(path) = &source.file.name else {
145            return;
146        };
147
148        if !path.starts_with(self.src_dir) {
149            let path = path.display();
150            trace!("ignore dependency {path}");
151            return;
152        }
153
154        self.referenced_contracts.insert(dependency.referenced_contract);
155        self.dependencies.push(dependency);
156    }
157}
158
159impl<'gcx> Visit<'gcx> for BytecodeDependencyCollector<'gcx, '_> {
160    type BreakValue = Never;
161
162    fn hir(&self) -> &'gcx Hir<'gcx> {
163        &self.gcx.hir
164    }
165
166    fn visit_expr(&mut self, expr: &'gcx Expr<'gcx>) -> ControlFlow<Self::BreakValue> {
167        match &expr.kind {
168            ExprKind::Call(call_expr, call_args, named_args) => {
169                if let Some(dependency) = handle_call_expr(
170                    self.src,
171                    self.gcx.sess.source_map(),
172                    expr,
173                    call_expr,
174                    call_args,
175                    named_args,
176                ) {
177                    self.collect_dependency(dependency);
178                }
179            }
180            ExprKind::Member(member_expr, ident) => {
181                if let ExprKind::TypeCall(ty) = &member_expr.kind
182                    && let TypeKind::Custom(contract_id) = &ty.kind
183                    && ident.name.as_str() == "creationCode"
184                    && let Some(contract_id) = contract_id.as_contract()
185                {
186                    self.collect_dependency(BytecodeDependency {
187                        kind: BytecodeDependencyKind::CreationCode,
188                        loc: span_to_range(self.gcx.sess.source_map(), expr.span),
189                        referenced_contract: contract_id,
190                    });
191                }
192            }
193            _ => {}
194        }
195        self.walk_expr(expr)
196    }
197
198    fn visit_stmt(&mut self, stmt: &'gcx Stmt<'gcx>) -> ControlFlow<Self::BreakValue> {
199        if let StmtKind::Try(stmt_try) = stmt.kind
200            && let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
201            && let Some(mut dependency) = handle_call_expr(
202                self.src,
203                self.gcx.sess.source_map(),
204                &stmt_try.expr,
205                call_expr,
206                call_args,
207                named_args,
208            )
209        {
210            let has_custom_return = if let Some(clause) = stmt_try.clauses.first()
211                && clause.args.len() == 1
212                && let Some(ret_var) = clause.args.first()
213                && let TypeKind::Custom(_) = self.hir().variable(*ret_var).ty.kind
214            {
215                true
216            } else {
217                false
218            };
219
220            if let BytecodeDependencyKind::New { try_stmt, .. } = &mut dependency.kind {
221                *try_stmt = Some(has_custom_return);
222            }
223            self.collect_dependency(dependency);
224
225            for clause in stmt_try.clauses {
226                for &var in clause.args {
227                    self.visit_nested_var(var)?;
228                }
229                for stmt in clause.block.stmts {
230                    self.visit_stmt(stmt)?;
231                }
232            }
233            return ControlFlow::Continue(());
234        }
235        self.walk_stmt(stmt)
236    }
237}
238
239/// Helper function to analyze and extract bytecode dependency from a given call expression.
240fn handle_call_expr(
241    src: &str,
242    source_map: &SourceMap,
243    parent_expr: &Expr<'_>,
244    call_expr: &Expr<'_>,
245    call_args: &CallArgs<'_>,
246    named_args: &Option<&[NamedArg<'_>]>,
247) -> Option<BytecodeDependency> {
248    if let ExprKind::New(ty_new) = &call_expr.kind
249        && let TypeKind::Custom(item_id) = ty_new.kind
250        && let Some(contract_id) = item_id.as_contract()
251    {
252        let name_loc = span_to_range(source_map, ty_new.span);
253        let name = &src[name_loc];
254
255        // Calculate offset to remove named args, e.g. for an expression like
256        // `new Counter {value: 333} (  address(this))`
257        // the offset will be used to replace `{value: 333} (  ` with `(`
258        let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
259            (call_args.span.lo() - ty_new.span.hi()).to_usize()
260        } else {
261            0
262        };
263
264        let args_len = parent_expr.span.hi() - ty_new.span.hi();
265        return Some(BytecodeDependency {
266            kind: BytecodeDependencyKind::New {
267                name: name.to_string(),
268                args_length: args_len.to_usize(),
269                call_args_offset,
270                value: named_arg(src, named_args, "value", source_map),
271                salt: named_arg(src, named_args, "salt", source_map),
272                try_stmt: None,
273            },
274            loc: span_to_range(source_map, call_expr.span),
275            referenced_contract: contract_id,
276        });
277    }
278    None
279}
280
281/// Helper function to extract value of a given named arg.
282fn named_arg(
283    src: &str,
284    named_args: &Option<&[NamedArg<'_>]>,
285    arg: &str,
286    source_map: &SourceMap,
287) -> Option<String> {
288    named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
289        |named_arg| {
290            let named_arg_loc = span_to_range(source_map, named_arg.value.span);
291            src[named_arg_loc].to_string()
292        },
293    )
294}
295
296/// Goes over all test/script files and replaces bytecode dependencies with cheatcode
297/// invocations.
298///
299/// Special handling of try/catch statements with custom returns, where the try statement becomes
300/// ```solidity
301/// try this.addressToCounter() returns (Counter c)
302/// ```
303/// and helper to cast address is appended
304/// ```solidity
305/// function addressToCounter(address addr) returns (Counter) {
306///     return Counter(addr);
307/// }
308/// ```
309pub(crate) fn remove_bytecode_dependencies(
310    gcx: Gcx<'_>,
311    deps: &PreprocessorDependencies,
312    data: &PreprocessorData,
313) -> Updates {
314    let mut updates = Updates::default();
315    for (contract_id, deps) in &deps.preprocessed_contracts {
316        let contract = gcx.hir.contract(*contract_id);
317        let source = gcx.hir.source(contract.source);
318        let FileName::Real(path) = &source.file.name else {
319            continue;
320        };
321
322        let updates = updates.entry(path.clone()).or_default();
323        let mut used_helpers = BTreeSet::new();
324
325        let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
326        // `address(uint160(uint256(keccak256("hevm cheat code"))))`
327        let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
328        let mut try_catch_helpers: HashSet<&str> = HashSet::default();
329
330        for dep in deps {
331            let Some(ContractData { artifact, constructor_data, .. }) =
332                data.get(&dep.referenced_contract)
333            else {
334                continue;
335            };
336
337            match &dep.kind {
338                BytecodeDependencyKind::CreationCode => {
339                    // for creation code we need to just call getCode
340                    updates.insert((
341                        dep.loc.start,
342                        dep.loc.end,
343                        format!("{vm}.getCode(\"{artifact}\")"),
344                    ));
345                }
346                BytecodeDependencyKind::New {
347                    name,
348                    args_length,
349                    call_args_offset,
350                    value,
351                    salt,
352                    try_stmt,
353                } => {
354                    let (mut update, closing_seq) = if let Some(has_ret) = try_stmt {
355                        if *has_ret {
356                            // try this.addressToCounter1() returns (Counter c)
357                            try_catch_helpers.insert(name);
358                            (format!("this.addressTo{name}{id}(", id = contract_id.get()), "}))")
359                        } else {
360                            (String::new(), "})")
361                        }
362                    } else {
363                        (format!("{name}(payable("), "})))")
364                    };
365                    update.push_str(&format!("{vm}.deployCode({{"));
366                    update.push_str(&format!("_artifact: \"{artifact}\""));
367
368                    if let Some(value) = value {
369                        update.push_str(", ");
370                        update.push_str(&format!("_value: {value}"));
371                    }
372
373                    if let Some(salt) = salt {
374                        update.push_str(", ");
375                        update.push_str(&format!("_salt: {salt}"));
376                    }
377
378                    if constructor_data.is_some() {
379                        // Insert our helper
380                        used_helpers.insert(dep.referenced_contract);
381
382                        update.push_str(", ");
383                        update.push_str(&format!(
384                            "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
385                            id = dep.referenced_contract.get()
386                        ));
387                        updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
388
389                        updates.insert((
390                            dep.loc.end + args_length,
391                            dep.loc.end + args_length,
392                            format!("){closing_seq}"),
393                        ));
394                    } else {
395                        update.push_str(closing_seq);
396                        updates.insert((dep.loc.start, dep.loc.end + args_length, update));
397                    }
398                }
399            };
400        }
401
402        // Add try catch statements after last function of the test contract.
403        if !try_catch_helpers.is_empty()
404            && let Some(last_fn_id) = contract.functions().last()
405        {
406            let last_fn_range =
407                span_to_range(gcx.sess.source_map(), gcx.hir.function(last_fn_id).span);
408            let to_address_fns = try_catch_helpers
409                .iter()
410                .map(|ty| {
411                    format!(
412                        r#"
413                            function addressTo{ty}{id}(address addr) public pure returns ({ty}) {{
414                                return {ty}(addr);
415                            }}
416                        "#,
417                        id = contract_id.get()
418                    )
419                })
420                .collect::<String>();
421
422            updates.insert((last_fn_range.end, last_fn_range.end, to_address_fns));
423        }
424
425        let helper_imports = used_helpers.into_iter().map(|id| {
426            let id = id.get();
427            format!(
428                "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
429            )
430        }).join("\n");
431        updates.insert((
432            source.file.src.len(),
433            source.file.src.len(),
434            format!(
435                r#"
436{helper_imports}
437
438interface {vm_interface_name} {{
439    function deployCode(string memory _artifact) external returns (address);
440    function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
441    function deployCode(string memory _artifact, bytes memory _args) external returns (address);
442    function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
443    function deployCode(string memory _artifact, uint256 _value) external returns (address);
444    function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
445    function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
446    function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
447    function getCode(string memory _artifact) external returns (bytes memory);
448}}"#
449            ),
450        ));
451    }
452    updates
453}