1use super::{CoverageItem, CoverageItemKind, SourceLocation};
2use alloy_primitives::map::HashMap;
3use foundry_common::TestFunctionExt;
4use foundry_compilers::ProjectCompileOutput;
5use rayon::prelude::*;
6use solar::{
7 ast::{self, ExprKind, ItemKind, StmtKind, yul},
8 data_structures::{Never, map::FxHashSet},
9 interface::{BytePos, Span},
10 sema::{Gcx, hir},
11};
12use std::{
13 ops::{ControlFlow, Range},
14 path::PathBuf,
15 sync::Arc,
16};
17
18#[derive(Clone)]
20struct SourceVisitor<'gcx> {
21 source_id: u32,
23 gcx: Gcx<'gcx>,
25
26 contract_name: Arc<str>,
28
29 branch_id: u32,
31
32 items: Vec<CoverageItem>,
34
35 all_lines: Vec<u32>,
36 function_calls: Vec<Span>,
37 function_calls_set: FxHashSet<Span>,
38}
39
40struct SourceVisitorCheckpoint {
41 items: usize,
42 all_lines: usize,
43 function_calls: usize,
44}
45
46impl<'gcx> SourceVisitor<'gcx> {
47 fn new(source_id: u32, gcx: Gcx<'gcx>) -> Self {
48 Self {
49 source_id,
50 gcx,
51 contract_name: Arc::default(),
52 branch_id: 0,
53 all_lines: Default::default(),
54 function_calls: Default::default(),
55 function_calls_set: Default::default(),
56 items: Default::default(),
57 }
58 }
59
60 fn checkpoint(&self) -> SourceVisitorCheckpoint {
61 SourceVisitorCheckpoint {
62 items: self.items.len(),
63 all_lines: self.all_lines.len(),
64 function_calls: self.function_calls.len(),
65 }
66 }
67
68 fn restore_checkpoint(&mut self, checkpoint: SourceVisitorCheckpoint) {
69 let SourceVisitorCheckpoint { items, all_lines, function_calls } = checkpoint;
70 self.items.truncate(items);
71 self.all_lines.truncate(all_lines);
72 self.function_calls.truncate(function_calls);
73 }
74
75 fn visit_contract<'ast>(&mut self, contract: &'ast ast::ItemContract<'ast>) {
76 let _ = ast::Visit::visit_item_contract(self, contract);
77 }
78
79 fn has_tests(&self, checkpoint: &SourceVisitorCheckpoint) -> bool {
81 self.items[checkpoint.items..].iter().any(|item| {
82 if let CoverageItemKind::Function { name } = &item.kind {
83 name.is_any_test()
84 } else {
85 false
86 }
87 })
88 }
89
90 fn disambiguate_functions(&mut self) {
92 let mut dups = HashMap::<_, Vec<usize>>::default();
93 for (i, item) in self.items.iter().enumerate() {
94 if let CoverageItemKind::Function { name } = &item.kind {
95 dups.entry(name.clone()).or_default().push(i);
96 }
97 }
98 for dups in dups.values() {
99 if dups.len() > 1 {
100 for (i, &dup) in dups.iter().enumerate() {
101 let item = &mut self.items[dup];
102 if let CoverageItemKind::Function { name } = &item.kind {
103 item.kind =
104 CoverageItemKind::Function { name: format!("{name}.{i}").into() };
105 }
106 }
107 }
108 }
109 }
110
111 fn resolve_function_calls(&mut self, hir_source_id: hir::SourceId) {
112 self.function_calls_set = self.function_calls.iter().copied().collect();
113 let _ = hir::Visit::visit_nested_source(self, hir_source_id);
114 }
115
116 fn sort(&mut self) {
117 self.items.sort();
118 }
119
120 fn push_lines(&mut self) {
121 self.all_lines.sort_unstable();
122 self.all_lines.dedup();
123 let mut lines = Vec::new();
124 for &line in &self.all_lines {
125 if let Some(reference_item) =
126 self.items.iter().find(|item| item.loc.lines.start == line)
127 {
128 lines.push(CoverageItem {
129 kind: CoverageItemKind::Line,
130 loc: reference_item.loc.clone(),
131 hits: 0,
132 });
133 }
134 }
135 self.items.extend(lines);
136 }
137
138 fn push_stmt(&mut self, span: Span) {
139 self.push_item_kind(CoverageItemKind::Statement, span);
140 }
141
142 fn push_item_kind(&mut self, kind: CoverageItemKind, span: Span) {
145 let item = CoverageItem { kind, loc: self.source_location_for(span), hits: 0 };
146
147 debug_assert!(!matches!(item.kind, CoverageItemKind::Line));
148 self.all_lines.push(item.loc.lines.start);
149
150 self.items.push(item);
151 }
152
153 fn source_location_for(&self, mut span: Span) -> SourceLocation {
154 if let Ok(snippet) = self.gcx.sess.source_map().span_to_snippet(span)
156 && let Some(stripped) = snippet.strip_suffix(';')
157 {
158 let stripped = stripped.trim_end();
159 let skipped = snippet.len() - stripped.len();
160 span = span.with_hi(span.hi() - BytePos::from_usize(skipped));
161 }
162
163 SourceLocation {
164 source_id: self.source_id as usize,
165 contract_name: self.contract_name.clone(),
166 bytes: self.byte_range(span),
167 lines: self.line_range(span),
168 }
169 }
170
171 fn byte_range(&self, span: Span) -> Range<u32> {
172 let bytes_usize = self.gcx.sess.source_map().span_to_source(span).unwrap().data;
173 bytes_usize.start as u32..bytes_usize.end as u32
174 }
175
176 fn line_range(&self, span: Span) -> Range<u32> {
177 let lines = self.gcx.sess.source_map().span_to_lines(span).unwrap().data;
178 assert!(!lines.is_empty());
179 let first = lines.first().unwrap();
180 let last = lines.last().unwrap();
181 first.line_index as u32 + 1..last.line_index as u32 + 2
182 }
183
184 fn next_branch_id(&mut self) -> u32 {
185 let id = self.branch_id;
186 self.branch_id = id + 1;
187 id
188 }
189}
190
191impl<'ast> ast::Visit<'ast> for SourceVisitor<'_> {
192 type BreakValue = Never;
193
194 fn visit_item_contract(
195 &mut self,
196 contract: &'ast ast::ItemContract<'ast>,
197 ) -> ControlFlow<Self::BreakValue> {
198 self.contract_name = contract.name.as_str().into();
199 self.walk_item_contract(contract)
200 }
201
202 #[expect(clippy::single_match)]
203 fn visit_item(&mut self, item: &'ast ast::Item<'ast>) -> ControlFlow<Self::BreakValue> {
204 match &item.kind {
205 ItemKind::Function(func) => {
206 if func.kind != ast::FunctionKind::Function && !has_statements(func.body.as_ref()) {
209 return ControlFlow::Continue(());
210 }
211
212 let name = func.header.name.as_ref().map(|n| n.as_str()).unwrap_or_else(|| {
213 match func.kind {
214 ast::FunctionKind::Constructor => "constructor",
215 ast::FunctionKind::Receive => "receive",
216 ast::FunctionKind::Fallback => "fallback",
217 ast::FunctionKind::Function | ast::FunctionKind::Modifier => unreachable!(),
218 }
219 });
220
221 let exclude_func = func.header.virtual_() && !func.is_implemented();
223 if !exclude_func {
224 self.push_item_kind(
225 CoverageItemKind::Function { name: name.into() },
226 item.span,
227 );
228 }
229
230 self.walk_item(item)?;
231 }
232 _ => {}
233 }
234 ControlFlow::Continue(())
236 }
237
238 fn visit_stmt(&mut self, stmt: &'ast ast::Stmt<'ast>) -> ControlFlow<Self::BreakValue> {
239 match &stmt.kind {
240 StmtKind::Break | StmtKind::Continue | StmtKind::Emit(..) | StmtKind::Revert(..) => {
241 self.push_stmt(stmt.span);
242 return ControlFlow::Continue(());
244 }
245 StmtKind::Return(_) | StmtKind::DeclSingle(_) | StmtKind::DeclMulti(..) => {
246 self.push_stmt(stmt.span);
247 }
248
249 StmtKind::If(_cond, then_stmt, else_stmt) => {
250 let branch_id = self.next_branch_id();
251
252 if stmt_has_statements(then_stmt)
254 || else_stmt.as_ref().is_some_and(|s| stmt_has_statements(s))
255 {
256 self.push_item_kind(
259 CoverageItemKind::Branch { branch_id, path_id: 0, is_first_opcode: true },
260 then_stmt.span,
261 );
262 if else_stmt.is_some() {
263 self.push_item_kind(
267 CoverageItemKind::Branch {
268 branch_id,
269 path_id: 1,
270 is_first_opcode: false,
271 },
272 stmt.span,
273 );
274 }
275 }
276 }
277
278 StmtKind::Try(ast::StmtTry { expr: _, clauses }) => {
279 let branch_id = self.next_branch_id();
280
281 let mut path_id = 0;
282 for catch in clauses.iter() {
283 let ast::TryCatchClause { span, name: _, args, block } = catch;
284 let span = if path_id == 0 { stmt.span.to(*span) } else { *span };
285 if path_id == 0 || has_statements(Some(block)) {
286 self.push_item_kind(
287 CoverageItemKind::Branch { branch_id, path_id, is_first_opcode: true },
288 span,
289 );
290 path_id += 1;
291 } else if !args.is_empty() {
292 self.push_stmt(span);
296 }
297 }
298 }
299
300 StmtKind::Assembly(_)
302 | StmtKind::Block(_)
303 | StmtKind::UncheckedBlock(_)
304 | StmtKind::Placeholder
305 | StmtKind::Expr(_)
306 | StmtKind::While(..)
307 | StmtKind::DoWhile(..)
308 | StmtKind::For { .. } => {}
309 }
310 self.walk_stmt(stmt)
311 }
312
313 fn visit_expr(&mut self, expr: &'ast ast::Expr<'ast>) -> ControlFlow<Self::BreakValue> {
314 match &expr.kind {
315 ExprKind::Assign(..)
316 | ExprKind::Unary(..)
317 | ExprKind::Binary(..)
318 | ExprKind::Ternary(..) => {
319 self.push_stmt(expr.span);
320 if matches!(expr.kind, ExprKind::Binary(..)) {
321 return self.walk_expr(expr);
322 }
323 }
324 ExprKind::Call(callee, _args) => {
325 self.function_calls.push(expr.span);
327
328 if let ExprKind::Ident(ident) = &callee.kind {
329 if ident.as_str() == "require" {
332 let branch_id = self.next_branch_id();
333 self.push_item_kind(
334 CoverageItemKind::Branch {
335 branch_id,
336 path_id: 0,
337 is_first_opcode: false,
338 },
339 expr.span,
340 );
341 self.push_item_kind(
342 CoverageItemKind::Branch {
343 branch_id,
344 path_id: 1,
345 is_first_opcode: false,
346 },
347 expr.span,
348 );
349 }
350 }
351 }
352 _ => {}
353 }
354 ControlFlow::Continue(())
356 }
357
358 fn visit_yul_stmt(&mut self, stmt: &'ast yul::Stmt<'ast>) -> ControlFlow<Self::BreakValue> {
359 use yul::StmtKind;
360 match &stmt.kind {
361 StmtKind::VarDecl(..)
362 | StmtKind::AssignSingle(..)
363 | StmtKind::AssignMulti(..)
364 | StmtKind::Leave
365 | StmtKind::Break
366 | StmtKind::Continue => {
367 self.push_stmt(stmt.span);
368 return ControlFlow::Continue(());
370 }
371 StmtKind::If(..) => {
372 let branch_id = self.next_branch_id();
373 self.push_item_kind(
374 CoverageItemKind::Branch { branch_id, path_id: 0, is_first_opcode: false },
375 stmt.span,
376 );
377 }
378 StmtKind::For(yul::StmtFor { body, .. }) => {
379 self.push_stmt(body.span);
380 }
381 StmtKind::Switch(switch) => {
382 for case in switch.cases.iter() {
383 self.push_stmt(case.span);
384 self.push_stmt(case.body.span);
385 }
386 }
387 StmtKind::FunctionDef(func) => {
388 let name = func.name.as_str();
389 self.push_item_kind(CoverageItemKind::Function { name: name.into() }, stmt.span);
390 }
391 StmtKind::Expr(_) => {
393 self.push_stmt(stmt.span);
394 return ControlFlow::Continue(());
395 }
396 StmtKind::Block(_) => {}
397 }
398 self.walk_yul_stmt(stmt)
399 }
400
401 fn visit_yul_expr(&mut self, expr: &'ast yul::Expr<'ast>) -> ControlFlow<Self::BreakValue> {
402 use yul::ExprKind;
403 match &expr.kind {
404 ExprKind::Path(_) | ExprKind::Lit(_) => {}
405 ExprKind::Call(_) => self.push_stmt(expr.span),
406 }
407 ControlFlow::Continue(())
409 }
410}
411
412impl<'gcx> hir::Visit<'gcx> for SourceVisitor<'gcx> {
413 type BreakValue = Never;
414
415 fn hir(&self) -> &'gcx hir::Hir<'gcx> {
416 &self.gcx.hir
417 }
418
419 fn visit_expr(&mut self, expr: &'gcx hir::Expr<'gcx>) -> ControlFlow<Self::BreakValue> {
420 if let hir::ExprKind::Call(lhs, ..) = &expr.kind
421 && self.function_calls_set.contains(&expr.span)
422 && is_regular_call(lhs)
423 {
424 self.push_stmt(expr.span);
425 }
426 self.walk_expr(expr)
427 }
428}
429
430fn is_regular_call(lhs: &hir::Expr<'_>) -> bool {
433 match lhs.peel_parens().kind {
434 hir::ExprKind::Ident([hir::Res::Item(hir::ItemId::Struct(_))]) => false,
436 hir::ExprKind::Type(_) => false,
438 _ => true,
439 }
440}
441
442fn has_statements(block: Option<&ast::Block<'_>>) -> bool {
443 block.is_some_and(|block| !block.is_empty())
444}
445
446fn stmt_has_statements(stmt: &ast::Stmt<'_>) -> bool {
447 match &stmt.kind {
448 StmtKind::Assembly(a) => !a.block.is_empty(),
449 StmtKind::Block(b) | StmtKind::UncheckedBlock(b) => has_statements(Some(b)),
450 _ => true,
451 }
452}
453
454#[derive(Clone, Debug, Default)]
456pub struct SourceAnalysis {
457 all_items: Vec<CoverageItem>,
459 map: Vec<(u32, u32)>,
461}
462
463impl SourceAnalysis {
464 #[instrument(name = "SourceAnalysis::new", skip_all)]
478 pub fn new(data: &SourceFiles, output: &ProjectCompileOutput) -> eyre::Result<Self> {
479 let mut sourced_items = output.parser().solc().compiler().enter(|compiler| {
480 data.sources
481 .par_iter()
482 .map(|(&source_id, path)| {
483 let _guard = debug_span!("SourceAnalysis::new::visit", ?path).entered();
484
485 let (_, source) = compiler.gcx().get_ast_source(path).unwrap();
486 let ast = source.ast.as_ref().unwrap();
487 let (hir_source_id, _) = compiler.gcx().get_hir_source(path).unwrap();
488
489 let mut visitor = SourceVisitor::new(source_id, compiler.gcx());
490 for item in ast.items.iter() {
491 let ItemKind::Contract(contract) = &item.kind else { continue };
493
494 if contract.kind.is_interface() {
496 continue;
497 }
498
499 let checkpoint = visitor.checkpoint();
500 visitor.visit_contract(contract);
501 if visitor.has_tests(&checkpoint) {
502 visitor.restore_checkpoint(checkpoint);
503 }
504 }
505
506 if !visitor.function_calls.is_empty() {
507 visitor.resolve_function_calls(hir_source_id);
508 }
509
510 if !visitor.items.is_empty() {
511 visitor.disambiguate_functions();
512 visitor.sort();
513 visitor.push_lines();
514 visitor.sort();
515 }
516 (source_id, visitor.items)
517 })
518 .collect::<Vec<(u32, Vec<CoverageItem>)>>()
519 });
520
521 sourced_items.sort_by_key(|(id, items)| (*id, items.first().map(|i| i.loc.bytes.start)));
523 let Some(&(max_idx, _)) = sourced_items.last() else { return Ok(Self::default()) };
524 let len = max_idx + 1;
525 let mut all_items = Vec::new();
526 let mut map = vec![(u32::MAX, 0); len as usize];
527 for (idx, items) in sourced_items {
528 let idx = idx as usize;
530 if map[idx].0 == u32::MAX {
531 map[idx].0 = all_items.len() as u32;
532 }
533 map[idx].1 += items.len() as u32;
534 all_items.extend(items);
535 }
536
537 Ok(Self { all_items, map })
538 }
539
540 pub fn all_items(&self) -> &[CoverageItem] {
542 &self.all_items
543 }
544
545 pub fn all_items_mut(&mut self) -> &mut Vec<CoverageItem> {
547 &mut self.all_items
548 }
549
550 pub fn items_for_source_enumerated(
552 &self,
553 source_id: u32,
554 ) -> impl Iterator<Item = (u32, &CoverageItem)> {
555 let (base_id, items) = self.items_for_source(source_id);
556 items.iter().enumerate().map(move |(idx, item)| (base_id + idx as u32, item))
557 }
558
559 pub fn items_for_source(&self, source_id: u32) -> (u32, &[CoverageItem]) {
561 let (mut offset, len) = self.map.get(source_id as usize).copied().unwrap_or_default();
562 if offset == u32::MAX {
563 offset = 0;
564 }
565 (offset, &self.all_items[offset as usize..][..len as usize])
566 }
567
568 #[inline]
570 pub fn get(&self, item_id: u32) -> Option<&CoverageItem> {
571 self.all_items.get(item_id as usize)
572 }
573}
574
575#[derive(Default)]
577pub struct SourceFiles {
578 pub sources: HashMap<u32, PathBuf>,
580}