1use super::IncorrectStrictEquality;
2use crate::{
3 linter::{LateLintPass, LintContext},
4 sol::{Severity, SolLint},
5};
6use solar::{
7 ast::{BinOpKind, ContractKind},
8 interface::sym,
9 sema::{
10 Hir,
11 hir::{ElementaryType, Expr, ExprKind, ItemId, Res, StructId, Type, TypeKind},
12 },
13};
14
15declare_forge_lint!(
16 INCORRECT_STRICT_EQUALITY,
17 Severity::Med,
18 "incorrect-strict-equality",
19 "dangerous strict equality check on an externally-influenced value"
20);
21
22impl<'hir> LateLintPass<'hir> for IncorrectStrictEquality {
23 fn check_expr(&mut self, ctx: &LintContext, hir: &'hir Hir<'hir>, expr: &'hir Expr<'hir>) {
24 if let ExprKind::Binary(lhs, op, rhs) = &expr.kind
25 && matches!(op.kind, BinOpKind::Eq | BinOpKind::Ne)
26 && (contains_externally_influenced(hir, lhs)
27 || contains_externally_influenced(hir, rhs))
28 {
29 ctx.emit(&INCORRECT_STRICT_EQUALITY, expr.span);
30 }
31 }
32}
33
34fn contains_externally_influenced<'hir>(hir: &'hir Hir<'hir>, expr: &Expr<'hir>) -> bool {
38 let expr = expr.peel_parens();
39 if is_externally_influenced(hir, expr) {
40 return true;
41 }
42 match &expr.kind {
43 ExprKind::Unary(_, inner) => contains_externally_influenced(hir, inner),
44 ExprKind::Binary(lhs, _, rhs) => {
45 contains_externally_influenced(hir, lhs) || contains_externally_influenced(hir, rhs)
46 }
47 ExprKind::Ternary(_, then, else_) => {
48 contains_externally_influenced(hir, then) || contains_externally_influenced(hir, else_)
49 }
50 ExprKind::Call(_, args, _) => args.exprs().any(|a| contains_externally_influenced(hir, a)),
51 _ => false,
52 }
53}
54
55fn is_externally_influenced<'hir>(hir: &'hir Hir<'hir>, expr: &Expr<'hir>) -> bool {
57 match &expr.peel_parens().kind {
58 ExprKind::Member(base, member) => {
61 member.as_str() == "balance" && is_address_expr(hir, base)
62 }
63
64 ExprKind::Call(callee, _, _) => {
69 if let ExprKind::Member(base, m) = &callee.peel_parens().kind
70 && m.as_str() == "balanceOf"
71 {
72 !matches!(&base.peel_parens().kind, ExprKind::Ident(reses) if reses.iter().any(|r| {
74 matches!(r, Res::Item(ItemId::Contract(cid)) if hir.contract(*cid).kind == ContractKind::Library)
75 }))
76 } else {
77 false
78 }
79 }
80
81 _ => false,
82 }
83}
84
85fn is_address_expr<'hir>(hir: &'hir Hir<'hir>, expr: &Expr<'hir>) -> bool {
91 let expr = expr.peel_parens();
92 match &expr.kind {
93 ExprKind::Payable(_) => true,
95
96 ExprKind::Call(callee, _, _) => {
98 let callee = callee.peel_parens();
99 if matches!(
101 &callee.kind,
102 ExprKind::Type(Type { kind: TypeKind::Elementary(ElementaryType::Address(_)), .. })
103 ) {
104 return true;
105 }
106 if let ExprKind::Ident(reses) = &callee.kind {
108 return reses.iter().any(|r| {
109 if let Res::Item(ItemId::Function(fid)) = r {
110 let func = hir.function(*fid);
111 if let [ret] = func.returns {
112 return matches!(
113 hir.variable(*ret).ty.kind,
114 TypeKind::Elementary(ElementaryType::Address(_))
115 );
116 }
117 }
118 false
119 });
120 }
121 false
122 }
123
124 ExprKind::Ident(reses) => reses.iter().any(|r| {
126 matches!(
127 r,
128 Res::Item(ItemId::Variable(vid))
129 if matches!(
130 hir.variable(*vid).ty.kind,
131 TypeKind::Elementary(ElementaryType::Address(_))
132 )
133 )
134 }),
135
136 ExprKind::Member(base, member) => {
137 let name = member.as_str();
138 if let ExprKind::Ident(reses) = &base.peel_parens().kind {
140 let is_builtin = reses.iter().any(|r| {
141 matches!(
142 r,
143 Res::Builtin(b) if {
144 let base_sym = b.name();
145 (base_sym == sym::msg && name == "sender")
146 || (base_sym == sym::tx && name == "origin")
147 || (base_sym == sym::block && name == "coinbase")
148 }
149 )
150 });
151 if is_builtin {
152 return true;
153 }
154 }
155 matches!(struct_field_type(hir, base, name), Some(ElementaryType::Address(_)))
157 }
158
159 ExprKind::Index(base, _) => {
161 matches!(indexed_element_type(hir, base), Some(ElementaryType::Address(_)))
162 }
163
164 _ => false,
165 }
166}
167
168fn struct_field_type<'hir>(
171 hir: &'hir Hir<'hir>,
172 base: &Expr<'hir>,
173 field_name: &str,
174) -> Option<ElementaryType> {
175 let strukt_id = struct_of(hir, base)?;
176 let strukt = hir.strukt(strukt_id);
177 for fid in strukt.fields {
178 let v = hir.variable(*fid);
179 if let Some(name) = v.name
180 && name.as_str() == field_name
181 && let TypeKind::Elementary(elem) = v.ty.kind
182 {
183 return Some(elem);
184 }
185 }
186 None
187}
188
189fn indexed_element_type<'hir>(hir: &'hir Hir<'hir>, base: &Expr<'hir>) -> Option<ElementaryType> {
192 let ExprKind::Ident(reses) = &base.peel_parens().kind else { return None };
193 let var = reses.iter().find_map(|r| match r {
194 Res::Item(ItemId::Variable(vid)) => Some(hir.variable(*vid)),
195 _ => None,
196 })?;
197 match &var.ty.kind {
198 TypeKind::Array(arr) => match arr.element.kind {
199 TypeKind::Elementary(elem) => Some(elem),
200 _ => None,
201 },
202 TypeKind::Mapping(m) => match m.value.kind {
203 TypeKind::Elementary(elem) => Some(elem),
204 _ => None,
205 },
206 _ => None,
207 }
208}
209
210fn struct_of<'hir>(hir: &'hir Hir<'hir>, expr: &Expr<'hir>) -> Option<StructId> {
212 let expr = expr.peel_parens();
213 match &expr.kind {
214 ExprKind::Ident(reses) => reses.iter().find_map(|r| match r {
215 Res::Item(ItemId::Variable(vid)) => match hir.variable(*vid).ty.kind {
216 TypeKind::Custom(ItemId::Struct(sid)) => Some(sid),
217 _ => None,
218 },
219 _ => None,
220 }),
221 ExprKind::Member(inner, member) => {
222 let strukt_id = struct_of(hir, inner)?;
223 let strukt = hir.strukt(strukt_id);
224 for fid in strukt.fields {
225 let v = hir.variable(*fid);
226 if let Some(name) = v.name
227 && name.as_str() == member.as_str()
228 && let TypeKind::Custom(ItemId::Struct(sid)) = v.ty.kind
229 {
230 return Some(sid);
231 }
232 }
233 None
234 }
235 _ => None,
236 }
237}