Skip to content

Commit 54a75d4

Browse files
Merge pull request #6393 from dotty-staging/refactor-matcher
Use opaque types in matcher to abstract over Matchings
2 parents fdf9ccc + 334af17 commit 54a75d4

File tree

3 files changed

+180
-86
lines changed

3 files changed

+180
-86
lines changed

library/src-3.x/scala/internal/quoted/Matcher.scala

Lines changed: 123 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,23 @@ object Matcher {
3131
* @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]``
3232
*/
3333
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
34+
// TODO improve performance
3435
import reflection.{Bind => BindPattern, _}
36+
import Matching._
3537

3638
type Env = Set[(Symbol, Symbol)]
3739

38-
// TODO improve performance
40+
inline def withEnv[T](env: Env)(body: => given Env => T): T = body given env
41+
42+
/** Check that all trees match with =#= and concatenate the results with && */
43+
def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Env: Matching = {
44+
def rec(l1: List[Tree], l2: List[Tree]): Matching = (l1, l2) match {
45+
case (x :: xs, y :: ys) => x =#= y && rec(xs, ys)
46+
case (Nil, Nil) => matched
47+
case _ => notMatched
48+
}
49+
rec(scrutinees, patterns)
50+
}
3951

4052
/** Check that the trees match and return the contents from the pattern holes.
4153
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
@@ -45,7 +57,7 @@ object Matcher {
4557
* @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
4658
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
4759
*/
48-
def treeMatches(scrutinee: Tree, pattern: Tree) given Env: Option[Tuple] = {
60+
def (scrutinee: Tree) =#= (pattern: Tree) given Env: Matching = {
4961

5062
/** Check that both are `val` or both are `lazy val` or both are `var` **/
5163
def checkValFlags(): Boolean = {
@@ -56,7 +68,7 @@ object Matcher {
5668
}
5769

5870
def bindingMatch(sym: Symbol) =
59-
Some(Tuple1(new Bind(sym.name, sym)))
71+
matched(new Bind(sym.name, sym))
6072

6173
def hasBindTypeAnnotation(tpt: TypeTree): Boolean = tpt match {
6274
case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), "<init>"), Nil)) => true
@@ -67,10 +79,6 @@ object Matcher {
6779
def hasBindAnnotation(sym: Symbol) =
6880
sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),"<init>"),List()) => true; case _ => true }
6981

70-
def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] =
71-
if (scrutinees.size != patterns.size) None
72-
else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*)
73-
7482
/** Normalieze the tree */
7583
def normalize(tree: Tree): Tree = tree match {
7684
case Block(Nil, expr) => normalize(expr)
@@ -85,126 +93,130 @@ object Matcher {
8593
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
8694
s.tpe <:< tpt.tpe &&
8795
tpt2.tpe.derivesFrom(definitions.RepeatedParamClass) =>
88-
Some(Tuple1(scrutinee.seal))
96+
matched(scrutinee.seal)
8997

9098
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
9199
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
92100
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
93101
scrutinee.tpe <:< tpt.tpe =>
94-
Some(Tuple1(scrutinee.seal))
102+
matched(scrutinee.seal)
95103

96104
//
97105
// Match two equivalent trees
98106
//
99107

100108
case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
101-
Some(())
109+
matched
102110

103111
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
104-
foldMatchings(treeMatches(expr1, expr2), treeMatches(tpt1, tpt2))
112+
expr1 =#= expr2 && tpt1 =#= tpt2
105113

106114
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) =>
107-
Some(())
115+
matched
108116

109117
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
110-
treeMatches(qual1, qual2)
118+
qual1 =#= qual2
111119

112120
case (IsRef(_), IsRef(_)) if scrutinee.symbol == pattern.symbol =>
113-
Some(())
121+
matched
114122

115123
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
116-
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
124+
fn1 =#= fn2 && args1 =##= args2
117125

118126
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
119-
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
127+
fn1 =#= fn2 && args1 =##= args2
120128

121129
case (Block(stats1, expr1), Block(stats2, expr2)) =>
122-
foldMatchings(treesMatch(stats1, stats2), treeMatches(expr1, expr2))
130+
withEnv(the[Env] ++ stats1.map(_.symbol).zip(stats2.map(_.symbol))) {
131+
stats1 =##= stats2 && expr1 =#= expr2
132+
}
123133

124134
case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
125-
foldMatchings(treeMatches(cond1, cond2), treeMatches(thenp1, thenp2), treeMatches(elsep1, elsep2))
135+
cond1 =#= cond2 && thenp1 =#= thenp2 && elsep1 =#= elsep2
126136

127137
case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
128138
val lhsMatch =
129-
if (treeMatches(lhs1, lhs2).isDefined) Some(())
130-
else None
131-
foldMatchings(lhsMatch, treeMatches(rhs1, rhs2))
139+
if ((lhs1 =#= lhs2).isMatch) matched
140+
else notMatched
141+
lhsMatch && rhs1 =#= rhs2
132142

133143
case (While(cond1, body1), While(cond2, body2)) =>
134-
foldMatchings(treeMatches(cond1, cond2), treeMatches(body1, body2))
144+
cond1 =#= cond2 && body1 =#= body2
135145

136146
case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
137-
treeMatches(expr1, expr2)
147+
expr1 =#= expr2
138148

139149
case (New(tpt1), New(tpt2)) =>
140-
treeMatches(tpt1, tpt2)
150+
tpt1 =#= tpt2
141151

142152
case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
143-
Some(())
153+
matched
144154

145155
case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 =>
146-
treeMatches(qual1, qual2)
156+
qual1 =#= qual2
147157

148158
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
149-
treesMatch(elems1, elems2)
159+
elems1 =##= elems2
150160

151161
case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol =>
152-
Some(())
162+
matched
153163

154164
case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
155-
Some(())
165+
matched
156166

157167
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
158-
foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
168+
tycon1 =#= tycon2 && args1 =##= args2
159169

160170
case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
161171
val bindMatch =
162172
if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
163-
else Some(())
164-
val returnTptMatch = treeMatches(tpt1, tpt2)
173+
else matched
174+
val returnTptMatch = tpt1 =#= tpt2
165175
val rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
166176
val rhsMatchings = treeOptMatches(rhs1, rhs2) given rhsEnv
167-
foldMatchings(bindMatch, returnTptMatch, rhsMatchings)
177+
bindMatch && returnTptMatch && rhsMatchings
168178

169179
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
170-
val typeParmasMatch = treesMatch(typeParams1, typeParams2)
180+
val typeParmasMatch = typeParams1 =##= typeParams2
171181
val paramssMatch =
172-
if (paramss1.size != paramss2.size) None
173-
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*)
182+
if (paramss1.size != paramss2.size) notMatched
183+
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => params1 =##= params2 }: _*)
174184
val bindMatch =
175185
if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
176-
else Some(())
177-
val tptMatch = treeMatches(tpt1, tpt2)
186+
else matched
187+
val tptMatch = tpt1 =#= tpt2
178188
val rhsEnv =
179189
the[Env] + (scrutinee.symbol -> pattern.symbol) ++
180190
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
181191
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
182-
val rhsMatch = treeMatches(rhs1, rhs2) given rhsEnv
192+
val rhsMatch = (rhs1 =#= rhs2) given rhsEnv
183193

184-
foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
194+
bindMatch && typeParmasMatch && paramssMatch && tptMatch && rhsMatch
185195

186196
case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
187197
// TODO match tpt1 with tpt2?
188-
Some(())
198+
matched
189199

190200
case (Match(scru1, cases1), Match(scru2, cases2)) =>
191-
val scrutineeMacth = treeMatches(scru1, scru2)
201+
val scrutineeMacth = scru1 =#= scru2
192202
val casesMatch =
193-
if (cases1.size != cases2.size) None
203+
if (cases1.size != cases2.size) notMatched
194204
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
195-
foldMatchings(scrutineeMacth, casesMatch)
205+
scrutineeMacth && casesMatch
196206

197207
case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) =>
198-
val bodyMacth = treeMatches(body1, body2)
208+
val bodyMacth = body1 =#= body2
199209
val casesMatch =
200-
if (cases1.size != cases2.size) None
210+
if (cases1.size != cases2.size) notMatched
201211
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
202212
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
203-
foldMatchings(bodyMacth, casesMatch, finalizerMatch)
213+
bodyMacth && casesMatch && finalizerMatch
204214

205215
// Ignore type annotations
206-
case (Annotated(tpt, _), _) => treeMatches(tpt, pattern)
207-
case (_, Annotated(tpt, _)) => treeMatches(scrutinee, tpt)
216+
case (Annotated(tpt, _), _) =>
217+
tpt =#= pattern
218+
case (_, Annotated(tpt, _)) =>
219+
scrutinee =#= tpt
208220

209221
// No Match
210222
case _ =>
@@ -225,26 +237,24 @@ object Matcher {
225237
|
226238
|
227239
|""".stripMargin)
228-
None
240+
notMatched
229241
}
230242
}
231243

232-
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Option[Tuple] = {
244+
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Matching = {
233245
(scrutinee, pattern) match {
234-
case (Some(x), Some(y)) => treeMatches(x, y)
235-
case (None, None) => Some(())
236-
case _ => None
246+
case (Some(x), Some(y)) => x =#= y
247+
case (None, None) => matched
248+
case _ => notMatched
237249
}
238250
}
239251

240-
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Option[Tuple] = {
241-
val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern)
242-
243-
{
244-
implied for Env = caseEnv
252+
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Matching = {
253+
val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern
254+
withEnv(caseEnv) {
245255
val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)
246-
val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs)
247-
foldMatchings(patternMatch, guardMatch, rhsMatch)
256+
val rhsMatch = scrutinee.rhs =#= pattern.rhs
257+
patternMatch && guardMatch && rhsMatch
248258
}
249259
}
250260

@@ -258,34 +268,34 @@ object Matcher {
258268
* @return The new environment containing the bindings defined in this pattern tuppled with
259269
* `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
260270
*/
261-
def patternMatches(scrutinee: Pattern, pattern: Pattern) given Env: (Env, Option[Tuple]) = (scrutinee, pattern) match {
271+
def (scrutinee: Pattern) =%= (pattern: Pattern) given Env: (Env, Matching) = (scrutinee, pattern) match {
262272
case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
263273
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
264-
(the[Env], Some(Tuple1(v1.seal)))
274+
(the[Env], matched(v1.seal))
265275

266276
case (Pattern.Value(v1), Pattern.Value(v2)) =>
267-
(the[Env], treeMatches(v1, v2))
277+
(the[Env], v1 =#= v2)
268278

269279
case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) =>
270280
val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
271-
patternMatches(body1, body2) given bindEnv
281+
(body1 =%= body2) given bindEnv
272282

273283
case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
274-
val funMatch = treeMatches(fun1, fun2)
284+
val funMatch = fun1 =#= fun2
275285
val implicitsMatch =
276-
if (implicits1.size != implicits2.size) None
277-
else foldMatchings(implicits1.zip(implicits2).map(treeMatches): _*)
286+
if (implicits1.size != implicits2.size) notMatched
287+
else foldMatchings(implicits1.zip(implicits2).map((i1, i2) => i1 =#= i2): _*)
278288
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
279-
(patEnv, foldMatchings(funMatch, implicitsMatch, patternsMatch))
289+
(patEnv, funMatch && implicitsMatch && patternsMatch)
280290

281291
case (Pattern.Alternatives(patterns1), Pattern.Alternatives(patterns2)) =>
282292
foldPatterns(patterns1, patterns2)
283293

284294
case (Pattern.TypeTest(tpt1), Pattern.TypeTest(tpt2)) =>
285-
(the[Env], treeMatches(tpt1, tpt2))
295+
(the[Env], tpt1 =#= tpt2)
286296

287297
case (Pattern.WildcardPattern(), Pattern.WildcardPattern()) =>
288-
(the[Env], Some(()))
298+
(the[Env], matched)
289299

290300
case _ =>
291301
if (debug)
@@ -305,30 +315,57 @@ object Matcher {
305315
|
306316
|
307317
|""".stripMargin)
308-
(the[Env], None)
318+
(the[Env], notMatched)
309319
}
310320

311-
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Option[Tuple]) = {
312-
if (patterns1.size != patterns2.size) (the[Env], None)
313-
else patterns1.zip(patterns2).foldLeft((the[Env], Option[Tuple](()))) { (acc, x) =>
314-
val (env, res) = patternMatches(x._1, x._2) given acc._1
315-
(env, foldMatchings(acc._2, res))
321+
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Matching) = {
322+
if (patterns1.size != patterns2.size) (the[Env], notMatched)
323+
else patterns1.zip(patterns2).foldLeft((the[Env], matched)) { (acc, x) =>
324+
val (env, res) = (x._1 =%= x._2) given acc._1
325+
(env, acc._2 && res)
316326
}
317327
}
318328

319329
implied for Env = Set.empty
320-
treeMatches(scrutineeExpr.unseal, patternExpr.unseal).asInstanceOf[Option[Tup]]
330+
(scrutineeExpr.unseal =#= patternExpr.unseal).asOptionOfTuple.asInstanceOf[Option[Tup]]
321331
}
322332

323-
/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
324-
* Otherwise the result is `Some` of the concatenation of the tupples.
325-
*/
326-
private def foldMatchings(matchings: Option[Tuple]*): Option[Tuple] = {
327-
// TODO improve performance
328-
matchings.foldLeft[Option[Tuple]](Some(())) {
329-
case (Some(acc), Some(holes)) => Some(acc ++ holes)
330-
case (_, _) => None
333+
/** Result of matching a part of an expression */
334+
private opaque type Matching = Option[Tuple]
335+
336+
private object Matching {
337+
338+
def notMatched: Matching = None
339+
val matched: Matching = Some(())
340+
def matched(x: Any): Matching = Some(Tuple1(x))
341+
342+
def (self: Matching) asOptionOfTuple: Option[Tuple] = self
343+
344+
/** Concatenates the contents of two sucessful matchings or return a `notMatched` */
345+
// FIXME inline to avoid alocation of by name closure (see #6395)
346+
/*inline*/ def (self: Matching) && (that: => Matching): Matching = self match {
347+
case Some(x) =>
348+
that match {
349+
case Some(y) => Some(x ++ y)
350+
case _ => None
351+
}
352+
case _ => None
331353
}
354+
355+
/** Is this matching the result of a successful match */
356+
def (self: Matching) isMatch: Boolean = self.isDefined
357+
358+
/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
359+
* Otherwise the result is `Some` of the concatenation of the tupples.
360+
*/
361+
def foldMatchings(matchings: Matching*): Matching = {
362+
// TODO improve performance
363+
matchings.foldLeft[Matching](Some(())) {
364+
case (Some(acc), Some(holes)) => Some(acc ++ holes)
365+
case (_, _) => None
366+
}
367+
}
368+
332369
}
333370

334371
}

0 commit comments

Comments
 (0)