1use super::{CoverageItem, CoverageItemKind, SourceLocation};
2use alloy_primitives::map::HashMap;
3use foundry_common::TestFunctionExt;
4use foundry_compilers::ProjectCompileOutput;
5use rayon::prelude::*;
6use solar::{
7 ast::{self, ExprKind, ItemKind, StmtKind, yul},
8 data_structures::{Never, map::FxHashSet},
9 interface::{BytePos, Span},
10 sema::{Gcx, hir},
11};
12use std::{
13 ops::{ControlFlow, Range},
14 path::PathBuf,
15 sync::Arc,
16};
17
18#[derive(Clone)]
20struct SourceVisitor<'gcx> {
21 source_id: u32,
23 gcx: Gcx<'gcx>,
25
26 contract_name: Arc<str>,
28
29 branch_id: u32,
31
32 items: Vec<CoverageItem>,
34
35 all_lines: Vec<u32>,
36 function_calls: Vec<Span>,
37 function_calls_set: FxHashSet<Span>,
38}
39
40struct SourceVisitorCheckpoint {
41 items: usize,
42 all_lines: usize,
43 function_calls: usize,
44}
45
46impl<'gcx> SourceVisitor<'gcx> {
47 fn new(source_id: u32, gcx: Gcx<'gcx>) -> Self {
48 Self {
49 source_id,
50 gcx,
51 contract_name: Arc::default(),
52 branch_id: 0,
53 all_lines: Default::default(),
54 function_calls: Default::default(),
55 function_calls_set: Default::default(),
56 items: Default::default(),
57 }
58 }
59
60 fn checkpoint(&self) -> SourceVisitorCheckpoint {
61 SourceVisitorCheckpoint {
62 items: self.items.len(),
63 all_lines: self.all_lines.len(),
64 function_calls: self.function_calls.len(),
65 }
66 }
67
68 fn restore_checkpoint(&mut self, checkpoint: SourceVisitorCheckpoint) {
69 let SourceVisitorCheckpoint { items, all_lines, function_calls } = checkpoint;
70 self.items.truncate(items);
71 self.all_lines.truncate(all_lines);
72 self.function_calls.truncate(function_calls);
73 }
74
75 fn visit_contract<'ast>(&mut self, contract: &'ast ast::ItemContract<'ast>) {
76 let _ = ast::Visit::visit_item_contract(self, contract);
77 }
78
79 fn has_tests(&self, checkpoint: &SourceVisitorCheckpoint) -> bool {
81 self.items[checkpoint.items..].iter().any(|item| {
82 if let CoverageItemKind::Function { name } = &item.kind {
83 name.is_any_test()
84 } else {
85 false
86 }
87 })
88 }
89
90 fn disambiguate_functions(&mut self) {
92 let mut dups = HashMap::<_, Vec<usize>>::default();
93 for (i, item) in self.items.iter().enumerate() {
94 if let CoverageItemKind::Function { name } = &item.kind {
95 dups.entry(name.clone()).or_default().push(i);
96 }
97 }
98 for dups in dups.values() {
99 if dups.len() > 1 {
100 for (i, &dup) in dups.iter().enumerate() {
101 let item = &mut self.items[dup];
102 if let CoverageItemKind::Function { name } = &item.kind {
103 item.kind =
104 CoverageItemKind::Function { name: format!("{name}.{i}").into() };
105 }
106 }
107 }
108 }
109 }
110
111 fn resolve_function_calls(&mut self, hir_source_id: hir::SourceId) {
112 self.function_calls_set = self.function_calls.iter().copied().collect();
113 let _ = hir::Visit::visit_nested_source(self, hir_source_id);
114 }
115
116 fn sort(&mut self) {
117 self.items.sort();
118 }
119
120 fn push_lines(&mut self) {
121 self.all_lines.sort_unstable();
122 self.all_lines.dedup();
123 let mut lines = Vec::new();
124 for &line in &self.all_lines {
125 if let Some(reference_item) =
126 self.items.iter().find(|item| item.loc.lines.start == line)
127 {
128 lines.push(CoverageItem {
129 kind: CoverageItemKind::Line,
130 loc: reference_item.loc.clone(),
131 hits: 0,
132 });
133 }
134 }
135 self.items.extend(lines);
136 }
137
138 fn push_stmt(&mut self, span: Span) {
139 self.push_item_kind(CoverageItemKind::Statement, span);
140 }
141
142 fn push_item_kind(&mut self, kind: CoverageItemKind, span: Span) {
145 let item = CoverageItem { kind, loc: self.source_location_for(span), hits: 0 };
146
147 debug_assert!(!matches!(item.kind, CoverageItemKind::Line));
148 self.all_lines.push(item.loc.lines.start);
149
150 self.items.push(item);
151 }
152
153 fn source_location_for(&self, mut span: Span) -> SourceLocation {
154 if let Ok(snippet) = self.gcx.sess.source_map().span_to_snippet(span)
156 && let Some(stripped) = snippet.strip_suffix(';')
157 {
158 let stripped = stripped.trim_end();
159 let skipped = snippet.len() - stripped.len();
160 span = span.with_hi(span.hi() - BytePos::from_usize(skipped));
161 }
162
163 SourceLocation {
164 source_id: self.source_id as usize,
165 contract_name: self.contract_name.clone(),
166 bytes: self.byte_range(span),
167 lines: self.line_range(span),
168 }
169 }
170
171 fn byte_range(&self, span: Span) -> Range<u32> {
172 let bytes_usize = self.gcx.sess.source_map().span_to_source(span).unwrap().data;
173 bytes_usize.start as u32..bytes_usize.end as u32
174 }
175
176 fn line_range(&self, span: Span) -> Range<u32> {
177 let lines = self.gcx.sess.source_map().span_to_lines(span).unwrap().data;
178 assert!(!lines.is_empty());
179 let first = lines.first().unwrap();
180 let last = lines.last().unwrap();
181 first.line_index as u32 + 1..last.line_index as u32 + 2
182 }
183
184 fn next_branch_id(&mut self) -> u32 {
185 let id = self.branch_id;
186 self.branch_id = id + 1;
187 id
188 }
189}
190
191impl<'ast> ast::Visit<'ast> for SourceVisitor<'_> {
192 type BreakValue = Never;
193
194 fn visit_item_contract(
195 &mut self,
196 contract: &'ast ast::ItemContract<'ast>,
197 ) -> ControlFlow<Self::BreakValue> {
198 self.contract_name = contract.name.as_str().into();
199 self.walk_item_contract(contract)
200 }
201
202 #[expect(clippy::single_match)]
203 fn visit_item(&mut self, item: &'ast ast::Item<'ast>) -> ControlFlow<Self::BreakValue> {
204 match &item.kind {
205 ItemKind::Function(func) => {
206 if func.kind != ast::FunctionKind::Function && !has_statements(func.body.as_ref()) {
209 return ControlFlow::Continue(());
210 }
211
212 let name = func.header.name.as_ref().map(|n| n.as_str()).unwrap_or_else(|| {
213 match func.kind {
214 ast::FunctionKind::Constructor => "constructor",
215 ast::FunctionKind::Receive => "receive",
216 ast::FunctionKind::Fallback => "fallback",
217 ast::FunctionKind::Function | ast::FunctionKind::Modifier => unreachable!(),
218 }
219 });
220
221 self.push_item_kind(CoverageItemKind::Function { name: name.into() }, item.span);
222 self.walk_item(item)?;
223 }
224 _ => {}
225 }
226 ControlFlow::Continue(())
228 }
229
230 fn visit_stmt(&mut self, stmt: &'ast ast::Stmt<'ast>) -> ControlFlow<Self::BreakValue> {
231 match &stmt.kind {
232 StmtKind::Break | StmtKind::Continue | StmtKind::Emit(..) | StmtKind::Revert(..) => {
233 self.push_stmt(stmt.span);
234 return ControlFlow::Continue(());
236 }
237 StmtKind::Return(_) | StmtKind::DeclSingle(_) | StmtKind::DeclMulti(..) => {
238 self.push_stmt(stmt.span);
239 }
240
241 StmtKind::If(_cond, then_stmt, else_stmt) => {
242 let branch_id = self.next_branch_id();
243
244 if stmt_has_statements(then_stmt)
246 || else_stmt.as_ref().is_some_and(|s| stmt_has_statements(s))
247 {
248 self.push_item_kind(
251 CoverageItemKind::Branch { branch_id, path_id: 0, is_first_opcode: true },
252 then_stmt.span,
253 );
254 if else_stmt.is_some() {
255 self.push_item_kind(
259 CoverageItemKind::Branch {
260 branch_id,
261 path_id: 1,
262 is_first_opcode: false,
263 },
264 stmt.span,
265 );
266 }
267 }
268 }
269
270 StmtKind::Try(ast::StmtTry { expr: _, clauses }) => {
271 let branch_id = self.next_branch_id();
272
273 let mut path_id = 0;
274 for catch in clauses.iter() {
275 let ast::TryCatchClause { span, name: _, args, block } = catch;
276 let span = if path_id == 0 { stmt.span.to(*span) } else { *span };
277 if path_id == 0 || has_statements(Some(block)) {
278 self.push_item_kind(
279 CoverageItemKind::Branch { branch_id, path_id, is_first_opcode: true },
280 span,
281 );
282 path_id += 1;
283 } else if !args.is_empty() {
284 self.push_stmt(span);
288 }
289 }
290 }
291
292 StmtKind::Assembly(_)
294 | StmtKind::Block(_)
295 | StmtKind::UncheckedBlock(_)
296 | StmtKind::Placeholder
297 | StmtKind::Expr(_)
298 | StmtKind::While(..)
299 | StmtKind::DoWhile(..)
300 | StmtKind::For { .. } => {}
301 }
302 self.walk_stmt(stmt)
303 }
304
305 fn visit_expr(&mut self, expr: &'ast ast::Expr<'ast>) -> ControlFlow<Self::BreakValue> {
306 match &expr.kind {
307 ExprKind::Assign(..)
308 | ExprKind::Unary(..)
309 | ExprKind::Binary(..)
310 | ExprKind::Ternary(..) => {
311 self.push_stmt(expr.span);
312 if matches!(expr.kind, ExprKind::Binary(..)) {
313 return self.walk_expr(expr);
314 }
315 }
316 ExprKind::Call(callee, _args) => {
317 self.function_calls.push(expr.span);
319
320 if let ExprKind::Ident(ident) = &callee.kind {
321 if ident.as_str() == "require" {
324 let branch_id = self.next_branch_id();
325 self.push_item_kind(
326 CoverageItemKind::Branch {
327 branch_id,
328 path_id: 0,
329 is_first_opcode: false,
330 },
331 expr.span,
332 );
333 self.push_item_kind(
334 CoverageItemKind::Branch {
335 branch_id,
336 path_id: 1,
337 is_first_opcode: false,
338 },
339 expr.span,
340 );
341 }
342 }
343 }
344 _ => {}
345 }
346 ControlFlow::Continue(())
348 }
349
350 fn visit_yul_stmt(&mut self, stmt: &'ast yul::Stmt<'ast>) -> ControlFlow<Self::BreakValue> {
351 use yul::StmtKind;
352 match &stmt.kind {
353 StmtKind::VarDecl(..)
354 | StmtKind::AssignSingle(..)
355 | StmtKind::AssignMulti(..)
356 | StmtKind::Leave
357 | StmtKind::Break
358 | StmtKind::Continue => {
359 self.push_stmt(stmt.span);
360 return ControlFlow::Continue(());
362 }
363 StmtKind::If(..) => {
364 let branch_id = self.next_branch_id();
365 self.push_item_kind(
366 CoverageItemKind::Branch { branch_id, path_id: 0, is_first_opcode: false },
367 stmt.span,
368 );
369 }
370 StmtKind::For { body, .. } => {
371 self.push_stmt(body.span);
372 }
373 StmtKind::Switch(switch) => {
374 for case in switch.cases.iter() {
375 self.push_stmt(case.span);
376 self.push_stmt(case.body.span);
377 }
378 }
379 StmtKind::FunctionDef(func) => {
380 let name = func.name.as_str();
381 self.push_item_kind(CoverageItemKind::Function { name: name.into() }, stmt.span);
382 }
383 StmtKind::Expr(_) => {
385 self.push_stmt(stmt.span);
386 return ControlFlow::Continue(());
387 }
388 StmtKind::Block(_) => {}
389 }
390 self.walk_yul_stmt(stmt)
391 }
392
393 fn visit_yul_expr(&mut self, expr: &'ast yul::Expr<'ast>) -> ControlFlow<Self::BreakValue> {
394 use yul::ExprKind;
395 match &expr.kind {
396 ExprKind::Path(_) | ExprKind::Lit(_) => {}
397 ExprKind::Call(_) => self.push_stmt(expr.span),
398 }
399 ControlFlow::Continue(())
401 }
402}
403
404impl<'gcx> hir::Visit<'gcx> for SourceVisitor<'gcx> {
405 type BreakValue = Never;
406
407 fn hir(&self) -> &'gcx hir::Hir<'gcx> {
408 &self.gcx.hir
409 }
410
411 fn visit_expr(&mut self, expr: &'gcx hir::Expr<'gcx>) -> ControlFlow<Self::BreakValue> {
412 if let hir::ExprKind::Call(lhs, ..) = &expr.kind
413 && self.function_calls_set.contains(&expr.span)
414 && is_regular_call(lhs)
415 {
416 self.push_stmt(expr.span);
417 }
418 self.walk_expr(expr)
419 }
420}
421
422fn is_regular_call(lhs: &hir::Expr<'_>) -> bool {
425 match lhs.peel_parens().kind {
426 hir::ExprKind::Ident([hir::Res::Item(hir::ItemId::Struct(_))]) => false,
428 hir::ExprKind::Type(_) => false,
430 _ => true,
431 }
432}
433
434fn has_statements(block: Option<&ast::Block<'_>>) -> bool {
435 block.is_some_and(|block| !block.is_empty())
436}
437
438fn stmt_has_statements(stmt: &ast::Stmt<'_>) -> bool {
439 match &stmt.kind {
440 StmtKind::Assembly(a) => !a.block.is_empty(),
441 StmtKind::Block(b) | StmtKind::UncheckedBlock(b) => has_statements(Some(b)),
442 _ => true,
443 }
444}
445
446#[derive(Clone, Debug, Default)]
448pub struct SourceAnalysis {
449 all_items: Vec<CoverageItem>,
451 map: Vec<(u32, u32)>,
453}
454
455impl SourceAnalysis {
456 #[instrument(name = "SourceAnalysis::new", skip_all)]
470 pub fn new(data: &SourceFiles, output: &ProjectCompileOutput) -> eyre::Result<Self> {
471 let mut sourced_items = output.parser().solc().compiler().enter(|compiler| {
472 data.sources
473 .par_iter()
474 .map(|(&source_id, path)| {
475 let _guard = debug_span!("SourceAnalysis::new::visit", ?path).entered();
476
477 let (_, source) = compiler.gcx().get_ast_source(path).unwrap();
478 let ast = source.ast.as_ref().unwrap();
479 let (hir_source_id, _) = compiler.gcx().get_hir_source(path).unwrap();
480
481 let mut visitor = SourceVisitor::new(source_id, compiler.gcx());
482 for item in ast.items.iter() {
483 let ItemKind::Contract(contract) = &item.kind else { continue };
485
486 if contract.kind.is_interface() {
488 continue;
489 }
490
491 let checkpoint = visitor.checkpoint();
492 visitor.visit_contract(contract);
493 if visitor.has_tests(&checkpoint) {
494 visitor.restore_checkpoint(checkpoint);
495 }
496 }
497
498 if !visitor.function_calls.is_empty() {
499 visitor.resolve_function_calls(hir_source_id);
500 }
501
502 if !visitor.items.is_empty() {
503 visitor.disambiguate_functions();
504 visitor.sort();
505 visitor.push_lines();
506 visitor.sort();
507 }
508 (source_id, visitor.items)
509 })
510 .collect::<Vec<(u32, Vec<CoverageItem>)>>()
511 });
512
513 sourced_items.sort_by_key(|(id, items)| (*id, items.first().map(|i| i.loc.bytes.start)));
515 let Some(&(max_idx, _)) = sourced_items.last() else { return Ok(Self::default()) };
516 let len = max_idx + 1;
517 let mut all_items = Vec::new();
518 let mut map = vec![(u32::MAX, 0); len as usize];
519 for (idx, items) in sourced_items {
520 let idx = idx as usize;
522 if map[idx].0 == u32::MAX {
523 map[idx].0 = all_items.len() as u32;
524 }
525 map[idx].1 += items.len() as u32;
526 all_items.extend(items);
527 }
528
529 Ok(Self { all_items, map })
530 }
531
532 pub fn all_items(&self) -> &[CoverageItem] {
534 &self.all_items
535 }
536
537 pub fn all_items_mut(&mut self) -> &mut Vec<CoverageItem> {
539 &mut self.all_items
540 }
541
542 pub fn items_for_source_enumerated(
544 &self,
545 source_id: u32,
546 ) -> impl Iterator<Item = (u32, &CoverageItem)> {
547 let (base_id, items) = self.items_for_source(source_id);
548 items.iter().enumerate().map(move |(idx, item)| (base_id + idx as u32, item))
549 }
550
551 pub fn items_for_source(&self, source_id: u32) -> (u32, &[CoverageItem]) {
553 let (mut offset, len) = self.map.get(source_id as usize).copied().unwrap_or_default();
554 if offset == u32::MAX {
555 offset = 0;
556 }
557 (offset, &self.all_items[offset as usize..][..len as usize])
558 }
559
560 #[inline]
562 pub fn get(&self, item_id: u32) -> Option<&CoverageItem> {
563 self.all_items.get(item_id as usize)
564 }
565}
566
567#[derive(Default)]
569pub struct SourceFiles {
570 pub sources: HashMap<u32, PathBuf>,
572}