|
| 1 | +package scala.runtime.quoted |
| 2 | + |
| 3 | +import scala.quoted._ |
| 4 | +import scala.quoted.matching.Binding |
| 5 | + |
| 6 | +import scala.tasty._ |
| 7 | + |
| 8 | +object Matcher { |
| 9 | + |
| 10 | + type Hole[T <: AnyKind] = T |
| 11 | + type bindHole[T] = T |
| 12 | + |
| 13 | + def hole[T]: T = ??? |
| 14 | + def seqHole[T <: Seq[_]]: T = ??? |
| 15 | + |
| 16 | + /** |
| 17 | + * |
| 18 | + * @param scrutineeExpr |
| 19 | + * @param patternExpr |
| 20 | + * @param reflection |
| 21 | + * @return None if it did not match, Some(seq) if it matched where seq contains Expr[_], Type[_], Binding[_] or Seq[Expr[_]] |
| 22 | + */ |
| 23 | + def unapplySeq(scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Seq[Any]] = { |
| 24 | + import reflection._ |
| 25 | + |
| 26 | + def treeMatches(scrutinee: Tree, pattern: Tree): Option[Seq[Any]] = { |
| 27 | + import Term._ |
| 28 | + import TypeTree.{Ident => TypeIdent, Select => TypeSelect, _} |
| 29 | + |
| 30 | + (scrutinee, pattern) match { |
| 31 | + // Normalize blocks without statements |
| 32 | + case (Block(Nil, expr), _) => treeMatches(expr, pattern) |
| 33 | + case (_, Block(Nil, pat)) => treeMatches(scrutinee, pat) |
| 34 | + |
| 35 | + case (IsTerm(scrutinee), TypeApply(Ident("hole"), tpt :: Nil)) |
| 36 | + if pattern.symbol.fullName == "scala.runtime.quoted.Matcher$.hole" && // TODO check symbol equality instead of its name |
| 37 | + scrutinee.tpe <:< tpt.tpe => |
| 38 | + Some(Seq(scrutinee.seal2)) |
| 39 | + |
| 40 | + case (IsTerm(scrutinee @ Term.Repeated(args, _)), TypeApply(Ident("seqHole"), tpt :: Nil)) |
| 41 | + if pattern.symbol.fullName == "scala.runtime.quoted.Matcher$.seqHole" && // TODO check symbol equality instead of its name |
| 42 | + scrutinee.tpe <:< tpt.tpe => |
| 43 | + Some(Seq(args.map(_.seal2).toSeq)) |
| 44 | + |
| 45 | + case (IsTypeTree(scrutinee), IsTypeTree(pattern @ TypeTree.Applied(TypeIdent("Hole"), IsTypeTree(tpt) :: Nil))) |
| 46 | + if pattern.symbol.fullName == "scala.runtime.quoted.Matcher$.Hole" && // TODO check symbol equality instead of its name |
| 47 | + scrutinee.tpe <:< tpt.tpe => // Is the subtype check required? |
| 48 | + Some(Seq(scrutinee)) |
| 49 | + |
| 50 | + case (ValDef(_, tpt1, rhs1), ValDef(_, bindHole @ TypeTree.Applied(TypeIdent("bindHole"), IsTypeTree(tpt2) :: Nil), rhs2)) |
| 51 | + if bindHole.symbol.fullName == "scala.runtime.quoted.Matcher$.bindHole" && // TODO check symbol equality instead of its name |
| 52 | + tpt1.tpe <:< tpt2.tpe => // Is the subtype check required? |
| 53 | + |
| 54 | + val sym = scrutinee.symbol |
| 55 | + val binding = new Binding(sym.name, sym) |
| 56 | + |
| 57 | + def rhsMatchings = (rhs1, rhs2) match { |
| 58 | + case (Some(b1), Some(b2)) => treeMatches(b1, b2) |
| 59 | + case (None, None) => Some(Seq.empty) |
| 60 | + case _ => None |
| 61 | + } |
| 62 | + |
| 63 | + foldMatchings(Some(Seq(binding)) :: treeMatches(tpt1, tpt2) :: rhsMatchings :: Nil) |
| 64 | + |
| 65 | + case (Inlined(_, Nil, scr), _) => |
| 66 | + treeMatches(scr, pattern) |
| 67 | + case (_, Inlined(_, Nil, pat)) => |
| 68 | + treeMatches(scrutinee, pat) |
| 69 | + |
| 70 | + case (Literal(constant1), Literal(constant2)) if constant1 == constant2 => |
| 71 | + Some(Seq.empty) |
| 72 | + |
| 73 | + case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol => |
| 74 | + Some(Seq.empty) |
| 75 | + |
| 76 | + case (Typed(expr1, tpt1), Typed(expr2, tpt2)) => |
| 77 | + foldMatchings(treeMatches(expr1, expr2) :: treeMatches(tpt1, tpt2) :: Nil) |
| 78 | + |
| 79 | + case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol => |
| 80 | + treeMatches(qual1, qual2) |
| 81 | + |
| 82 | + case (Ident(_), Select(_, _)) if scrutinee.symbol == pattern.symbol => |
| 83 | + Some(Seq.empty) |
| 84 | + |
| 85 | + case (Select(_, _), Ident(_)) if scrutinee.symbol == pattern.symbol => |
| 86 | + Some(Seq.empty) |
| 87 | + |
| 88 | + case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol => |
| 89 | + foldMatchings(treeMatches(fn1, fn2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield treeMatches(arg1, arg2))) |
| 90 | + |
| 91 | + case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol => |
| 92 | + foldMatchings(treeMatches(fn1, fn2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield treeMatches(arg1, arg2))) |
| 93 | + |
| 94 | + case (Block(stats1, expr1), Block(stats2, expr2)) => |
| 95 | + // TODO handle bindings |
| 96 | + foldMatchings((for ((stat1, stat2) <- stats1.zip(stats2)) yield treeMatches(stat1, stat2)) ::: treeMatches(expr1, expr2) :: Nil) |
| 97 | + |
| 98 | + case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) => |
| 99 | + foldMatchings(treeMatches(cond1, cond2) :: treeMatches(thenp1, thenp2) :: treeMatches(elsep1, elsep2) :: Nil) |
| 100 | + |
| 101 | + case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) => |
| 102 | + // TODO how to handle LHS? |
| 103 | + if (treeMatches(lhs1, lhs2).isDefined) treeMatches(rhs1, rhs2) |
| 104 | + else None |
| 105 | + |
| 106 | + case (While(cond1, body1), While(cond2, body2)) => |
| 107 | + foldMatchings(treeMatches(cond1, cond2) :: treeMatches(body1, body2) :: Nil) |
| 108 | + |
| 109 | + case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 => |
| 110 | + treeMatches(expr1, expr2) |
| 111 | + |
| 112 | + case (New(tpt1), New(tpt2)) => |
| 113 | + treeMatches(tpt1, tpt2) |
| 114 | + |
| 115 | + case (This(_), This(_)) if scrutinee.symbol == pattern.symbol => |
| 116 | + Some(Seq.empty) |
| 117 | + |
| 118 | + case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 => |
| 119 | + treeMatches(qual1, qual2) |
| 120 | + |
| 121 | + case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size => |
| 122 | + foldMatchings(for ((elem1, elem2) <- elems1.zip(elems2)) yield treeMatches(elem1, elem2)) |
| 123 | + |
| 124 | + case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol => |
| 125 | + Some(Seq.empty) |
| 126 | + |
| 127 | + case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe => |
| 128 | + Some(Seq.empty) |
| 129 | + |
| 130 | + case (Applied(tycon1, args1), Applied(tycon2, args2)) => |
| 131 | + val matchings: List[Option[Seq[Any]]] = |
| 132 | + treeMatches(tycon1, tycon2) :: (for ((arg1, arg2) <- args1.zip(args2)) yield treeMatches(arg1, arg2)) |
| 133 | + foldMatchings(matchings) |
| 134 | + |
| 135 | + case (DefDef(_, typeParams1, paramss1, returnTpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, returnTpt2, Some(rhs2))) => |
| 136 | + val matchings: List[Option[Seq[Any]]] = |
| 137 | + for ((tree1: Tree, tree2: Tree) <- (typeParams1 ::: paramss1.flatten ::: rhs1 :: Nil).zip(typeParams2 ::: paramss2.flatten ::: rhs2 :: Nil)) yield treeMatches(tree1, tree2) |
| 138 | + foldMatchings(matchings) |
| 139 | + |
| 140 | + case (Term.Lambda(_, tpt1), Term.Lambda(_, tpt2)) => |
| 141 | + // TODO match tpt1 with tpt2? |
| 142 | + Some(Seq.empty) |
| 143 | + |
| 144 | + case _ => |
| 145 | +// println( |
| 146 | +// s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> |
| 147 | +// |Scrutinee |
| 148 | +// | ${scrutinee.showCode} |
| 149 | +// | |
| 150 | +// |${scrutinee.show} |
| 151 | +// | |
| 152 | +// |did not match pattern |
| 153 | +// | ${pattern.showCode} |
| 154 | +// | |
| 155 | +// |${pattern.show} |
| 156 | +// | |
| 157 | +// | |
| 158 | +// | |
| 159 | +// | |
| 160 | +// |""".stripMargin) |
| 161 | + None |
| 162 | + } |
| 163 | + } |
| 164 | + |
| 165 | + treeMatches(scrutineeExpr.unseal, patternExpr.unseal) |
| 166 | + } |
| 167 | + |
| 168 | + private def foldMatchings(matchings: List[Option[Seq[Any]]]): Option[Seq[Any]] = { |
| 169 | + matchings.foldLeft[Option[Seq[Any]]](Some(Seq.empty)) { |
| 170 | + case (Some(acc), Some(holes)) => Some(acc ++ holes) |
| 171 | + case (_, _) => None |
| 172 | + } |
| 173 | + } |
| 174 | + |
| 175 | +} |
0 commit comments