Skip to content

Use opaque types in matcher to abstract over Matchings #6393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
209 changes: 123 additions & 86 deletions library/src-3.x/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,23 @@ object Matcher {
* @return None if it did not match, `Some(tup)` if it matched where `tup` contains `Expr[Ti]``
*/
def unapply[Tup <: Tuple](scrutineeExpr: Expr[_])(implicit patternExpr: Expr[_], reflection: Reflection): Option[Tup] = {
// TODO improve performance
import reflection.{Bind => BindPattern, _}
import Matching._

type Env = Set[(Symbol, Symbol)]

// TODO improve performance
inline def withEnv[T](env: Env)(body: => given Env => T): T = body given env

/** Check that all trees match with =#= and concatenate the results with && */
def (scrutinees: List[Tree]) =##= (patterns: List[Tree]) given Env: Matching = {
def rec(l1: List[Tree], l2: List[Tree]): Matching = (l1, l2) match {
case (x :: xs, y :: ys) => x =#= y && rec(xs, ys)
case (Nil, Nil) => matched
case _ => notMatched
}
rec(scrutinees, patterns)
}

/** Check that the trees match and return the contents from the pattern holes.
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
Expand All @@ -45,7 +57,7 @@ object Matcher {
* @param `the[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
*/
def treeMatches(scrutinee: Tree, pattern: Tree) given Env: Option[Tuple] = {
def (scrutinee: Tree) =#= (pattern: Tree) given Env: Matching = {

/** Check that both are `val` or both are `lazy val` or both are `var` **/
def checkValFlags(): Boolean = {
Expand All @@ -56,7 +68,7 @@ object Matcher {
}

def bindingMatch(sym: Symbol) =
Some(Tuple1(new Bind(sym.name, sym)))
matched(new Bind(sym.name, sym))

def hasBindTypeAnnotation(tpt: TypeTree): Boolean = tpt match {
case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), "<init>"), Nil)) => true
Expand All @@ -67,10 +79,6 @@ object Matcher {
def hasBindAnnotation(sym: Symbol) =
sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),"<init>"),List()) => true; case _ => true }

def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] =
if (scrutinees.size != patterns.size) None
else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*)

/** Normalieze the tree */
def normalize(tree: Tree): Tree = tree match {
case Block(Nil, expr) => normalize(expr)
Expand All @@ -85,126 +93,130 @@ object Matcher {
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
s.tpe <:< tpt.tpe &&
tpt2.tpe.derivesFrom(definitions.RepeatedParamClass) =>
Some(Tuple1(scrutinee.seal))
matched(scrutinee.seal)

// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
case (IsTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
if patternHole.symbol == kernel.Definitions_InternalQuoted_patternHole &&
scrutinee.tpe <:< tpt.tpe =>
Some(Tuple1(scrutinee.seal))
matched(scrutinee.seal)

//
// Match two equivalent trees
//

case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
Some(())
matched

case (Typed(expr1, tpt1), Typed(expr2, tpt2)) =>
foldMatchings(treeMatches(expr1, expr2), treeMatches(tpt1, tpt2))
expr1 =#= expr2 && tpt1 =#= tpt2

case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || the[Env].apply((scrutinee.symbol, pattern.symbol)) =>
Some(())
matched

case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
treeMatches(qual1, qual2)
qual1 =#= qual2

case (IsRef(_), IsRef(_)) if scrutinee.symbol == pattern.symbol =>
Some(())
matched

case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
fn1 =#= fn2 && args1 =##= args2

case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
foldMatchings(treeMatches(fn1, fn2), treesMatch(args1, args2))
fn1 =#= fn2 && args1 =##= args2

case (Block(stats1, expr1), Block(stats2, expr2)) =>
foldMatchings(treesMatch(stats1, stats2), treeMatches(expr1, expr2))
withEnv(the[Env] ++ stats1.map(_.symbol).zip(stats2.map(_.symbol))) {
stats1 =##= stats2 && expr1 =#= expr2
}

case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
foldMatchings(treeMatches(cond1, cond2), treeMatches(thenp1, thenp2), treeMatches(elsep1, elsep2))
cond1 =#= cond2 && thenp1 =#= thenp2 && elsep1 =#= elsep2

case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
val lhsMatch =
if (treeMatches(lhs1, lhs2).isDefined) Some(())
else None
foldMatchings(lhsMatch, treeMatches(rhs1, rhs2))
if ((lhs1 =#= lhs2).isMatch) matched
else notMatched
lhsMatch && rhs1 =#= rhs2

case (While(cond1, body1), While(cond2, body2)) =>
foldMatchings(treeMatches(cond1, cond2), treeMatches(body1, body2))
cond1 =#= cond2 && body1 =#= body2

case (NamedArg(name1, expr1), NamedArg(name2, expr2)) if name1 == name2 =>
treeMatches(expr1, expr2)
expr1 =#= expr2

case (New(tpt1), New(tpt2)) =>
treeMatches(tpt1, tpt2)
tpt1 =#= tpt2

case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
Some(())
matched

case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 =>
treeMatches(qual1, qual2)
qual1 =#= qual2

case (Repeated(elems1, _), Repeated(elems2, _)) if elems1.size == elems2.size =>
treesMatch(elems1, elems2)
elems1 =##= elems2

case (IsTypeTree(scrutinee @ TypeIdent(_)), IsTypeTree(pattern @ TypeIdent(_))) if scrutinee.symbol == pattern.symbol =>
Some(())
matched

case (IsInferred(scrutinee), IsInferred(pattern)) if scrutinee.tpe <:< pattern.tpe =>
Some(())
matched

case (Applied(tycon1, args1), Applied(tycon2, args2)) =>
foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
tycon1 =#= tycon2 && args1 =##= args2

case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
val bindMatch =
if (hasBindAnnotation(pattern.symbol) || hasBindTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
else Some(())
val returnTptMatch = treeMatches(tpt1, tpt2)
else matched
val returnTptMatch = tpt1 =#= tpt2
val rhsEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
val rhsMatchings = treeOptMatches(rhs1, rhs2) given rhsEnv
foldMatchings(bindMatch, returnTptMatch, rhsMatchings)
bindMatch && returnTptMatch && rhsMatchings

case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
val typeParmasMatch = treesMatch(typeParams1, typeParams2)
val typeParmasMatch = typeParams1 =##= typeParams2
val paramssMatch =
if (paramss1.size != paramss2.size) None
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*)
if (paramss1.size != paramss2.size) notMatched
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => params1 =##= params2 }: _*)
val bindMatch =
if (hasBindAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
else Some(())
val tptMatch = treeMatches(tpt1, tpt2)
else matched
val tptMatch = tpt1 =#= tpt2
val rhsEnv =
the[Env] + (scrutinee.symbol -> pattern.symbol) ++
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
val rhsMatch = treeMatches(rhs1, rhs2) given rhsEnv
val rhsMatch = (rhs1 =#= rhs2) given rhsEnv

foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
bindMatch && typeParmasMatch && paramssMatch && tptMatch && rhsMatch

case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
// TODO match tpt1 with tpt2?
Some(())
matched

case (Match(scru1, cases1), Match(scru2, cases2)) =>
val scrutineeMacth = treeMatches(scru1, scru2)
val scrutineeMacth = scru1 =#= scru2
val casesMatch =
if (cases1.size != cases2.size) None
if (cases1.size != cases2.size) notMatched
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
foldMatchings(scrutineeMacth, casesMatch)
scrutineeMacth && casesMatch

case (Try(body1, cases1, finalizer1), Try(body2, cases2, finalizer2)) =>
val bodyMacth = treeMatches(body1, body2)
val bodyMacth = body1 =#= body2
val casesMatch =
if (cases1.size != cases2.size) None
if (cases1.size != cases2.size) notMatched
else foldMatchings(cases1.zip(cases2).map(caseMatches): _*)
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
foldMatchings(bodyMacth, casesMatch, finalizerMatch)
bodyMacth && casesMatch && finalizerMatch

// Ignore type annotations
case (Annotated(tpt, _), _) => treeMatches(tpt, pattern)
case (_, Annotated(tpt, _)) => treeMatches(scrutinee, tpt)
case (Annotated(tpt, _), _) =>
tpt =#= pattern
case (_, Annotated(tpt, _)) =>
scrutinee =#= tpt

// No Match
case _ =>
Expand All @@ -225,26 +237,24 @@ object Matcher {
|
|
|""".stripMargin)
None
notMatched
}
}

def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Option[Tuple] = {
def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree]) given Env: Matching = {
(scrutinee, pattern) match {
case (Some(x), Some(y)) => treeMatches(x, y)
case (None, None) => Some(())
case _ => None
case (Some(x), Some(y)) => x =#= y
case (None, None) => matched
case _ => notMatched
}
}

def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Option[Tuple] = {
val (caseEnv, patternMatch) = patternMatches(scrutinee.pattern, pattern.pattern)

{
implied for Env = caseEnv
def caseMatches(scrutinee: CaseDef, pattern: CaseDef) given Env: Matching = {
val (caseEnv, patternMatch) = scrutinee.pattern =%= pattern.pattern
withEnv(caseEnv) {
val guardMatch = treeOptMatches(scrutinee.guard, pattern.guard)
val rhsMatch = treeMatches(scrutinee.rhs, pattern.rhs)
foldMatchings(patternMatch, guardMatch, rhsMatch)
val rhsMatch = scrutinee.rhs =#= pattern.rhs
patternMatch && guardMatch && rhsMatch
}
}

Expand All @@ -258,34 +268,34 @@ object Matcher {
* @return The new environment containing the bindings defined in this pattern tuppled with
* `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
*/
def patternMatches(scrutinee: Pattern, pattern: Pattern) given Env: (Env, Option[Tuple]) = (scrutinee, pattern) match {
def (scrutinee: Pattern) =%= (pattern: Pattern) given Env: (Env, Matching) = (scrutinee, pattern) match {
case (Pattern.Value(v1), Pattern.Unapply(TypeApply(Select(patternHole @ Ident("patternHole"), "unapply"), List(tpt)), Nil, Nil))
if patternHole.symbol.owner.fullName == "scala.runtime.quoted.Matcher$" =>
(the[Env], Some(Tuple1(v1.seal)))
(the[Env], matched(v1.seal))

case (Pattern.Value(v1), Pattern.Value(v2)) =>
(the[Env], treeMatches(v1, v2))
(the[Env], v1 =#= v2)

case (Pattern.Bind(name1, body1), Pattern.Bind(name2, body2)) =>
val bindEnv = the[Env] + (scrutinee.symbol -> pattern.symbol)
patternMatches(body1, body2) given bindEnv
(body1 =%= body2) given bindEnv

case (Pattern.Unapply(fun1, implicits1, patterns1), Pattern.Unapply(fun2, implicits2, patterns2)) =>
val funMatch = treeMatches(fun1, fun2)
val funMatch = fun1 =#= fun2
val implicitsMatch =
if (implicits1.size != implicits2.size) None
else foldMatchings(implicits1.zip(implicits2).map(treeMatches): _*)
if (implicits1.size != implicits2.size) notMatched
else foldMatchings(implicits1.zip(implicits2).map((i1, i2) => i1 =#= i2): _*)
val (patEnv, patternsMatch) = foldPatterns(patterns1, patterns2)
(patEnv, foldMatchings(funMatch, implicitsMatch, patternsMatch))
(patEnv, funMatch && implicitsMatch && patternsMatch)

case (Pattern.Alternatives(patterns1), Pattern.Alternatives(patterns2)) =>
foldPatterns(patterns1, patterns2)

case (Pattern.TypeTest(tpt1), Pattern.TypeTest(tpt2)) =>
(the[Env], treeMatches(tpt1, tpt2))
(the[Env], tpt1 =#= tpt2)

case (Pattern.WildcardPattern(), Pattern.WildcardPattern()) =>
(the[Env], Some(()))
(the[Env], matched)

case _ =>
if (debug)
Expand All @@ -305,30 +315,57 @@ object Matcher {
|
|
|""".stripMargin)
(the[Env], None)
(the[Env], notMatched)
}

def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Option[Tuple]) = {
if (patterns1.size != patterns2.size) (the[Env], None)
else patterns1.zip(patterns2).foldLeft((the[Env], Option[Tuple](()))) { (acc, x) =>
val (env, res) = patternMatches(x._1, x._2) given acc._1
(env, foldMatchings(acc._2, res))
def foldPatterns(patterns1: List[Pattern], patterns2: List[Pattern]) given Env: (Env, Matching) = {
if (patterns1.size != patterns2.size) (the[Env], notMatched)
else patterns1.zip(patterns2).foldLeft((the[Env], matched)) { (acc, x) =>
val (env, res) = (x._1 =%= x._2) given acc._1
(env, acc._2 && res)
}
}

implied for Env = Set.empty
treeMatches(scrutineeExpr.unseal, patternExpr.unseal).asInstanceOf[Option[Tup]]
(scrutineeExpr.unseal =#= patternExpr.unseal).asOptionOfTuple.asInstanceOf[Option[Tup]]
}

/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
* Otherwise the result is `Some` of the concatenation of the tupples.
*/
private def foldMatchings(matchings: Option[Tuple]*): Option[Tuple] = {
// TODO improve performance
matchings.foldLeft[Option[Tuple]](Some(())) {
case (Some(acc), Some(holes)) => Some(acc ++ holes)
case (_, _) => None
/** Result of matching a part of an expression */
private opaque type Matching = Option[Tuple]

private object Matching {

def notMatched: Matching = None
val matched: Matching = Some(())
def matched(x: Any): Matching = Some(Tuple1(x))

def (self: Matching) asOptionOfTuple: Option[Tuple] = self

/** Concatenates the contents of two sucessful matchings or return a `notMatched` */
// FIXME inline to avoid alocation of by name closure (see #6395)
/*inline*/ def (self: Matching) && (that: => Matching): Matching = self match {
case Some(x) =>
that match {
case Some(y) => Some(x ++ y)
case _ => None
}
case _ => None
}

/** Is this matching the result of a successful match */
def (self: Matching) isMatch: Boolean = self.isDefined

/** Joins the mattchings into a single matching. If any matching is `None` the result is `None`.
* Otherwise the result is `Some` of the concatenation of the tupples.
*/
def foldMatchings(matchings: Matching*): Matching = {
// TODO improve performance
matchings.foldLeft[Matching](Some(())) {
case (Some(acc), Some(holes)) => Some(acc ++ holes)
case (_, _) => None
}
}

}

}
Loading