1use super::ExternalFunction;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{Severity, SolLint},
5};
6use solar::{
7 ast::{ContractKind, DataLocation, UnOpKind, Visibility},
8 interface::{Symbol, data_structures::Never},
9 sema::hir::{
10 self, ContractId, ExprKind, FunctionId, ItemId, Res, StmtKind, VariableId, Visit as _,
11 },
12};
13use std::{
14 cell::RefCell,
15 collections::{HashMap, HashSet},
16 ops::ControlFlow,
17 rc::Rc,
18};
19
20declare_forge_lint!(
21 EXTERNAL_FUNCTION,
22 Severity::Gas,
23 "external-function",
24 "public function can be declared external"
25);
26
27#[derive(Default)]
28struct ProjectIndex {
29 referenced: HashSet<FunctionId>,
32 super_called: HashMap<Symbol, HashSet<ContractId>>,
35}
36
37thread_local! {
38 static PROJECT_INDEX: RefCell<Option<(usize, Rc<ProjectIndex>)>> = const { RefCell::new(None) };
42}
43
44fn project_index_for<'hir>(hir: &'hir hir::Hir<'hir>) -> Rc<ProjectIndex> {
45 let key = std::ptr::from_ref::<hir::Hir<'_>>(hir) as usize;
46 PROJECT_INDEX.with(|cell| {
47 let mut slot = cell.borrow_mut();
48 if let Some((cached_key, cached)) = slot.as_ref()
49 && *cached_key == key
50 {
51 return cached.clone();
52 }
53 let fresh = Rc::new(build_project_index(hir));
54 *slot = Some((key, fresh.clone()));
55 fresh
56 })
57}
58
59impl<'hir> LateLintPass<'hir> for ExternalFunction {
60 fn check_nested_contract(
61 &mut self,
62 ctx: &LintContext,
63 _gcx: solar::sema::Gcx<'hir>,
64 hir: &'hir hir::Hir<'hir>,
65 contract_id: ContractId,
66 ) {
67 if !ctx.is_lint_enabled(EXTERNAL_FUNCTION.id) {
68 return;
69 }
70
71 let contract = hir.contract(contract_id);
72
73 if !matches!(contract.kind, ContractKind::Contract | ContractKind::AbstractContract) {
76 return;
77 }
78 if contract.linearization_failed() {
79 return;
80 }
81
82 let index = project_index_for(hir);
83
84 for fid in contract.functions() {
85 let func = hir.function(fid);
86
87 if func.visibility != Visibility::Public || !func.is_ordinary() {
89 continue;
90 }
91 if func.override_ {
94 continue;
95 }
96 let Some(body) = func.body else { continue };
98
99 let has_memory_reference_param = func.parameters.iter().any(|&pid| {
102 let p = hir.variable(pid);
103 p.ty.kind.is_reference_type() && p.data_location == Some(DataLocation::Memory)
104 });
105 if !has_memory_reference_param {
106 continue;
107 }
108
109 if body_escapes_params(hir, &body, func.parameters)
110 || modifier_args_reference_params(func.modifiers, func.parameters)
111 {
112 continue;
113 }
114
115 let Some(name) = func.name else { continue };
116
117 if super_called_from_derivative(hir, contract_id, &name.name, &index.super_called) {
120 continue;
121 }
122 if any_override_referenced(hir, contract_id, func, &index.referenced) {
123 continue;
124 }
125
126 ctx.emit(&EXTERNAL_FUNCTION, name.span);
127 }
128 }
129}
130
131fn build_project_index<'hir>(hir: &'hir hir::Hir<'hir>) -> ProjectIndex {
132 let mut builder = IndexBuilder { hir, idx: ProjectIndex::default(), current_contract: None };
133 for func in hir.functions() {
134 builder.current_contract = func.contract;
135 let _ = builder.visit_function(func);
136 }
137 for var in hir.variables() {
140 if var.is_state_variable() {
141 builder.current_contract = var.contract;
142 let _ = builder.visit_var(var);
143 }
144 }
145 builder.idx
146}
147
148struct IndexBuilder<'hir> {
152 hir: &'hir hir::Hir<'hir>,
153 idx: ProjectIndex,
154 current_contract: Option<ContractId>,
156}
157
158impl<'hir> hir::Visit<'hir> for IndexBuilder<'hir> {
159 type BreakValue = Never;
160
161 fn hir(&self) -> &'hir hir::Hir<'hir> {
162 self.hir
163 }
164
165 fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
166 match &expr.kind {
167 ExprKind::Ident(reses) => {
168 for res in *reses {
169 if let Res::Item(ItemId::Function(fid)) = res {
170 self.idx.referenced.insert(*fid);
171 }
172 }
173 }
174 ExprKind::Member(base, member) => {
175 if let Some(cid) = self.current_contract
176 && let ExprKind::Ident(reses) = &base.peel_parens().kind
177 && reses.iter().any(|r| matches!(r, Res::Builtin(b) if b.name() == solar::interface::sym::super_))
178 {
179 self.idx.super_called.entry(member.name).or_default().insert(cid);
180 }
181 }
182 _ => {}
183 }
184 self.walk_expr(expr)
185 }
186}
187
188fn super_called_from_derivative(
191 hir: &hir::Hir<'_>,
192 base_contract_id: ContractId,
193 name: &Symbol,
194 super_called: &HashMap<Symbol, HashSet<ContractId>>,
195) -> bool {
196 let Some(callers) = super_called.get(name) else { return false };
197 callers.iter().any(|&caller_cid| {
198 caller_cid != base_contract_id
199 && hir.contract(caller_cid).linearized_bases.contains(&base_contract_id)
200 })
201}
202
203fn any_override_referenced(
208 hir: &hir::Hir<'_>,
209 contract_id: ContractId,
210 base: &hir::Function<'_>,
211 referenced: &HashSet<FunctionId>,
212) -> bool {
213 let Some(base_name) = base.name else { return false };
214 let arity = base.parameters.len();
215
216 for (other_cid, other_contract) in hir.contracts_enumerated() {
217 if other_cid != contract_id && !other_contract.linearized_bases.contains(&contract_id) {
218 continue;
219 }
220 for fid in other_contract.functions() {
221 if referenced.contains(&fid) {
222 let other = hir.function(fid);
223 if let Some(other_name) = other.name
224 && other_name.name == base_name.name
225 && other.parameters.len() == arity
226 {
227 return true;
228 }
229 }
230 }
231 }
232 false
233}
234
235fn body_escapes_params<'hir>(
238 hir: &'hir hir::Hir<'hir>,
239 body: &hir::Block<'hir>,
240 params: &[VariableId],
241) -> bool {
242 let mut finder = ParamEscapeFinder { hir, params };
243 body.stmts.iter().any(|stmt| finder.visit_stmt(stmt).is_break())
244}
245
246fn modifier_args_reference_params(modifiers: &[hir::Modifier<'_>], params: &[VariableId]) -> bool {
249 modifiers.iter().any(|m| m.args.exprs().any(|arg| expr_root_is_param(arg, params)))
250}
251
252struct ParamEscapeFinder<'a, 'hir> {
253 hir: &'hir hir::Hir<'hir>,
254 params: &'a [VariableId],
255}
256
257impl<'hir> hir::Visit<'hir> for ParamEscapeFinder<'_, 'hir> {
258 type BreakValue = ();
259
260 fn hir(&self) -> &'hir hir::Hir<'hir> {
261 self.hir
262 }
263
264 fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
265 if let StmtKind::DeclSingle(vid) = &stmt.kind {
266 let var = self.hir.variable(*vid);
267 if let Some(init) = var.initializer
268 && var.ty.kind.is_reference_type()
269 && var.data_location == Some(DataLocation::Memory)
270 && expr_root_is_param(init, self.params)
271 {
272 return ControlFlow::Break(());
273 }
274 }
275 self.walk_stmt(stmt)
276 }
277
278 fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
279 match &expr.kind {
280 ExprKind::Assign(lhs, op, rhs) => {
281 if expr_root_is_param(lhs, self.params) {
282 return ControlFlow::Break(());
283 }
284 if op.is_none()
285 && lhs_is_local_memory_reference(self.hir, lhs)
286 && expr_root_is_param(rhs, self.params)
287 {
288 return ControlFlow::Break(());
289 }
290 }
291 ExprKind::Delete(inner) if expr_root_is_param(inner, self.params) => {
292 return ControlFlow::Break(());
293 }
294 ExprKind::Unary(op, inner)
295 if matches!(
296 op.kind,
297 UnOpKind::PreInc | UnOpKind::PreDec | UnOpKind::PostInc | UnOpKind::PostDec
298 ) && expr_root_is_param(inner, self.params) =>
299 {
300 return ControlFlow::Break(());
301 }
302 ExprKind::Call(callee, args, opts) if !is_type_conversion_callee(callee) => {
303 for arg in args.exprs() {
304 if expr_root_is_param(arg, self.params) {
305 return ControlFlow::Break(());
306 }
307 }
308 if let Some(opts) = opts {
309 for opt in opts.args {
310 if expr_root_is_param(&opt.value, self.params) {
311 return ControlFlow::Break(());
312 }
313 }
314 }
315 if let ExprKind::Member(receiver, _) = &callee.peel_parens().kind
316 && expr_root_is_param(receiver, self.params)
317 {
318 return ControlFlow::Break(());
319 }
320 }
321 _ => {}
322 }
323 self.walk_expr(expr)
324 }
325}
326
327fn is_type_conversion_callee(callee: &hir::Expr<'_>) -> bool {
329 let c = callee.peel_parens();
330 match &c.kind {
331 ExprKind::Type(_) | ExprKind::TypeCall(_) | ExprKind::New(_) => true,
332 ExprKind::Ident(reses) => reses.iter().any(|r| {
333 matches!(
334 r,
335 Res::Item(
336 ItemId::Struct(_) | ItemId::Contract(_) | ItemId::Enum(_) | ItemId::Udvt(_)
337 )
338 )
339 }),
340 _ => false,
341 }
342}
343
344fn lhs_is_local_memory_reference(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> bool {
347 let mut cur = lhs.peel_parens();
348 loop {
349 match &cur.kind {
350 ExprKind::Member(base, _) | ExprKind::Payable(base) => cur = base.peel_parens(),
351 ExprKind::Index(base, _) | ExprKind::Slice(base, _, _) => cur = base.peel_parens(),
352 ExprKind::Ident(reses) => {
353 return reses.iter().any(|r| {
354 if let Res::Item(ItemId::Variable(vid)) = r {
355 let v = hir.variable(*vid);
356 v.is_local_variable()
357 && v.ty.kind.is_reference_type()
358 && v.data_location == Some(DataLocation::Memory)
359 } else {
360 false
361 }
362 });
363 }
364 _ => return false,
365 }
366 }
367}
368
369fn expr_root_is_param(expr: &hir::Expr<'_>, params: &[VariableId]) -> bool {
372 let mut cur = expr.peel_parens();
373 loop {
374 match &cur.kind {
375 ExprKind::Member(base, _) | ExprKind::Payable(base) => cur = base.peel_parens(),
376 ExprKind::Index(base, _) | ExprKind::Slice(base, _, _) => cur = base.peel_parens(),
377 ExprKind::Ident(reses) => {
378 return reses.iter().any(
379 |r| matches!(r, Res::Item(ItemId::Variable(vid)) if params.contains(vid)),
380 );
381 }
382 _ => return false,
383 }
384 }
385}