Skip to main content

forge_lint/sol/med/
incorrect_strict_equality.rs

1use super::IncorrectStrictEquality;
2use crate::{
3    linter::{LateLintPass, LintContext},
4    sol::{Severity, SolLint},
5};
6use solar::{
7    ast::{BinOpKind, ContractKind},
8    interface::sym,
9    sema::{
10        Gcx, Hir,
11        hir::{ElementaryType, Expr, ExprKind, ItemId, Res, StructId, Type, TypeKind},
12    },
13};
14
15declare_forge_lint!(
16    INCORRECT_STRICT_EQUALITY,
17    Severity::Med,
18    "incorrect-strict-equality",
19    "dangerous strict equality check on an externally-influenced value"
20);
21
22impl<'hir> LateLintPass<'hir> for IncorrectStrictEquality {
23    fn check_expr(
24        &mut self,
25        ctx: &LintContext,
26        _gcx: Gcx<'hir>,
27        hir: &'hir Hir<'hir>,
28        expr: &'hir Expr<'hir>,
29    ) {
30        if let ExprKind::Binary(lhs, op, rhs) = &expr.kind
31            && matches!(op.kind, BinOpKind::Eq | BinOpKind::Ne)
32            && (contains_externally_influenced(hir, lhs)
33                || contains_externally_influenced(hir, rhs))
34        {
35            ctx.emit(&INCORRECT_STRICT_EQUALITY, expr.span);
36        }
37    }
38}
39
40/// Recursively checks whether an expression tree contains an externally-influenced
41/// balance read. This makes the lint fire on cases like
42/// `address(this).balance + 1 == target` or `target == token.balanceOf(address(this)) - 1`.
43fn contains_externally_influenced<'hir>(hir: &'hir Hir<'hir>, expr: &Expr<'hir>) -> bool {
44    let expr = expr.peel_parens();
45    if is_externally_influenced(hir, expr) {
46        return true;
47    }
48    match &expr.kind {
49        ExprKind::Unary(_, inner) => contains_externally_influenced(hir, inner),
50        ExprKind::Binary(lhs, _, rhs) => {
51            contains_externally_influenced(hir, lhs) || contains_externally_influenced(hir, rhs)
52        }
53        ExprKind::Ternary(_, then, else_) => {
54            contains_externally_influenced(hir, then) || contains_externally_influenced(hir, else_)
55        }
56        ExprKind::Call(_, args, _) => args.exprs().any(|a| contains_externally_influenced(hir, a)),
57        _ => false,
58    }
59}
60
61/// Returns `true` if `expr` is `<address>.balance` or `<expr>.balanceOf(...)`.
62fn is_externally_influenced<'hir>(hir: &'hir Hir<'hir>, expr: &Expr<'hir>) -> bool {
63    match &expr.peel_parens().kind {
64        // `<expr>.balance`, only flag when we can prove the receiver is an `address`.
65        // Otherwise any user-defined struct field named `balance` would trigger this lint.
66        ExprKind::Member(base, member) => {
67            member.as_str() == "balance" && is_address_expr(hir, base)
68        }
69
70        // `<expr>.balanceOf(...)`, ERC-20 style external call. We match by name, since
71        // `balanceOf` is overwhelmingly an ERC-20 / token method.
72        // Skip calls where the receiver resolves to a library to avoid false positives
73        // on internal library helpers named `balanceOf`.
74        ExprKind::Call(callee, _, _) => {
75            if let ExprKind::Member(base, m) = &callee.peel_parens().kind
76                && m.as_str() == "balanceOf"
77            {
78                // Skip if the receiver resolves to a library contract.
79                !matches!(&base.peel_parens().kind, ExprKind::Ident(reses) if reses.iter().any(|r| {
80                    matches!(r, Res::Item(ItemId::Contract(cid)) if hir.contract(*cid).kind == ContractKind::Library)
81                }))
82            } else {
83                false
84            }
85        }
86
87        _ => false,
88    }
89}
90
91/// Conservatively returns `true` if `expr` is provably of type `address`
92/// (or `address payable`).
93///
94/// Returning `false` simply skips the lint, so being conservative is preferred over
95/// being exhaustive (see `docs/incorrect-strict-equality.md`).
96fn is_address_expr<'hir>(hir: &'hir Hir<'hir>, expr: &Expr<'hir>) -> bool {
97    let expr = expr.peel_parens();
98    match &expr.kind {
99        // `payable(x)` always returns `address payable`.
100        ExprKind::Payable(_) => true,
101
102        // `address(x)` cast, or a function call whose single return type is address.
103        ExprKind::Call(callee, _, _) => {
104            let callee = callee.peel_parens();
105            // Type cast: `address(x)` / `address payable(x)`.
106            if matches!(
107                &callee.kind,
108                ExprKind::Type(Type { kind: TypeKind::Elementary(ElementaryType::Address(_)), .. })
109            ) {
110                return true;
111            }
112            // Function call returning a single `address`.
113            if let ExprKind::Ident(reses) = &callee.kind {
114                return reses.iter().any(|r| {
115                    if let Res::Item(ItemId::Function(fid)) = r {
116                        let func = hir.function(*fid);
117                        if let [ret] = func.returns {
118                            return matches!(
119                                hir.variable(*ret).ty.kind,
120                                TypeKind::Elementary(ElementaryType::Address(_))
121                            );
122                        }
123                    }
124                    false
125                });
126            }
127            false
128        }
129
130        // Identifier resolving to a variable declared as `address` / `address payable`.
131        ExprKind::Ident(reses) => reses.iter().any(|r| {
132            matches!(
133                r,
134                Res::Item(ItemId::Variable(vid))
135                    if matches!(
136                        hir.variable(*vid).ty.kind,
137                        TypeKind::Elementary(ElementaryType::Address(_))
138                    )
139            )
140        }),
141
142        ExprKind::Member(base, member) => {
143            let name = member.as_str();
144            // Built-in members that return `address`: `msg.sender`, `tx.origin`, `block.coinbase`.
145            if let ExprKind::Ident(reses) = &base.peel_parens().kind {
146                let is_builtin = reses.iter().any(|r| {
147                    matches!(
148                        r,
149                        Res::Builtin(b) if {
150                            let base_sym = b.name();
151                            (base_sym == sym::msg && name == "sender")
152                                || (base_sym == sym::tx && name == "origin")
153                                || (base_sym == sym::block && name == "coinbase")
154                        }
155                    )
156                });
157                if is_builtin {
158                    return true;
159                }
160            }
161            // Struct field whose declared type is `address` (e.g. `holder.owner`).
162            matches!(struct_field_type(hir, base, name), Some(ElementaryType::Address(_)))
163        }
164
165        // Indexing into an array/mapping of `address` (e.g. `holders[i]`).
166        ExprKind::Index(base, _) => {
167            matches!(indexed_element_type(hir, base), Some(ElementaryType::Address(_)))
168        }
169
170        _ => false,
171    }
172}
173
174/// Resolves the declared elementary type of `field_name` on `base`, when `base` is
175/// known to be a struct value.
176fn struct_field_type<'hir>(
177    hir: &'hir Hir<'hir>,
178    base: &Expr<'hir>,
179    field_name: &str,
180) -> Option<ElementaryType> {
181    let strukt_id = struct_of(hir, base)?;
182    let strukt = hir.strukt(strukt_id);
183    for fid in strukt.fields {
184        let v = hir.variable(*fid);
185        if let Some(name) = v.name
186            && name.as_str() == field_name
187            && let TypeKind::Elementary(elem) = v.ty.kind
188        {
189            return Some(elem);
190        }
191    }
192    None
193}
194
195/// Returns the element type of `base` when it is an array or the value type when it is
196/// a mapping, restricted to elementary types.
197fn indexed_element_type<'hir>(hir: &'hir Hir<'hir>, base: &Expr<'hir>) -> Option<ElementaryType> {
198    let ExprKind::Ident(reses) = &base.peel_parens().kind else { return None };
199    let var = reses.iter().find_map(|r| match r {
200        Res::Item(ItemId::Variable(vid)) => Some(hir.variable(*vid)),
201        _ => None,
202    })?;
203    match &var.ty.kind {
204        TypeKind::Array(arr) => match arr.element.kind {
205            TypeKind::Elementary(elem) => Some(elem),
206            _ => None,
207        },
208        TypeKind::Mapping(m) => match m.value.kind {
209            TypeKind::Elementary(elem) => Some(elem),
210            _ => None,
211        },
212        _ => None,
213    }
214}
215
216/// Returns the [`StructId`] of `expr` when it is a (possibly chained) struct value.
217fn struct_of<'hir>(hir: &'hir Hir<'hir>, expr: &Expr<'hir>) -> Option<StructId> {
218    let expr = expr.peel_parens();
219    match &expr.kind {
220        ExprKind::Ident(reses) => reses.iter().find_map(|r| match r {
221            Res::Item(ItemId::Variable(vid)) => match hir.variable(*vid).ty.kind {
222                TypeKind::Custom(ItemId::Struct(sid)) => Some(sid),
223                _ => None,
224            },
225            _ => None,
226        }),
227        ExprKind::Member(inner, member) => {
228            let strukt_id = struct_of(hir, inner)?;
229            let strukt = hir.strukt(strukt_id);
230            for fid in strukt.fields {
231                let v = hir.variable(*fid);
232                if let Some(name) = v.name
233                    && name.as_str() == member.as_str()
234                    && let TypeKind::Custom(ItemId::Struct(sid)) = v.ty.kind
235                {
236                    return Some(sid);
237                }
238            }
239            None
240        }
241        _ => None,
242    }
243}