Skip to main content

foundry_common/preprocessor/
deps.rs

1use super::{
2    data::{ContractData, PreprocessorData},
3    span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar::sema::{
8    Gcx, Hir,
9    hir::{CallArgs, ContractId, Expr, ExprKind, NamedArg, Stmt, StmtKind, TypeKind, Visit},
10    interface::{SourceMap, data_structures::Never, source_map::FileName},
11};
12use std::{
13    collections::{BTreeMap, BTreeSet, HashSet},
14    ops::{ControlFlow, Range},
15    path::{Path, PathBuf},
16};
17
18/// Holds data about referenced source contracts and bytecode dependencies.
19pub(crate) struct PreprocessorDependencies {
20    // Mapping contract id to preprocess -> contract bytecode dependencies.
21    pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22    // Referenced contract ids.
23    pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27    pub fn new(
28        gcx: Gcx<'_>,
29        paths: &[PathBuf],
30        script_paths: &HashSet<PathBuf>,
31        src_dir: &Path,
32        root_dir: &Path,
33        mocks: &mut HashSet<PathBuf>,
34    ) -> Self {
35        let mut preprocessed_contracts = BTreeMap::new();
36        let mut referenced_contracts = HashSet::new();
37        let mut current_mocks = HashSet::new();
38
39        // Helper closure for iterating candidate contracts to preprocess (tests and scripts).
40        let candidate_contracts = || {
41            gcx.hir.contract_ids().filter_map(|id| {
42                let contract = gcx.hir.contract(id);
43                let source = gcx.hir.source(contract.source);
44                let FileName::Real(path) = &source.file.name else {
45                    return None;
46                };
47
48                if !paths.contains(path) {
49                    trace!("{} is not test or script", path.display());
50                    return None;
51                }
52
53                Some((id, contract, source, path))
54            })
55        };
56
57        // Collect current mocks.
58        for (_, contract, _, path) in candidate_contracts() {
59            if contract.linearized_bases.iter().any(|base_id| {
60                let base = gcx.hir.contract(*base_id);
61                matches!(
62                    &gcx.hir.source(base.source).file.name,
63                    FileName::Real(base_path) if base_path.starts_with(src_dir)
64                )
65            }) {
66                let mock_path = root_dir.join(path);
67                trace!("found mock contract {}", mock_path.display());
68                current_mocks.insert(mock_path);
69            }
70        }
71
72        // Collect dependencies for non-mock test/script contracts.
73        for (contract_id, contract, source, path) in candidate_contracts() {
74            let full_path = root_dir.join(path);
75
76            if current_mocks.contains(&full_path) {
77                trace!("{} is a mock, skipping", path.display());
78                continue;
79            }
80
81            // Make sure current contract is not in list of mocks (could happen when a contract
82            // which used to be a mock is refactored to a non-mock implementation).
83            mocks.remove(&full_path);
84
85            // Treat the contract as a script when its file lives under the configured script
86            // directory, or when it inherits from a `Script` base (forge-std). The inheritance
87            // check covers atypical layouts where script contracts are placed under `src/`.
88            let is_script = script_paths.contains(path)
89                || contract
90                    .linearized_bases
91                    .iter()
92                    .skip(1)
93                    .any(|base_id| gcx.hir.contract(*base_id).name.as_str() == "Script");
94            let mut deps_collector =
95                BytecodeDependencyCollector::new(gcx, source.file.src.as_str(), src_dir, is_script);
96            // Analyze current contract.
97            let _ = deps_collector.walk_contract(contract);
98            // Ignore empty test contracts declared in source files with other contracts.
99            if !deps_collector.dependencies.is_empty() {
100                preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
101            }
102
103            // Record collected referenced contract ids.
104            referenced_contracts.extend(deps_collector.referenced_contracts);
105        }
106
107        // Add current mocks.
108        mocks.extend(current_mocks);
109
110        Self { preprocessed_contracts, referenced_contracts }
111    }
112}
113
114/// Represents a bytecode dependency kind.
115#[derive(Debug)]
116enum BytecodeDependencyKind {
117    /// `type(Contract).creationCode`
118    CreationCode,
119    /// `new Contract`.
120    New {
121        /// Contract name.
122        name: String,
123        /// Constructor args length.
124        args_length: usize,
125        /// Constructor call args offset.
126        call_args_offset: usize,
127        /// `msg.value` (if any) used when creating contract.
128        value: Option<String>,
129        /// `salt` (if any) used when creating contract.
130        salt: Option<String>,
131        /// Whether it's a try contract creation statement, with custom return.
132        try_stmt: Option<bool>,
133    },
134}
135
136/// Represents a single bytecode dependency.
137#[derive(Debug)]
138pub(crate) struct BytecodeDependency {
139    /// Dependency kind.
140    kind: BytecodeDependencyKind,
141    /// Source map location of this dependency.
142    loc: Range<usize>,
143    /// HIR id of referenced contract.
144    referenced_contract: ContractId,
145}
146
147/// Walks over contract HIR and collects [`BytecodeDependency`]s and referenced contracts.
148struct BytecodeDependencyCollector<'gcx, 'src> {
149    /// Source map, used for determining contract item locations.
150    gcx: Gcx<'gcx>,
151    /// Source content of current contract.
152    src: &'src str,
153    /// Project source dir, used to determine if referenced contract is a source contract.
154    src_dir: &'src Path,
155    /// Whether the contract being analyzed lives in a script file.
156    /// Salted `new Contract{salt:...}()` in scripts must not be rewritten: at broadcast depth
157    /// Foundry redirects native CREATE2 through the deterministic factory, preserving the salt.
158    /// `vm.deployCode` runs at a deeper call depth and bypasses that redirect, so the broadcast
159    /// would record a plain CREATE and deploy at the wrong address.
160    is_script: bool,
161    /// Dependencies collected for current contract.
162    dependencies: Vec<BytecodeDependency>,
163    /// Unique HIR ids of contracts referenced from current contract.
164    referenced_contracts: HashSet<ContractId>,
165}
166
167impl<'gcx, 'src> BytecodeDependencyCollector<'gcx, 'src> {
168    fn new(gcx: Gcx<'gcx>, src: &'src str, src_dir: &'src Path, is_script: bool) -> Self {
169        Self {
170            gcx,
171            src,
172            src_dir,
173            is_script,
174            dependencies: vec![],
175            referenced_contracts: HashSet::default(),
176        }
177    }
178
179    /// Collects reference identified as bytecode dependency of analyzed contract.
180    /// Discards any reference that is not in project src directory (e.g. external
181    /// libraries or mock contracts that extend source contracts).
182    fn collect_dependency(&mut self, dependency: BytecodeDependency) {
183        // Salted new-expressions in scripts must not be rewritten. See field doc on `is_script`.
184        if self.is_script
185            && let BytecodeDependencyKind::New { salt: Some(_), .. } = &dependency.kind
186        {
187            trace!("skip salted new-expression in script");
188            return;
189        }
190
191        let contract = self.gcx.hir.contract(dependency.referenced_contract);
192        let source = self.gcx.hir.source(contract.source);
193        let FileName::Real(path) = &source.file.name else {
194            return;
195        };
196
197        if !path.starts_with(self.src_dir) {
198            let path = path.display();
199            trace!("ignore dependency {path}");
200            return;
201        }
202
203        self.referenced_contracts.insert(dependency.referenced_contract);
204        self.dependencies.push(dependency);
205    }
206}
207
208impl<'gcx> Visit<'gcx> for BytecodeDependencyCollector<'gcx, '_> {
209    type BreakValue = Never;
210
211    fn hir(&self) -> &'gcx Hir<'gcx> {
212        &self.gcx.hir
213    }
214
215    fn visit_expr(&mut self, expr: &'gcx Expr<'gcx>) -> ControlFlow<Self::BreakValue> {
216        #[allow(clippy::collapsible_match)]
217        match &expr.kind {
218            ExprKind::Call(call_expr, call_args, named_args) => {
219                if let Some(dependency) = handle_call_expr(
220                    self.src,
221                    self.gcx.sess.source_map(),
222                    expr,
223                    call_expr,
224                    call_args,
225                    named_args,
226                ) {
227                    self.collect_dependency(dependency);
228                }
229            }
230            ExprKind::Member(member_expr, ident) => {
231                if let ExprKind::TypeCall(ty) = &member_expr.kind
232                    && let TypeKind::Custom(contract_id) = &ty.kind
233                    && ident.name.as_str() == "creationCode"
234                    && let Some(contract_id) = contract_id.as_contract()
235                {
236                    self.collect_dependency(BytecodeDependency {
237                        kind: BytecodeDependencyKind::CreationCode,
238                        loc: span_to_range(self.gcx.sess.source_map(), expr.span),
239                        referenced_contract: contract_id,
240                    });
241                }
242            }
243            _ => {}
244        }
245        self.walk_expr(expr)
246    }
247
248    fn visit_stmt(&mut self, stmt: &'gcx Stmt<'gcx>) -> ControlFlow<Self::BreakValue> {
249        if let StmtKind::Try(stmt_try) = stmt.kind
250            && let ExprKind::Call(call_expr, call_args, named_args) = &stmt_try.expr.kind
251            && let Some(mut dependency) = handle_call_expr(
252                self.src,
253                self.gcx.sess.source_map(),
254                &stmt_try.expr,
255                call_expr,
256                call_args,
257                named_args,
258            )
259        {
260            let has_custom_return = if let Some(clause) = stmt_try.clauses.first()
261                && clause.args.len() == 1
262                && let Some(ret_var) = clause.args.first()
263                && let TypeKind::Custom(_) = self.hir().variable(*ret_var).ty.kind
264            {
265                true
266            } else {
267                false
268            };
269
270            if let BytecodeDependencyKind::New { try_stmt, .. } = &mut dependency.kind {
271                *try_stmt = Some(has_custom_return);
272            }
273            self.collect_dependency(dependency);
274
275            for clause in stmt_try.clauses {
276                for &var in clause.args {
277                    self.visit_nested_var(var)?;
278                }
279                for stmt in clause.block.stmts {
280                    self.visit_stmt(stmt)?;
281                }
282            }
283            return ControlFlow::Continue(());
284        }
285        self.walk_stmt(stmt)
286    }
287}
288
289/// Helper function to analyze and extract bytecode dependency from a given call expression.
290fn handle_call_expr(
291    src: &str,
292    source_map: &SourceMap,
293    parent_expr: &Expr<'_>,
294    call_expr: &Expr<'_>,
295    call_args: &CallArgs<'_>,
296    named_args: &Option<&[NamedArg<'_>]>,
297) -> Option<BytecodeDependency> {
298    if let ExprKind::New(ty_new) = &call_expr.kind
299        && let TypeKind::Custom(item_id) = ty_new.kind
300        && let Some(contract_id) = item_id.as_contract()
301    {
302        let name_loc = span_to_range(source_map, ty_new.span);
303        let name = &src[name_loc];
304
305        // Calculate offset to remove named args, e.g. for an expression like
306        // `new Counter {value: 333} (  address(this))`
307        // the offset will be used to replace `{value: 333} (  ` with `(`
308        let call_args_offset = if named_args.is_some() && !call_args.is_empty() {
309            (call_args.span.lo() - ty_new.span.hi()).to_usize()
310        } else {
311            0
312        };
313
314        let args_len = parent_expr.span.hi() - ty_new.span.hi();
315        return Some(BytecodeDependency {
316            kind: BytecodeDependencyKind::New {
317                name: name.to_string(),
318                args_length: args_len.to_usize(),
319                call_args_offset,
320                value: named_arg(src, named_args, "value", source_map),
321                salt: named_arg(src, named_args, "salt", source_map),
322                try_stmt: None,
323            },
324            loc: span_to_range(source_map, call_expr.span),
325            referenced_contract: contract_id,
326        });
327    }
328    None
329}
330
331/// Helper function to extract value of a given named arg.
332fn named_arg(
333    src: &str,
334    named_args: &Option<&[NamedArg<'_>]>,
335    arg: &str,
336    source_map: &SourceMap,
337) -> Option<String> {
338    named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
339        |named_arg| {
340            let named_arg_loc = span_to_range(source_map, named_arg.value.span);
341            src[named_arg_loc].to_string()
342        },
343    )
344}
345
346/// Goes over all test/script files and replaces bytecode dependencies with cheatcode
347/// invocations.
348///
349/// Special handling of try/catch statements with custom returns, where the try statement becomes
350/// ```solidity
351/// try this.addressToCounter() returns (Counter c)
352/// ```
353/// and helper to cast address is appended
354/// ```solidity
355/// function addressToCounter(address addr) returns (Counter) {
356///     return Counter(addr);
357/// }
358/// ```
359pub(crate) fn remove_bytecode_dependencies(
360    gcx: Gcx<'_>,
361    deps: &PreprocessorDependencies,
362    data: &PreprocessorData,
363) -> Updates {
364    let mut updates = Updates::default();
365    for (contract_id, deps) in &deps.preprocessed_contracts {
366        let contract = gcx.hir.contract(*contract_id);
367        let source = gcx.hir.source(contract.source);
368        let FileName::Real(path) = &source.file.name else {
369            continue;
370        };
371
372        let updates = updates.entry(path.clone()).or_default();
373        let mut used_helpers = BTreeSet::new();
374
375        let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
376        // `address(uint160(uint256(keccak256("hevm cheat code"))))`
377        let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
378        let mut try_catch_helpers: HashSet<&str> = HashSet::default();
379
380        for dep in deps {
381            let Some(ContractData { artifact, constructor_data, .. }) =
382                data.get(&dep.referenced_contract)
383            else {
384                continue;
385            };
386
387            match &dep.kind {
388                BytecodeDependencyKind::CreationCode => {
389                    // for creation code we need to just call getCode
390                    updates.insert((
391                        dep.loc.start,
392                        dep.loc.end,
393                        format!("{vm}.getCode(\"{artifact}\")"),
394                    ));
395                }
396                BytecodeDependencyKind::New {
397                    name,
398                    args_length,
399                    call_args_offset,
400                    value,
401                    salt,
402                    try_stmt,
403                } => {
404                    let (mut update, closing_seq) = if let Some(has_ret) = try_stmt {
405                        if *has_ret {
406                            // try this.addressToCounter1() returns (Counter c)
407                            try_catch_helpers.insert(name);
408                            (format!("this.addressTo{name}{id}(", id = contract_id.get()), "}))")
409                        } else {
410                            (String::new(), "})")
411                        }
412                    } else {
413                        (format!("{name}(payable("), "})))")
414                    };
415                    update.push_str(&format!("{vm}.deployCode({{"));
416                    update.push_str(&format!("_artifact: \"{artifact}\""));
417
418                    if let Some(value) = value {
419                        update.push_str(", ");
420                        update.push_str(&format!("_value: {value}"));
421                    }
422
423                    if let Some(salt) = salt {
424                        update.push_str(", ");
425                        update.push_str(&format!("_salt: {salt}"));
426                    }
427
428                    if constructor_data.is_some() {
429                        // Insert our helper
430                        used_helpers.insert(dep.referenced_contract);
431
432                        update.push_str(", ");
433                        update.push_str(&format!(
434                            "_args: encodeArgs{id}(DeployHelper{id}.FoundryPpConstructorArgs",
435                            id = dep.referenced_contract.get()
436                        ));
437                        updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
438
439                        updates.insert((
440                            dep.loc.end + args_length,
441                            dep.loc.end + args_length,
442                            format!("){closing_seq}"),
443                        ));
444                    } else {
445                        update.push_str(closing_seq);
446                        updates.insert((dep.loc.start, dep.loc.end + args_length, update));
447                    }
448                }
449            };
450        }
451
452        // Add try catch statements after last function of the test contract.
453        if !try_catch_helpers.is_empty()
454            && let Some(last_fn_id) = contract.functions().last()
455        {
456            let last_fn_range =
457                span_to_range(gcx.sess.source_map(), gcx.hir.function(last_fn_id).span);
458            let to_address_fns = try_catch_helpers
459                .iter()
460                .map(|ty| {
461                    format!(
462                        r#"
463                            function addressTo{ty}{id}(address addr) public pure returns ({ty}) {{
464                                return {ty}(addr);
465                            }}
466                        "#,
467                        id = contract_id.get()
468                    )
469                })
470                .collect::<String>();
471
472            updates.insert((last_fn_range.end, last_fn_range.end, to_address_fns));
473        }
474
475        let helper_imports = used_helpers.into_iter().map(|id| {
476            let id = id.get();
477            format!(
478                "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
479            )
480        }).join("\n");
481        updates.insert((
482            source.file.src.len(),
483            source.file.src.len(),
484            format!(
485                r#"
486{helper_imports}
487
488interface {vm_interface_name} {{
489    function deployCode(string memory _artifact) external returns (address);
490    function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
491    function deployCode(string memory _artifact, bytes memory _args) external returns (address);
492    function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
493    function deployCode(string memory _artifact, uint256 _value) external returns (address);
494    function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
495    function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
496    function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
497    function getCode(string memory _artifact) external view returns (bytes memory);
498}}"#
499            ),
500        ));
501    }
502    updates
503}