Skip to main content

foundry_common/preprocessor/
deps.rs

1use super::{
2    data::{ContractData, PreprocessorData},
3    span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar::sema::{
8    Gcx, Hir,
9    hir::{CallArgs, ContractId, Expr, ExprKind, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10    interface::{SourceMap, data_structures::Never, source_map::FileName},
11};
12use std::{
13    collections::{BTreeMap, BTreeSet, HashSet},
14    ops::{ControlFlow, Range},
15    path::{Path, PathBuf},
16};
17
18/// Holds data about referenced source contracts and bytecode dependencies.
19pub(crate) struct PreprocessorDependencies {
20    // Mapping contract id to preprocess -> contract bytecode dependencies.
21    pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22    // Referenced contract ids.
23    pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27    pub fn new(
28        gcx: Gcx<'_>,
29        paths: &[PathBuf],
30        src_dir: &Path,
31        root_dir: &Path,
32        mocks: &mut HashSet<PathBuf>,
33    ) -> Self {
34        let mut preprocessed_contracts = BTreeMap::new();
35        let mut referenced_contracts = HashSet::new();
36        let mut current_mocks = HashSet::new();
37
38        // Helper closure for iterating candidate contracts to preprocess (tests and scripts).
39        let candidate_contracts = || {
40            gcx.hir.contract_ids().filter_map(|id| {
41                let contract = gcx.hir.contract(id);
42                let source = gcx.hir.source(contract.source);
43                let FileName::Real(path) = &source.file.name else {
44                    return None;
45                };
46
47                if !paths.contains(path) {
48                    trace!("{} is not test or script", path.display());
49                    return None;
50                }
51
52                Some((id, contract, source, path))
53            })
54        };
55
56        // Collect current mocks.
57        for (_, contract, _, path) in candidate_contracts() {
58            if contract.linearized_bases.iter().any(|base_id| {
59                let base = gcx.hir.contract(*base_id);
60                matches!(
61                    &gcx.hir.source(base.source).file.name,
62                    FileName::Real(base_path) if base_path.starts_with(src_dir)
63                )
64            }) {
65                let mock_path = root_dir.join(path);
66                trace!("found mock contract {}", mock_path.display());
67                current_mocks.insert(mock_path);
68            }
69        }
70
71        // Collect dependencies for non-mock test/script contracts.
72        for (contract_id, contract, source, path) in candidate_contracts() {
73            let full_path = root_dir.join(path);
74
75            if current_mocks.contains(&full_path) {
76                trace!("{} is a mock, skipping", path.display());
77                continue;
78            }
79
80            // Make sure current contract is not in list of mocks (could happen when a contract
81            // which used to be a mock is refactored to a non-mock implementation).
82            mocks.remove(&full_path);
83
84            let mut deps_collector =
85                BytecodeDependencyCollector::new(gcx, source.file.src.as_str(), src_dir);
86            // Analyze current contract.
87            let _ = deps_collector.walk_contract(contract);
88            // Ignore empty test contracts declared in source files with other contracts.
89            if !deps_collector.dependencies.is_empty() {
90                preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
91            }
92
93            // Record collected referenced contract ids.
94            referenced_contracts.extend(deps_collector.referenced_contracts);
95        }
96
97        // Add current mocks.
98        mocks.extend(current_mocks);
99
100        Self { preprocessed_contracts, referenced_contracts }
101    }
102}
103
104/// Represents a bytecode dependency kind.
105#[derive(Debug)]
106enum BytecodeDependencyKind {
107    /// `type(Contract).creationCode`
108    CreationCode,
109    /// `new Contract`.
110    New {
111        /// Contract name.
112        name: String,
113        /// Constructor args length.
114        args_length: usize,
115        /// Constructor call args offset.
116        call_args_offset: usize,
117        /// `msg.value` (if any) used when creating contract.
118        value: Option<String>,
119        /// `salt` (if any) used when creating contract.
120        salt: Option<String>,
121        /// Whether it's a try contract creation statement, with custom return.
122        try_stmt: Option<bool>,
123    },
124}
125
126/// Represents a single bytecode dependency.
127#[derive(Debug)]
128pub(crate) struct BytecodeDependency {
129    /// Dependency kind.
130    kind: BytecodeDependencyKind,
131    /// Source map location of this dependency.
132    loc: Range<usize>,
133    /// HIR id of referenced contract.
134    referenced_contract: ContractId,
135}
136
137/// Walks over contract HIR and collects [`BytecodeDependency`]s and referenced contracts.
138struct BytecodeDependencyCollector<'gcx, 'src> {
139    /// Source map, used for determining contract item locations.
140    gcx: Gcx<'gcx>,
141    /// Source content of current contract.
142    src: &'src str,
143    /// Project source dir, used to determine if referenced contract is a source contract.
144    src_dir: &'src Path,
145    /// Dependencies collected for current contract.
146    dependencies: Vec<BytecodeDependency>,
147    /// Unique HIR ids of contracts referenced from current contract.
148    referenced_contracts: HashSet<ContractId>,
149}
150
151impl<'gcx, 'src> BytecodeDependencyCollector<'gcx, 'src> {
152    fn new(gcx: Gcx<'gcx>, src: &'src str, src_dir: &'src Path) -> Self {
153        Self { gcx, src, src_dir, dependencies: vec![], referenced_contracts: HashSet::default() }
154    }
155
156    /// Collects reference identified as bytecode dependency of analyzed contract.
157    /// Discards any reference that is not in project src directory (e.g. external
158    /// libraries or mock contracts that extend source contracts).
159    fn collect_dependency(&mut self, dependency: BytecodeDependency) {
160        let contract = self.gcx.hir.contract(dependency.referenced_contract);
161        let source = self.gcx.hir.source(contract.source);
162        let FileName::Real(path) = &source.file.name else {
163            return;
164        };
165
166        if !path.starts_with(self.src_dir) {
167            let path = path.display();
168            trace!("ignore dependency {path}");
169            return;
170        }
171
172        self.referenced_contracts.insert(dependency.referenced_contract);
173        self.dependencies.push(dependency);
174    }
175}
176
177impl<'gcx> Visit<'gcx> for BytecodeDependencyCollector<'gcx, '_> {
178    type BreakValue = Never;
179
180    fn hir(&self) -> &'gcx Hir<'gcx> {
181        &self.gcx.hir
182    }
183
184    fn visit_expr(&mut self, expr: &'gcx Expr<'gcx>) -> ControlFlow<Self::BreakValue> {
185        #[allow(clippy::collapsible_match)]
186        match &expr.kind {
187            ExprKind::Call(call_expr, call_args, named_args) => {
188                if let Some(dependency) = handle_call_expr(
189                    self.src,
190                    self.gcx.sess.source_map(),
191                    expr,
192                    call_expr,
193                    call_args,
194                    named_args,
195                ) {
196                    self.collect_dependency(dependency);
197                }
198            }
199            ExprKind::Member(member_expr, ident) => {
200                if let ExprKind::TypeCall(ty) = &member_expr.kind
201                    && let TypeKind::Custom(contract_id) = &ty.kind
202                    && ident.name.as_str() == "creationCode"
203                    && let Some(contract_id) = contract_id.as_contract()
204                {
205                    self.collect_dependency(BytecodeDependency {
206                        kind: BytecodeDependencyKind::CreationCode,
207                        loc: span_to_range(self.gcx.sess.source_map(), expr.span),
208                        referenced_contract: contract_id,
209                    });
210                }
211            }
212            _ => {}
213        }
214        self.walk_expr(expr)
215    }
216
217    fn visit_stmt(&mut self, stmt: &'gcx Stmt<'gcx>) -> ControlFlow<Self::BreakValue> {
218        if let StmtKind::Try(stmt_try) = stmt.kind
219            && let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
220            && let Some(mut dependency) = handle_call_expr(
221                self.src,
222                self.gcx.sess.source_map(),
223                &stmt_try.expr,
224                call_expr,
225                call_args,
226                named_args,
227            )
228        {
229            let has_custom_return = if let Some(clause) = stmt_try.clauses.first()
230                && clause.args.len() == 1
231                && let Some(ret_var) = clause.args.first()
232                && let TypeKind::Custom(_) = self.hir().variable(*ret_var).ty.kind
233            {
234                true
235            } else {
236                false
237            };
238
239            if let BytecodeDependencyKind::New { try_stmt, .. } = &mut dependency.kind {
240                *try_stmt = Some(has_custom_return);
241            }
242            self.collect_dependency(dependency);
243
244            for clause in stmt_try.clauses {
245                for &var in clause.args {
246                    self.visit_nested_var(var)?;
247                }
248                for stmt in clause.block.stmts {
249                    self.visit_stmt(stmt)?;
250                }
251            }
252            return ControlFlow::Continue(());
253        }
254        self.walk_stmt(stmt)
255    }
256}
257
258/// Helper function to analyze and extract bytecode dependency from a given call expression.
259fn handle_call_expr(
260    src: &str,
261    source_map: &SourceMap,
262    parent_expr: &Expr<'_>,
263    call_expr: &Expr<'_>,
264    call_args: &CallArgs<'_>,
265    named_args: &Option<&[NamedArg<'_>]>,
266) -> Option<BytecodeDependency> {
267    if let ExprKind::New(ty_new) = &call_expr.kind
268        && let TypeKind::Custom(item_id) = ty_new.kind
269        && let Some(contract_id) = item_id.as_contract()
270    {
271        let name_loc = span_to_range(source_map, ty_new.span);
272        let name = &src[name_loc];
273
274        // Calculate offset to remove named args, e.g. for an expression like
275        // `new Counter {value: 333} (  address(this))`
276        // the offset will be used to replace `{value: 333} (  ` with `(`
277        let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
278            (call_args.span.lo() - ty_new.span.hi()).to_usize()
279        } else {
280            0
281        };
282
283        let args_len = parent_expr.span.hi() - ty_new.span.hi();
284        return Some(BytecodeDependency {
285            kind: BytecodeDependencyKind::New {
286                name: name.to_string(),
287                args_length: args_len.to_usize(),
288                call_args_offset,
289                value: named_arg(src, named_args, "value", source_map),
290                salt: named_arg(src, named_args, "salt", source_map),
291                try_stmt: None,
292            },
293            loc: span_to_range(source_map, call_expr.span),
294            referenced_contract: contract_id,
295        });
296    }
297    None
298}
299
300/// Helper function to extract value of a given named arg.
301fn named_arg(
302    src: &str,
303    named_args: &Option<&[NamedArg<'_>]>,
304    arg: &str,
305    source_map: &SourceMap,
306) -> Option<String> {
307    named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
308        |named_arg| {
309            let named_arg_loc = span_to_range(source_map, named_arg.value.span);
310            src[named_arg_loc].to_string()
311        },
312    )
313}
314
315/// Goes over all test/script files and replaces bytecode dependencies with cheatcode
316/// invocations.
317///
318/// Special handling of try/catch statements with custom returns, where the try statement becomes
319/// ```solidity
320/// try this.addressToCounter() returns (Counter c)
321/// ```
322/// and helper to cast address is appended
323/// ```solidity
324/// function addressToCounter(address addr) returns (Counter) {
325///     return Counter(addr);
326/// }
327/// ```
328pub(crate) fn remove_bytecode_dependencies(
329    gcx: Gcx<'_>,
330    deps: &PreprocessorDependencies,
331    data: &PreprocessorData,
332) -> Updates {
333    let mut updates = Updates::default();
334    for (contract_id, deps) in &deps.preprocessed_contracts {
335        let contract = gcx.hir.contract(*contract_id);
336        let source = gcx.hir.source(contract.source);
337        let FileName::Real(path) = &source.file.name else {
338            continue;
339        };
340
341        let updates = updates.entry(path.clone()).or_default();
342        let mut used_helpers = BTreeSet::new();
343
344        let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
345        // `address(uint160(uint256(keccak256("hevm cheat code"))))`
346        let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
347        let mut try_catch_helpers: HashSet<&str> = HashSet::default();
348
349        for dep in deps {
350            let Some(ContractData { artifact, constructor_data, .. }) =
351                data.get(&dep.referenced_contract)
352            else {
353                continue;
354            };
355
356            match &dep.kind {
357                BytecodeDependencyKind::CreationCode => {
358                    // for creation code we need to just call getCode
359                    updates.insert((
360                        dep.loc.start,
361                        dep.loc.end,
362                        format!("{vm}.getCode(\"{artifact}\")"),
363                    ));
364                }
365                BytecodeDependencyKind::New {
366                    name,
367                    args_length,
368                    call_args_offset,
369                    value,
370                    salt,
371                    try_stmt,
372                } => {
373                    let (mut update, closing_seq) = if let Some(has_ret) = try_stmt {
374                        if *has_ret {
375                            // try this.addressToCounter1() returns (Counter c)
376                            try_catch_helpers.insert(name);
377                            (format!("this.addressTo{name}{id}(", id = contract_id.get()), "}))")
378                        } else {
379                            (String::new(), "})")
380                        }
381                    } else {
382                        (format!("{name}(payable("), "})))")
383                    };
384                    update.push_str(&format!("{vm}.deployCode({{"));
385                    update.push_str(&format!("_artifact: \"{artifact}\""));
386
387                    if let Some(value) = value {
388                        update.push_str(", ");
389                        update.push_str(&format!("_value: {value}"));
390                    }
391
392                    if let Some(salt) = salt {
393                        update.push_str(", ");
394                        update.push_str(&format!("_salt: {salt}"));
395                    }
396
397                    if constructor_data.is_some() {
398                        // Insert our helper
399                        used_helpers.insert(dep.referenced_contract);
400
401                        update.push_str(", ");
402                        update.push_str(&format!(
403                            "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
404                            id = dep.referenced_contract.get()
405                        ));
406                        updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
407
408                        updates.insert((
409                            dep.loc.end + args_length,
410                            dep.loc.end + args_length,
411                            format!("){closing_seq}"),
412                        ));
413                    } else {
414                        update.push_str(closing_seq);
415                        updates.insert((dep.loc.start, dep.loc.end + args_length, update));
416                    }
417                }
418            };
419        }
420
421        // Add try catch statements after last function of the test contract.
422        if !try_catch_helpers.is_empty()
423            && let Some(last_fn_id) = contract.functions().last()
424        {
425            let last_fn_range =
426                span_to_range(gcx.sess.source_map(), gcx.hir.function(last_fn_id).span);
427            let to_address_fns = try_catch_helpers
428                .iter()
429                .map(|ty| {
430                    format!(
431                        r#"
432                            function addressTo{ty}{id}(address addr) public pure returns ({ty}) {{
433                                return {ty}(addr);
434                            }}
435                        "#,
436                        id = contract_id.get()
437                    )
438                })
439                .collect::<String>();
440
441            updates.insert((last_fn_range.end, last_fn_range.end, to_address_fns));
442        }
443
444        let helper_imports = used_helpers.into_iter().map(|id| {
445            let id = id.get();
446            format!(
447                "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
448            )
449        }).join("\n");
450        updates.insert((
451            source.file.src.len(),
452            source.file.src.len(),
453            format!(
454                r#"
455{helper_imports}
456
457interface {vm_interface_name} {{
458    function deployCode(string memory _artifact) external returns (address);
459    function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
460    function deployCode(string memory _artifact, bytes memory _args) external returns (address);
461    function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
462    function deployCode(string memory _artifact, uint256 _value) external returns (address);
463    function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
464    function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
465    function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
466    function getCode(string memory _artifact) external view returns (bytes memory);
467}}"#
468            ),
469        ));
470    }
471    updates
472}