1use alloy_primitives::{Address, I256, U256};
2use solang_parser::pt::*;
3use std::str::FromStr;
4
5fn to_num(string: &str) -> I256 {
7 if string.is_empty() {
8 return I256::ZERO
9 }
10 string.replace('_', "").trim().parse().unwrap()
11}
12
13fn to_num_reversed(string: &str) -> U256 {
16 if string.is_empty() {
17 return U256::from(0)
18 }
19 string.replace('_', "").trim().chars().rev().collect::<String>().parse().unwrap()
20}
21
22fn filter_params(list: &ParameterList) -> ParameterList {
25 list.iter().filter(|(_, param)| param.is_some()).cloned().collect::<Vec<_>>()
26}
27
28pub trait AstEq {
31 fn ast_eq(&self, other: &Self) -> bool;
32}
33
34impl AstEq for Loc {
35 fn ast_eq(&self, _other: &Self) -> bool {
36 true
37 }
38}
39
40impl AstEq for IdentifierPath {
41 fn ast_eq(&self, other: &Self) -> bool {
42 self.identifiers.ast_eq(&other.identifiers)
43 }
44}
45
46impl AstEq for SourceUnit {
47 fn ast_eq(&self, other: &Self) -> bool {
48 self.0.ast_eq(&other.0)
49 }
50}
51
52impl AstEq for VariableDefinition {
53 fn ast_eq(&self, other: &Self) -> bool {
54 let sorted_attrs = |def: &Self| {
55 let mut attrs = def.attrs.clone();
56 attrs.sort();
57 attrs
58 };
59 self.ty.ast_eq(&other.ty) &&
60 self.name.ast_eq(&other.name) &&
61 self.initializer.ast_eq(&other.initializer) &&
62 sorted_attrs(self).ast_eq(&sorted_attrs(other))
63 }
64}
65
66impl AstEq for FunctionDefinition {
67 fn ast_eq(&self, other: &Self) -> bool {
68 let sorted_attrs = |def: &Self| {
70 let mut attrs = def.attributes.clone();
71 attrs.sort();
72 attrs
73 };
74
75 let left_params = filter_params(&self.params);
77 let right_params = filter_params(&other.params);
78 let left_returns = filter_params(&self.returns);
79 let right_returns = filter_params(&other.returns);
80
81 self.ty.ast_eq(&other.ty) &&
82 self.name.ast_eq(&other.name) &&
83 left_params.ast_eq(&right_params) &&
84 self.return_not_returns.ast_eq(&other.return_not_returns) &&
85 left_returns.ast_eq(&right_returns) &&
86 self.body.ast_eq(&other.body) &&
87 sorted_attrs(self).ast_eq(&sorted_attrs(other))
88 }
89}
90
91impl AstEq for Base {
92 fn ast_eq(&self, other: &Self) -> bool {
93 self.name.ast_eq(&other.name) &&
94 self.args.clone().unwrap_or_default().ast_eq(&other.args.clone().unwrap_or_default())
95 }
96}
97
98impl<T> AstEq for Vec<T>
99where
100 T: AstEq,
101{
102 fn ast_eq(&self, other: &Self) -> bool {
103 if self.len() != other.len() {
104 false
105 } else {
106 self.iter().zip(other.iter()).all(|(left, right)| left.ast_eq(right))
107 }
108 }
109}
110
111impl<T> AstEq for Option<T>
112where
113 T: AstEq,
114{
115 fn ast_eq(&self, other: &Self) -> bool {
116 match (self, other) {
117 (Some(left), Some(right)) => left.ast_eq(right),
118 (None, None) => true,
119 _ => false,
120 }
121 }
122}
123
124impl<T> AstEq for Box<T>
125where
126 T: AstEq,
127{
128 fn ast_eq(&self, other: &Self) -> bool {
129 T::ast_eq(self, other)
130 }
131}
132
133impl AstEq for () {
134 fn ast_eq(&self, _other: &Self) -> bool {
135 true
136 }
137}
138
139impl<T> AstEq for &T
140where
141 T: AstEq,
142{
143 fn ast_eq(&self, other: &Self) -> bool {
144 T::ast_eq(self, other)
145 }
146}
147
148impl AstEq for String {
149 fn ast_eq(&self, other: &Self) -> bool {
150 match (Address::from_str(self), Address::from_str(other)) {
151 (Ok(left), Ok(right)) => left == right,
152 _ => self == other,
153 }
154 }
155}
156
157macro_rules! ast_eq_field {
158 (#[ast_eq_use($convert_func:ident)] $field:ident) => {
159 $convert_func($field)
160 };
161 ($field:ident) => {
162 $field
163 };
164}
165
166macro_rules! gen_ast_eq_enum {
167 ($self:expr, $other:expr, $name:ident {
168 $($unit_variant:ident),* $(,)?
169 _
170 $($tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? )),* $(,)?
171 _
172 $($struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? }),* $(,)?
173 }) => {
174 match $self {
175 $($name::$unit_variant => gen_ast_eq_enum!($other, $name, $unit_variant),)*
176 $($name::$tuple_variant($($tuple_field),*) =>
177 gen_ast_eq_enum!($other, $name, $tuple_variant ($($(#[ast_eq_use($tuple_convert_func)])? $tuple_field),*)),)*
178 $($name::$struct_variant { $($struct_field),* } =>
179 gen_ast_eq_enum!($other, $name, $struct_variant {$($(#[ast_eq_use($struct_convert_func)])? $struct_field),*}),)*
180 }
181 };
182 ($other:expr, $name:ident, $unit_variant:ident) => {
183 {
184 matches!($other, $name::$unit_variant)
185 }
186 };
187 ($other:expr, $name:ident, $tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? ) ) => {
188 {
189 let left = ($(ast_eq_field!($(#[ast_eq_use($tuple_convert_func)])? $tuple_field)),*);
190 if let $name::$tuple_variant($($tuple_field),*) = $other {
191 let right = ($(ast_eq_field!($(#[ast_eq_use($tuple_convert_func)])? $tuple_field)),*);
192 left.ast_eq(&right)
193 } else {
194 false
195 }
196 }
197 };
198 ($other:expr, $name:ident, $struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? } ) => {
199 {
200 let left = ($(ast_eq_field!($(#[ast_eq_use($struct_convert_func)])? $struct_field)),*);
201 if let $name::$struct_variant { $($struct_field),* } = $other {
202 let right = ($(ast_eq_field!($(#[ast_eq_use($struct_convert_func)])? $struct_field)),*);
203 left.ast_eq(&right)
204 } else {
205 false
206 }
207 }
208 };
209}
210
211macro_rules! wrap_in_box {
212 ($stmt:expr, $loc:expr) => {
213 if !matches!(**$stmt, Statement::Block { .. }) {
214 Box::new(Statement::Block {
215 loc: $loc,
216 unchecked: false,
217 statements: vec![*$stmt.clone()],
218 })
219 } else {
220 $stmt.clone()
221 }
222 };
223}
224
225impl AstEq for Statement {
226 fn ast_eq(&self, other: &Self) -> bool {
227 match self {
228 Self::If(loc, expr, stmt1, stmt2) => {
229 #[expect(clippy::borrowed_box)]
230 let wrap_if = |stmt1: &Box<Self>, stmt2: &Option<Box<Self>>| {
231 (
232 wrap_in_box!(stmt1, *loc),
233 stmt2.as_ref().map(|stmt2| {
234 if matches!(**stmt2, Self::If(..)) {
235 stmt2.clone()
236 } else {
237 wrap_in_box!(stmt2, *loc)
238 }
239 }),
240 )
241 };
242 let (stmt1, stmt2) = wrap_if(stmt1, stmt2);
243 let left = (loc, expr, &stmt1, &stmt2);
244 if let Self::If(loc, expr, stmt1, stmt2) = other {
245 let (stmt1, stmt2) = wrap_if(stmt1, stmt2);
246 let right = (loc, expr, &stmt1, &stmt2);
247 left.ast_eq(&right)
248 } else {
249 false
250 }
251 }
252 Self::While(loc, expr, stmt1) => {
253 let stmt1 = wrap_in_box!(stmt1, *loc);
254 let left = (loc, expr, &stmt1);
255 if let Self::While(loc, expr, stmt1) = other {
256 let stmt1 = wrap_in_box!(stmt1, *loc);
257 let right = (loc, expr, &stmt1);
258 left.ast_eq(&right)
259 } else {
260 false
261 }
262 }
263 Self::DoWhile(loc, stmt1, expr) => {
264 let stmt1 = wrap_in_box!(stmt1, *loc);
265 let left = (loc, &stmt1, expr);
266 if let Self::DoWhile(loc, stmt1, expr) = other {
267 let stmt1 = wrap_in_box!(stmt1, *loc);
268 let right = (loc, &stmt1, expr);
269 left.ast_eq(&right)
270 } else {
271 false
272 }
273 }
274 Self::For(loc, stmt1, expr, stmt2, stmt3) => {
275 let stmt3 = stmt3.as_ref().map(|stmt3| wrap_in_box!(stmt3, *loc));
276 let left = (loc, stmt1, expr, stmt2, &stmt3);
277 if let Self::For(loc, stmt1, expr, stmt2, stmt3) = other {
278 let stmt3 = stmt3.as_ref().map(|stmt3| wrap_in_box!(stmt3, *loc));
279 let right = (loc, stmt1, expr, stmt2, &stmt3);
280 left.ast_eq(&right)
281 } else {
282 false
283 }
284 }
285 Self::Try(loc, expr, returns, catch) => {
286 let left_returns =
287 returns.as_ref().map(|(params, stmt)| (filter_params(params), stmt));
288 let left = (loc, expr, left_returns, catch);
289 if let Self::Try(loc, expr, returns, catch) = other {
290 let right_returns =
291 returns.as_ref().map(|(params, stmt)| (filter_params(params), stmt));
292 let right = (loc, expr, right_returns, catch);
293 left.ast_eq(&right)
294 } else {
295 false
296 }
297 }
298 _ => gen_ast_eq_enum!(self, other, Statement {
299 _
300 Args(loc, args),
301 Expression(loc, expr),
302 VariableDefinition(loc, decl, expr),
303 Continue(loc, ),
304 Break(loc, ),
305 Return(loc, expr),
306 Revert(loc, expr, expr2),
307 RevertNamedArgs(loc, expr, args),
308 Emit(loc, expr),
309 If(loc, expr, stmt1, stmt2),
311 While(loc, expr, stmt1),
312 DoWhile(loc, stmt1, expr),
313 For(loc, stmt1, expr, stmt2, stmt3),
314 Try(loc, expr, params, clause),
315 Error(loc)
316 _
317 Block {
318 loc,
319 unchecked,
320 statements,
321 },
322 Assembly {
323 loc,
324 dialect,
325 block,
326 flags,
327 },
328 }),
329 }
330 }
331}
332
333macro_rules! derive_ast_eq {
334 ($name:ident) => {
335 impl AstEq for $name {
336 fn ast_eq(&self, other: &Self) -> bool {
337 self == other
338 }
339 }
340 };
341 (($($index:tt $gen:tt),*)) => {
342 impl < $( $gen ),* > AstEq for ($($gen,)*) where $($gen: AstEq),* {
343 fn ast_eq(&self, other: &Self) -> bool {
344 $(
345 if !self.$index.ast_eq(&other.$index) {
346 return false
347 }
348 )*
349 true
350 }
351 }
352 };
353 (struct $name:ident { $($field:ident),* $(,)? }) => {
354 impl AstEq for $name {
355 fn ast_eq(&self, other: &Self) -> bool {
356 let $name { $($field),* } = self;
357 let left = ($($field),*);
358 let $name { $($field),* } = other;
359 let right = ($($field),*);
360 left.ast_eq(&right)
361 }
362 }
363 };
364 (enum $name:ident {
365 $($unit_variant:ident),* $(,)?
366 _
367 $($tuple_variant:ident ( $($(#[ast_eq_use($tuple_convert_func:ident)])? $tuple_field:ident),* $(,)? )),* $(,)?
368 _
369 $($struct_variant:ident { $($(#[ast_eq_use($struct_convert_func:ident)])? $struct_field:ident),* $(,)? }),* $(,)?
370 }) => {
371 impl AstEq for $name {
372 fn ast_eq(&self, other: &Self) -> bool {
373 gen_ast_eq_enum!(self, other, $name {
374 $($unit_variant),*
375 _
376 $($tuple_variant ( $($(#[ast_eq_use($tuple_convert_func)])? $tuple_field),* )),*
377 _
378 $($struct_variant { $($(#[ast_eq_use($struct_convert_func)])? $struct_field),* }),*
379 })
380 }
381 }
382 }
383}
384
385derive_ast_eq! { (0 A) }
386derive_ast_eq! { (0 A, 1 B) }
387derive_ast_eq! { (0 A, 1 B, 2 C) }
388derive_ast_eq! { (0 A, 1 B, 2 C, 3 D) }
389derive_ast_eq! { (0 A, 1 B, 2 C, 3 D, 4 E) }
390derive_ast_eq! { bool }
391derive_ast_eq! { u8 }
392derive_ast_eq! { u16 }
393derive_ast_eq! { I256 }
394derive_ast_eq! { U256 }
395derive_ast_eq! { struct Identifier { loc, name } }
396derive_ast_eq! { struct HexLiteral { loc, hex } }
397derive_ast_eq! { struct StringLiteral { loc, unicode, string } }
398derive_ast_eq! { struct Parameter { loc, annotation, ty, storage, name } }
399derive_ast_eq! { struct NamedArgument { loc, name, expr } }
400derive_ast_eq! { struct YulBlock { loc, statements } }
401derive_ast_eq! { struct YulFunctionCall { loc, id, arguments } }
402derive_ast_eq! { struct YulFunctionDefinition { loc, id, params, returns, body } }
403derive_ast_eq! { struct YulSwitch { loc, condition, cases, default } }
404derive_ast_eq! { struct YulFor {
405 loc,
406 init_block,
407 condition,
408 post_block,
409 execution_block,
410}}
411derive_ast_eq! { struct YulTypedIdentifier { loc, id, ty } }
412derive_ast_eq! { struct VariableDeclaration { loc, ty, storage, name } }
413derive_ast_eq! { struct Using { loc, list, ty, global } }
414derive_ast_eq! { struct UsingFunction { loc, path, oper } }
415derive_ast_eq! { struct TypeDefinition { loc, name, ty } }
416derive_ast_eq! { struct ContractDefinition { loc, ty, name, base, parts } }
417derive_ast_eq! { struct EventParameter { loc, ty, indexed, name } }
418derive_ast_eq! { struct ErrorParameter { loc, ty, name } }
419derive_ast_eq! { struct EventDefinition { loc, name, fields, anonymous } }
420derive_ast_eq! { struct ErrorDefinition { loc, keyword, name, fields } }
421derive_ast_eq! { struct StructDefinition { loc, name, fields } }
422derive_ast_eq! { struct EnumDefinition { loc, name, values } }
423derive_ast_eq! { struct Annotation { loc, id, value } }
424derive_ast_eq! { enum UsingList {
425 Error,
426 _
427 Library(expr),
428 Functions(exprs),
429 _
430}}
431derive_ast_eq! { enum UserDefinedOperator {
432 BitwiseAnd,
433 BitwiseNot,
434 Negate,
435 BitwiseOr,
436 BitwiseXor,
437 Add,
438 Divide,
439 Modulo,
440 Multiply,
441 Subtract,
442 Equal,
443 More,
444 MoreEqual,
445 Less,
446 LessEqual,
447 NotEqual,
448 _
449 _
450}}
451derive_ast_eq! { enum Visibility {
452 _
453 External(loc),
454 Public(loc),
455 Internal(loc),
456 Private(loc),
457 _
458}}
459derive_ast_eq! { enum Mutability {
460 _
461 Pure(loc),
462 View(loc),
463 Constant(loc),
464 Payable(loc),
465 _
466}}
467derive_ast_eq! { enum FunctionAttribute {
468 _
469 Mutability(muta),
470 Visibility(visi),
471 Virtual(loc),
472 Immutable(loc),
473 Override(loc, idents),
474 BaseOrModifier(loc, base),
475 Error(loc),
476 _
477}}
478derive_ast_eq! { enum StorageLocation {
479 _
480 Memory(loc),
481 Storage(loc),
482 Calldata(loc),
483 _
484}}
485derive_ast_eq! { enum Type {
486 Address,
487 AddressPayable,
488 Payable,
489 Bool,
490 Rational,
491 DynamicBytes,
492 String,
493 _
494 Int(int),
495 Uint(int),
496 Bytes(int),
497 _
498 Mapping{ loc, key, key_name, value, value_name },
499 Function { params, attributes, returns },
500}}
501derive_ast_eq! { enum Expression {
502 _
503 PostIncrement(loc, expr1),
504 PostDecrement(loc, expr1),
505 New(loc, expr1),
506 ArraySubscript(loc, expr1, expr2),
507 ArraySlice(
508 loc,
509 expr1,
510 expr2,
511 expr3,
512 ),
513 MemberAccess(loc, expr1, ident1),
514 FunctionCall(loc, expr1, exprs1),
515 FunctionCallBlock(loc, expr1, stmt),
516 NamedFunctionCall(loc, expr1, args),
517 Not(loc, expr1),
518 BitwiseNot(loc, expr1),
519 Delete(loc, expr1),
520 PreIncrement(loc, expr1),
521 PreDecrement(loc, expr1),
522 UnaryPlus(loc, expr1),
523 Negate(loc, expr1),
524 Power(loc, expr1, expr2),
525 Multiply(loc, expr1, expr2),
526 Divide(loc, expr1, expr2),
527 Modulo(loc, expr1, expr2),
528 Add(loc, expr1, expr2),
529 Subtract(loc, expr1, expr2),
530 ShiftLeft(loc, expr1, expr2),
531 ShiftRight(loc, expr1, expr2),
532 BitwiseAnd(loc, expr1, expr2),
533 BitwiseXor(loc, expr1, expr2),
534 BitwiseOr(loc, expr1, expr2),
535 Less(loc, expr1, expr2),
536 More(loc, expr1, expr2),
537 LessEqual(loc, expr1, expr2),
538 MoreEqual(loc, expr1, expr2),
539 Equal(loc, expr1, expr2),
540 NotEqual(loc, expr1, expr2),
541 And(loc, expr1, expr2),
542 Or(loc, expr1, expr2),
543 ConditionalOperator(loc, expr1, expr2, expr3),
544 Assign(loc, expr1, expr2),
545 AssignOr(loc, expr1, expr2),
546 AssignAnd(loc, expr1, expr2),
547 AssignXor(loc, expr1, expr2),
548 AssignShiftLeft(loc, expr1, expr2),
549 AssignShiftRight(loc, expr1, expr2),
550 AssignAdd(loc, expr1, expr2),
551 AssignSubtract(loc, expr1, expr2),
552 AssignMultiply(loc, expr1, expr2),
553 AssignDivide(loc, expr1, expr2),
554 AssignModulo(loc, expr1, expr2),
555 BoolLiteral(loc, bool1),
556 NumberLiteral(loc, #[ast_eq_use(to_num)] str1, #[ast_eq_use(to_num)] str2, unit),
557 RationalNumberLiteral(
558 loc,
559 #[ast_eq_use(to_num)] str1,
560 #[ast_eq_use(to_num_reversed)] str2,
561 #[ast_eq_use(to_num)] str3,
562 unit
563 ),
564 HexNumberLiteral(loc, str1, unit),
565 StringLiteral(strs1),
566 Type(loc, ty1),
567 HexLiteral(hexs1),
568 AddressLiteral(loc, str1),
569 Variable(ident1),
570 List(loc, params1),
571 ArrayLiteral(loc, exprs1),
572 Parenthesis(loc, expr)
573 _
574}}
575derive_ast_eq! { enum CatchClause {
576 _
577 Simple(param, ident, stmt),
578 Named(loc, ident, param, stmt),
579 _
580}}
581derive_ast_eq! { enum YulStatement {
582 _
583 Assign(loc, exprs, expr),
584 VariableDeclaration(loc, idents, expr),
585 If(loc, expr, block),
586 For(yul_for),
587 Switch(switch),
588 Leave(loc),
589 Break(loc),
590 Continue(loc),
591 Block(block),
592 FunctionDefinition(def),
593 FunctionCall(func),
594 Error(loc),
595 _
596}}
597derive_ast_eq! { enum YulExpression {
598 _
599 BoolLiteral(loc, boo, ident),
600 NumberLiteral(loc, string1, string2, ident),
601 HexNumberLiteral(loc, string, ident),
602 HexStringLiteral(hex, ident),
603 StringLiteral(string, ident),
604 Variable(ident),
605 FunctionCall(func),
606 SuffixAccess(loc, expr, ident),
607 _
608}}
609derive_ast_eq! { enum YulSwitchOptions {
610 _
611 Case(loc, expr, block),
612 Default(loc, block),
613 _
614}}
615derive_ast_eq! { enum SourceUnitPart {
616 _
617 ContractDefinition(def),
618 PragmaDirective(loc, ident, string),
619 ImportDirective(import),
620 EnumDefinition(def),
621 StructDefinition(def),
622 EventDefinition(def),
623 ErrorDefinition(def),
624 FunctionDefinition(def),
625 VariableDefinition(def),
626 TypeDefinition(def),
627 Using(using),
628 StraySemicolon(loc),
629 Annotation(annotation),
630 _
631}}
632derive_ast_eq! { enum ImportPath {
633 _
634 Filename(lit),
635 Path(path),
636 _
637}}
638derive_ast_eq! { enum Import {
639 _
640 Plain(string, loc),
641 GlobalSymbol(string, ident, loc),
642 Rename(string, idents, loc),
643 _
644}}
645derive_ast_eq! { enum FunctionTy {
646 Constructor,
647 Function,
648 Fallback,
649 Receive,
650 Modifier,
651 _
652 _
653}}
654derive_ast_eq! { enum ContractPart {
655 _
656 StructDefinition(def),
657 EventDefinition(def),
658 EnumDefinition(def),
659 ErrorDefinition(def),
660 VariableDefinition(def),
661 FunctionDefinition(def),
662 TypeDefinition(def),
663 StraySemicolon(loc),
664 Using(using),
665 Annotation(annotation),
666 _
667}}
668derive_ast_eq! { enum ContractTy {
669 _
670 Abstract(loc),
671 Contract(loc),
672 Interface(loc),
673 Library(loc),
674 _
675}}
676derive_ast_eq! { enum VariableAttribute {
677 _
678 Visibility(visi),
679 Constant(loc),
680 Immutable(loc),
681 Override(loc, idents),
682 _
683}}