foundry_evm_fuzz/strategies/
literals.rs

1use alloy_dyn_abi::DynSolType;
2use alloy_primitives::{
3    B256, Bytes, I256, U256, keccak256,
4    map::{B256IndexSet, HashMap, IndexSet},
5};
6use foundry_common::Analysis;
7use foundry_compilers::ProjectPathsConfig;
8use solar::{
9    ast::{self, Visit},
10    interface::source_map::FileName,
11};
12use std::{
13    ops::ControlFlow,
14    sync::{Arc, OnceLock},
15};
16
17#[derive(Clone, Debug)]
18pub struct LiteralsDictionary {
19    maps: Arc<OnceLock<LiteralMaps>>,
20}
21
22impl Default for LiteralsDictionary {
23    fn default() -> Self {
24        Self::new(None, None, usize::MAX)
25    }
26}
27
28impl LiteralsDictionary {
29    pub fn new(
30        analysis: Option<Analysis>,
31        paths_config: Option<ProjectPathsConfig>,
32        max_values: usize,
33    ) -> Self {
34        let maps = Arc::new(OnceLock::<LiteralMaps>::new());
35        if let Some(analysis) = analysis
36            && max_values > 0
37        {
38            let maps = maps.clone();
39            // This can't be done in a rayon task (including inside of `get`) because it can cause a
40            // deadlock, since internally `solar` also uses rayon.
41            let _ = std::thread::Builder::new().name("literal-collector".into()).spawn(move || {
42                let _ = maps.get_or_init(|| {
43                    let literals =
44                        LiteralsCollector::process(&analysis, paths_config.as_ref(), max_values);
45                    debug!(
46                        words = literals.words.values().map(|set| set.len()).sum::<usize>(),
47                        strings = literals.strings.len(),
48                        bytes = literals.bytes.len(),
49                        "collected source code literals for fuzz dictionary"
50                    );
51                    literals
52                });
53            });
54        } else {
55            maps.set(Default::default()).unwrap();
56        }
57        Self { maps }
58    }
59
60    /// Returns a reference to the `LiteralMaps`.
61    pub fn get(&self) -> &LiteralMaps {
62        self.maps.wait()
63    }
64
65    /// Test-only helper to seed the dictionary with literal values.
66    #[cfg(test)]
67    pub(crate) fn set(&mut self, map: super::LiteralMaps) {
68        self.maps = Arc::new(OnceLock::new());
69        self.maps.set(map).unwrap();
70    }
71}
72
73#[derive(Debug, Default)]
74pub struct LiteralMaps {
75    pub words: HashMap<DynSolType, B256IndexSet>,
76    pub strings: IndexSet<String>,
77    pub bytes: IndexSet<Bytes>,
78}
79
80#[derive(Debug, Default)]
81pub struct LiteralsCollector {
82    max_values: usize,
83    total_values: usize,
84    output: LiteralMaps,
85}
86
87impl LiteralsCollector {
88    fn new(max_values: usize) -> Self {
89        Self { max_values, ..Default::default() }
90    }
91
92    pub fn process(
93        analysis: &Analysis,
94        paths_config: Option<&ProjectPathsConfig>,
95        max_values: usize,
96    ) -> LiteralMaps {
97        analysis.enter(|compiler| {
98            let mut literals_collector = Self::new(max_values);
99            for source in compiler.sources().iter() {
100                // Ignore scripts, and libs
101                if let Some(paths) = paths_config
102                    && let FileName::Real(source_path) = &source.file.name
103                    && !(source_path.starts_with(&paths.sources) || paths.is_test(source_path))
104                {
105                    continue;
106                }
107
108                if let Some(ast) = &source.ast
109                    && literals_collector.visit_source_unit(ast).is_break()
110                {
111                    break;
112                }
113            }
114
115            literals_collector.output
116        })
117    }
118}
119
120impl<'ast> ast::Visit<'ast> for LiteralsCollector {
121    type BreakValue = ();
122
123    fn visit_expr(&mut self, expr: &'ast ast::Expr<'ast>) -> ControlFlow<()> {
124        // Stop early if we've hit the limit
125        if self.total_values >= self.max_values {
126            return ControlFlow::Break(());
127        }
128
129        // Handle unary negation of number literals
130        if let ast::ExprKind::Unary(un_op, inner_expr) = &expr.kind
131            && un_op.kind == ast::UnOpKind::Neg
132            && let ast::ExprKind::Lit(lit, _) = &inner_expr.kind
133            && let ast::LitKind::Number(n) = &lit.kind
134        {
135            // Compute the negative I256 value
136            if let Ok(pos_i256) = I256::try_from(*n) {
137                let neg_value = -pos_i256;
138                let neg_b256 = B256::from(neg_value.into_raw());
139
140                // Store under all intN sizes that can represent this value
141                for bits in [16, 32, 64, 128, 256] {
142                    if can_fit_int(neg_value, bits)
143                        && self
144                            .output
145                            .words
146                            .entry(DynSolType::Int(bits))
147                            .or_default()
148                            .insert(neg_b256)
149                    {
150                        self.total_values += 1;
151                    }
152                }
153            }
154
155            // Continue walking the expression
156            return self.walk_expr(expr);
157        }
158
159        // Handle literals
160        if let ast::ExprKind::Lit(lit, _) = &expr.kind {
161            let is_new = match &lit.kind {
162                ast::LitKind::Number(n) => {
163                    let pos_value = U256::from(*n);
164                    let pos_b256 = B256::from(pos_value);
165
166                    // Store under all uintN sizes that can represent this value
167                    for bits in [8, 16, 32, 64, 128, 256] {
168                        if can_fit_uint(pos_value, bits)
169                            && self
170                                .output
171                                .words
172                                .entry(DynSolType::Uint(bits))
173                                .or_default()
174                                .insert(pos_b256)
175                        {
176                            self.total_values += 1;
177                        }
178                    }
179                    false // already handled inserts individually
180                }
181                ast::LitKind::Address(addr) => self
182                    .output
183                    .words
184                    .entry(DynSolType::Address)
185                    .or_default()
186                    .insert(addr.into_word()),
187                ast::LitKind::Str(ast::StrKind::Hex, sym, _) => {
188                    self.output.bytes.insert(Bytes::copy_from_slice(sym.as_byte_str()))
189                }
190                ast::LitKind::Str(_, sym, _) => {
191                    let s = String::from_utf8_lossy(sym.as_byte_str()).into_owned();
192                    // For strings, also store the hashed version
193                    let hash = keccak256(s.as_bytes());
194                    if self.output.words.entry(DynSolType::FixedBytes(32)).or_default().insert(hash)
195                    {
196                        self.total_values += 1;
197                    }
198                    // And the right-padded version if it fits.
199                    if s.len() <= 32 {
200                        let padded = B256::right_padding_from(s.as_bytes());
201                        if self
202                            .output
203                            .words
204                            .entry(DynSolType::FixedBytes(32))
205                            .or_default()
206                            .insert(padded)
207                        {
208                            self.total_values += 1;
209                        }
210                    }
211                    self.output.strings.insert(s)
212                }
213                ast::LitKind::Bool(..) | ast::LitKind::Rational(..) | ast::LitKind::Err(..) => {
214                    false // ignore
215                }
216            };
217
218            if is_new {
219                self.total_values += 1;
220            }
221        }
222
223        self.walk_expr(expr)
224    }
225}
226
227/// Checks if a signed integer value can fit in intN type.
228fn can_fit_int(value: I256, bits: usize) -> bool {
229    // Calculate the maximum positive value for intN: 2^(N-1) - 1
230    let max_val = I256::try_from((U256::from(1) << (bits - 1)) - U256::from(1))
231        .expect("max value should fit in I256");
232    // Calculate the minimum negative value for intN: -2^(N-1)
233    let min_val = -max_val - I256::ONE;
234
235    value >= min_val && value <= max_val
236}
237
238/// Checks if an unsigned integer value can fit in uintN type.
239fn can_fit_uint(value: U256, bits: usize) -> bool {
240    if bits == 256 {
241        return true;
242    }
243    // Calculate the maximum value for uintN: 2^N - 1
244    let max_val = (U256::from(1) << bits) - U256::from(1);
245    value <= max_val
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use alloy_primitives::address;
252    use solar::interface::{Session, source_map};
253
254    const SOURCE: &str = r#"
255    contract Magic {
256        // plain literals
257        address constant DAI = 0x6B175474E89094C44Da98b954EedeAC495271d0F;
258        uint64 constant MAGIC_NUMBER = 1122334455;
259        int32 constant MAGIC_INT = -777;
260        bytes32 constant MAGIC_WORD = "abcd1234";
261        bytes constant MAGIC_BYTES = hex"deadbeef";
262        string constant MAGIC_STRING = "xyzzy";
263
264        // constant exprs with folding
265        uint256 constant NEG_FOLDING = uint(-2);
266        uint256 constant BIN_FOLDING = 2 * 2 ether;
267        bytes32 constant IMPLEMENTATION_SLOT = bytes32(uint256(keccak256('eip1967.proxy.implementation')) - 1);
268    }"#;
269
270    #[test]
271    fn test_literals_collector_coverage() {
272        let map = process_source_literals(SOURCE);
273
274        // Expected values from the SOURCE contract
275        let addr = address!("0x6B175474E89094C44Da98b954EedeAC495271d0F").into_word();
276        let num = B256::from(U256::from(1122334455u64));
277        let int = B256::from(I256::try_from(-777i32).unwrap().into_raw());
278        let word = B256::right_padding_from(b"abcd1234");
279        let dyn_bytes = Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef]);
280
281        assert_word(&map, DynSolType::Address, addr, "Expected DAI in address set");
282        assert_word(&map, DynSolType::Uint(64), num, "Expected MAGIC_NUMBER in uint64 set");
283        assert_word(&map, DynSolType::Int(32), int, "Expected MAGIC_INT in int32 set");
284        assert_word(&map, DynSolType::FixedBytes(32), word, "Expected MAGIC_WORD in bytes32 set");
285        assert!(map.strings.contains("xyzzy"), "Expected MAGIC_STRING to be collected");
286        assert!(
287            map.strings.contains("eip1967.proxy.implementation"),
288            "Expected IMPLEMENTATION_SLOT in string set"
289        );
290        assert!(map.bytes.contains(&dyn_bytes), "Expected MAGIC_BYTES in bytes set");
291    }
292
293    #[test]
294    fn test_literals_collector_size() {
295        let literals = process_source_literals(SOURCE);
296
297        // Helper to get count for a type, returns 0 if not present
298        let count = |ty: DynSolType| literals.words.get(&ty).map_or(0, |set| set.len());
299
300        assert_eq!(count(DynSolType::Address), 1, "Address literal count mismatch");
301        assert_eq!(literals.strings.len(), 3, "String literals count mismatch");
302        assert_eq!(literals.bytes.len(), 1, "Byte literals count mismatch");
303
304        // Unsigned integers - MAGIC_NUMBER (1122334455) appears in multiple sizes
305        assert_eq!(count(DynSolType::Uint(8)), 2, "Uint(8) count mismatch");
306        assert_eq!(count(DynSolType::Uint(16)), 3, "Uint(16) count mismatch");
307        assert_eq!(count(DynSolType::Uint(32)), 4, "Uint(32) count mismatch");
308        assert_eq!(count(DynSolType::Uint(64)), 5, "Uint(64) count mismatch");
309        assert_eq!(count(DynSolType::Uint(128)), 5, "Uint(128) count mismatch");
310        assert_eq!(count(DynSolType::Uint(256)), 5, "Uint(256) count mismatch");
311
312        // Signed integers - MAGIC_INT (-777) appears in multiple sizes
313        assert_eq!(count(DynSolType::Int(16)), 2, "Int(16) count mismatch");
314        assert_eq!(count(DynSolType::Int(32)), 2, "Int(32) count mismatch");
315        assert_eq!(count(DynSolType::Int(64)), 2, "Int(64) count mismatch");
316        assert_eq!(count(DynSolType::Int(128)), 2, "Int(128) count mismatch");
317        assert_eq!(count(DynSolType::Int(256)), 2, "Int(256) count mismatch");
318
319        // FixedBytes(32) includes:
320        // - MAGIC_WORD
321        // - String literals (hashed and right-padded versions)
322        assert_eq!(count(DynSolType::FixedBytes(32)), 6, "FixedBytes(32) count mismatch");
323
324        // Total count check
325        assert_eq!(
326            literals.words.values().map(|set| set.len()).sum::<usize>(),
327            41,
328            "Total word values count mismatch"
329        );
330    }
331
332    // -- TEST HELPERS ---------------------------------------------------------
333
334    fn process_source_literals(source: &str) -> LiteralMaps {
335        let mut compiler =
336            solar::sema::Compiler::new(Session::builder().with_stderr_emitter().build());
337        compiler
338            .enter_mut(|c| -> std::io::Result<()> {
339                let mut pcx = c.parse();
340                pcx.set_resolve_imports(false);
341
342                pcx.add_file(
343                    c.sess().source_map().new_source_file(source_map::FileName::Stdin, source)?,
344                );
345                pcx.parse();
346                let _ = c.lower_asts();
347                Ok(())
348            })
349            .expect("Failed to compile test source");
350
351        LiteralsCollector::process(&std::sync::Arc::new(compiler), None, usize::MAX)
352    }
353
354    fn assert_word(literals: &LiteralMaps, ty: DynSolType, value: B256, msg: &str) {
355        assert!(literals.words.get(&ty).is_some_and(|set| set.contains(&value)), "{}", msg);
356    }
357}