|
| 1 | +package scala.runtime.quoted |
| 2 | + |
| 3 | +import scala.quoted._ |
| 4 | +import scala.tasty._ |
| 5 | + |
| 6 | +object Matcher { |
| 7 | + |
| 8 | + type Hole[T /* <: AnyKind */] = T |
| 9 | + |
| 10 | + def hole[T]: T = ??? |
| 11 | + |
| 12 | + def unapplySeq(scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Seq[Any]] = { |
| 13 | + import reflection._ |
| 14 | + |
| 15 | + def treeMatches(scrutinee: Tree, pattern: Tree): Option[Seq[Any]] = { |
| 16 | + import Term._ |
| 17 | + (scrutinee, pattern) match { |
| 18 | + // Normalize blocks without statements |
| 19 | + case (Block(Nil, expr), _) => treeMatches(expr, pattern) |
| 20 | + case (_, Block(Nil, pat)) => treeMatches(scrutinee, pat) |
| 21 | + |
| 22 | + case (IsTerm(scrutinee), TypeApply(Ident("hole"), tpt :: Nil)) |
| 23 | + if pattern.symbol.fullName == "scala.runtime.quoted.Matcher$.hole" && // TODO check symbol equality instead of its name |
| 24 | + scrutinee.tpe <:< tpt.tpe => |
| 25 | + Some(Seq(scrutinee)) |
| 26 | + |
| 27 | + case (Inlined(None, Nil, scr), _) => |
| 28 | + treeMatches(scr, pattern) |
| 29 | + case (_, Inlined(None, Nil, pat)) => |
| 30 | + treeMatches(scrutinee, pat) |
| 31 | + |
| 32 | + case (Literal(constant1), Literal(constant2)) if constant1 == constant2 => |
| 33 | + Some(Seq.empty) |
| 34 | + |
| 35 | + case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol => |
| 36 | + Some(Seq.empty) |
| 37 | + |
| 38 | + case (Typed(expr1, tpt1), Typed(expr2, tpt2)) => |
| 39 | + foldMatchings(treeMatches(expr1, expr2) :: typeTreeMatches(tpt1, tpt2) :: Nil) |
| 40 | + |
| 41 | + case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol => |
| 42 | + treeMatches(qual1, qual2) |
| 43 | + |
| 44 | + case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol => |
| 45 | + foldMatchings(treeMatches(fn1, fn2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield treeMatches(arg1, arg2))) |
| 46 | + |
| 47 | + case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol => |
| 48 | + foldMatchings(treeMatches(fn1, fn2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield typeTreeMatches(arg1, arg2))) |
| 49 | + |
| 50 | + case (Block(stats1, expr1), Block(stats2, expr2)) => |
| 51 | + // TODO handle bindings |
| 52 | + foldMatchings((for ((stat1, stat2) <- stats1.zip(stats2)) yield treeMatches(stat1, stat2)) ::: treeMatches(expr1, expr2) :: Nil) |
| 53 | + |
| 54 | + case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) => |
| 55 | + foldMatchings(treeMatches(cond1, cond2) :: treeMatches(thenp1, thenp2) :: treeMatches(elsep1, elsep2) :: Nil) |
| 56 | + |
| 57 | + case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) => |
| 58 | + // TODO how to handle LHS? |
| 59 | + if (treeMatches(lhs1, lhs2).isDefined) treeMatches(rhs1, rhs2) |
| 60 | + else None |
| 61 | + |
| 62 | + case (While(cond1, body1), While(cond2, body2)) => |
| 63 | + foldMatchings(treeMatches(cond1, cond2) :: treeMatches(body1, body2) :: Nil) |
| 64 | + |
| 65 | + case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 => |
| 66 | + treeMatches(expr1, expr2) |
| 67 | + |
| 68 | + case (New(tpt1), New(tpt2)) => |
| 69 | + typeTreeMatches(tpt1, tpt2) |
| 70 | + |
| 71 | + case (This(_), This(_)) if scrutinee.symbol == pattern.symbol => |
| 72 | + Some(Seq.empty) |
| 73 | + |
| 74 | + case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size => |
| 75 | + foldMatchings(for ((elem1, elem2) <- elems1.zip(elems2)) yield treeMatches(elem1, elem2)) |
| 76 | + |
| 77 | + case _ => |
| 78 | +// println( |
| 79 | +// s"""Scrutinee ${scrutinee.showCode} |
| 80 | +// |${scrutinee.show} |
| 81 | +// |did not match pattern ${pattern.showCode} |
| 82 | +// |${pattern.show} |
| 83 | +// | |
| 84 | +// |""".stripMargin) |
| 85 | + None |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + def typeTreeMatches(scrutinee: TypeOrBoundsTree, pattern: TypeOrBoundsTree): Option[Seq[Any]] = { |
| 90 | + import TypeTree._ |
| 91 | + (scrutinee, pattern) match { |
| 92 | + case (IsTypeTree(scrutinee), IsTypeTree(pattern @ Applied(Ident("Hole"), IsTypeTree(tpt) :: Nil))) |
| 93 | + if pattern.symbol.fullName == "scala.runtime.quoted.Matcher$.Hole" && // TODO check symbol equality instead of its name |
| 94 | + scrutinee.tpe <:< tpt.tpe => // Is the subtype check required? |
| 95 | + Some(Seq(scrutinee)) |
| 96 | + |
| 97 | + case (IsTypeTree(scrutinee @ Ident(_)), IsTypeTree(pattern @ Ident(_))) if scrutinee.symbol == pattern.symbol => |
| 98 | + Some(Seq.empty) |
| 99 | + |
| 100 | + case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe => |
| 101 | + Some(Seq.empty) |
| 102 | + |
| 103 | + case (Applied(tycon1, args1), Applied(tycon2, args2)) => |
| 104 | + val matchings: List[Option[Seq[Any]]] = |
| 105 | + typeTreeMatches(tycon1, tycon2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield typeTreeMatches(arg1, arg2)) |
| 106 | + foldMatchings(matchings) |
| 107 | + |
| 108 | + case _ => |
| 109 | +// println( |
| 110 | +// s"""Scrutinee ${scrutinee.showCode} |
| 111 | +// |${scrutinee.show} |
| 112 | +// |did not match pattern ${pattern.showCode} |
| 113 | +// |${pattern.show} |
| 114 | +// | |
| 115 | +// |""".stripMargin) |
| 116 | + None |
| 117 | + } |
| 118 | + } |
| 119 | + |
| 120 | + treeMatches(scrutineeExpr.unseal, patternExpr.unseal) |
| 121 | + } |
| 122 | + |
| 123 | + private def foldMatchings(matchings: List[Option[Seq[Any]]]): Option[Seq[Any]] = { |
| 124 | + matchings.foldLeft[Option[Seq[Any]]](Some(Seq.empty)) { |
| 125 | + case (Some(acc), Some(holes)) => Some(acc ++ holes) |
| 126 | + case (_, _) => None |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | +} |
0 commit comments