foundry_common/preprocessor/
deps.rs

1use super::{
2    data::{ContractData, PreprocessorData},
3    span_to_range,
4};
5use foundry_compilers::Updates;
6use itertools::Itertools;
7use solar_parse::interface::Session;
8use solar_sema::{
9    hir::{ContractId, Expr, ExprKind, Hir, NamedArg, TypeKind, Visit},
10    interface::{data_structures::Never, source_map::FileName, SourceMap},
11};
12use std::{
13    collections::{BTreeMap, BTreeSet, HashSet},
14    ops::{ControlFlow, Range},
15    path::{Path, PathBuf},
16};
17
18/// Holds data about referenced source contracts and bytecode dependencies.
19pub(crate) struct PreprocessorDependencies {
20    // Mapping contract id to preprocess -> contract bytecode dependencies.
21    pub preprocessed_contracts: BTreeMap<ContractId, Vec<BytecodeDependency>>,
22    // Referenced contract ids.
23    pub referenced_contracts: HashSet<ContractId>,
24}
25
26impl PreprocessorDependencies {
27    pub fn new(
28        sess: &Session,
29        hir: &Hir<'_>,
30        paths: &[PathBuf],
31        src_dir: &Path,
32        root_dir: &Path,
33        mocks: &mut HashSet<PathBuf>,
34    ) -> Self {
35        let mut preprocessed_contracts = BTreeMap::new();
36        let mut referenced_contracts = HashSet::new();
37        for contract_id in hir.contract_ids() {
38            let contract = hir.contract(contract_id);
39            let source = hir.source(contract.source);
40
41            let FileName::Real(path) = &source.file.name else {
42                continue;
43            };
44
45            // Collect dependencies only for tests and scripts.
46            if !paths.contains(path) {
47                let path = path.display();
48                trace!("{path} is not test or script");
49                continue;
50            }
51
52            // Do not collect dependencies for mock contracts. Walk through base contracts and
53            // check if they're from src dir.
54            if contract.linearized_bases.iter().any(|base_contract_id| {
55                let base_contract = hir.contract(*base_contract_id);
56                let FileName::Real(path) = &hir.source(base_contract.source).file.name else {
57                    return false;
58                };
59                path.starts_with(src_dir)
60            }) {
61                // Record mock contracts to be evicted from preprocessed cache.
62                mocks.insert(root_dir.join(path));
63                let path = path.display();
64                trace!("found mock contract {path}");
65                continue;
66            } else {
67                // Make sure current contract is not in list of mocks (could happen when a contract
68                // which used to be a mock is refactored to a non-mock implementation).
69                mocks.remove(&root_dir.join(path));
70            }
71
72            let mut deps_collector = BytecodeDependencyCollector::new(
73                sess.source_map(),
74                hir,
75                source.file.src.as_str(),
76                src_dir,
77            );
78            // Analyze current contract.
79            let _ = deps_collector.walk_contract(contract);
80            // Ignore empty test contracts declared in source files with other contracts.
81            if !deps_collector.dependencies.is_empty() {
82                preprocessed_contracts.insert(contract_id, deps_collector.dependencies);
83            }
84            // Record collected referenced contract ids.
85            referenced_contracts.extend(deps_collector.referenced_contracts);
86        }
87        Self { preprocessed_contracts, referenced_contracts }
88    }
89}
90
91/// Represents a bytecode dependency kind.
92#[derive(Debug)]
93enum BytecodeDependencyKind {
94    /// `type(Contract).creationCode`
95    CreationCode,
96    /// `new Contract`.
97    New {
98        /// Contract name.
99        name: String,
100        /// Constructor args length.
101        args_length: usize,
102        /// Constructor call args offset.
103        call_args_offset: usize,
104        /// `msg.value` (if any) used when creating contract.
105        value: Option<String>,
106        /// `salt` (if any) used when creating contract.
107        salt: Option<String>,
108    },
109}
110
111/// Represents a single bytecode dependency.
112#[derive(Debug)]
113pub(crate) struct BytecodeDependency {
114    /// Dependency kind.
115    kind: BytecodeDependencyKind,
116    /// Source map location of this dependency.
117    loc: Range<usize>,
118    /// HIR id of referenced contract.
119    referenced_contract: ContractId,
120}
121
122/// Walks over contract HIR and collects [`BytecodeDependency`]s and referenced contracts.
123struct BytecodeDependencyCollector<'hir> {
124    /// Source map, used for determining contract item locations.
125    source_map: &'hir SourceMap,
126    /// Parsed HIR.
127    hir: &'hir Hir<'hir>,
128    /// Source content of current contract.
129    src: &'hir str,
130    /// Project source dir, used to determine if referenced contract is a source contract.
131    src_dir: &'hir Path,
132    /// Dependencies collected for current contract.
133    dependencies: Vec<BytecodeDependency>,
134    /// Unique HIR ids of contracts referenced from current contract.
135    referenced_contracts: HashSet<ContractId>,
136}
137
138impl<'hir> BytecodeDependencyCollector<'hir> {
139    fn new(
140        source_map: &'hir SourceMap,
141        hir: &'hir Hir<'hir>,
142        src: &'hir str,
143        src_dir: &'hir Path,
144    ) -> Self {
145        Self {
146            source_map,
147            hir,
148            src,
149            src_dir,
150            dependencies: vec![],
151            referenced_contracts: HashSet::default(),
152        }
153    }
154
155    /// Collects reference identified as bytecode dependency of analyzed contract.
156    /// Discards any reference that is not in project src directory (e.g. external
157    /// libraries or mock contracts that extend source contracts).
158    fn collect_dependency(&mut self, dependency: BytecodeDependency) {
159        let contract = self.hir.contract(dependency.referenced_contract);
160        let source = self.hir.source(contract.source);
161        let FileName::Real(path) = &source.file.name else {
162            return;
163        };
164
165        if !path.starts_with(self.src_dir) {
166            let path = path.display();
167            trace!("ignore dependency {path}");
168            return;
169        }
170
171        self.referenced_contracts.insert(dependency.referenced_contract);
172        self.dependencies.push(dependency);
173    }
174}
175
176impl<'hir> Visit<'hir> for BytecodeDependencyCollector<'hir> {
177    type BreakValue = Never;
178
179    fn hir(&self) -> &'hir Hir<'hir> {
180        self.hir
181    }
182
183    fn visit_expr(&mut self, expr: &'hir Expr<'hir>) -> ControlFlow<Self::BreakValue> {
184        match &expr.kind {
185            ExprKind::Call(ty, call_args, named_args) => {
186                if let ExprKind::New(ty_new) = &ty.kind {
187                    if let TypeKind::Custom(item_id) = ty_new.kind {
188                        if let Some(contract_id) = item_id.as_contract() {
189                            let name_loc = span_to_range(self.source_map, ty_new.span);
190                            let name = &self.src[name_loc];
191
192                            // Calculate offset to remove named args, e.g. for an expression like
193                            // `new Counter {value: 333} (  address(this))`
194                            // the offset will be used to replace `{value: 333} (  ` with `(`
195                            let call_args_offset = if named_args.is_some() && !call_args.is_empty()
196                            {
197                                (call_args.span().lo() - ty_new.span.hi()).to_usize()
198                            } else {
199                                0
200                            };
201
202                            let args_len = expr.span.hi() - ty_new.span.hi();
203                            self.collect_dependency(BytecodeDependency {
204                                kind: BytecodeDependencyKind::New {
205                                    name: name.to_string(),
206                                    args_length: args_len.to_usize(),
207                                    call_args_offset,
208                                    value: named_arg(
209                                        self.src,
210                                        named_args,
211                                        "value",
212                                        self.source_map,
213                                    ),
214                                    salt: named_arg(self.src, named_args, "salt", self.source_map),
215                                },
216                                loc: span_to_range(self.source_map, ty.span),
217                                referenced_contract: contract_id,
218                            });
219                        }
220                    }
221                }
222            }
223            ExprKind::Member(member_expr, ident) => {
224                if let ExprKind::TypeCall(ty) = &member_expr.kind {
225                    if let TypeKind::Custom(contract_id) = &ty.kind {
226                        if ident.name.as_str() == "creationCode" {
227                            if let Some(contract_id) = contract_id.as_contract() {
228                                self.collect_dependency(BytecodeDependency {
229                                    kind: BytecodeDependencyKind::CreationCode,
230                                    loc: span_to_range(self.source_map, expr.span),
231                                    referenced_contract: contract_id,
232                                });
233                            }
234                        }
235                    }
236                }
237            }
238            _ => {}
239        }
240        self.walk_expr(expr)
241    }
242}
243
244/// Helper function to extract value of a given named arg.
245fn named_arg(
246    src: &str,
247    named_args: &Option<&[NamedArg<'_>]>,
248    arg: &str,
249    source_map: &SourceMap,
250) -> Option<String> {
251    named_args.unwrap_or_default().iter().find(|named_arg| named_arg.name.as_str() == arg).map(
252        |named_arg| {
253            let named_arg_loc = span_to_range(source_map, named_arg.value.span);
254            src[named_arg_loc].to_string()
255        },
256    )
257}
258
259/// Goes over all test/script files and replaces bytecode dependencies with cheatcode
260/// invocations.
261pub(crate) fn remove_bytecode_dependencies(
262    hir: &Hir<'_>,
263    deps: &PreprocessorDependencies,
264    data: &PreprocessorData,
265) -> Updates {
266    let mut updates = Updates::default();
267    for (contract_id, deps) in &deps.preprocessed_contracts {
268        let contract = hir.contract(*contract_id);
269        let source = hir.source(contract.source);
270        let FileName::Real(path) = &source.file.name else {
271            continue;
272        };
273
274        let updates = updates.entry(path.clone()).or_default();
275        let mut used_helpers = BTreeSet::new();
276
277        let vm_interface_name = format!("VmContractHelper{}", contract_id.get());
278        // `address(uint160(uint256(keccak256("hevm cheat code"))))`
279        let vm = format!("{vm_interface_name}(0x7109709ECfa91a80626fF3989D68f67F5b1DD12D)");
280
281        for dep in deps {
282            let Some(ContractData { artifact, constructor_data, .. }) =
283                data.get(&dep.referenced_contract)
284            else {
285                continue;
286            };
287
288            match &dep.kind {
289                BytecodeDependencyKind::CreationCode => {
290                    // for creation code we need to just call getCode
291                    updates.insert((
292                        dep.loc.start,
293                        dep.loc.end,
294                        format!("{vm}.getCode(\"{artifact}\")"),
295                    ));
296                }
297                BytecodeDependencyKind::New {
298                    name,
299                    args_length,
300                    call_args_offset,
301                    value,
302                    salt,
303                } => {
304                    let mut update = format!("{name}(payable({vm}.deployCode({{");
305                    update.push_str(&format!("_artifact: \"{artifact}\""));
306
307                    if let Some(value) = value {
308                        update.push_str(", ");
309                        update.push_str(&format!("_value: {value}"));
310                    }
311
312                    if let Some(salt) = salt {
313                        update.push_str(", ");
314                        update.push_str(&format!("_salt: {salt}"));
315                    }
316
317                    if constructor_data.is_some() {
318                        // Insert our helper
319                        used_helpers.insert(dep.referenced_contract);
320
321                        update.push_str(", ");
322                        update.push_str(&format!(
323                            "_args: encodeArgs{id}(DeployHelper{id}.ConstructorArgs",
324                            id = dep.referenced_contract.get()
325                        ));
326                        if *call_args_offset > 0 {
327                            update.push('(');
328                        }
329                        updates.insert((dep.loc.start, dep.loc.end + call_args_offset, update));
330                        updates.insert((
331                            dep.loc.end + args_length,
332                            dep.loc.end + args_length,
333                            ")})))".to_string(),
334                        ));
335                    } else {
336                        update.push_str("})))");
337                        updates.insert((dep.loc.start, dep.loc.end + args_length, update));
338                    }
339                }
340            };
341        }
342        let helper_imports = used_helpers.into_iter().map(|id| {
343            let id = id.get();
344            format!(
345                "import {{DeployHelper{id}, encodeArgs{id}}} from \"foundry-pp/DeployHelper{id}.sol\";",
346            )
347        }).join("\n");
348        updates.insert((
349            source.file.src.len(),
350            source.file.src.len(),
351            format!(
352                r#"
353{helper_imports}
354
355interface {vm_interface_name} {{
356    function deployCode(string memory _artifact) external returns (address);
357    function deployCode(string memory _artifact, bytes32 _salt) external returns (address);
358    function deployCode(string memory _artifact, bytes memory _args) external returns (address);
359    function deployCode(string memory _artifact, bytes memory _args, bytes32 _salt) external returns (address);
360    function deployCode(string memory _artifact, uint256 _value) external returns (address);
361    function deployCode(string memory _artifact, uint256 _value, bytes32 _salt) external returns (address);
362    function deployCode(string memory _artifact, bytes memory _args, uint256 _value) external returns (address);
363    function deployCode(string memory _artifact, bytes memory _args, uint256 _value, bytes32 _salt) external returns (address);
364    function getCode(string memory _artifact) external returns (bytes memory);
365}}"#
366            ),
367        ));
368    }
369    updates
370}