1use super::ExternalFunction;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{Severity, SolLint},
5};
6use solar::{
7 ast::{ContractKind, DataLocation, UnOpKind, Visibility},
8 interface::{Symbol, data_structures::Never},
9 sema::hir::{
10 self, ContractId, ExprKind, FunctionId, ItemId, Res, StmtKind, VariableId, Visit as _,
11 },
12};
13use std::{
14 cell::RefCell,
15 collections::{HashMap, HashSet},
16 ops::ControlFlow,
17 rc::Rc,
18};
19
20declare_forge_lint!(
21 EXTERNAL_FUNCTION,
22 Severity::Gas,
23 "external-function",
24 "public function can be declared external"
25);
26
27#[derive(Default)]
28struct ProjectIndex {
29 referenced: HashSet<FunctionId>,
32 super_called: HashMap<Symbol, HashSet<ContractId>>,
35}
36
37thread_local! {
38 static PROJECT_INDEX: RefCell<Option<(usize, Rc<ProjectIndex>)>> = const { RefCell::new(None) };
42}
43
44fn project_index_for<'hir>(hir: &'hir hir::Hir<'hir>) -> Rc<ProjectIndex> {
45 let key = std::ptr::from_ref::<hir::Hir<'_>>(hir) as usize;
46 PROJECT_INDEX.with(|cell| {
47 let mut slot = cell.borrow_mut();
48 if let Some((cached_key, cached)) = slot.as_ref()
49 && *cached_key == key
50 {
51 return cached.clone();
52 }
53 let fresh = Rc::new(build_project_index(hir));
54 *slot = Some((key, fresh.clone()));
55 fresh
56 })
57}
58
59impl<'hir> LateLintPass<'hir> for ExternalFunction {
60 fn check_nested_contract(
61 &mut self,
62 ctx: &LintContext,
63 hir: &'hir hir::Hir<'hir>,
64 contract_id: ContractId,
65 ) {
66 if !ctx.is_lint_enabled(EXTERNAL_FUNCTION.id) {
67 return;
68 }
69
70 let contract = hir.contract(contract_id);
71
72 if !matches!(contract.kind, ContractKind::Contract | ContractKind::AbstractContract) {
75 return;
76 }
77 if contract.linearization_failed() {
78 return;
79 }
80
81 let index = project_index_for(hir);
82
83 for fid in contract.functions() {
84 let func = hir.function(fid);
85
86 if func.visibility != Visibility::Public || !func.is_ordinary() {
88 continue;
89 }
90 if func.override_ {
93 continue;
94 }
95 let Some(body) = func.body else { continue };
97
98 let has_memory_reference_param = func.parameters.iter().any(|&pid| {
101 let p = hir.variable(pid);
102 p.ty.kind.is_reference_type() && p.data_location == Some(DataLocation::Memory)
103 });
104 if !has_memory_reference_param {
105 continue;
106 }
107
108 if body_escapes_params(hir, &body, func.parameters)
109 || modifier_args_reference_params(func.modifiers, func.parameters)
110 {
111 continue;
112 }
113
114 let Some(name) = func.name else { continue };
115
116 if super_called_from_derivative(hir, contract_id, &name.name, &index.super_called) {
119 continue;
120 }
121 if any_override_referenced(hir, contract_id, func, &index.referenced) {
122 continue;
123 }
124
125 ctx.emit(&EXTERNAL_FUNCTION, name.span);
126 }
127 }
128}
129
130fn build_project_index<'hir>(hir: &'hir hir::Hir<'hir>) -> ProjectIndex {
131 let mut builder = IndexBuilder { hir, idx: ProjectIndex::default(), current_contract: None };
132 for func in hir.functions() {
133 builder.current_contract = func.contract;
134 let _ = builder.visit_function(func);
135 }
136 for var in hir.variables() {
139 if var.is_state_variable() {
140 builder.current_contract = var.contract;
141 let _ = builder.visit_var(var);
142 }
143 }
144 builder.idx
145}
146
147struct IndexBuilder<'hir> {
151 hir: &'hir hir::Hir<'hir>,
152 idx: ProjectIndex,
153 current_contract: Option<ContractId>,
155}
156
157impl<'hir> hir::Visit<'hir> for IndexBuilder<'hir> {
158 type BreakValue = Never;
159
160 fn hir(&self) -> &'hir hir::Hir<'hir> {
161 self.hir
162 }
163
164 fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
165 match &expr.kind {
166 ExprKind::Ident(reses) => {
167 for res in *reses {
168 if let Res::Item(ItemId::Function(fid)) = res {
169 self.idx.referenced.insert(*fid);
170 }
171 }
172 }
173 ExprKind::Member(base, member) => {
174 if let Some(cid) = self.current_contract
175 && let ExprKind::Ident(reses) = &base.peel_parens().kind
176 && reses.iter().any(|r| matches!(r, Res::Builtin(b) if b.name() == solar::interface::sym::super_))
177 {
178 self.idx.super_called.entry(member.name).or_default().insert(cid);
179 }
180 }
181 _ => {}
182 }
183 self.walk_expr(expr)
184 }
185}
186
187fn super_called_from_derivative(
190 hir: &hir::Hir<'_>,
191 base_contract_id: ContractId,
192 name: &Symbol,
193 super_called: &HashMap<Symbol, HashSet<ContractId>>,
194) -> bool {
195 let Some(callers) = super_called.get(name) else { return false };
196 callers.iter().any(|&caller_cid| {
197 caller_cid != base_contract_id
198 && hir.contract(caller_cid).linearized_bases.contains(&base_contract_id)
199 })
200}
201
202fn any_override_referenced(
207 hir: &hir::Hir<'_>,
208 contract_id: ContractId,
209 base: &hir::Function<'_>,
210 referenced: &HashSet<FunctionId>,
211) -> bool {
212 let Some(base_name) = base.name else { return false };
213 let arity = base.parameters.len();
214
215 for (other_cid, other_contract) in hir.contracts_enumerated() {
216 if other_cid != contract_id && !other_contract.linearized_bases.contains(&contract_id) {
217 continue;
218 }
219 for fid in other_contract.functions() {
220 if referenced.contains(&fid) {
221 let other = hir.function(fid);
222 if let Some(other_name) = other.name
223 && other_name.name == base_name.name
224 && other.parameters.len() == arity
225 {
226 return true;
227 }
228 }
229 }
230 }
231 false
232}
233
234fn body_escapes_params<'hir>(
237 hir: &'hir hir::Hir<'hir>,
238 body: &hir::Block<'hir>,
239 params: &[VariableId],
240) -> bool {
241 let mut finder = ParamEscapeFinder { hir, params };
242 body.stmts.iter().any(|stmt| finder.visit_stmt(stmt).is_break())
243}
244
245fn modifier_args_reference_params(modifiers: &[hir::Modifier<'_>], params: &[VariableId]) -> bool {
248 modifiers.iter().any(|m| m.args.exprs().any(|arg| expr_root_is_param(arg, params)))
249}
250
251struct ParamEscapeFinder<'a, 'hir> {
252 hir: &'hir hir::Hir<'hir>,
253 params: &'a [VariableId],
254}
255
256impl<'hir> hir::Visit<'hir> for ParamEscapeFinder<'_, 'hir> {
257 type BreakValue = ();
258
259 fn hir(&self) -> &'hir hir::Hir<'hir> {
260 self.hir
261 }
262
263 fn visit_stmt(&mut self, stmt: &'hir hir::Stmt<'hir>) -> ControlFlow<Self::BreakValue> {
264 if let StmtKind::DeclSingle(vid) = &stmt.kind {
265 let var = self.hir.variable(*vid);
266 if let Some(init) = var.initializer
267 && var.ty.kind.is_reference_type()
268 && var.data_location == Some(DataLocation::Memory)
269 && expr_root_is_param(init, self.params)
270 {
271 return ControlFlow::Break(());
272 }
273 }
274 self.walk_stmt(stmt)
275 }
276
277 fn visit_expr(&mut self, expr: &'hir hir::Expr<'hir>) -> ControlFlow<Self::BreakValue> {
278 match &expr.kind {
279 ExprKind::Assign(lhs, op, rhs) => {
280 if expr_root_is_param(lhs, self.params) {
281 return ControlFlow::Break(());
282 }
283 if op.is_none()
284 && lhs_is_local_memory_reference(self.hir, lhs)
285 && expr_root_is_param(rhs, self.params)
286 {
287 return ControlFlow::Break(());
288 }
289 }
290 ExprKind::Delete(inner) => {
291 if expr_root_is_param(inner, self.params) {
292 return ControlFlow::Break(());
293 }
294 }
295 ExprKind::Unary(op, inner)
296 if matches!(
297 op.kind,
298 UnOpKind::PreInc | UnOpKind::PreDec | UnOpKind::PostInc | UnOpKind::PostDec
299 ) =>
300 {
301 if expr_root_is_param(inner, self.params) {
302 return ControlFlow::Break(());
303 }
304 }
305 ExprKind::Call(callee, args, opts) if !is_type_conversion_callee(callee) => {
306 for arg in args.exprs() {
307 if expr_root_is_param(arg, self.params) {
308 return ControlFlow::Break(());
309 }
310 }
311 if let Some(opts) = opts {
312 for opt in *opts {
313 if expr_root_is_param(&opt.value, self.params) {
314 return ControlFlow::Break(());
315 }
316 }
317 }
318 if let ExprKind::Member(receiver, _) = &callee.peel_parens().kind
319 && expr_root_is_param(receiver, self.params)
320 {
321 return ControlFlow::Break(());
322 }
323 }
324 _ => {}
325 }
326 self.walk_expr(expr)
327 }
328}
329
330fn is_type_conversion_callee(callee: &hir::Expr<'_>) -> bool {
332 let c = callee.peel_parens();
333 match &c.kind {
334 ExprKind::Type(_) | ExprKind::TypeCall(_) | ExprKind::New(_) => true,
335 ExprKind::Ident(reses) => reses.iter().any(|r| {
336 matches!(
337 r,
338 Res::Item(
339 ItemId::Struct(_) | ItemId::Contract(_) | ItemId::Enum(_) | ItemId::Udvt(_)
340 )
341 )
342 }),
343 _ => false,
344 }
345}
346
347fn lhs_is_local_memory_reference(hir: &hir::Hir<'_>, lhs: &hir::Expr<'_>) -> bool {
350 let mut cur = lhs.peel_parens();
351 loop {
352 match &cur.kind {
353 ExprKind::Member(base, _) | ExprKind::Payable(base) => cur = base.peel_parens(),
354 ExprKind::Index(base, _) | ExprKind::Slice(base, _, _) => cur = base.peel_parens(),
355 ExprKind::Ident(reses) => {
356 return reses.iter().any(|r| {
357 if let Res::Item(ItemId::Variable(vid)) = r {
358 let v = hir.variable(*vid);
359 v.is_local_variable()
360 && v.ty.kind.is_reference_type()
361 && v.data_location == Some(DataLocation::Memory)
362 } else {
363 false
364 }
365 });
366 }
367 _ => return false,
368 }
369 }
370}
371
372fn expr_root_is_param(expr: &hir::Expr<'_>, params: &[VariableId]) -> bool {
375 let mut cur = expr.peel_parens();
376 loop {
377 match &cur.kind {
378 ExprKind::Member(base, _) | ExprKind::Payable(base) => cur = base.peel_parens(),
379 ExprKind::Index(base, _) | ExprKind::Slice(base, _, _) => cur = base.peel_parens(),
380 ExprKind::Ident(reses) => {
381 return reses.iter().any(
382 |r| matches!(r, Res::Item(ItemId::Variable(vid)) if params.contains(vid)),
383 );
384 }
385 _ => return false,
386 }
387 }
388}