1use alloy_primitives::{Address, I256, U256};
2use solang_parser::pt::*;
3use std::str::FromStr;
4
5fn to_num(string: &str) -> I256 {
7 if string.is_empty() {
8 return I256::ZERO
9 }
10 string.replace('_', "").trim().parse().unwrap()
11}
12
13fn to_num_reversed(string: &str) -> U256 {
16 if string.is_empty() {
17 return U256::from(0)
18 }
19 string.replace('_', "").trim().chars().rev().collect::<String>().parse().unwrap()
20}
21
22fn filter_params(list: &ParameterList) -> ParameterList {
25 list.iter().filter(|(_, param)| param.is_some()).cloned().collect::<Vec<_>>()
26}
27
28pub trait AstEq {
31 fn ast_eq(&self, other: &Self) -> bool;
32}
33
34impl AstEq for Loc {
35 fn ast_eq(&self, _other: &Self) -> bool {
36 true
37 }
38}
39
40impl AstEq for IdentifierPath {
41 fn ast_eq(&self, other: &Self) -> bool {
42 self.identifiers.ast_eq(&other.identifiers)
43 }
44}
45
46impl AstEq for SourceUnit {
47 fn ast_eq(&self, other: &Self) -> bool {
48 self.0.ast_eq(&other.0)
49 }
50}
51
52impl AstEq for VariableDefinition {
53 fn ast_eq(&self, other: &Self) -> bool {
54 let sorted_attrs = |def: &Self| {
55 let mut attrs = def.attrs.clone();
56 attrs.sort();
57 attrs
58 };
59 self.ty.ast_eq(&other.ty) &&
60 self.name.ast_eq(&other.name) &&
61 self.initializer.ast_eq(&other.initializer) &&
62 sorted_attrs(self).ast_eq(&sorted_attrs(other))
63 }
64}
65
66impl AstEq for FunctionDefinition {
67 fn ast_eq(&self, other: &Self) -> bool {
68 let sorted_attrs = |def: &Self| {
70 let mut attrs = def.attributes.clone();
71 attrs.sort();
72 attrs
73 };
74
75 let left_params = filter_params(&self.params);
77 let right_params = filter_params(&other.params);
78 let left_returns = filter_params(&self.returns);
79 let right_returns = filter_params(&other.returns);
80
81 self.ty.ast_eq(&other.ty) &&
82 self.name.ast_eq(&other.name) &&
83 left_params.ast_eq(&right_params) &&
84 self.return_not_returns.ast_eq(&other.return_not_returns) &&
85 left_returns.ast_eq(&right_returns) &&
86 self.body.ast_eq(&other.body) &&
87 sorted_attrs(self).ast_eq(&sorted_attrs(other))
88 }
89}
90
91impl AstEq for Base {
92 fn ast_eq(&self, other: &Self) -> bool {
93 self.name.ast_eq(&other.name) &&
94 self.args.clone().unwrap_or_default().ast_eq(&other.args.clone().unwrap_or_default())
95 }
96}
97
98impl<T> AstEq for Vec<T>
99where
100 T: AstEq,
101{
102 fn ast_eq(&self, other: &Self) -> bool {
103 if self.len() != other.len() {
104 false
105 } else {
106 self.iter().zip(other.iter()).all(|(left, right)| left.ast_eq(right))
107 }
108 }
109}
110
111impl<T> AstEq for Option<T>
112where
113 T: AstEq,
114{
115 fn ast_eq(&self, other: &Self) -> bool {
116 match (self, other) {
117 (Some(left), Some(right)) => left.ast_eq(right),
118 (None, None) => true,
119 _ => false,
120 }
121 }
122}
123
124impl<T> AstEq for Box<T>
125where
126 T: AstEq,
127{
128 fn ast_eq(&self, other: &Self) -> bool {
129 T::ast_eq(self, other)
130 }
131}
132
133impl AstEq for () {
134 fn ast_eq(&self, _other: &Self) -> bool {
135 true
136 }
137}
138
139impl<T> AstEq for &T
140where
141 T: AstEq,
142{
143 fn ast_eq(&self, other: &Self) -> bool {
144 T::ast_eq(self, other)
145 }
146}
147
148impl AstEq for String {
149 fn ast_eq(&self, other: &Self) -> bool {
150 match (Address::from_str(self), Address::from_str(other)) {
151 (Ok(left), Ok(right)) => left == right,
152 _ => self == other,
153 }
154 }
155}
156
157macro_rules! ast_eq_field {
158 (#[ast_eq_use($convert_func:ident)] $field:ident) => {
159 $convert_func($field)
160 };
161 ($field:ident) => {
162 $field
163 };
164}
165
166macro_rules! gen_ast_eq_enum {
167 ($self:expr, $other:expr, $name:ident {
168 $($unit_variant:ident),* $(,)?
169 _
170 $($tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? )),* $(,)?
171 _
172 $($struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? }),* $(,)?
173 }) => {
174 match $self {
175 $($name::$unit_variant => gen_ast_eq_enum!($other, $name, $unit_variant),)*
176 $($name::$tuple_variant($($tuple_field),*) =>
177 gen_ast_eq_enum!($other, $name, $tuple_variant ($($(#[ast_eq_use($tuple_convert_func)])? $tuple_field),*)),)*
178 $($name::$struct_variant { $($struct_field),* } =>
179 gen_ast_eq_enum!($other, $name, $struct_variant {$($(#[ast_eq_use($struct_convert_func)])? $struct_field),*}),)*
180 }
181 };
182 ($other:expr, $name:ident, $unit_variant:ident) => {
183 {
184 matches!($other, $name::$unit_variant)
185 }
186 };
187 ($other:expr, $name:ident, $tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? ) ) => {
188 {
189 let left = ($(ast_eq_field!($(#[ast_eq_use($tuple_convert_func)])? $tuple_field)),*);
190 if let $name::$tuple_variant($($tuple_field),*) = $other {
191 let right = ($(ast_eq_field!($(#[ast_eq_use($tuple_convert_func)])? $tuple_field)),*);
192 left.ast_eq(&right)
193 } else {
194 false
195 }
196 }
197 };
198 ($other:expr, $name:ident, $struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? } ) => {
199 {
200 let left = ($(ast_eq_field!($(#[ast_eq_use($struct_convert_func)])? $struct_field)),*);
201 if let $name::$struct_variant { $($struct_field),* } = $other {
202 let right = ($(ast_eq_field!($(#[ast_eq_use($struct_convert_func)])? $struct_field)),*);
203 left.ast_eq(&right)
204 } else {
205 false
206 }
207 }
208 };
209}
210
211macro_rules! wrap_in_box {
212 ($stmt:expr, $loc:expr) => {
213 if !matches!(**$stmt, Statement::Block { .. }) {
214 Box::new(Statement::Block {
215 loc: $loc,
216 unchecked: false,
217 statements: vec![*$stmt.clone()],
218 })
219 } else {
220 $stmt.clone()
221 }
222 };
223}
224
225impl AstEq for Statement {
226 fn ast_eq(&self, other: &Self) -> bool {
227 match self {
228 Self::If(loc, expr, stmt1, stmt2) => {
229 #[expect(clippy::borrowed_box)]
230 let wrap_if = |stmt1: &Box<Self>, stmt2: &Option<Box<Self>>| {
231 (
232 wrap_in_box!(stmt1, *loc),
233 stmt2.as_ref().map(|stmt2| {
234 if matches!(**stmt2, Self::If(..)) {
235 stmt2.clone()
236 } else {
237 wrap_in_box!(stmt2, *loc)
238 }
239 }),
240 )
241 };
242 let (stmt1, stmt2) = wrap_if(stmt1, stmt2);
243 let left = (loc, expr, &stmt1, &stmt2);
244 if let Self::If(loc, expr, stmt1, stmt2) = other {
245 let (stmt1, stmt2) = wrap_if(stmt1, stmt2);
246 let right = (loc, expr, &stmt1, &stmt2);
247 left.ast_eq(&right)
248 } else {
249 false
250 }
251 }
252 Self::While(loc, expr, stmt1) => {
253 let stmt1 = wrap_in_box!(stmt1, *loc);
254 let left = (loc, expr, &stmt1);
255 if let Self::While(loc, expr, stmt1) = other {
256 let stmt1 = wrap_in_box!(stmt1, *loc);
257 let right = (loc, expr, &stmt1);
258 left.ast_eq(&right)
259 } else {
260 false
261 }
262 }
263 Self::DoWhile(loc, stmt1, expr) => {
264 let stmt1 = wrap_in_box!(stmt1, *loc);
265 let left = (loc, &stmt1, expr);
266 if let Self::DoWhile(loc, stmt1, expr) = other {
267 let stmt1 = wrap_in_box!(stmt1, *loc);
268 let right = (loc, &stmt1, expr);
269 left.ast_eq(&right)
270 } else {
271 false
272 }
273 }
274 Self::For(loc, stmt1, expr, stmt2, stmt3) => {
275 let stmt3 = stmt3.as_ref().map(|stmt3| wrap_in_box!(stmt3, *loc));
276 let left = (loc, stmt1, expr, stmt2, &stmt3);
277 if let Self::For(loc, stmt1, expr, stmt2, stmt3) = other {
278 let stmt3 = stmt3.as_ref().map(|stmt3| wrap_in_box!(stmt3, *loc));
279 let right = (loc, stmt1, expr, stmt2, &stmt3);
280 left.ast_eq(&right)
281 } else {
282 false
283 }
284 }
285 Self::Try(loc, expr, returns, catch) => {
286 let left_returns =
287 returns.as_ref().map(|(params, stmt)| (filter_params(params), stmt));
288 let left = (loc, expr, left_returns, catch);
289 if let Self::Try(loc, expr, returns, catch) = other {
290 let right_returns =
291 returns.as_ref().map(|(params, stmt)| (filter_params(params), stmt));
292 let right = (loc, expr, right_returns, catch);
293 left.ast_eq(&right)
294 } else {
295 false
296 }
297 }
298 _ => gen_ast_eq_enum!(self, other, Statement {
299 _
300 Args(loc, args),
301 Expression(loc, expr),
302 VariableDefinition(loc, decl, expr),
303 Continue(loc, ),
304 Break(loc, ),
305 Return(loc, expr),
306 Revert(loc, expr, expr2),
307 RevertNamedArgs(loc, expr, args),
308 Emit(loc, expr),
309 If(loc, expr, stmt1, stmt2),
311 While(loc, expr, stmt1),
312 DoWhile(loc, stmt1, expr),
313 For(loc, stmt1, expr, stmt2, stmt3),
314 Try(loc, expr, params, clause),
315 Error(loc)
316 _
317 Block {
318 loc,
319 unchecked,
320 statements,
321 },
322 Assembly {
323 loc,
324 dialect,
325 block,
326 flags,
327 },
328 }),
329 }
330 }
331}
332
333macro_rules! derive_ast_eq {
334 ($name:ident) => {
335 impl AstEq for $name {
336 fn ast_eq(&self, other: &Self) -> bool {
337 self == other
338 }
339 }
340 };
341 (($($index:tt $gen:tt),*)) => {
342 impl < $( $gen ),* > AstEq for ($($gen,)*) where $($gen: AstEq),* {
343 fn ast_eq(&self, other: &Self) -> bool {
344 $(
345 if !self.$index.ast_eq(&other.$index) {
346 return false
347 }
348 )*
349 true
350 }
351 }
352 };
353 (struct $name:ident { $($field:ident),* $(,)? }) => {
354 impl AstEq for $name {
355 fn ast_eq(&self, other: &Self) -> bool {
356 let $name { $($field),* } = self;
357 let left = ($($field),*);
358 let $name { $($field),* } = other;
359 let right = ($($field),*);
360 left.ast_eq(&right)
361 }
362 }
363 };
364 (enum $name:ident {
365 $($unit_variant:ident),* $(,)?
366 _
367 $($tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? )),* $(,)?
368 _
369 $($struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? }),* $(,)?
370 }) => {
371 impl AstEq for $name {
372 fn ast_eq(&self, other: &Self) -> bool {
373 gen_ast_eq_enum!(self, other, $name {
374 $($unit_variant),*
375 _
376 $($tuple_variant ( $($(#[ast_eq_use($tuple_convert_func)])? $tuple_field),* )),*
377 _
378 $($struct_variant { $($(#[ast_eq_use($struct_convert_func)])? $struct_field),* }),*
379 })
380 }
381 }
382 }
383}
384
385derive_ast_eq! { (0 A) }
386derive_ast_eq! { (0 A, 1 B) }
387derive_ast_eq! { (0 A, 1 B, 2 C) }
388derive_ast_eq! { (0 A, 1 B, 2 C, 3 D) }
389derive_ast_eq! { (0 A, 1 B, 2 C, 3 D, 4 E) }
390derive_ast_eq! { (0 A, 1 B, 2 C, 3 D, 4 E, 5 F) }
391derive_ast_eq! { (0 A, 1 B, 2 C, 3 D, 4 E, 5 F, 6 G) }
392derive_ast_eq! { bool }
393derive_ast_eq! { u8 }
394derive_ast_eq! { u16 }
395derive_ast_eq! { I256 }
396derive_ast_eq! { U256 }
397derive_ast_eq! { struct Identifier { loc, name } }
398derive_ast_eq! { struct HexLiteral { loc, hex } }
399derive_ast_eq! { struct StringLiteral { loc, unicode, string } }
400derive_ast_eq! { struct Parameter { loc, annotation, ty, storage, name } }
401derive_ast_eq! { struct NamedArgument { loc, name, expr } }
402derive_ast_eq! { struct YulBlock { loc, statements } }
403derive_ast_eq! { struct YulFunctionCall { loc, id, arguments } }
404derive_ast_eq! { struct YulFunctionDefinition { loc, id, params, returns, body } }
405derive_ast_eq! { struct YulSwitch { loc, condition, cases, default } }
406derive_ast_eq! { struct YulFor {
407 loc,
408 init_block,
409 condition,
410 post_block,
411 execution_block,
412}}
413derive_ast_eq! { struct YulTypedIdentifier { loc, id, ty } }
414derive_ast_eq! { struct VariableDeclaration { loc, ty, storage, name } }
415derive_ast_eq! { struct Using { loc, list, ty, global } }
416derive_ast_eq! { struct UsingFunction { loc, path, oper } }
417derive_ast_eq! { struct TypeDefinition { loc, name, ty } }
418derive_ast_eq! { struct ContractDefinition { loc, ty, name, base, layout, parts } }
419derive_ast_eq! { struct EventParameter { loc, ty, indexed, name } }
420derive_ast_eq! { struct ErrorParameter { loc, ty, name } }
421derive_ast_eq! { struct EventDefinition { loc, name, fields, anonymous } }
422derive_ast_eq! { struct ErrorDefinition { loc, keyword, name, fields } }
423derive_ast_eq! { struct StructDefinition { loc, name, fields } }
424derive_ast_eq! { struct EnumDefinition { loc, name, values } }
425derive_ast_eq! { struct Annotation { loc, id, value } }
426derive_ast_eq! { enum PragmaDirective {
427 _
428 Identifier(loc, id1, id2),
429 StringLiteral(loc, id, lit),
430 Version(loc, id, version),
431 _
432}}
433derive_ast_eq! { enum UsingList {
434 Error,
435 _
436 Library(expr),
437 Functions(exprs),
438 _
439}}
440derive_ast_eq! { enum UserDefinedOperator {
441 BitwiseAnd,
442 BitwiseNot,
443 Negate,
444 BitwiseOr,
445 BitwiseXor,
446 Add,
447 Divide,
448 Modulo,
449 Multiply,
450 Subtract,
451 Equal,
452 More,
453 MoreEqual,
454 Less,
455 LessEqual,
456 NotEqual,
457 _
458 _
459}}
460derive_ast_eq! { enum Visibility {
461 _
462 External(loc),
463 Public(loc),
464 Internal(loc),
465 Private(loc),
466 _
467}}
468derive_ast_eq! { enum Mutability {
469 _
470 Pure(loc),
471 View(loc),
472 Constant(loc),
473 Payable(loc),
474 _
475}}
476derive_ast_eq! { enum FunctionAttribute {
477 _
478 Mutability(muta),
479 Visibility(visi),
480 Virtual(loc),
481 Immutable(loc),
482 Override(loc, idents),
483 BaseOrModifier(loc, base),
484 Error(loc),
485 _
486}}
487derive_ast_eq! { enum StorageLocation {
488 _
489 Memory(loc),
490 Storage(loc),
491 Calldata(loc),
492 Transient(loc),
493 _
494}}
495derive_ast_eq! { enum Type {
496 Address,
497 AddressPayable,
498 Payable,
499 Bool,
500 Rational,
501 DynamicBytes,
502 String,
503 _
504 Int(int),
505 Uint(int),
506 Bytes(int),
507 _
508 Mapping{ loc, key, key_name, value, value_name },
509 Function { params, attributes, returns },
510}}
511derive_ast_eq! { enum Expression {
512 _
513 PostIncrement(loc, expr1),
514 PostDecrement(loc, expr1),
515 New(loc, expr1),
516 ArraySubscript(loc, expr1, expr2),
517 ArraySlice(
518 loc,
519 expr1,
520 expr2,
521 expr3,
522 ),
523 MemberAccess(loc, expr1, ident1),
524 FunctionCall(loc, expr1, exprs1),
525 FunctionCallBlock(loc, expr1, stmt),
526 NamedFunctionCall(loc, expr1, args),
527 Not(loc, expr1),
528 BitwiseNot(loc, expr1),
529 Delete(loc, expr1),
530 PreIncrement(loc, expr1),
531 PreDecrement(loc, expr1),
532 UnaryPlus(loc, expr1),
533 Negate(loc, expr1),
534 Power(loc, expr1, expr2),
535 Multiply(loc, expr1, expr2),
536 Divide(loc, expr1, expr2),
537 Modulo(loc, expr1, expr2),
538 Add(loc, expr1, expr2),
539 Subtract(loc, expr1, expr2),
540 ShiftLeft(loc, expr1, expr2),
541 ShiftRight(loc, expr1, expr2),
542 BitwiseAnd(loc, expr1, expr2),
543 BitwiseXor(loc, expr1, expr2),
544 BitwiseOr(loc, expr1, expr2),
545 Less(loc, expr1, expr2),
546 More(loc, expr1, expr2),
547 LessEqual(loc, expr1, expr2),
548 MoreEqual(loc, expr1, expr2),
549 Equal(loc, expr1, expr2),
550 NotEqual(loc, expr1, expr2),
551 And(loc, expr1, expr2),
552 Or(loc, expr1, expr2),
553 ConditionalOperator(loc, expr1, expr2, expr3),
554 Assign(loc, expr1, expr2),
555 AssignOr(loc, expr1, expr2),
556 AssignAnd(loc, expr1, expr2),
557 AssignXor(loc, expr1, expr2),
558 AssignShiftLeft(loc, expr1, expr2),
559 AssignShiftRight(loc, expr1, expr2),
560 AssignAdd(loc, expr1, expr2),
561 AssignSubtract(loc, expr1, expr2),
562 AssignMultiply(loc, expr1, expr2),
563 AssignDivide(loc, expr1, expr2),
564 AssignModulo(loc, expr1, expr2),
565 BoolLiteral(loc, bool1),
566 NumberLiteral(loc, #[ast_eq_use(to_num)] str1, #[ast_eq_use(to_num)] str2, unit),
567 RationalNumberLiteral(
568 loc,
569 #[ast_eq_use(to_num)] str1,
570 #[ast_eq_use(to_num_reversed)] str2,
571 #[ast_eq_use(to_num)] str3,
572 unit
573 ),
574 HexNumberLiteral(loc, str1, unit),
575 StringLiteral(strs1),
576 Type(loc, ty1),
577 HexLiteral(hexs1),
578 AddressLiteral(loc, str1),
579 Variable(ident1),
580 List(loc, params1),
581 ArrayLiteral(loc, exprs1),
582 Parenthesis(loc, expr)
583 _
584}}
585derive_ast_eq! { enum CatchClause {
586 _
587 Simple(param, ident, stmt),
588 Named(loc, ident, param, stmt),
589 _
590}}
591derive_ast_eq! { enum YulStatement {
592 _
593 Assign(loc, exprs, expr),
594 VariableDeclaration(loc, idents, expr),
595 If(loc, expr, block),
596 For(yul_for),
597 Switch(switch),
598 Leave(loc),
599 Break(loc),
600 Continue(loc),
601 Block(block),
602 FunctionDefinition(def),
603 FunctionCall(func),
604 Error(loc),
605 _
606}}
607derive_ast_eq! { enum YulExpression {
608 _
609 BoolLiteral(loc, boo, ident),
610 NumberLiteral(loc, string1, string2, ident),
611 HexNumberLiteral(loc, string, ident),
612 HexStringLiteral(hex, ident),
613 StringLiteral(string, ident),
614 Variable(ident),
615 FunctionCall(func),
616 SuffixAccess(loc, expr, ident),
617 _
618}}
619derive_ast_eq! { enum YulSwitchOptions {
620 _
621 Case(loc, expr, block),
622 Default(loc, block),
623 _
624}}
625derive_ast_eq! { enum SourceUnitPart {
626 _
627 ContractDefinition(def),
628 PragmaDirective(pragma),
629 ImportDirective(import),
630 EnumDefinition(def),
631 StructDefinition(def),
632 EventDefinition(def),
633 ErrorDefinition(def),
634 FunctionDefinition(def),
635 VariableDefinition(def),
636 TypeDefinition(def),
637 Using(using),
638 StraySemicolon(loc),
639 Annotation(annotation),
640 _
641}}
642derive_ast_eq! { enum ImportPath {
643 _
644 Filename(lit),
645 Path(path),
646 _
647}}
648derive_ast_eq! { enum Import {
649 _
650 Plain(string, loc),
651 GlobalSymbol(string, ident, loc),
652 Rename(string, idents, loc),
653 _
654}}
655derive_ast_eq! { enum FunctionTy {
656 Constructor,
657 Function,
658 Fallback,
659 Receive,
660 Modifier,
661 _
662 _
663}}
664derive_ast_eq! { enum ContractPart {
665 _
666 StructDefinition(def),
667 EventDefinition(def),
668 EnumDefinition(def),
669 ErrorDefinition(def),
670 VariableDefinition(def),
671 FunctionDefinition(def),
672 TypeDefinition(def),
673 StraySemicolon(loc),
674 Using(using),
675 Annotation(annotation),
676 _
677}}
678derive_ast_eq! { enum ContractTy {
679 _
680 Abstract(loc),
681 Contract(loc),
682 Interface(loc),
683 Library(loc),
684 _
685}}
686derive_ast_eq! { enum VariableAttribute {
687 _
688 Visibility(visi),
689 Constant(loc),
690 Immutable(loc),
691 Override(loc, idents),
692 StorageType(st),
693 StorageLocation(st),
694 _
695}}
696
697impl AstEq for StorageType {
699 fn ast_eq(&self, _other: &Self) -> bool {
700 true
701 }
702}
703
704impl AstEq for VersionComparator {
705 fn ast_eq(&self, _other: &Self) -> bool {
706 true
707 }
708}