Skip to content

Commit ca6fd5c

Browse files
committed
Use opaque types in matcher to abstract over Matchings
1 parent ea670d9 commit ca6fd5c

File tree

1 file changed

+111
-81
lines changed

1 file changed

+111
-81
lines changed

library/src-3.x/scala/internal/quoted/Matcher.scala

Lines changed: 111 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ object Matcher {
3232
*/
3333
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
3434
import reflection.{Bind => BindPattern, _}
35+
import Matching._
3536

3637
type Env = Set[(Symbol, Symbol)]
3738

@@ -45,7 +46,7 @@ object Matcher {
4546
* @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
4647
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
4748
*/
48-
def treeMatches(scrutinee: Tree, pattern: Tree) given Env: Option[Tuple] = {
49+
def (scrutinee: Tree) =#= (pattern: Tree) given Env: Matching = {
4950

5051
/** Check that both are `val` or both are `lazy val` or both are `var` **/
5152
def checkValFlags(): Boolean = {
@@ -56,7 +57,7 @@ object Matcher {
5657
}
5758

5859
def bindingMatch(sym: Symbol) =
59-
Some(Tuple1(new Bind(sym.name, sym)))
60+
matched(new Bind(sym.name, sym))
6061

6162
def hasBindTypeAnnotation(tpt: TypeTree): Boolean = tpt match {
6263
case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), "<init>"), Nil)) => true
@@ -67,9 +68,9 @@ object Matcher {
6768
def hasBindAnnotation(sym: Symbol) =
6869
sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),"<init>"),List()) => true; case _ => true }
6970

70-
def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] =
71-
if (scrutinees.size != patterns.size) None
72-
else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*)
71+
def (scrutinees: List[Tree]) =##= (patterns: List[Tree]): Matching =
72+
if (scrutinees.size != patterns.size) notMatched
73+
else foldMatchings(scrutinees.zip(patterns).map((s, p) => s =#= p): _*)
7374

7475
/** Normalieze the tree */
7576
def normalize(tree: Tree): Tree = tree match {
@@ -85,126 +86,129 @@ object Matcher {
8586
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
8687
s.tpe <:< tpt.tpe &&
8788
tpt2.tpe.derivesFrom(definitions.RepeatedParamClass) =>
88-
Some(Tuple1(scrutinee.seal))
89+
matched(scrutinee.seal)
8990

9091
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
9192
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
9293
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
9394
scrutinee.tpe <:< tpt.tpe =>
94-
Some(Tuple1(scrutinee.seal))
95+
matched(scrutinee.seal)
9596

9697
//
9798
// Match two equivalent trees
9899
//
99100

100101
case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
101-
Some(())
102+
matched
102103

103104
case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
104-
foldMatchings(treeMatches(expr1, expr2), treeMatches(tpt1, tpt2))
105+
expr1 =#= expr2 && tpt1 =#= tpt2
105106

106107
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) =>
107-
Some(())
108+
matched
108109

109110
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
110-
treeMatches(qual1, qual2)
111+
qual1 =#= qual2
111112

112113
case (IsRef(_), IsRef(_)) if scrutinee.symbol == pattern.symbol =>
113-
Some(())
114+
matched
114115

115116
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
116-
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
117+
fn1 =#= fn2 && args1 =##= args2
117118

118119
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
119-
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
120+
fn1 =#= fn2 && args1 =##= args2
120121

121122
case (Block(stats1, expr1), Block(stats2, expr2)) =>
122-
foldMatchings(treesMatch(stats1, stats2), treeMatches(expr1, expr2))
123+
// FIXME update the environment
124+
stats1 =##= stats2 && expr1 =#= expr2
123125

124126
case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
125-
foldMatchings(treeMatches(cond1, cond2), treeMatches(thenp1, thenp2), treeMatches(elsep1, elsep2))
127+
cond1 =#= cond2 && thenp1 =#= thenp2 && elsep1 =#= elsep2
126128

127129
case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
128130
val lhsMatch =
129-
if (treeMatches(lhs1, lhs2).isDefined) Some(())
130-
else None
131-
foldMatchings(lhsMatch, treeMatches(rhs1, rhs2))
131+
if ((lhs1 =#= lhs2).isMatch) matched
132+
else notMatched
133+
lhsMatch && rhs1 =#= rhs2
132134

133135
case (While(cond1, body1), While(cond2, body2)) =>
134-
foldMatchings(treeMatches(cond1, cond2), treeMatches(body1, body2))
136+
cond1 =#= cond2 && body1 =#= body2
135137

136138
case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
137-
treeMatches(expr1, expr2)
139+
expr1 =#= expr2
138140

139141
case (New(tpt1), New(tpt2)) =>
140-
treeMatches(tpt1, tpt2)
142+
tpt1 =#= tpt2
141143

142144
case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
143-
Some(())
145+
matched
144146

145147
case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 =>
146-
treeMatches(qual1, qual2)
148+
qual1 =#= qual2
147149

148150
case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
149-
treesMatch(elems1, elems2)
151+
elems1 =##= elems2
150152

151153
case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol =>
152-
Some(())
154+
matched
153155

154156
case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
155-
Some(())
157+
matched
156158

157159
case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
158-
foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
160+
tycon1 =#= tycon2 && args1 =##= args2
159161

160162
case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
161163
val bindMatch =
162164
if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
163-
else Some(())
164-
val returnTptMatch = treeMatches(tpt1, tpt2)
165+
else matched
166+
val returnTptMatch = tpt1 =#= tpt2
165167
val rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
166168
val rhsMatchings = treeOptMatches(rhs1, rhs2) given rhsEnv
167-
foldMatchings(bindMatch, returnTptMatch, rhsMatchings)
169+
bindMatch && returnTptMatch && rhsMatchings
168170

169171
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
170-
val typeParmasMatch = treesMatch(typeParams1, typeParams2)
172+
val typeParmasMatch = typeParams1 =##= typeParams2
171173
val paramssMatch =
172-
if (paramss1.size != paramss2.size) None
173-
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*)
174+
if (paramss1.size != paramss2.size) notMatched
175+
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => params1 =##= params2 }: _*)
174176
val bindMatch =
175177
if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
176-
else Some(())
177-
val tptMatch = treeMatches(tpt1, tpt2)
178+
else matched
179+
val tptMatch = tpt1 =#= tpt2
178180
val rhsEnv =
179181
the[Env] + (scrutinee.symbol -> pattern.symbol) ++
180182
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
181183
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
182-
val rhsMatch = treeMatches(rhs1, rhs2) given rhsEnv
184+
val rhsMatch = (rhs1 =#= rhs2) given rhsEnv
183185

184-
foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
186+
bindMatch && typeParmasMatch && paramssMatch && tptMatch && rhsMatch
185187

186188
case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
187189
// TODO match tpt1 with tpt2?
188-
Some(())
190+
matched
189191

190192
case (Match(scru1, cases1), Match(scru2, cases2)) =>
191-
val scrutineeMacth = treeMatches(scru1, scru2)
193+
val scrutineeMacth = scru1 =#= scru2
192194
val casesMatch =
193-
if (cases1.size != cases2.size) None
195+
if (cases1.size != cases2.size) notMatched
194196
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
195-
foldMatchings(scrutineeMacth, casesMatch)
197+
scrutineeMacth && casesMatch
196198

197199
case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) =>
198-
val bodyMacth = treeMatches(body1, body2)
200+
val bodyMacth = body1 =#= body2
199201
val casesMatch =
200-
if (cases1.size != cases2.size) None
202+
if (cases1.size != cases2.size) notMatched
201203
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
202204
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
203-
foldMatchings(bodyMacth, casesMatch, finalizerMatch)
205+
bodyMacth && casesMatch && finalizerMatch
204206

205207
// Ignore type annotations
206-
case (Annotated(tpt, _), _) => treeMatches(tpt, pattern)
207-
case (_, Annotated(tpt, _)) => treeMatches(scrutinee, tpt)
208+
case (Annotated(tpt, _), _) =>
209+
tpt =#= pattern
210+
case (_, Annotated(tpt, _)) =>
211+
scrutinee =#= tpt
208212

209213
// No Match
210214
case _ =>
@@ -225,26 +229,26 @@ object Matcher {
225229
|
226230
|
227231
|""".stripMargin)
228-
None
232+
notMatched
229233
}
230234
}
231235

232-
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Option[Tuple] = {
236+
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Matching = {
233237
(scrutinee, pattern) match {
234-
case (Some(x), Some(y)) => treeMatches(x, y)
235-
case (None, None) => Some(())
236-
case _ => None
238+
case (Some(x), Some(y)) => x =#= y
239+
case (None, None) => matched
240+
case _ => notMatched
237241
}
238242
}
239243

240-
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Option[Tuple] = {
241-
val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern)
244+
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Matching = {
245+
val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern
242246

243247
{
244248
implied for Env = caseEnv
245249
val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)
246-
val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs)
247-
foldMatchings(patternMatch, guardMatch, rhsMatch)
250+
val rhsMatch = scrutinee.rhs =#= pattern.rhs
251+
patternMatch && guardMatch && rhsMatch
248252
}
249253
}
250254

@@ -258,34 +262,34 @@ object Matcher {
258262
* @return The new environment containing the bindings defined in this pattern tuppled with
259263
* `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
260264
*/
261-
def patternMatches(scrutinee: Pattern, pattern: Pattern) given Env: (Env, Option[Tuple]) = (scrutinee, pattern) match {
265+
def (scrutinee: Pattern) =%= (pattern: Pattern) given Env: (Env, Matching) = (scrutinee, pattern) match {
262266
case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
263267
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
264-
(the[Env], Some(Tuple1(v1.seal)))
268+
(the[Env], matched(v1.seal))
265269

266270
case (Pattern.Value(v1), Pattern.Value(v2)) =>
267-
(the[Env], treeMatches(v1, v2))
271+
(the[Env], v1 =#= v2)
268272

269273
case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) =>
270274
val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
271-
patternMatches(body1, body2) given bindEnv
275+
(body1 =%= body2) given bindEnv
272276

273277
case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
274-
val funMatch = treeMatches(fun1, fun2)
278+
val funMatch = fun1 =#= fun2
275279
val implicitsMatch =
276-
if (implicits1.size != implicits2.size) None
277-
else foldMatchings(implicits1.zip(implicits2).map(treeMatches): _*)
280+
if (implicits1.size != implicits2.size) notMatched
281+
else foldMatchings(implicits1.zip(implicits2).map((i1, i2) => i1 =#= i2): _*)
278282
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
279-
(patEnv, foldMatchings(funMatch, implicitsMatch, patternsMatch))
283+
(patEnv, funMatch && implicitsMatch && patternsMatch)
280284

281285
case (Pattern.Alternatives(patterns1), Pattern.Alternatives(patterns2)) =>
282286
foldPatterns(patterns1, patterns2)
283287

284288
case (Pattern.TypeTest(tpt1), Pattern.TypeTest(tpt2)) =>
285-
(the[Env], treeMatches(tpt1, tpt2))
289+
(the[Env], tpt1 =#= tpt2)
286290

287291
case (Pattern.WildcardPattern(), Pattern.WildcardPattern()) =>
288-
(the[Env], Some(()))
292+
(the[Env], matched)
289293

290294
case _ =>
291295
if (debug)
@@ -305,30 +309,56 @@ object Matcher {
305309
|
306310
|
307311
|""".stripMargin)
308-
(the[Env], None)
312+
(the[Env], notMatched)
309313
}
310314

311-
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Option[Tuple]) = {
312-
if (patterns1.size != patterns2.size) (the[Env], None)
313-
else patterns1.zip(patterns2).foldLeft((the[Env], Option[Tuple](()))) { (acc, x) =>
314-
val (env, res) = patternMatches(x._1, x._2) given acc._1
315-
(env, foldMatchings(acc._2, res))
315+
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Matching) = {
316+
if (patterns1.size != patterns2.size) (the[Env], notMatched)
317+
else patterns1.zip(patterns2).foldLeft((the[Env], matched)) { (acc, x) =>
318+
val (env, res) = (x._1 =%= x._2) given acc._1
319+
(env, acc._2 && res)
316320
}
317321
}
318322

319323
implied for Env = Set.empty
320-
treeMatches(scrutineeExpr.unseal, patternExpr.unseal).asInstanceOf[Option[Tup]]
324+
(scrutineeExpr.unseal =#= patternExpr.unseal).asOptionOfTuple.asInstanceOf[Option[Tup]]
321325
}
322326

323-
/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
324-
* Otherwise the result is `Some` of the concatenation of the tupples.
325-
*/
326-
private def foldMatchings(matchings: Option[Tuple]*): Option[Tuple] = {
327-
// TODO improve performance
328-
matchings.foldLeft[Option[Tuple]](Some(())) {
329-
case (Some(acc), Some(holes)) => Some(acc ++ holes)
330-
case (_, _) => None
327+
/** Result of matching a part of an expression */
328+
private opaque type Matching = Option[Tuple]
329+
330+
private object Matching {
331+
332+
def notMatched: Matching = None
333+
val matched: Matching = Some(())
334+
def matched(x: Any): Matching = Some(Tuple1(x))
335+
336+
def (self: Matching) asOptionOfTuple: Option[Tuple] = self
337+
338+
/** Concatenates the contents of two sucessful matchings or return a `notMatched` */
339+
def (self: Matching) && (that: Matching): Matching = self match {
340+
case Some(x) =>
341+
that match {
342+
case Some(y) => Some(x ++ y)
343+
case _ => None
344+
}
345+
case _ => None
331346
}
347+
348+
/** Is this matching the result of a successful match */
349+
def (self: Matching) isMatch: Boolean = self.isDefined
350+
351+
/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
352+
* Otherwise the result is `Some` of the concatenation of the tupples.
353+
*/
354+
def foldMatchings(matchings: Matching*): Matching = {
355+
// TODO improve performance
356+
matchings.foldLeft[Matching](Some(())) {
357+
case (Some(acc), Some(holes)) => Some(acc ++ holes)
358+
case (_, _) => None
359+
}
360+
}
361+
332362
}
333363

334364
}

0 commit comments

Comments
 (0)