foundry_evm_fuzz/strategies/
literals.rs1use alloy_dyn_abi::DynSolType;
2use alloy_primitives::{
3 B256, Bytes, I256, U256, keccak256,
4 map::{B256IndexSet, HashMap, IndexSet},
5};
6use foundry_common::Analysis;
7use foundry_compilers::ProjectPathsConfig;
8use solar::{
9 ast::{self, Visit},
10 interface::source_map::FileName,
11};
12use std::{
13 ops::ControlFlow,
14 sync::{Arc, OnceLock},
15};
16
17#[derive(Clone, Debug)]
18pub struct LiteralsDictionary {
19 maps: Arc<OnceLock<LiteralMaps>>,
20}
21
22impl Default for LiteralsDictionary {
23 fn default() -> Self {
24 Self::new(None, None, usize::MAX)
25 }
26}
27
28impl LiteralsDictionary {
29 pub fn new(
30 analysis: Option<Analysis>,
31 paths_config: Option<ProjectPathsConfig>,
32 max_values: usize,
33 ) -> Self {
34 let maps = Arc::new(OnceLock::<LiteralMaps>::new());
35 if let Some(analysis) = analysis
36 && max_values > 0
37 {
38 let maps = maps.clone();
39 let _ = std::thread::Builder::new().name("literal-collector".into()).spawn(move || {
42 let _ = maps.get_or_init(|| {
43 let literals =
44 LiteralsCollector::process(&analysis, paths_config.as_ref(), max_values);
45 debug!(
46 words = literals.words.values().map(|set| set.len()).sum::<usize>(),
47 strings = literals.strings.len(),
48 bytes = literals.bytes.len(),
49 "collected source code literals for fuzz dictionary"
50 );
51 literals
52 });
53 });
54 } else {
55 maps.set(Default::default()).unwrap();
56 }
57 Self { maps }
58 }
59
60 pub fn get(&self) -> &LiteralMaps {
62 self.maps.wait()
63 }
64
65 #[cfg(test)]
67 pub(crate) fn set(&mut self, map: super::LiteralMaps) {
68 self.maps = Arc::new(OnceLock::new());
69 self.maps.set(map).unwrap();
70 }
71}
72
73#[derive(Debug, Default)]
74pub struct LiteralMaps {
75 pub words: HashMap<DynSolType, B256IndexSet>,
76 pub strings: IndexSet<String>,
77 pub bytes: IndexSet<Bytes>,
78}
79
80#[derive(Debug, Default)]
81pub struct LiteralsCollector {
82 max_values: usize,
83 total_values: usize,
84 output: LiteralMaps,
85}
86
87impl LiteralsCollector {
88 fn new(max_values: usize) -> Self {
89 Self { max_values, ..Default::default() }
90 }
91
92 pub fn process(
93 analysis: &Analysis,
94 paths_config: Option<&ProjectPathsConfig>,
95 max_values: usize,
96 ) -> LiteralMaps {
97 analysis.enter(|compiler| {
98 let mut literals_collector = Self::new(max_values);
99 for source in compiler.sources().iter() {
100 if let Some(paths) = paths_config
102 && let FileName::Real(source_path) = &source.file.name
103 && !(source_path.starts_with(&paths.sources) || paths.is_test(source_path))
104 {
105 continue;
106 }
107
108 if let Some(ast) = &source.ast
109 && literals_collector.visit_source_unit(ast).is_break()
110 {
111 break;
112 }
113 }
114
115 literals_collector.output
116 })
117 }
118}
119
120impl<'ast> ast::Visit<'ast> for LiteralsCollector {
121 type BreakValue = ();
122
123 fn visit_expr(&mut self, expr: &'ast ast::Expr<'ast>) -> ControlFlow<()> {
124 if self.total_values >= self.max_values {
126 return ControlFlow::Break(());
127 }
128
129 if let ast::ExprKind::Unary(un_op, inner_expr) = &expr.kind
131 && un_op.kind == ast::UnOpKind::Neg
132 && let ast::ExprKind::Lit(lit, _) = &inner_expr.kind
133 && let ast::LitKind::Number(n) = &lit.kind
134 {
135 if let Ok(pos_i256) = I256::try_from(*n) {
137 let neg_value = -pos_i256;
138 let neg_b256 = B256::from(neg_value.into_raw());
139
140 for bits in [16, 32, 64, 128, 256] {
142 if can_fit_int(neg_value, bits)
143 && self
144 .output
145 .words
146 .entry(DynSolType::Int(bits))
147 .or_default()
148 .insert(neg_b256)
149 {
150 self.total_values += 1;
151 }
152 }
153 }
154
155 return self.walk_expr(expr);
157 }
158
159 if let ast::ExprKind::Lit(lit, _) = &expr.kind {
161 let is_new = match &lit.kind {
162 ast::LitKind::Number(n) => {
163 let pos_value = U256::from(*n);
164 let pos_b256 = B256::from(pos_value);
165
166 for bits in [8, 16, 32, 64, 128, 256] {
168 if can_fit_uint(pos_value, bits)
169 && self
170 .output
171 .words
172 .entry(DynSolType::Uint(bits))
173 .or_default()
174 .insert(pos_b256)
175 {
176 self.total_values += 1;
177 }
178 }
179 false }
181 ast::LitKind::Address(addr) => self
182 .output
183 .words
184 .entry(DynSolType::Address)
185 .or_default()
186 .insert(addr.into_word()),
187 ast::LitKind::Str(ast::StrKind::Hex, sym, _) => {
188 self.output.bytes.insert(Bytes::copy_from_slice(sym.as_byte_str()))
189 }
190 ast::LitKind::Str(_, sym, _) => {
191 let s = String::from_utf8_lossy(sym.as_byte_str()).into_owned();
192 let hash = keccak256(s.as_bytes());
194 if self.output.words.entry(DynSolType::FixedBytes(32)).or_default().insert(hash)
195 {
196 self.total_values += 1;
197 }
198 if s.len() <= 32 {
200 let padded = B256::right_padding_from(s.as_bytes());
201 if self
202 .output
203 .words
204 .entry(DynSolType::FixedBytes(32))
205 .or_default()
206 .insert(padded)
207 {
208 self.total_values += 1;
209 }
210 }
211 self.output.strings.insert(s)
212 }
213 ast::LitKind::Bool(..) | ast::LitKind::Rational(..) | ast::LitKind::Err(..) => {
214 false }
216 };
217
218 if is_new {
219 self.total_values += 1;
220 }
221 }
222
223 self.walk_expr(expr)
224 }
225}
226
227fn can_fit_int(value: I256, bits: usize) -> bool {
229 let max_val = I256::try_from((U256::from(1) << (bits - 1)) - U256::from(1))
231 .expect("max value should fit in I256");
232 let min_val = -max_val - I256::ONE;
234
235 value >= min_val && value <= max_val
236}
237
238fn can_fit_uint(value: U256, bits: usize) -> bool {
240 if bits == 256 {
241 return true;
242 }
243 let max_val = (U256::from(1) << bits) - U256::from(1);
245 value <= max_val
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251 use alloy_primitives::address;
252 use solar::interface::{Session, source_map};
253
254 const SOURCE: &str = r#"
255 contract Magic {
256 // plain literals
257 address constant DAI = 0x6B175474E89094C44Da98b954EedeAC495271d0F;
258 uint64 constant MAGIC_NUMBER = 1122334455;
259 int32 constant MAGIC_INT = -777;
260 bytes32 constant MAGIC_WORD = "abcd1234";
261 bytes constant MAGIC_BYTES = hex"deadbeef";
262 string constant MAGIC_STRING = "xyzzy";
263
264 // constant exprs with folding
265 uint256 constant NEG_FOLDING = uint(-2);
266 uint256 constant BIN_FOLDING = 2 * 2 ether;
267 bytes32 constant IMPLEMENTATION_SLOT = bytes32(uint256(keccak256('eip1967.proxy.implementation')) - 1);
268 }"#;
269
270 #[test]
271 fn test_literals_collector_coverage() {
272 let map = process_source_literals(SOURCE);
273
274 let addr = address!("0x6B175474E89094C44Da98b954EedeAC495271d0F").into_word();
276 let num = B256::from(U256::from(1122334455u64));
277 let int = B256::from(I256::try_from(-777i32).unwrap().into_raw());
278 let word = B256::right_padding_from(b"abcd1234");
279 let dyn_bytes = Bytes::from_static(&[0xde, 0xad, 0xbe, 0xef]);
280
281 assert_word(&map, DynSolType::Address, addr, "Expected DAI in address set");
282 assert_word(&map, DynSolType::Uint(64), num, "Expected MAGIC_NUMBER in uint64 set");
283 assert_word(&map, DynSolType::Int(32), int, "Expected MAGIC_INT in int32 set");
284 assert_word(&map, DynSolType::FixedBytes(32), word, "Expected MAGIC_WORD in bytes32 set");
285 assert!(map.strings.contains("xyzzy"), "Expected MAGIC_STRING to be collected");
286 assert!(
287 map.strings.contains("eip1967.proxy.implementation"),
288 "Expected IMPLEMENTATION_SLOT in string set"
289 );
290 assert!(map.bytes.contains(&dyn_bytes), "Expected MAGIC_BYTES in bytes set");
291 }
292
293 #[test]
294 fn test_literals_collector_size() {
295 let literals = process_source_literals(SOURCE);
296
297 let count = |ty: DynSolType| literals.words.get(&ty).map_or(0, |set| set.len());
299
300 assert_eq!(count(DynSolType::Address), 1, "Address literal count mismatch");
301 assert_eq!(literals.strings.len(), 3, "String literals count mismatch");
302 assert_eq!(literals.bytes.len(), 1, "Byte literals count mismatch");
303
304 assert_eq!(count(DynSolType::Uint(8)), 2, "Uint(8) count mismatch");
306 assert_eq!(count(DynSolType::Uint(16)), 3, "Uint(16) count mismatch");
307 assert_eq!(count(DynSolType::Uint(32)), 4, "Uint(32) count mismatch");
308 assert_eq!(count(DynSolType::Uint(64)), 5, "Uint(64) count mismatch");
309 assert_eq!(count(DynSolType::Uint(128)), 5, "Uint(128) count mismatch");
310 assert_eq!(count(DynSolType::Uint(256)), 5, "Uint(256) count mismatch");
311
312 assert_eq!(count(DynSolType::Int(16)), 2, "Int(16) count mismatch");
314 assert_eq!(count(DynSolType::Int(32)), 2, "Int(32) count mismatch");
315 assert_eq!(count(DynSolType::Int(64)), 2, "Int(64) count mismatch");
316 assert_eq!(count(DynSolType::Int(128)), 2, "Int(128) count mismatch");
317 assert_eq!(count(DynSolType::Int(256)), 2, "Int(256) count mismatch");
318
319 assert_eq!(count(DynSolType::FixedBytes(32)), 6, "FixedBytes(32) count mismatch");
323
324 assert_eq!(
326 literals.words.values().map(|set| set.len()).sum::<usize>(),
327 41,
328 "Total word values count mismatch"
329 );
330 }
331
332 fn process_source_literals(source: &str) -> LiteralMaps {
335 let mut compiler =
336 solar::sema::Compiler::new(Session::builder().with_stderr_emitter().build());
337 compiler
338 .enter_mut(|c| -> std::io::Result<()> {
339 let mut pcx = c.parse();
340 pcx.set_resolve_imports(false);
341
342 pcx.add_file(
343 c.sess().source_map().new_source_file(source_map::FileName::Stdin, source)?,
344 );
345 pcx.parse();
346 let _ = c.lower_asts();
347 Ok(())
348 })
349 .expect("Failed to compile test source");
350
351 LiteralsCollector::process(&std::sync::Arc::new(compiler), None, usize::MAX)
352 }
353
354 fn assert_word(literals: &LiteralMaps, ty: DynSolType, value: B256, msg: &str) {
355 assert!(literals.words.get(&ty).is_some_and(|set| set.contains(&value)), "{}", msg);
356 }
357}