Skip to content

Commit 49b32d1

Browse files
committed
Use opaque types in matcher to abstract over Matchings
1 parent ea670d9 commit 49b32d1

File tree

1 file changed

+107
-78
lines changed

1 file changed

+107
-78
lines changed

library/src-3.x/scala/internal/quoted/Matcher.scala

Lines changed: 107 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ object Matcher {
3232
*/
3333
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
3434
import reflection.{Bind => BindPattern, _}
35+
import Matching._
3536

3637
type Env = Set[(Symbol, Symbol)]
3738

@@ -45,7 +46,7 @@ object Matcher {
4546
* @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
4647
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
4748
*/
48-
def treeMatches(scrutinee: Tree, pattern: Tree) given Env: Option[Tuple] = {
49+
def (scrutinee: Tree) =#= (pattern: Tree) given Env: Matching = {
4950

5051
/** Check that both are `val` or both are `lazy val` or both are `var` **/
5152
def checkValFlags(): Boolean = {
@@ -56,7 +57,7 @@ object Matcher {
5657
}
5758

5859
def bindingMatch(sym: Symbol) =
59-
Some(Tuple1(new Bind(sym.name, sym)))
60+
matched(new Bind(sym.name, sym))
6061

6162
def hasBindTypeAnnotation(tpt: TypeTree): Boolean = tpt match {
6263
case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), "<init>"), Nil)) => true
@@ -67,9 +68,9 @@ object Matcher {
6768
def hasBindAnnotation(sym: Symbol) =
6869
sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),"<init>"),List()) => true; case _ => true }
6970

70-
def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] =
71-
if (scrutinees.size != patterns.size) None
72-
else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*)
71+
def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Matching =
72+
if (scrutinees.size != patterns.size) notMatched
73+
else foldMatchings(scrutinees.zip(patterns).map((s, p) => s =#= p): _*)
7374

7475
/** Normalieze the tree */
7576
def normalize(tree: Tree): Tree = tree match {
@@ -85,126 +86,128 @@ object Matcher {
8586
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
8687
s.tpe <:< tpt.tpe &&
8788
tpt2.tpe.derivesFrom(definitions.RepeatedParamClass) =>
88-
Some(Tuple1(scrutinee.seal))
89+
matched(scrutinee.seal)
8990

9091
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
9192
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
9293
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
9394
scrutinee.tpe <:< tpt.tpe =>
94-
Some(Tuple1(scrutinee.seal))
95+
matched(scrutinee.seal)
9596

9697
//
9798
// Match two equivalent trees
9899
//
99100

100101
case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
101-
Some(())
102+
matched
102103

103104
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
104-
foldMatchings(treeMatches(expr1, expr2), treeMatches(tpt1, tpt2))
105+
expr1 =#= expr2 && tpt1 =#= tpt2
105106

106107
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) =>
107-
Some(())
108+
matched
108109

109110
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
110-
treeMatches(qual1, qual2)
111+
qual1 =#= qual2
111112

112113
case (IsRef(_), IsRef(_)) if scrutinee.symbol == pattern.symbol =>
113-
Some(())
114+
matched
114115

115116
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
116-
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
117+
fn1 =#= fn2 && treesMatch(args1, args2)
117118

118119
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
119-
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
120+
fn1 =#= fn2 && treesMatch(args1, args2)
120121

121122
case (Block(stats1, expr1), Block(stats2, expr2)) =>
122-
foldMatchings(treesMatch(stats1, stats2), treeMatches(expr1, expr2))
123+
treesMatch(stats1, stats2) && expr1 =#= expr2
123124

124125
case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
125-
foldMatchings(treeMatches(cond1, cond2), treeMatches(thenp1, thenp2), treeMatches(elsep1, elsep2))
126+
cond1 =#= cond2 && thenp1 =#= thenp2 && elsep1 =#= elsep2
126127

127128
case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
128129
val lhsMatch =
129-
if (treeMatches(lhs1, lhs2).isDefined) Some(())
130-
else None
131-
foldMatchings(lhsMatch, treeMatches(rhs1, rhs2))
130+
if ((lhs1 =#= lhs2).isMatch) matched
131+
else notMatched
132+
lhsMatch && rhs1 =#= rhs2
132133

133134
case (While(cond1, body1), While(cond2, body2)) =>
134-
foldMatchings(treeMatches(cond1, cond2), treeMatches(body1, body2))
135+
cond1 =#= cond2 && body1 =#= body2
135136

136137
case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
137-
treeMatches(expr1, expr2)
138+
expr1 =#= expr2
138139

139140
case (New(tpt1), New(tpt2)) =>
140-
treeMatches(tpt1, tpt2)
141+
tpt1 =#= tpt2
141142

142143
case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
143-
Some(())
144+
matched
144145

145146
case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 =>
146-
treeMatches(qual1, qual2)
147+
qual1 =#= qual2
147148

148149
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
149150
treesMatch(elems1, elems2)
150151

151152
case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol =>
152-
Some(())
153+
matched
153154

154155
case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
155-
Some(())
156+
matched
156157

157158
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
158-
foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
159+
tycon1 =#= tycon2 && treesMatch(args1, args2)
159160

160161
case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
161162
val bindMatch =
162163
if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
163-
else Some(())
164-
val returnTptMatch = treeMatches(tpt1, tpt2)
164+
else matched
165+
val returnTptMatch = tpt1 =#= tpt2
165166
val rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
166167
val rhsMatchings = treeOptMatches(rhs1, rhs2) given rhsEnv
167-
foldMatchings(bindMatch, returnTptMatch, rhsMatchings)
168+
bindMatch && returnTptMatch && rhsMatchings
168169

169170
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
170171
val typeParmasMatch = treesMatch(typeParams1, typeParams2)
171172
val paramssMatch =
172-
if (paramss1.size != paramss2.size) None
173+
if (paramss1.size != paramss2.size) notMatched
173174
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*)
174175
val bindMatch =
175176
if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
176-
else Some(())
177-
val tptMatch = treeMatches(tpt1, tpt2)
177+
else matched
178+
val tptMatch = tpt1 =#= tpt2
178179
val rhsEnv =
179180
the[Env] + (scrutinee.symbol -> pattern.symbol) ++
180181
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
181182
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
182-
val rhsMatch = treeMatches(rhs1, rhs2) given rhsEnv
183+
val rhsMatch = (rhs1 =#= rhs2) given rhsEnv
183184

184-
foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
185+
bindMatch && typeParmasMatch && paramssMatch && tptMatch && rhsMatch
185186

186187
case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
187188
// TODO match tpt1 with tpt2?
188-
Some(())
189+
matched
189190

190191
case (Match(scru1, cases1), Match(scru2, cases2)) =>
191-
val scrutineeMacth = treeMatches(scru1, scru2)
192+
val scrutineeMacth = scru1 =#= scru2
192193
val casesMatch =
193-
if (cases1.size != cases2.size) None
194+
if (cases1.size != cases2.size) notMatched
194195
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
195-
foldMatchings(scrutineeMacth, casesMatch)
196+
scrutineeMacth && casesMatch
196197

197198
case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) =>
198-
val bodyMacth = treeMatches(body1, body2)
199+
val bodyMacth = body1 =#= body2
199200
val casesMatch =
200-
if (cases1.size != cases2.size) None
201+
if (cases1.size != cases2.size) notMatched
201202
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
202203
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
203-
foldMatchings(bodyMacth, casesMatch, finalizerMatch)
204+
bodyMacth && casesMatch && finalizerMatch
204205

205206
// Ignore type annotations
206-
case (Annotated(tpt, _), _) => treeMatches(tpt, pattern)
207-
case (_, Annotated(tpt, _)) => treeMatches(scrutinee, tpt)
207+
case (Annotated(tpt, _), _) =>
208+
tpt =#= pattern
209+
case (_, Annotated(tpt, _)) =>
210+
scrutinee =#= tpt
208211

209212
// No Match
210213
case _ =>
@@ -225,26 +228,26 @@ object Matcher {
225228
|
226229
|
227230
|""".stripMargin)
228-
None
231+
notMatched
229232
}
230233
}
231234

232-
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Option[Tuple] = {
235+
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Matching = {
233236
(scrutinee, pattern) match {
234-
case (Some(x), Some(y)) => treeMatches(x, y)
235-
case (None, None) => Some(())
236-
case _ => None
237+
case (Some(x), Some(y)) => x =#= y
238+
case (None, None) => matched
239+
case _ => notMatched
237240
}
238241
}
239242

240-
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Option[Tuple] = {
241-
val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern)
243+
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Matching = {
244+
val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern
242245

243246
{
244247
implied for Env = caseEnv
245248
val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)
246-
val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs)
247-
foldMatchings(patternMatch, guardMatch, rhsMatch)
249+
val rhsMatch = scrutinee.rhs =#= pattern.rhs
250+
patternMatch && guardMatch && rhsMatch
248251
}
249252
}
250253

@@ -258,34 +261,34 @@ object Matcher {
258261
* @return The new environment containing the bindings defined in this pattern tuppled with
259262
* `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
260263
*/
261-
def patternMatches(scrutinee: Pattern, pattern: Pattern) given Env: (Env, Option[Tuple]) = (scrutinee, pattern) match {
264+
def (scrutinee: Pattern) =%= (pattern: Pattern) given Env: (Env, Matching) = (scrutinee, pattern) match {
262265
case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
263266
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
264-
(the[Env], Some(Tuple1(v1.seal)))
267+
(the[Env], matched(v1.seal))
265268

266269
case (Pattern.Value(v1), Pattern.Value(v2)) =>
267-
(the[Env], treeMatches(v1, v2))
270+
(the[Env], v1 =#= v2)
268271

269272
case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) =>
270273
val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
271-
patternMatches(body1, body2) given bindEnv
274+
(body1 =%= body2) given bindEnv
272275

273276
case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
274-
val funMatch = treeMatches(fun1, fun2)
277+
val funMatch = fun1 =#= fun2
275278
val implicitsMatch =
276-
if (implicits1.size != implicits2.size) None
277-
else foldMatchings(implicits1.zip(implicits2).map(treeMatches): _*)
279+
if (implicits1.size != implicits2.size) notMatched
280+
else foldMatchings(implicits1.zip(implicits2).map((i1, i2) => i1 =#= i2): _*)
278281
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
279-
(patEnv, foldMatchings(funMatch, implicitsMatch, patternsMatch))
282+
(patEnv, funMatch && implicitsMatch && patternsMatch)
280283

281284
case (Pattern.Alternatives(patterns1), Pattern.Alternatives(patterns2)) =>
282285
foldPatterns(patterns1, patterns2)
283286

284287
case (Pattern.TypeTest(tpt1), Pattern.TypeTest(tpt2)) =>
285-
(the[Env], treeMatches(tpt1, tpt2))
288+
(the[Env], tpt1 =#= tpt2)
286289

287290
case (Pattern.WildcardPattern(), Pattern.WildcardPattern()) =>
288-
(the[Env], Some(()))
291+
(the[Env], matched)
289292

290293
case _ =>
291294
if (debug)
@@ -305,30 +308,56 @@ object Matcher {
305308
|
306309
|
307310
|""".stripMargin)
308-
(the[Env], None)
311+
(the[Env], notMatched)
309312
}
310313

311-
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Option[Tuple]) = {
312-
if (patterns1.size != patterns2.size) (the[Env], None)
313-
else patterns1.zip(patterns2).foldLeft((the[Env], Option[Tuple](()))) { (acc, x) =>
314-
val (env, res) = patternMatches(x._1, x._2) given acc._1
315-
(env, foldMatchings(acc._2, res))
314+
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Matching) = {
315+
if (patterns1.size != patterns2.size) (the[Env], notMatched)
316+
else patterns1.zip(patterns2).foldLeft((the[Env], matched)) { (acc, x) =>
317+
val (env, res) = (x._1 =%= x._2) given acc._1
318+
(env, acc._2 && res)
316319
}
317320
}
318321

319322
implied for Env = Set.empty
320-
treeMatches(scrutineeExpr.unseal, patternExpr.unseal).asInstanceOf[Option[Tup]]
323+
(scrutineeExpr.unseal =#= patternExpr.unseal).asOptionOfTuple.asInstanceOf[Option[Tup]]
321324
}
322325

323-
/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
324-
* Otherwise the result is `Some` of the concatenation of the tupples.
325-
*/
326-
private def foldMatchings(matchings: Option[Tuple]*): Option[Tuple] = {
327-
// TODO improve performance
328-
matchings.foldLeft[Option[Tuple]](Some(())) {
329-
case (Some(acc), Some(holes)) => Some(acc ++ holes)
330-
case (_, _) => None
326+
/** Result of matching a part of an expression */
327+
private opaque type Matching = Option[Tuple]
328+
329+
private object Matching {
330+
331+
def notMatched: Matching = None
332+
val matched: Matching = Some(())
333+
def matched(x: Any): Matching = Some(Tuple1(x))
334+
335+
def (self: Matching) asOptionOfTuple: Option[Tuple] = self
336+
337+
/** Concatenates the contents of two sucessful matchings or return a `notMatched` */
338+
def (self: Matching) && (that: Matching): Matching = self match {
339+
case Some(x) =>
340+
that match {
341+
case Some(y) => Some(x ++ y)
342+
case _ => None
343+
}
344+
case _ => None
331345
}
346+
347+
/** Is this matching the result of a successful match */
348+
def (self: Matching) isMatch: Boolean = self.isDefined
349+
350+
/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
351+
* Otherwise the result is `Some` of the concatenation of the tupples.
352+
*/
353+
def foldMatchings(matchings: Matching*): Matching = {
354+
// TODO improve performance
355+
matchings.foldLeft[Matching](Some(())) {
356+
case (Some(acc), Some(holes)) => Some(acc ++ holes)
357+
case (_, _) => None
358+
}
359+
}
360+
332361
}
333362

334363
}

0 commit comments

Comments
 (0)