forge_fmt/solang_ext/
ast_eq.rs

1use alloy_primitives::{Address, I256, U256};
2use solang_parser::pt::*;
3use std::str::FromStr;
4
5/// Helper to convert a string number into a comparable one
6fn to_num(string: &str) -> I256 {
7    if string.is_empty() {
8        return I256::ZERO
9    }
10    string.replace('_', "").trim().parse().unwrap()
11}
12
13/// Helper to convert the fractional part of a number into a comparable one.
14/// This will reverse the number so that 0's can be ignored
15fn to_num_reversed(string: &str) -> U256 {
16    if string.is_empty() {
17        return U256::from(0)
18    }
19    string.replace('_', "").trim().chars().rev().collect::<String>().parse().unwrap()
20}
21
22/// Helper to filter [ParameterList] to omit empty
23/// parameters
24fn filter_params(list: &ParameterList) -> ParameterList {
25    list.iter().filter(|(_, param)| param.is_some()).cloned().collect::<Vec<_>>()
26}
27
28/// Check if two ParseTrees are equal ignoring location information or ordering if ordering does
29/// not matter
30pub trait AstEq {
31    fn ast_eq(&self, other: &Self) -> bool;
32}
33
34impl AstEq for Loc {
35    fn ast_eq(&self, _other: &Self) -> bool {
36        true
37    }
38}
39
40impl AstEq for IdentifierPath {
41    fn ast_eq(&self, other: &Self) -> bool {
42        self.identifiers.ast_eq(&other.identifiers)
43    }
44}
45
46impl AstEq for SourceUnit {
47    fn ast_eq(&self, other: &Self) -> bool {
48        self.0.ast_eq(&other.0)
49    }
50}
51
52impl AstEq for VariableDefinition {
53    fn ast_eq(&self, other: &Self) -> bool {
54        let sorted_attrs = |def: &Self| {
55            let mut attrs = def.attrs.clone();
56            attrs.sort();
57            attrs
58        };
59        self.ty.ast_eq(&other.ty) &&
60            self.name.ast_eq(&other.name) &&
61            self.initializer.ast_eq(&other.initializer) &&
62            sorted_attrs(self).ast_eq(&sorted_attrs(other))
63    }
64}
65
66impl AstEq for FunctionDefinition {
67    fn ast_eq(&self, other: &Self) -> bool {
68        // attributes
69        let sorted_attrs = |def: &Self| {
70            let mut attrs = def.attributes.clone();
71            attrs.sort();
72            attrs
73        };
74
75        // params
76        let left_params = filter_params(&self.params);
77        let right_params = filter_params(&other.params);
78        let left_returns = filter_params(&self.returns);
79        let right_returns = filter_params(&other.returns);
80
81        self.ty.ast_eq(&other.ty) &&
82            self.name.ast_eq(&other.name) &&
83            left_params.ast_eq(&right_params) &&
84            self.return_not_returns.ast_eq(&other.return_not_returns) &&
85            left_returns.ast_eq(&right_returns) &&
86            self.body.ast_eq(&other.body) &&
87            sorted_attrs(self).ast_eq(&sorted_attrs(other))
88    }
89}
90
91impl AstEq for Base {
92    fn ast_eq(&self, other: &Self) -> bool {
93        self.name.ast_eq(&other.name) &&
94            self.args.clone().unwrap_or_default().ast_eq(&other.args.clone().unwrap_or_default())
95    }
96}
97
98impl<T> AstEq for Vec<T>
99where
100    T: AstEq,
101{
102    fn ast_eq(&self, other: &Self) -> bool {
103        if self.len() != other.len() {
104            false
105        } else {
106            self.iter().zip(other.iter()).all(|(left, right)| left.ast_eq(right))
107        }
108    }
109}
110
111impl<T> AstEq for Option<T>
112where
113    T: AstEq,
114{
115    fn ast_eq(&self, other: &Self) -> bool {
116        match (self, other) {
117            (Some(left), Some(right)) => left.ast_eq(right),
118            (None, None) => true,
119            _ => false,
120        }
121    }
122}
123
124impl<T> AstEq for Box<T>
125where
126    T: AstEq,
127{
128    fn ast_eq(&self, other: &Self) -> bool {
129        T::ast_eq(self, other)
130    }
131}
132
133impl AstEq for () {
134    fn ast_eq(&self, _other: &Self) -> bool {
135        true
136    }
137}
138
139impl<T> AstEq for &T
140where
141    T: AstEq,
142{
143    fn ast_eq(&self, other: &Self) -> bool {
144        T::ast_eq(self, other)
145    }
146}
147
148impl AstEq for String {
149    fn ast_eq(&self, other: &Self) -> bool {
150        match (Address::from_str(self), Address::from_str(other)) {
151            (Ok(left), Ok(right)) => left == right,
152            _ => self == other,
153        }
154    }
155}
156
157macro_rules! ast_eq_field {
158    (#[ast_eq_use($convert_func:ident)] $field:ident) => {
159        $convert_func($field)
160    };
161    ($field:ident) => {
162        $field
163    };
164}
165
166macro_rules! gen_ast_eq_enum {
167    ($self:expr, $other:expr, $name:ident {
168        $($unit_variant:ident),* $(,)?
169        _
170        $($tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? )),*  $(,)?
171        _
172        $($struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? }),*  $(,)?
173    }) => {
174        match $self {
175            $($name::$unit_variant => gen_ast_eq_enum!($other, $name, $unit_variant),)*
176            $($name::$tuple_variant($($tuple_field),*) =>
177                gen_ast_eq_enum!($other, $name, $tuple_variant ($($(#[ast_eq_use($tuple_convert_func)])? $tuple_field),*)),)*
178            $($name::$struct_variant { $($struct_field),* } =>
179                gen_ast_eq_enum!($other, $name, $struct_variant {$($(#[ast_eq_use($struct_convert_func)])? $struct_field),*}),)*
180        }
181    };
182    ($other:expr, $name:ident, $unit_variant:ident) => {
183        {
184            matches!($other, $name::$unit_variant)
185        }
186    };
187    ($other:expr, $name:ident, $tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? ) ) => {
188        {
189            let left = ($(ast_eq_field!($(#[ast_eq_use($tuple_convert_func)])? $tuple_field)),*);
190            if let $name::$tuple_variant($($tuple_field),*) = $other {
191                let right = ($(ast_eq_field!($(#[ast_eq_use($tuple_convert_func)])? $tuple_field)),*);
192                left.ast_eq(&right)
193            } else {
194                false
195            }
196        }
197    };
198    ($other:expr, $name:ident, $struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? } ) => {
199        {
200            let left = ($(ast_eq_field!($(#[ast_eq_use($struct_convert_func)])? $struct_field)),*);
201            if let $name::$struct_variant { $($struct_field),* } = $other {
202                let right = ($(ast_eq_field!($(#[ast_eq_use($struct_convert_func)])? $struct_field)),*);
203                left.ast_eq(&right)
204            } else {
205                false
206            }
207        }
208    };
209}
210
211macro_rules! wrap_in_box {
212    ($stmt:expr, $loc:expr) => {
213        if !matches!(**$stmt, Statement::Block { .. }) {
214            Box::new(Statement::Block {
215                loc: $loc,
216                unchecked: false,
217                statements: vec![*$stmt.clone()],
218            })
219        } else {
220            $stmt.clone()
221        }
222    };
223}
224
225impl AstEq for Statement {
226    fn ast_eq(&self, other: &Self) -> bool {
227        match self {
228            Self::If(loc, expr, stmt1, stmt2) => {
229                #[expect(clippy::borrowed_box)]
230                let wrap_if = |stmt1: &Box<Self>, stmt2: &Option<Box<Self>>| {
231                    (
232                        wrap_in_box!(stmt1, *loc),
233                        stmt2.as_ref().map(|stmt2| {
234                            if matches!(**stmt2, Self::If(..)) {
235                                stmt2.clone()
236                            } else {
237                                wrap_in_box!(stmt2, *loc)
238                            }
239                        }),
240                    )
241                };
242                let (stmt1, stmt2) = wrap_if(stmt1, stmt2);
243                let left = (loc, expr, &stmt1, &stmt2);
244                if let Self::If(loc, expr, stmt1, stmt2) = other {
245                    let (stmt1, stmt2) = wrap_if(stmt1, stmt2);
246                    let right = (loc, expr, &stmt1, &stmt2);
247                    left.ast_eq(&right)
248                } else {
249                    false
250                }
251            }
252            Self::While(loc, expr, stmt1) => {
253                let stmt1 = wrap_in_box!(stmt1, *loc);
254                let left = (loc, expr, &stmt1);
255                if let Self::While(loc, expr, stmt1) = other {
256                    let stmt1 = wrap_in_box!(stmt1, *loc);
257                    let right = (loc, expr, &stmt1);
258                    left.ast_eq(&right)
259                } else {
260                    false
261                }
262            }
263            Self::DoWhile(loc, stmt1, expr) => {
264                let stmt1 = wrap_in_box!(stmt1, *loc);
265                let left = (loc, &stmt1, expr);
266                if let Self::DoWhile(loc, stmt1, expr) = other {
267                    let stmt1 = wrap_in_box!(stmt1, *loc);
268                    let right = (loc, &stmt1, expr);
269                    left.ast_eq(&right)
270                } else {
271                    false
272                }
273            }
274            Self::For(loc, stmt1, expr, stmt2, stmt3) => {
275                let stmt3 = stmt3.as_ref().map(|stmt3| wrap_in_box!(stmt3, *loc));
276                let left = (loc, stmt1, expr, stmt2, &stmt3);
277                if let Self::For(loc, stmt1, expr, stmt2, stmt3) = other {
278                    let stmt3 = stmt3.as_ref().map(|stmt3| wrap_in_box!(stmt3, *loc));
279                    let right = (loc, stmt1, expr, stmt2, &stmt3);
280                    left.ast_eq(&right)
281                } else {
282                    false
283                }
284            }
285            Self::Try(loc, expr, returns, catch) => {
286                let left_returns =
287                    returns.as_ref().map(|(params, stmt)| (filter_params(params), stmt));
288                let left = (loc, expr, left_returns, catch);
289                if let Self::Try(loc, expr, returns, catch) = other {
290                    let right_returns =
291                        returns.as_ref().map(|(params, stmt)| (filter_params(params), stmt));
292                    let right = (loc, expr, right_returns, catch);
293                    left.ast_eq(&right)
294                } else {
295                    false
296                }
297            }
298            _ => gen_ast_eq_enum!(self, other, Statement {
299                _
300                Args(loc, args),
301                Expression(loc, expr),
302                VariableDefinition(loc, decl, expr),
303                Continue(loc, ),
304                Break(loc, ),
305                Return(loc, expr),
306                Revert(loc, expr, expr2),
307                RevertNamedArgs(loc, expr, args),
308                Emit(loc, expr),
309                // provide overridden variants regardless
310                If(loc, expr, stmt1, stmt2),
311                While(loc, expr, stmt1),
312                DoWhile(loc, stmt1, expr),
313                For(loc, stmt1, expr, stmt2, stmt3),
314                Try(loc, expr, params, clause),
315                Error(loc)
316                _
317                Block {
318                    loc,
319                    unchecked,
320                    statements,
321                },
322                Assembly {
323                    loc,
324                    dialect,
325                    block,
326                    flags,
327                },
328            }),
329        }
330    }
331}
332
333macro_rules! derive_ast_eq {
334    ($name:ident) => {
335        impl AstEq for $name {
336            fn ast_eq(&self, other: &Self) -> bool {
337                self == other
338            }
339        }
340    };
341    (($($index:tt $gen:tt),*)) => {
342        impl < $( $gen ),* > AstEq for ($($gen,)*) where $($gen: AstEq),* {
343            fn ast_eq(&self, other: &Self) -> bool {
344                $(
345                if !self.$index.ast_eq(&other.$index) {
346                    return false
347                }
348                )*
349                true
350            }
351        }
352    };
353    (struct $name:ident { $($field:ident),* $(,)? }) => {
354        impl AstEq for $name {
355            fn ast_eq(&self, other: &Self) -> bool {
356                let $name { $($field),* } = self;
357                let left = ($($field),*);
358                let $name { $($field),* } = other;
359                let right = ($($field),*);
360                left.ast_eq(&right)
361            }
362        }
363    };
364    (enum $name:ident {
365        $($unit_variant:ident),* $(,)?
366        _
367        $($tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? )),*  $(,)?
368        _
369        $($struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? }),*  $(,)?
370    }) => {
371        impl AstEq for $name {
372            fn ast_eq(&self, other: &Self) -> bool {
373                gen_ast_eq_enum!(self, other, $name {
374                    $($unit_variant),*
375                    _
376                    $($tuple_variant ( $($(#[ast_eq_use($tuple_convert_func)])? $tuple_field),* )),*
377                    _
378                    $($struct_variant { $($(#[ast_eq_use($struct_convert_func)])? $struct_field),* }),*
379                })
380            }
381        }
382    }
383}
384
385derive_ast_eq! { (0 A) }
386derive_ast_eq! { (0 A, 1 B) }
387derive_ast_eq! { (0 A, 1 B, 2 C) }
388derive_ast_eq! { (0 A, 1 B, 2 C, 3 D) }
389derive_ast_eq! { (0 A, 1 B, 2 C, 3 D, 4 E) }
390derive_ast_eq! { (0 A, 1 B, 2 C, 3 D, 4 E, 5 F) }
391derive_ast_eq! { (0 A, 1 B, 2 C, 3 D, 4 E, 5 F, 6 G) }
392derive_ast_eq! { bool }
393derive_ast_eq! { u8 }
394derive_ast_eq! { u16 }
395derive_ast_eq! { I256 }
396derive_ast_eq! { U256 }
397derive_ast_eq! { struct Identifier { loc, name } }
398derive_ast_eq! { struct HexLiteral { loc, hex } }
399derive_ast_eq! { struct StringLiteral { loc, unicode, string } }
400derive_ast_eq! { struct Parameter { loc, annotation, ty, storage, name } }
401derive_ast_eq! { struct NamedArgument { loc, name, expr } }
402derive_ast_eq! { struct YulBlock { loc, statements } }
403derive_ast_eq! { struct YulFunctionCall { loc, id, arguments } }
404derive_ast_eq! { struct YulFunctionDefinition { loc, id, params, returns, body } }
405derive_ast_eq! { struct YulSwitch { loc, condition, cases, default } }
406derive_ast_eq! { struct YulFor {
407    loc,
408    init_block,
409    condition,
410    post_block,
411    execution_block,
412}}
413derive_ast_eq! { struct YulTypedIdentifier { loc, id, ty } }
414derive_ast_eq! { struct VariableDeclaration { loc, ty, storage, name } }
415derive_ast_eq! { struct Using { loc, list, ty, global } }
416derive_ast_eq! { struct UsingFunction { loc, path, oper } }
417derive_ast_eq! { struct TypeDefinition { loc, name, ty } }
418derive_ast_eq! { struct ContractDefinition { loc, ty, name, base, layout, parts } }
419derive_ast_eq! { struct EventParameter { loc, ty, indexed, name } }
420derive_ast_eq! { struct ErrorParameter { loc, ty, name } }
421derive_ast_eq! { struct EventDefinition { loc, name, fields, anonymous } }
422derive_ast_eq! { struct ErrorDefinition { loc, keyword, name, fields } }
423derive_ast_eq! { struct StructDefinition { loc, name, fields } }
424derive_ast_eq! { struct EnumDefinition { loc, name, values } }
425derive_ast_eq! { struct Annotation { loc, id, value } }
426derive_ast_eq! { enum PragmaDirective {
427    _
428    Identifier(loc, id1, id2),
429    StringLiteral(loc, id, lit),
430    Version(loc, id, version),
431    _
432}}
433derive_ast_eq! { enum UsingList {
434    Error,
435    _
436    Library(expr),
437    Functions(exprs),
438    _
439}}
440derive_ast_eq! { enum UserDefinedOperator {
441    BitwiseAnd,
442    BitwiseNot,
443    Negate,
444    BitwiseOr,
445    BitwiseXor,
446    Add,
447    Divide,
448    Modulo,
449    Multiply,
450    Subtract,
451    Equal,
452    More,
453    MoreEqual,
454    Less,
455    LessEqual,
456    NotEqual,
457    _
458    _
459}}
460derive_ast_eq! { enum Visibility {
461    _
462    External(loc),
463    Public(loc),
464    Internal(loc),
465    Private(loc),
466    _
467}}
468derive_ast_eq! { enum Mutability {
469    _
470    Pure(loc),
471    View(loc),
472    Constant(loc),
473    Payable(loc),
474    _
475}}
476derive_ast_eq! { enum FunctionAttribute {
477    _
478    Mutability(muta),
479    Visibility(visi),
480    Virtual(loc),
481    Immutable(loc),
482    Override(loc, idents),
483    BaseOrModifier(loc, base),
484    Error(loc),
485    _
486}}
487derive_ast_eq! { enum StorageLocation {
488    _
489    Memory(loc),
490    Storage(loc),
491    Calldata(loc),
492    Transient(loc),
493    _
494}}
495derive_ast_eq! { enum Type {
496    Address,
497    AddressPayable,
498    Payable,
499    Bool,
500    Rational,
501    DynamicBytes,
502    String,
503    _
504    Int(int),
505    Uint(int),
506    Bytes(int),
507    _
508    Mapping{ loc, key, key_name, value, value_name },
509    Function { params, attributes, returns },
510}}
511derive_ast_eq! { enum Expression {
512    _
513    PostIncrement(loc, expr1),
514    PostDecrement(loc, expr1),
515    New(loc, expr1),
516    ArraySubscript(loc, expr1, expr2),
517    ArraySlice(
518        loc,
519        expr1,
520        expr2,
521        expr3,
522    ),
523    MemberAccess(loc, expr1, ident1),
524    FunctionCall(loc, expr1, exprs1),
525    FunctionCallBlock(loc, expr1, stmt),
526    NamedFunctionCall(loc, expr1, args),
527    Not(loc, expr1),
528    BitwiseNot(loc, expr1),
529    Delete(loc, expr1),
530    PreIncrement(loc, expr1),
531    PreDecrement(loc, expr1),
532    UnaryPlus(loc, expr1),
533    Negate(loc, expr1),
534    Power(loc, expr1, expr2),
535    Multiply(loc, expr1, expr2),
536    Divide(loc, expr1, expr2),
537    Modulo(loc, expr1, expr2),
538    Add(loc, expr1, expr2),
539    Subtract(loc, expr1, expr2),
540    ShiftLeft(loc, expr1, expr2),
541    ShiftRight(loc, expr1, expr2),
542    BitwiseAnd(loc, expr1, expr2),
543    BitwiseXor(loc, expr1, expr2),
544    BitwiseOr(loc, expr1, expr2),
545    Less(loc, expr1, expr2),
546    More(loc, expr1, expr2),
547    LessEqual(loc, expr1, expr2),
548    MoreEqual(loc, expr1, expr2),
549    Equal(loc, expr1, expr2),
550    NotEqual(loc, expr1, expr2),
551    And(loc, expr1, expr2),
552    Or(loc, expr1, expr2),
553    ConditionalOperator(loc, expr1, expr2, expr3),
554    Assign(loc, expr1, expr2),
555    AssignOr(loc, expr1, expr2),
556    AssignAnd(loc, expr1, expr2),
557    AssignXor(loc, expr1, expr2),
558    AssignShiftLeft(loc, expr1, expr2),
559    AssignShiftRight(loc, expr1, expr2),
560    AssignAdd(loc, expr1, expr2),
561    AssignSubtract(loc, expr1, expr2),
562    AssignMultiply(loc, expr1, expr2),
563    AssignDivide(loc, expr1, expr2),
564    AssignModulo(loc, expr1, expr2),
565    BoolLiteral(loc, bool1),
566    NumberLiteral(loc, #[ast_eq_use(to_num)] str1, #[ast_eq_use(to_num)] str2, unit),
567    RationalNumberLiteral(
568        loc,
569        #[ast_eq_use(to_num)] str1,
570        #[ast_eq_use(to_num_reversed)] str2,
571        #[ast_eq_use(to_num)] str3,
572        unit
573    ),
574    HexNumberLiteral(loc, str1, unit),
575    StringLiteral(strs1),
576    Type(loc, ty1),
577    HexLiteral(hexs1),
578    AddressLiteral(loc, str1),
579    Variable(ident1),
580    List(loc, params1),
581    ArrayLiteral(loc, exprs1),
582    Parenthesis(loc, expr)
583    _
584}}
585derive_ast_eq! { enum CatchClause {
586    _
587    Simple(param, ident, stmt),
588    Named(loc, ident, param, stmt),
589    _
590}}
591derive_ast_eq! { enum YulStatement {
592    _
593    Assign(loc, exprs, expr),
594    VariableDeclaration(loc, idents, expr),
595    If(loc, expr, block),
596    For(yul_for),
597    Switch(switch),
598    Leave(loc),
599    Break(loc),
600    Continue(loc),
601    Block(block),
602    FunctionDefinition(def),
603    FunctionCall(func),
604    Error(loc),
605    _
606}}
607derive_ast_eq! { enum YulExpression {
608    _
609    BoolLiteral(loc, boo, ident),
610    NumberLiteral(loc, string1, string2, ident),
611    HexNumberLiteral(loc, string, ident),
612    HexStringLiteral(hex, ident),
613    StringLiteral(string, ident),
614    Variable(ident),
615    FunctionCall(func),
616    SuffixAccess(loc, expr, ident),
617    _
618}}
619derive_ast_eq! { enum YulSwitchOptions {
620    _
621    Case(loc, expr, block),
622    Default(loc, block),
623    _
624}}
625derive_ast_eq! { enum SourceUnitPart {
626    _
627    ContractDefinition(def),
628    PragmaDirective(pragma),
629    ImportDirective(import),
630    EnumDefinition(def),
631    StructDefinition(def),
632    EventDefinition(def),
633    ErrorDefinition(def),
634    FunctionDefinition(def),
635    VariableDefinition(def),
636    TypeDefinition(def),
637    Using(using),
638    StraySemicolon(loc),
639    Annotation(annotation),
640    _
641}}
642derive_ast_eq! { enum ImportPath {
643    _
644    Filename(lit),
645    Path(path),
646    _
647}}
648derive_ast_eq! { enum Import {
649    _
650    Plain(string, loc),
651    GlobalSymbol(string, ident, loc),
652    Rename(string, idents, loc),
653    _
654}}
655derive_ast_eq! { enum FunctionTy {
656    Constructor,
657    Function,
658    Fallback,
659    Receive,
660    Modifier,
661    _
662    _
663}}
664derive_ast_eq! { enum ContractPart {
665    _
666    StructDefinition(def),
667    EventDefinition(def),
668    EnumDefinition(def),
669    ErrorDefinition(def),
670    VariableDefinition(def),
671    FunctionDefinition(def),
672    TypeDefinition(def),
673    StraySemicolon(loc),
674    Using(using),
675    Annotation(annotation),
676    _
677}}
678derive_ast_eq! { enum ContractTy {
679    _
680    Abstract(loc),
681    Contract(loc),
682    Interface(loc),
683    Library(loc),
684    _
685}}
686derive_ast_eq! { enum VariableAttribute {
687    _
688    Visibility(visi),
689    Constant(loc),
690    Immutable(loc),
691    Override(loc, idents),
692    StorageType(st),
693    StorageLocation(st),
694    _
695}}
696
697// Who cares
698impl AstEq for StorageType {
699    fn ast_eq(&self, _other: &Self) -> bool {
700        true
701    }
702}
703
704impl AstEq for VersionComparator {
705    fn ast_eq(&self, _other: &Self) -> bool {
706        true
707    }
708}