forge_fmt/solang_ext/
ast_eq.rs

1use alloy_primitives::{Address, I256, U256};
2use solang_parser::pt::*;
3use std::str::FromStr;
4
5/// Helper to convert a string number into a comparable one
6fn to_num(string: &str) -> I256 {
7    if string.is_empty() {
8        return I256::ZERO
9    }
10    string.replace('_', "").trim().parse().unwrap()
11}
12
13/// Helper to convert the fractional part of a number into a comparable one.
14/// This will reverse the number so that 0's can be ignored
15fn to_num_reversed(string: &str) -> U256 {
16    if string.is_empty() {
17        return U256::from(0)
18    }
19    string.replace('_', "").trim().chars().rev().collect::<String>().parse().unwrap()
20}
21
22/// Helper to filter [ParameterList] to omit empty
23/// parameters
24fn filter_params(list: &ParameterList) -> ParameterList {
25    list.iter().filter(|(_, param)| param.is_some()).cloned().collect::<Vec<_>>()
26}
27
28/// Check if two ParseTrees are equal ignoring location information or ordering if ordering does
29/// not matter
30pub trait AstEq {
31    fn ast_eq(&self, other: &Self) -> bool;
32}
33
34impl AstEq for Loc {
35    fn ast_eq(&self, _other: &Self) -> bool {
36        true
37    }
38}
39
40impl AstEq for IdentifierPath {
41    fn ast_eq(&self, other: &Self) -> bool {
42        self.identifiers.ast_eq(&other.identifiers)
43    }
44}
45
46impl AstEq for SourceUnit {
47    fn ast_eq(&self, other: &Self) -> bool {
48        self.0.ast_eq(&other.0)
49    }
50}
51
52impl AstEq for VariableDefinition {
53    fn ast_eq(&self, other: &Self) -> bool {
54        let sorted_attrs = |def: &Self| {
55            let mut attrs = def.attrs.clone();
56            attrs.sort();
57            attrs
58        };
59        self.ty.ast_eq(&other.ty) &&
60            self.name.ast_eq(&other.name) &&
61            self.initializer.ast_eq(&other.initializer) &&
62            sorted_attrs(self).ast_eq(&sorted_attrs(other))
63    }
64}
65
66impl AstEq for FunctionDefinition {
67    fn ast_eq(&self, other: &Self) -> bool {
68        // attributes
69        let sorted_attrs = |def: &Self| {
70            let mut attrs = def.attributes.clone();
71            attrs.sort();
72            attrs
73        };
74
75        // params
76        let left_params = filter_params(&self.params);
77        let right_params = filter_params(&other.params);
78        let left_returns = filter_params(&self.returns);
79        let right_returns = filter_params(&other.returns);
80
81        self.ty.ast_eq(&other.ty) &&
82            self.name.ast_eq(&other.name) &&
83            left_params.ast_eq(&right_params) &&
84            self.return_not_returns.ast_eq(&other.return_not_returns) &&
85            left_returns.ast_eq(&right_returns) &&
86            self.body.ast_eq(&other.body) &&
87            sorted_attrs(self).ast_eq(&sorted_attrs(other))
88    }
89}
90
91impl AstEq for Base {
92    fn ast_eq(&self, other: &Self) -> bool {
93        self.name.ast_eq(&other.name) &&
94            self.args.clone().unwrap_or_default().ast_eq(&other.args.clone().unwrap_or_default())
95    }
96}
97
98impl<T> AstEq for Vec<T>
99where
100    T: AstEq,
101{
102    fn ast_eq(&self, other: &Self) -> bool {
103        if self.len() != other.len() {
104            false
105        } else {
106            self.iter().zip(other.iter()).all(|(left, right)| left.ast_eq(right))
107        }
108    }
109}
110
111impl<T> AstEq for Option<T>
112where
113    T: AstEq,
114{
115    fn ast_eq(&self, other: &Self) -> bool {
116        match (self, other) {
117            (Some(left), Some(right)) => left.ast_eq(right),
118            (None, None) => true,
119            _ => false,
120        }
121    }
122}
123
124impl<T> AstEq for Box<T>
125where
126    T: AstEq,
127{
128    fn ast_eq(&self, other: &Self) -> bool {
129        T::ast_eq(self, other)
130    }
131}
132
133impl AstEq for () {
134    fn ast_eq(&self, _other: &Self) -> bool {
135        true
136    }
137}
138
139impl<T> AstEq for &T
140where
141    T: AstEq,
142{
143    fn ast_eq(&self, other: &Self) -> bool {
144        T::ast_eq(self, other)
145    }
146}
147
148impl AstEq for String {
149    fn ast_eq(&self, other: &Self) -> bool {
150        match (Address::from_str(self), Address::from_str(other)) {
151            (Ok(left), Ok(right)) => left == right,
152            _ => self == other,
153        }
154    }
155}
156
157macro_rules! ast_eq_field {
158    (#[ast_eq_use($convert_func:ident)] $field:ident) => {
159        $convert_func($field)
160    };
161    ($field:ident) => {
162        $field
163    };
164}
165
166macro_rules! gen_ast_eq_enum {
167    ($self:expr, $other:expr, $name:ident {
168        $($unit_variant:ident),* $(,)?
169        _
170        $($tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? )),*  $(,)?
171        _
172        $($struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? }),*  $(,)?
173    }) => {
174        match $self {
175            $($name::$unit_variant => gen_ast_eq_enum!($other, $name, $unit_variant),)*
176            $($name::$tuple_variant($($tuple_field),*) =>
177                gen_ast_eq_enum!($other, $name, $tuple_variant ($($(#[ast_eq_use($tuple_convert_func)])? $tuple_field),*)),)*
178            $($name::$struct_variant { $($struct_field),* } =>
179                gen_ast_eq_enum!($other, $name, $struct_variant {$($(#[ast_eq_use($struct_convert_func)])? $struct_field),*}),)*
180        }
181    };
182    ($other:expr, $name:ident, $unit_variant:ident) => {
183        {
184            matches!($other, $name::$unit_variant)
185        }
186    };
187    ($other:expr, $name:ident, $tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? ) ) => {
188        {
189            let left = ($(ast_eq_field!($(#[ast_eq_use($tuple_convert_func)])? $tuple_field)),*);
190            if let $name::$tuple_variant($($tuple_field),*) = $other {
191                let right = ($(ast_eq_field!($(#[ast_eq_use($tuple_convert_func)])? $tuple_field)),*);
192                left.ast_eq(&right)
193            } else {
194                false
195            }
196        }
197    };
198    ($other:expr, $name:ident, $struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? } ) => {
199        {
200            let left = ($(ast_eq_field!($(#[ast_eq_use($struct_convert_func)])? $struct_field)),*);
201            if let $name::$struct_variant { $($struct_field),* } = $other {
202                let right = ($(ast_eq_field!($(#[ast_eq_use($struct_convert_func)])? $struct_field)),*);
203                left.ast_eq(&right)
204            } else {
205                false
206            }
207        }
208    };
209}
210
211macro_rules! wrap_in_box {
212    ($stmt:expr, $loc:expr) => {
213        if !matches!(**$stmt, Statement::Block { .. }) {
214            Box::new(Statement::Block {
215                loc: $loc,
216                unchecked: false,
217                statements: vec![*$stmt.clone()],
218            })
219        } else {
220            $stmt.clone()
221        }
222    };
223}
224
225impl AstEq for Statement {
226    fn ast_eq(&self, other: &Self) -> bool {
227        match self {
228            Self::If(loc, expr, stmt1, stmt2) => {
229                #[expect(clippy::borrowed_box)]
230                let wrap_if = |stmt1: &Box<Self>, stmt2: &Option<Box<Self>>| {
231                    (
232                        wrap_in_box!(stmt1, *loc),
233                        stmt2.as_ref().map(|stmt2| {
234                            if matches!(**stmt2, Self::If(..)) {
235                                stmt2.clone()
236                            } else {
237                                wrap_in_box!(stmt2, *loc)
238                            }
239                        }),
240                    )
241                };
242                let (stmt1, stmt2) = wrap_if(stmt1, stmt2);
243                let left = (loc, expr, &stmt1, &stmt2);
244                if let Self::If(loc, expr, stmt1, stmt2) = other {
245                    let (stmt1, stmt2) = wrap_if(stmt1, stmt2);
246                    let right = (loc, expr, &stmt1, &stmt2);
247                    left.ast_eq(&right)
248                } else {
249                    false
250                }
251            }
252            Self::While(loc, expr, stmt1) => {
253                let stmt1 = wrap_in_box!(stmt1, *loc);
254                let left = (loc, expr, &stmt1);
255                if let Self::While(loc, expr, stmt1) = other {
256                    let stmt1 = wrap_in_box!(stmt1, *loc);
257                    let right = (loc, expr, &stmt1);
258                    left.ast_eq(&right)
259                } else {
260                    false
261                }
262            }
263            Self::DoWhile(loc, stmt1, expr) => {
264                let stmt1 = wrap_in_box!(stmt1, *loc);
265                let left = (loc, &stmt1, expr);
266                if let Self::DoWhile(loc, stmt1, expr) = other {
267                    let stmt1 = wrap_in_box!(stmt1, *loc);
268                    let right = (loc, &stmt1, expr);
269                    left.ast_eq(&right)
270                } else {
271                    false
272                }
273            }
274            Self::For(loc, stmt1, expr, stmt2, stmt3) => {
275                let stmt3 = stmt3.as_ref().map(|stmt3| wrap_in_box!(stmt3, *loc));
276                let left = (loc, stmt1, expr, stmt2, &stmt3);
277                if let Self::For(loc, stmt1, expr, stmt2, stmt3) = other {
278                    let stmt3 = stmt3.as_ref().map(|stmt3| wrap_in_box!(stmt3, *loc));
279                    let right = (loc, stmt1, expr, stmt2, &stmt3);
280                    left.ast_eq(&right)
281                } else {
282                    false
283                }
284            }
285            Self::Try(loc, expr, returns, catch) => {
286                let left_returns =
287                    returns.as_ref().map(|(params, stmt)| (filter_params(params), stmt));
288                let left = (loc, expr, left_returns, catch);
289                if let Self::Try(loc, expr, returns, catch) = other {
290                    let right_returns =
291                        returns.as_ref().map(|(params, stmt)| (filter_params(params), stmt));
292                    let right = (loc, expr, right_returns, catch);
293                    left.ast_eq(&right)
294                } else {
295                    false
296                }
297            }
298            _ => gen_ast_eq_enum!(self, other, Statement {
299                _
300                Args(loc, args),
301                Expression(loc, expr),
302                VariableDefinition(loc, decl, expr),
303                Continue(loc, ),
304                Break(loc, ),
305                Return(loc, expr),
306                Revert(loc, expr, expr2),
307                RevertNamedArgs(loc, expr, args),
308                Emit(loc, expr),
309                // provide overridden variants regardless
310                If(loc, expr, stmt1, stmt2),
311                While(loc, expr, stmt1),
312                DoWhile(loc, stmt1, expr),
313                For(loc, stmt1, expr, stmt2, stmt3),
314                Try(loc, expr, params, clause),
315                Error(loc)
316                _
317                Block {
318                    loc,
319                    unchecked,
320                    statements,
321                },
322                Assembly {
323                    loc,
324                    dialect,
325                    block,
326                    flags,
327                },
328            }),
329        }
330    }
331}
332
333macro_rules! derive_ast_eq {
334    ($name:ident) => {
335        impl AstEq for $name {
336            fn ast_eq(&self, other: &Self) -> bool {
337                self == other
338            }
339        }
340    };
341    (($($index:tt $gen:tt),*)) => {
342        impl < $( $gen ),* > AstEq for ($($gen,)*) where $($gen: AstEq),* {
343            fn ast_eq(&self, other: &Self) -> bool {
344                $(
345                if !self.$index.ast_eq(&other.$index) {
346                    return false
347                }
348                )*
349                true
350            }
351        }
352    };
353    (struct $name:ident { $($field:ident),* $(,)? }) => {
354        impl AstEq for $name {
355            fn ast_eq(&self, other: &Self) -> bool {
356                let $name { $($field),* } = self;
357                let left = ($($field),*);
358                let $name { $($field),* } = other;
359                let right = ($($field),*);
360                left.ast_eq(&right)
361            }
362        }
363    };
364    (enum $name:ident {
365        $($unit_variant:ident),* $(,)?
366        _
367        $($tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? )),*  $(,)?
368        _
369        $($struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? }),*  $(,)?
370    }) => {
371        impl AstEq for $name {
372            fn ast_eq(&self, other: &Self) -> bool {
373                gen_ast_eq_enum!(self, other, $name {
374                    $($unit_variant),*
375                    _
376                    $($tuple_variant ( $($(#[ast_eq_use($tuple_convert_func)])? $tuple_field),* )),*
377                    _
378                    $($struct_variant { $($(#[ast_eq_use($struct_convert_func)])? $struct_field),* }),*
379                })
380            }
381        }
382    }
383}
384
385derive_ast_eq! { (0 A) }
386derive_ast_eq! { (0 A, 1 B) }
387derive_ast_eq! { (0 A, 1 B, 2 C) }
388derive_ast_eq! { (0 A, 1 B, 2 C, 3 D) }
389derive_ast_eq! { (0 A, 1 B, 2 C, 3 D, 4 E) }
390derive_ast_eq! { bool }
391derive_ast_eq! { u8 }
392derive_ast_eq! { u16 }
393derive_ast_eq! { I256 }
394derive_ast_eq! { U256 }
395derive_ast_eq! { struct Identifier { loc, name } }
396derive_ast_eq! { struct HexLiteral { loc, hex } }
397derive_ast_eq! { struct StringLiteral { loc, unicode, string } }
398derive_ast_eq! { struct Parameter { loc, annotation, ty, storage, name } }
399derive_ast_eq! { struct NamedArgument { loc, name, expr } }
400derive_ast_eq! { struct YulBlock { loc, statements } }
401derive_ast_eq! { struct YulFunctionCall { loc, id, arguments } }
402derive_ast_eq! { struct YulFunctionDefinition { loc, id, params, returns, body } }
403derive_ast_eq! { struct YulSwitch { loc, condition, cases, default } }
404derive_ast_eq! { struct YulFor {
405    loc,
406    init_block,
407    condition,
408    post_block,
409    execution_block,
410}}
411derive_ast_eq! { struct YulTypedIdentifier { loc, id, ty } }
412derive_ast_eq! { struct VariableDeclaration { loc, ty, storage, name } }
413derive_ast_eq! { struct Using { loc, list, ty, global } }
414derive_ast_eq! { struct UsingFunction { loc, path, oper } }
415derive_ast_eq! { struct TypeDefinition { loc, name, ty } }
416derive_ast_eq! { struct ContractDefinition { loc, ty, name, base, parts } }
417derive_ast_eq! { struct EventParameter { loc, ty, indexed, name } }
418derive_ast_eq! { struct ErrorParameter { loc, ty, name } }
419derive_ast_eq! { struct EventDefinition { loc, name, fields, anonymous } }
420derive_ast_eq! { struct ErrorDefinition { loc, keyword, name, fields } }
421derive_ast_eq! { struct StructDefinition { loc, name, fields } }
422derive_ast_eq! { struct EnumDefinition { loc, name, values } }
423derive_ast_eq! { struct Annotation { loc, id, value } }
424derive_ast_eq! { enum UsingList {
425    Error,
426    _
427    Library(expr),
428    Functions(exprs),
429    _
430}}
431derive_ast_eq! { enum UserDefinedOperator {
432    BitwiseAnd,
433    BitwiseNot,
434    Negate,
435    BitwiseOr,
436    BitwiseXor,
437    Add,
438    Divide,
439    Modulo,
440    Multiply,
441    Subtract,
442    Equal,
443    More,
444    MoreEqual,
445    Less,
446    LessEqual,
447    NotEqual,
448    _
449    _
450}}
451derive_ast_eq! { enum Visibility {
452    _
453    External(loc),
454    Public(loc),
455    Internal(loc),
456    Private(loc),
457    _
458}}
459derive_ast_eq! { enum Mutability {
460    _
461    Pure(loc),
462    View(loc),
463    Constant(loc),
464    Payable(loc),
465    _
466}}
467derive_ast_eq! { enum FunctionAttribute {
468    _
469    Mutability(muta),
470    Visibility(visi),
471    Virtual(loc),
472    Immutable(loc),
473    Override(loc, idents),
474    BaseOrModifier(loc, base),
475    Error(loc),
476    _
477}}
478derive_ast_eq! { enum StorageLocation {
479    _
480    Memory(loc),
481    Storage(loc),
482    Calldata(loc),
483    _
484}}
485derive_ast_eq! { enum Type {
486    Address,
487    AddressPayable,
488    Payable,
489    Bool,
490    Rational,
491    DynamicBytes,
492    String,
493    _
494    Int(int),
495    Uint(int),
496    Bytes(int),
497    _
498    Mapping{ loc, key, key_name, value, value_name },
499    Function { params, attributes, returns },
500}}
501derive_ast_eq! { enum Expression {
502    _
503    PostIncrement(loc, expr1),
504    PostDecrement(loc, expr1),
505    New(loc, expr1),
506    ArraySubscript(loc, expr1, expr2),
507    ArraySlice(
508        loc,
509        expr1,
510        expr2,
511        expr3,
512    ),
513    MemberAccess(loc, expr1, ident1),
514    FunctionCall(loc, expr1, exprs1),
515    FunctionCallBlock(loc, expr1, stmt),
516    NamedFunctionCall(loc, expr1, args),
517    Not(loc, expr1),
518    BitwiseNot(loc, expr1),
519    Delete(loc, expr1),
520    PreIncrement(loc, expr1),
521    PreDecrement(loc, expr1),
522    UnaryPlus(loc, expr1),
523    Negate(loc, expr1),
524    Power(loc, expr1, expr2),
525    Multiply(loc, expr1, expr2),
526    Divide(loc, expr1, expr2),
527    Modulo(loc, expr1, expr2),
528    Add(loc, expr1, expr2),
529    Subtract(loc, expr1, expr2),
530    ShiftLeft(loc, expr1, expr2),
531    ShiftRight(loc, expr1, expr2),
532    BitwiseAnd(loc, expr1, expr2),
533    BitwiseXor(loc, expr1, expr2),
534    BitwiseOr(loc, expr1, expr2),
535    Less(loc, expr1, expr2),
536    More(loc, expr1, expr2),
537    LessEqual(loc, expr1, expr2),
538    MoreEqual(loc, expr1, expr2),
539    Equal(loc, expr1, expr2),
540    NotEqual(loc, expr1, expr2),
541    And(loc, expr1, expr2),
542    Or(loc, expr1, expr2),
543    ConditionalOperator(loc, expr1, expr2, expr3),
544    Assign(loc, expr1, expr2),
545    AssignOr(loc, expr1, expr2),
546    AssignAnd(loc, expr1, expr2),
547    AssignXor(loc, expr1, expr2),
548    AssignShiftLeft(loc, expr1, expr2),
549    AssignShiftRight(loc, expr1, expr2),
550    AssignAdd(loc, expr1, expr2),
551    AssignSubtract(loc, expr1, expr2),
552    AssignMultiply(loc, expr1, expr2),
553    AssignDivide(loc, expr1, expr2),
554    AssignModulo(loc, expr1, expr2),
555    BoolLiteral(loc, bool1),
556    NumberLiteral(loc, #[ast_eq_use(to_num)] str1, #[ast_eq_use(to_num)] str2, unit),
557    RationalNumberLiteral(
558        loc,
559        #[ast_eq_use(to_num)] str1,
560        #[ast_eq_use(to_num_reversed)] str2,
561        #[ast_eq_use(to_num)] str3,
562        unit
563    ),
564    HexNumberLiteral(loc, str1, unit),
565    StringLiteral(strs1),
566    Type(loc, ty1),
567    HexLiteral(hexs1),
568    AddressLiteral(loc, str1),
569    Variable(ident1),
570    List(loc, params1),
571    ArrayLiteral(loc, exprs1),
572    Parenthesis(loc, expr)
573    _
574}}
575derive_ast_eq! { enum CatchClause {
576    _
577    Simple(param, ident, stmt),
578    Named(loc, ident, param, stmt),
579    _
580}}
581derive_ast_eq! { enum YulStatement {
582    _
583    Assign(loc, exprs, expr),
584    VariableDeclaration(loc, idents, expr),
585    If(loc, expr, block),
586    For(yul_for),
587    Switch(switch),
588    Leave(loc),
589    Break(loc),
590    Continue(loc),
591    Block(block),
592    FunctionDefinition(def),
593    FunctionCall(func),
594    Error(loc),
595    _
596}}
597derive_ast_eq! { enum YulExpression {
598    _
599    BoolLiteral(loc, boo, ident),
600    NumberLiteral(loc, string1, string2, ident),
601    HexNumberLiteral(loc, string, ident),
602    HexStringLiteral(hex, ident),
603    StringLiteral(string, ident),
604    Variable(ident),
605    FunctionCall(func),
606    SuffixAccess(loc, expr, ident),
607    _
608}}
609derive_ast_eq! { enum YulSwitchOptions {
610    _
611    Case(loc, expr, block),
612    Default(loc, block),
613    _
614}}
615derive_ast_eq! { enum SourceUnitPart {
616    _
617    ContractDefinition(def),
618    PragmaDirective(loc, ident, string),
619    ImportDirective(import),
620    EnumDefinition(def),
621    StructDefinition(def),
622    EventDefinition(def),
623    ErrorDefinition(def),
624    FunctionDefinition(def),
625    VariableDefinition(def),
626    TypeDefinition(def),
627    Using(using),
628    StraySemicolon(loc),
629    Annotation(annotation),
630    _
631}}
632derive_ast_eq! { enum ImportPath {
633    _
634    Filename(lit),
635    Path(path),
636    _
637}}
638derive_ast_eq! { enum Import {
639    _
640    Plain(string, loc),
641    GlobalSymbol(string, ident, loc),
642    Rename(string, idents, loc),
643    _
644}}
645derive_ast_eq! { enum FunctionTy {
646    Constructor,
647    Function,
648    Fallback,
649    Receive,
650    Modifier,
651    _
652    _
653}}
654derive_ast_eq! { enum ContractPart {
655    _
656    StructDefinition(def),
657    EventDefinition(def),
658    EnumDefinition(def),
659    ErrorDefinition(def),
660    VariableDefinition(def),
661    FunctionDefinition(def),
662    TypeDefinition(def),
663    StraySemicolon(loc),
664    Using(using),
665    Annotation(annotation),
666    _
667}}
668derive_ast_eq! { enum ContractTy {
669    _
670    Abstract(loc),
671    Contract(loc),
672    Interface(loc),
673    Library(loc),
674    _
675}}
676derive_ast_eq! { enum VariableAttribute {
677    _
678    Visibility(visi),
679    Constant(loc),
680    Immutable(loc),
681    Override(loc, idents),
682    _
683}}