Skip to content

Commit c02ce19

Browse files
committed
Implement Matcher runtime logic
This PR adds: * runtime logic for pattern matching on `'{ ... }` * `Literal` pattern to match a literal expression and extract its value
1 parent 3d19187 commit c02ce19

File tree

12 files changed

+1129
-10
lines changed

12 files changed

+1129
-10
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,6 +1730,8 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
17301730
// DEFINITIONS
17311731
//
17321732

1733+
// Symbols
1734+
17331735
def Definitions_RootPackage: Symbol = defn.RootPackage
17341736
def Definitions_RootClass: Symbol = defn.RootClass
17351737

@@ -1778,6 +1780,10 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
17781780
defn.FunctionClass(arity, isImplicit, isErased).asClass
17791781
def Definitions_TupleClass(arity: Int): Symbol = defn.TupleType(arity).classSymbol.asClass
17801782

1783+
def Definitions_InternalQuoted_patternHole: Symbol = defn.InternalQuoted_patternHole
1784+
1785+
// Types
1786+
17811787
def Definitions_UnitType: Type = defn.UnitType
17821788
def Definitions_ByteType: Type = defn.ByteType
17831789
def Definitions_ShortType: Type = defn.ShortType
Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
package scala.runtime.quoted
2+
3+
import scala.annotation.internal.sharable
4+
5+
import scala.quoted._
6+
import scala.tasty._
7+
8+
object Matcher {
9+
10+
private final val debug = false
11+
12+
/**
13+
*
14+
* @param scrutineeExpr
15+
* @param patternExpr
16+
* @param reflection
17+
* @return None if it did not match, Some(tup) if it matched where tup contains Expr[_], Type[_] or Binding[_]
18+
*/
19+
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
20+
import reflection._
21+
22+
def treeMatches(scrutinee: Tree, pattern: Tree)(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = {
23+
24+
/** Check that both are `val` or both are `lazy val` or both are `var` **/
25+
def checkValFlags(): Boolean = {
26+
import Flags._
27+
val sFlags = scrutinee.symbol.flags
28+
val pFlags = pattern.symbol.flags
29+
sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable)
30+
}
31+
32+
def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] =
33+
if (scrutinees.size != patterns.size) None
34+
else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*)
35+
36+
(scrutinee, pattern) match {
37+
// Normalize blocks without statements
38+
case (Block(Nil, expr), _) => treeMatches(expr, pattern)
39+
case (_, Block(Nil, pat)) => treeMatches(scrutinee, pat)
40+
41+
// Match
42+
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
43+
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole && scrutinee.tpe <:< tpt.tpe =>
44+
Some(Tuple1(scrutinee.seal))
45+
46+
case (Inlined(_, Nil, scr), _) =>
47+
treeMatches(scr, pattern)
48+
case (_, Inlined(_, Nil, pat)) =>
49+
treeMatches(scrutinee, pat)
50+
51+
case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
52+
Some(())
53+
54+
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || env((scrutinee.symbol, pattern.symbol)) =>
55+
Some(())
56+
57+
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
58+
foldMatchings(treeMatches(expr1, expr2), treeMatches(tpt1, tpt2))
59+
60+
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
61+
treeMatches(qual1, qual2)
62+
63+
case (Ident(_), Select(_, _)) if scrutinee.symbol == pattern.symbol =>
64+
Some(())
65+
66+
case (Select(_, _), Ident(_)) if scrutinee.symbol == pattern.symbol =>
67+
Some(())
68+
69+
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
70+
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
71+
72+
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
73+
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
74+
75+
case (Block(stats1, expr1), Block(stats2, expr2)) =>
76+
foldMatchings(treesMatch(stats1, stats2), treeMatches(expr1, expr2))
77+
78+
case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
79+
foldMatchings(treeMatches(cond1, cond2), treeMatches(thenp1, thenp2), treeMatches(elsep1, elsep2))
80+
81+
case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
82+
val lhsMatch =
83+
if (treeMatches(lhs1, lhs2).isDefined) Some(())
84+
else None
85+
foldMatchings(lhsMatch, treeMatches(rhs1, rhs2))
86+
87+
case (While(cond1, body1), While(cond2, body2)) =>
88+
foldMatchings(treeMatches(cond1, cond2), treeMatches(body1, body2))
89+
90+
case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
91+
treeMatches(expr1, expr2)
92+
93+
case (New(tpt1), New(tpt2)) =>
94+
treeMatches(tpt1, tpt2)
95+
96+
case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
97+
Some(())
98+
99+
case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 =>
100+
treeMatches(qual1, qual2)
101+
102+
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
103+
treesMatch(elems1, elems2)
104+
105+
case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol =>
106+
Some(())
107+
108+
case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
109+
Some(())
110+
111+
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
112+
foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
113+
114+
case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
115+
val returnTptMatch = treeMatches(tpt1, tpt2)
116+
val rhsEnv = env + (scrutinee.symbol -> pattern.symbol)
117+
val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv)
118+
foldMatchings(returnTptMatch, rhsMatchings)
119+
120+
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
121+
val typeParmasMatch = treesMatch(typeParams1, typeParams2)
122+
val paramssMatch =
123+
if (paramss1.size != paramss2.size) None
124+
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*)
125+
val tptMatch = treeMatches(tpt1, tpt2)
126+
val rhsEnv =
127+
env + (scrutinee.symbol -> pattern.symbol) ++
128+
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
129+
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
130+
val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv)
131+
132+
foldMatchings(typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
133+
134+
case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
135+
// TODO match tpt1 with tpt2?
136+
Some(())
137+
138+
case (Match(scru1, cases1), Match(scru2, cases2)) =>
139+
val scrutineeMacth = treeMatches(scru1, scru2)
140+
val casesMatch =
141+
if (cases1.size != cases2.size) None
142+
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
143+
foldMatchings(scrutineeMacth, casesMatch)
144+
145+
case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) =>
146+
val bodyMacth = treeMatches(body1, body2)
147+
val casesMatch =
148+
if (cases1.size != cases2.size) None
149+
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
150+
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
151+
foldMatchings(bodyMacth, casesMatch, finalizerMatch)
152+
153+
case _ =>
154+
if (debug)
155+
println(
156+
s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
157+
|Scrutinee
158+
| ${scrutinee.showCode}
159+
|
160+
|${scrutinee.show}
161+
|
162+
|did not match pattern
163+
| ${pattern.showCode}
164+
|
165+
|${pattern.show}
166+
|
167+
|
168+
|
169+
|
170+
|""".stripMargin)
171+
None
172+
}
173+
}
174+
175+
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = {
176+
(scrutinee, pattern) match {
177+
case (Some(x), Some(y)) => treeMatches(x, y)
178+
case (None, None) => Some(())
179+
case _ => None
180+
}
181+
}
182+
183+
def caseMatches(scrutinee: CaseDef, pattern: CaseDef)(implicit env: Set[(Symbol, Symbol)]): Option[Tuple] = {
184+
val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern)
185+
val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)(caseEnv)
186+
val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs)(caseEnv)
187+
foldMatchings(patternMatch, guardMatch, rhsMatch)
188+
}
189+
190+
def patternMatches(scrutinee: Pattern, pattern: Pattern)(implicit env: Set[(Symbol, Symbol)]): (Set[(Symbol, Symbol)], Option[Tuple]) = (scrutinee, pattern) match {
191+
case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
192+
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
193+
(env, Some(Tuple1(v1.seal)))
194+
195+
case (Pattern.Value(v1), Pattern.Value(v2)) =>
196+
(env, treeMatches(v1, v2))
197+
198+
case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) =>
199+
val bindEnv = env + (scrutinee.symbol -> pattern.symbol)
200+
patternMatches(body1, body2)(bindEnv)
201+
202+
case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
203+
val funMatch = treeMatches(fun1, fun2)
204+
val implicitsMatch =
205+
if (implicits1.size != implicits2.size) None
206+
else foldMatchings(implicits1.zip(implicits2).map(treeMatches): _*)
207+
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
208+
(patEnv, foldMatchings(funMatch, implicitsMatch, patternsMatch))
209+
210+
case (Pattern.Alternatives(patterns1), Pattern.Alternatives(patterns2)) =>
211+
foldPatterns(patterns1, patterns2)
212+
213+
case (Pattern.TypeTest(tpt1), Pattern.TypeTest(tpt2)) =>
214+
(env, treeMatches(tpt1, tpt2))
215+
216+
case _ =>
217+
if (debug)
218+
println(
219+
s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
220+
|Scrutinee
221+
| ${scrutinee.showCode}
222+
|
223+
|${scrutinee.show}
224+
|
225+
|did not match pattern
226+
| ${pattern.showCode}
227+
|
228+
|${pattern.show}
229+
|
230+
|
231+
|
232+
|
233+
|""".stripMargin)
234+
(env, None)
235+
}
236+
237+
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern])(implicit env: Set[(Symbol, Symbol)]): (Set[(Symbol, Symbol)], Option[Tuple]) = {
238+
if (patterns1.size != patterns2.size) (env, None)
239+
else patterns1.zip(patterns2).foldLeft((env, Option[Tuple](()))) { (acc, x) =>
240+
val (env, res) = patternMatches(x._1, x._2)(acc._1)
241+
(env, foldMatchings(acc._2, res))
242+
}
243+
}
244+
245+
treeMatches(scrutineeExpr.unseal, patternExpr.unseal)(Set.empty).asInstanceOf[Option[Tup]]
246+
}
247+
248+
private def foldMatchings(matchings: Option[Tuple]*): Option[Tuple] = {
249+
matchings.foldLeft[Option[Tuple]](Some(())) {
250+
case (Some(acc), Some(holes)) => Some(acc ++ holes)
251+
case (_, _) => None
252+
}
253+
}
254+
255+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package scala.runtime.quoted
2+
3+
import scala.quoted._
4+
import scala.tasty._
5+
6+
object Matcher {
7+
8+
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] =
9+
throw new Exception("running on non bootstrapped library")
10+
11+
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package scala.quoted.matching
2+
3+
import scala.quoted.Expr
4+
5+
import scala.tasty.Reflection // TODO do not depend on reflection directly
6+
7+
/** Matches expressions containing literal values and extracts the value.
8+
* It may match expressions of type Boolean, Byte, Short, Int, Long,
9+
* Float, Double, Char, String, ClassTag, scala.Symbol, Null and Unit.
10+
*
11+
* Usage:
12+
* ```
13+
* (x: Expr[B]) match {
14+
* case Literal(value: B) => ...
15+
* }
16+
* ```
17+
*/
18+
object Literal {
19+
20+
def unapply[T](expr: Expr[T])(implicit reflect: Reflection): Option[T] = {
21+
import reflect.{Literal => LiteralTree, _} // TODO rename reflect.Literal to avoid this clash
22+
def literal(tree: Term): Option[T] = tree match {
23+
case LiteralTree(c) => Some(c.value.asInstanceOf[T])
24+
case Block(Nil, e) => literal(e)
25+
case Inlined(_, Nil, e) => literal(e)
26+
case _ => None
27+
}
28+
literal(expr.unseal)
29+
}
30+
31+
}

library/src/scala/runtime/quoted/Matcher.scala

Lines changed: 0 additions & 10 deletions
This file was deleted.

library/src/scala/tasty/reflect/Kernel.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,6 +1444,9 @@ trait Kernel {
14441444

14451445
def Definitions_TupleClass(arity: Int): Symbol
14461446

1447+
/** Symbol of scala.runtime.Quoted.patternHole */
1448+
def Definitions_InternalQuoted_patternHole: Symbol
1449+
14471450
def Definitions_UnitType: Type
14481451
def Definitions_ByteType: Type
14491452
def Definitions_ShortType: Type

0 commit comments

Comments
 (0)