foundry_evm_fuzz/strategies/
literals.rs

1use alloy_dyn_abi::DynSolType;
2use alloy_primitives::{
3    B256, Bytes, I256, U256, keccak256,
4    map::{B256IndexSet, HashMap, IndexSet},
5};
6use foundry_common::Analysis;
7use foundry_compilers::ProjectPathsConfig;
8use solar::{
9    ast::{self, Visit},
10    interface::source_map::FileName,
11};
12use std::{ops::ControlFlow, sync::OnceLock};
13
14#[derive(Debug, Default)]
15pub struct LiteralsDictionary {
16    /// Data required for initialization, captured from `EvmFuzzState::new`.
17    analysis: Option<Analysis>,
18    paths_config: Option<ProjectPathsConfig>,
19    max_values: usize,
20
21    /// Lazy initialized literal maps.
22    maps: OnceLock<LiteralMaps>,
23}
24
25impl LiteralsDictionary {
26    pub fn new(
27        analysis: Option<Analysis>,
28        paths_config: Option<ProjectPathsConfig>,
29        max_values: usize,
30    ) -> Self {
31        Self { analysis, paths_config, max_values, maps: OnceLock::default() }
32    }
33
34    /// Returns a reference to the `LiteralMaps`, initializing them on the first call.
35    pub fn get(&self) -> &LiteralMaps {
36        self.maps.get_or_init(|| {
37            if let Some(analysis) = &self.analysis {
38                let literals = LiteralsCollector::process(
39                    analysis,
40                    self.paths_config.as_ref(),
41                    self.max_values,
42                );
43                trace!(
44                    words = literals.words.values().map(|set| set.len()).sum::<usize>(),
45                    strings = literals.strings.len(),
46                    bytes = literals.bytes.len(),
47                    "collected source code literals for fuzz dictionary"
48                );
49                literals
50            } else {
51                LiteralMaps::default()
52            }
53        })
54    }
55
56    /// Takes ownership of the dictionary words, leaving an empty map in their place.
57    /// Ensures the map is initialized before taking its contents.
58    pub fn take_words(&mut self) -> HashMap<DynSolType, B256IndexSet> {
59        let _ = self.get();
60        self.maps.get_mut().map(|m| std::mem::take(&mut m.words)).unwrap_or_default()
61    }
62
63    #[cfg(test)]
64    /// Test-only helper to seed the dictionary with literal values.
65    pub(crate) fn set(&mut self, map: super::LiteralMaps) {
66        let _ = self.maps.set(map);
67    }
68}
69
70#[derive(Debug, Default)]
71pub struct LiteralMaps {
72    pub words: HashMap<DynSolType, B256IndexSet>,
73    pub strings: IndexSet<String>,
74    pub bytes: IndexSet<Bytes>,
75}
76
77#[derive(Debug, Default)]
78pub struct LiteralsCollector {
79    max_values: usize,
80    total_values: usize,
81    output: LiteralMaps,
82}
83
84impl LiteralsCollector {
85    fn new(max_values: usize) -> Self {
86        Self { max_values, ..Default::default() }
87    }
88
89    pub fn process(
90        analysis: &Analysis,
91        paths_config: Option<&ProjectPathsConfig>,
92        max_values: usize,
93    ) -> LiteralMaps {
94        analysis.enter(|compiler| {
95            let mut literals_collector = Self::new(max_values);
96            for source in compiler.sources().iter() {
97                // Ignore scripts, and libs
98                if let Some(paths) = paths_config
99                    && let FileName::Real(source_path) = &source.file.name
100                    && !(source_path.starts_with(&paths.sources) || paths.is_test(source_path))
101                {
102                    continue;
103                }
104
105                if let Some(ref ast) = source.ast {
106                    let _ = literals_collector.visit_source_unit(ast);
107                }
108            }
109
110            literals_collector.output
111        })
112    }
113}
114
115impl<'ast> ast::Visit<'ast> for LiteralsCollector {
116    type BreakValue = ();
117
118    fn visit_expr(&mut self, expr: &'ast ast::Expr<'ast>) -> ControlFlow<()> {
119        // Stop early if we've hit the limit
120        if self.total_values >= self.max_values {
121            return ControlFlow::Break(());
122        }
123
124        // Handle unary negation of number literals
125        if let ast::ExprKind::Unary(un_op, inner_expr) = &expr.kind
126            && un_op.kind == ast::UnOpKind::Neg
127            && let ast::ExprKind::Lit(lit, _) = &inner_expr.kind
128            && let ast::LitKind::Number(n) = &lit.kind
129        {
130            // Compute the negative I256 value
131            if let Ok(pos_i256) = I256::try_from(*n) {
132                let neg_value = -pos_i256;
133                let neg_b256 = B256::from(neg_value.into_raw());
134
135                // Store under all intN sizes that can represent this value
136                for bits in [16, 32, 64, 128, 256] {
137                    if can_fit_int(neg_value, bits)
138                        && self
139                            .output
140                            .words
141                            .entry(DynSolType::Int(bits))
142                            .or_default()
143                            .insert(neg_b256)
144                    {
145                        self.total_values += 1;
146                    }
147                }
148            }
149
150            // Continue walking the expression
151            return self.walk_expr(expr);
152        }
153
154        // Handle literals
155        if let ast::ExprKind::Lit(lit, _) = &expr.kind {
156            let is_new = match &lit.kind {
157                ast::LitKind::Number(n) => {
158                    let pos_value = U256::from(*n);
159                    let pos_b256 = B256::from(pos_value);
160
161                    // Store under all uintN sizes that can represent this value
162                    for bits in [8, 16, 32, 64, 128, 256] {
163                        if can_fit_uint(pos_value, bits)
164                            && self
165                                .output
166                                .words
167                                .entry(DynSolType::Uint(bits))
168                                .or_default()
169                                .insert(pos_b256)
170                        {
171                            self.total_values += 1;
172                        }
173                    }
174                    false // already handled inserts individually
175                }
176                ast::LitKind::Address(addr) => self
177                    .output
178                    .words
179                    .entry(DynSolType::Address)
180                    .or_default()
181                    .insert(addr.into_word()),
182                ast::LitKind::Str(ast::StrKind::Hex, sym, _) => {
183                    self.output.bytes.insert(Bytes::copy_from_slice(sym.as_byte_str()))
184                }
185                ast::LitKind::Str(_, sym, _) => {
186                    let s = String::from_utf8_lossy(sym.as_byte_str()).into_owned();
187                    // For strings, also store the hashed version
188                    let hash = keccak256(s.as_bytes());
189                    if self.output.words.entry(DynSolType::FixedBytes(32)).or_default().insert(hash)
190                    {
191                        self.total_values += 1;
192                    }
193                    // And the right-padded version if it fits.
194                    if s.len() <= 32 {
195                        let padded = B256::right_padding_from(s.as_bytes());
196                        if self
197                            .output
198                            .words
199                            .entry(DynSolType::FixedBytes(32))
200                            .or_default()
201                            .insert(padded)
202                        {
203                            self.total_values += 1;
204                        }
205                    }
206                    self.output.strings.insert(s)
207                }
208                ast::LitKind::Bool(..) | ast::LitKind::Rational(..) | ast::LitKind::Err(..) => {
209                    false // ignore
210                }
211            };
212
213            if is_new {
214                self.total_values += 1;
215            }
216        }
217
218        self.walk_expr(expr)
219    }
220}
221
222/// Checks if a signed integer value can fit in intN type.
223fn can_fit_int(value: I256, bits: usize) -> bool {
224    // Calculate the maximum positive value for intN: 2^(N-1) - 1
225    let max_val = I256::try_from((U256::from(1) << (bits - 1)) - U256::from(1))
226        .expect("max value should fit in I256");
227    // Calculate the minimum negative value for intN: -2^(N-1)
228    let min_val = -max_val - I256::ONE;
229
230    value >= min_val && value <= max_val
231}
232
233/// Checks if an unsigned integer value can fit in uintN type.
234fn can_fit_uint(value: U256, bits: usize) -> bool {
235    if bits == 256 {
236        return true;
237    }
238    // Calculate the maximum value for uintN: 2^N - 1
239    let max_val = (U256::from(1) << bits) - U256::from(1);
240    value <= max_val
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use alloy_primitives::address;
247    use solar::interface::{Session, source_map};
248
249    const SOURCE: &str = r#"
250    contract Magic {
251        // plain literals
252        address constant DAI = 0x6B175474E89094C44Da98b954EedeAC495271d0F;
253        uint64 constant MAGIC_NUMBER = 1122334455;
254        int32 constant MAGIC_INT = -777;
255        bytes32 constant MAGIC_WORD = "abcd1234";
256        bytes constant MAGIC_BYTES = hex"deadbeef";
257        string constant MAGIC_STRING = "xyzzy";
258
259        // constant exprs with folding
260        uint256 constant NEG_FOLDING = uint(-2);
261        uint256 constant BIN_FOLDING = 2 * 2 ether;
262        bytes32 constant IMPLEMENTATION_SLOT = bytes32(uint256(keccak256('eip1967.proxy.implementation')) - 1);
263    }"#;
264
265    #[test]
266    fn test_literals_collector_coverage() {
267        let map = process_source_literals(SOURCE);
268
269        // Expected values from the SOURCE contract
270        let addr = address!("0x6B175474E89094C44Da98b954EedeAC495271d0F").into_word();
271        let num = B256::from(U256::from(1122334455u64));
272        let int = B256::from(I256::try_from(-777i32).unwrap().into_raw());
273        let word = B256::right_padding_from(b"abcd1234");
274        let dyn_bytes = Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef]);
275
276        assert_word(&map, DynSolType::Address, addr, "Expected DAI in address set");
277        assert_word(&map, DynSolType::Uint(64), num, "Expected MAGIC_NUMBER in uint64 set");
278        assert_word(&map, DynSolType::Int(32), int, "Expected MAGIC_INT in int32 set");
279        assert_word(&map, DynSolType::FixedBytes(32), word, "Expected MAGIC_WORD in bytes32 set");
280        assert!(map.strings.contains("xyzzy"), "Expected MAGIC_STRING to be collected");
281        assert!(
282            map.strings.contains("eip1967.proxy.implementation"),
283            "Expected IMPLEMENTATION_SLOT in string set"
284        );
285        assert!(map.bytes.contains(&dyn_bytes), "Expected MAGIC_BYTES in bytes set");
286    }
287
288    #[test]
289    fn test_literals_collector_size() {
290        let literals = process_source_literals(SOURCE);
291
292        // Helper to get count for a type, returns 0 if not present
293        let count = |ty: DynSolType| literals.words.get(&ty).map_or(0, |set| set.len());
294
295        assert_eq!(count(DynSolType::Address), 1, "Address literal count mismatch");
296        assert_eq!(literals.strings.len(), 3, "String literals count mismatch");
297        assert_eq!(literals.bytes.len(), 1, "Byte literals count mismatch");
298
299        // Unsigned integers - MAGIC_NUMBER (1122334455) appears in multiple sizes
300        assert_eq!(count(DynSolType::Uint(8)), 2, "Uint(8) count mismatch");
301        assert_eq!(count(DynSolType::Uint(16)), 3, "Uint(16) count mismatch");
302        assert_eq!(count(DynSolType::Uint(32)), 4, "Uint(32) count mismatch");
303        assert_eq!(count(DynSolType::Uint(64)), 5, "Uint(64) count mismatch");
304        assert_eq!(count(DynSolType::Uint(128)), 5, "Uint(128) count mismatch");
305        assert_eq!(count(DynSolType::Uint(256)), 5, "Uint(256) count mismatch");
306
307        // Signed integers - MAGIC_INT (-777) appears in multiple sizes
308        assert_eq!(count(DynSolType::Int(16)), 2, "Int(16) count mismatch");
309        assert_eq!(count(DynSolType::Int(32)), 2, "Int(32) count mismatch");
310        assert_eq!(count(DynSolType::Int(64)), 2, "Int(64) count mismatch");
311        assert_eq!(count(DynSolType::Int(128)), 2, "Int(128) count mismatch");
312        assert_eq!(count(DynSolType::Int(256)), 2, "Int(256) count mismatch");
313
314        // FixedBytes(32) includes:
315        // - MAGIC_WORD
316        // - String literals (hashed and right-padded versions)
317        assert_eq!(count(DynSolType::FixedBytes(32)), 6, "FixedBytes(32) count mismatch");
318
319        // Total count check
320        assert_eq!(
321            literals.words.values().map(|set| set.len()).sum::<usize>(),
322            41,
323            "Total word values count mismatch"
324        );
325    }
326
327    // -- TEST HELPERS ---------------------------------------------------------
328
329    fn process_source_literals(source: &str) -> LiteralMaps {
330        let mut compiler =
331            solar::sema::Compiler::new(Session::builder().with_stderr_emitter().build());
332        compiler
333            .enter_mut(|c| -> std::io::Result<()> {
334                let mut pcx = c.parse();
335                pcx.set_resolve_imports(false);
336
337                pcx.add_file(
338                    c.sess().source_map().new_source_file(source_map::FileName::Stdin, source)?,
339                );
340                pcx.parse();
341                let _ = c.lower_asts();
342                Ok(())
343            })
344            .expect("Failed to compile test source");
345
346        LiteralsCollector::process(&std::sync::Arc::new(compiler), None, usize::MAX)
347    }
348
349    fn assert_word(literals: &LiteralMaps, ty: DynSolType, value: B256, msg: &str) {
350        assert!(literals.words.get(&ty).is_some_and(|set| set.contains(&value)), "{}", msg);
351    }
352}