Skip to content

Support simple higher order pattern splices #7591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given Context): Closure =
tpd.cpy.Closure(original)(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree))

def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
tpd.Lambda(tpe, rhsFn)

type If = tpd.If

def isInstanceOfIf(given ctx: Context): IsInstanceOf[If] = new {
Expand Down Expand Up @@ -1141,17 +1144,10 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend

def Type_isSubType(self: Type)(that: Type)(given Context): Boolean = self <:< that

/** Widen from singleton type to its underlying non-singleton
* base type by applying one or more `underlying` dereferences,
* Also go from => T to T.
* Identity for all other types. Example:
*
* class Outer { class C ; val x: C }
* def o: Outer
* <o.x.type>.widen = o.C
*/
def Type_widen(self: Type)(given Context): Type = self.widen

def Type_widenTermRefExpr(self: Type)(given Context): Type = self.widenTermRefExpr

def Type_dealias(self: Type)(given Context): Type = self.dealias

def Type_simplified(self: Type)(given Context): Type = self.simplified
Expand Down Expand Up @@ -1398,6 +1394,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
case _ => None
}

def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
Types.MethodType(paramNames.map(_.toTermName))(paramInfosExp, resultTypeExp)

def MethodType_isErased(self: MethodType): Boolean = self.isErasedMethod
def MethodType_isImplicit(self: MethodType): Boolean = self.isImplicitMethod
def MethodType_paramNames(self: MethodType)(given Context): List[String] = self.paramNames.map(_.toString)
Expand Down
46 changes: 46 additions & 0 deletions library/src-non-bootstrapped/scala/tasty/reflect/TreeUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,50 @@ trait TreeUtils
with SymbolOps
with TreeOps { self: Reflection =>

abstract class TreeAccumulator[X] {
def foldTree(x: X, tree: Tree)(given ctx: Context): X
def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X =
throw new Exception("non-bootstraped-library")
def foldOverTree(x: X, tree: Tree)(given ctx: Context): X =
throw new Exception("non-bootstraped-library")
}

abstract class TreeTraverser extends TreeAccumulator[Unit] {
def traverseTree(tree: Tree)(given ctx: Context): Unit =
throw new Exception("non-bootstraped-library")
def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit =
throw new Exception("non-bootstraped-library")
protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit =
throw new Exception("non-bootstraped-library")
}

abstract class TreeMap { self =>
def transformTree(tree: Tree)(given ctx: Context): Tree =
throw new Exception("non-bootstraped-library")
def transformStatement(tree: Statement)(given ctx: Context): Statement =
throw new Exception("non-bootstraped-library")
def transformTerm(tree: Term)(given ctx: Context): Term =
throw new Exception("non-bootstraped-library")
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree =
throw new Exception("non-bootstraped-library")
def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef =
throw new Exception("non-bootstraped-library")
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef =
throw new Exception("non-bootstraped-library")
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
throw new Exception("non-bootstraped-library")
def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] =
throw new Exception("non-bootstraped-library")
def transformTerms(trees: List[Term])(given ctx: Context): List[Term] =
throw new Exception("non-bootstraped-library")
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
throw new Exception("non-bootstraped-library")
def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] =
throw new Exception("non-bootstraped-library")
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
throw new Exception("non-bootstraped-library")
def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] =
throw new Exception("non-bootstraped-library")
}

}
79 changes: 69 additions & 10 deletions library/src/scala/internal/quoted/Matcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,27 @@ private[quoted] object Matcher {
class QuoteMatcher[QCtx <: QuoteContext & Singleton](given val qctx: QCtx) {
// TODO improve performance

// TODO use flag from qctx.tasty.rootContext. Maybe -debug or add -debug-macros
private final val debug = false

import qctx.tasty.{_, given}
import Matching._

private type Env = Set[(Symbol, Symbol)]
/** A map relating equivalent symbols from the scrutinee and the pattern
* For example in
* ```
* '{val a = 4; a * a} match case '{ val x = 4; x * x }
* ```
* when matching `a * a` with `x * x` the enviroment will contain `Map(a -> x)`.
*/
private type Env = Map[Symbol, Symbol]

inline private def withEnv[T](env: Env)(body: => (given Env) => T): T = body(given env)

class SymBinding(val sym: Symbol, val fromAbove: Boolean)

def termMatch(scrutineeTerm: Term, patternTerm: Term, hasTypeSplices: Boolean): Option[Tuple] = {
implicit val env: Env = Set.empty
implicit val env: Env = Map.empty
if (hasTypeSplices) {
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
val matchings = scrutineeTerm.underlyingArgument =?= patternTerm.underlyingArgument
Expand All @@ -42,7 +50,7 @@ private[quoted] object Matcher {

// TODO factor out common logic with `termMatch`
def typeTreeMatch(scrutineeTypeTree: TypeTree, patternTypeTree: TypeTree, hasTypeSplices: Boolean): Option[Tuple] = {
implicit val env: Env = Set.empty
implicit val env: Env = Map.empty
if (hasTypeSplices) {
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
val matchings = scrutineeTypeTree =?= patternTypeTree
Expand Down Expand Up @@ -138,11 +146,29 @@ private[quoted] object Matcher {
matched(scrutinee.seal)

// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
case (scrutinee: Term, TypeApply(patternHole, tpt :: Nil))
case (ClosedPatternTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
scrutinee.tpe <:< tpt.tpe =>
matched(scrutinee.seal)

// Matches an open term and wraps it into a lambda that provides the free variables
case (scrutinee, pattern @ Apply(Select(TypeApply(patternHole, List(Inferred())), "apply"), args0 @ IdentArgs(args)))
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole =>
def bodyFn(lambdaArgs: List[Tree]): Tree = {
val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap
new TreeMap {
override def transformTerm(tree: Term)(given ctx: Context): Term =
tree match
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transformTerm(tree)
}.transformTree(scrutinee)
}
val names = args.map(_.name)
val argTypes = args0.map(x => x.tpe.widenTermRefExpr)
val resType = pattern.tpe
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn)
matched(res.seal)

//
// Match two equivalent trees
//
Expand All @@ -156,7 +182,7 @@ private[quoted] object Matcher {
case (scrutinee, Typed(expr2, _)) =>
scrutinee =?= expr2

case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].apply((scrutinee.symbol, pattern.symbol)) =>
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) =>
matched

case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
Expand All @@ -165,18 +191,24 @@ private[quoted] object Matcher {
case (_: Ref, _: Ref) if scrutinee.symbol == pattern.symbol =>
matched

case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
fn1 =?= fn2 && args1 =?= args2

case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
fn1 =?= fn2 && args1 =?= args2

case (Block(stats1, expr1), Block(binding :: stats2, expr2)) if isTypeBinding(binding) =>
qctx.tasty.internal.Context_GADT_addToConstraint(summon[Context])(binding.symbol :: Nil)
matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block(stats1, expr1) =?= Block(stats2, expr2)

case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) =>
withEnv(summon[Env] + (stat1.symbol -> stat2.symbol)) {
val newEnv = (stat1, stat2) match {
case (stat1: Definition, stat2: Definition) =>
summon[Env] + (stat1.symbol -> stat2.symbol)
case _ =>
summon[Env]
}
withEnv(newEnv) {
stat1 =?= stat2 && Block(stats1, expr1) =?= Block(stats2, expr2)
}

Expand Down Expand Up @@ -268,7 +300,7 @@ private[quoted] object Matcher {
|
|${pattern.showExtractors}
|
|
|with environment: ${summon[Env]}
|
|
|""".stripMargin)
Expand All @@ -277,6 +309,33 @@ private[quoted] object Matcher {
}
end treeOps

private object ClosedPatternTerm {
/** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */
def unapply(term: Term)(given Context, Env): Option[term.type] =
if freePatternVars(term).isEmpty then Some(term) else None

/** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
def freePatternVars(term: Term)(given qctx: Context, env: Env): Set[Symbol] =
val accumulator = new TreeAccumulator[Set[Symbol]] {
def foldTree(x: Set[Symbol], tree: Tree)(given ctx: Context): Set[Symbol] =
tree match
case tree: Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree)
case _ => foldOverTree(x, tree)
}
accumulator.foldTree(Set.empty, term)
}

private object IdentArgs {
def unapply(args: List[Term])(given Context): Option[List[Ident]] =
args.foldRight(Option(List.empty[Ident])) {
case (id: Ident, Some(acc)) => Some(id :: acc)
case (Block(List(DefDef("$anonfun", Nil, List(params), Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
if params.zip(args).forall(_.symbol == _.symbol) =>
Some(id :: acc)
case _ => None
}
}

private def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(given Context, Env): Matching = {
(scrutinee, pattern) match {
case (Some(x), Some(y)) => x =?= y
Expand Down Expand Up @@ -344,7 +403,7 @@ private[quoted] object Matcher {
|
|${pattern.showExtractors}
|
|
|with environment: ${summon[Env]}
|
|
|""".stripMargin)
Expand Down
38 changes: 37 additions & 1 deletion library/src/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,44 @@ package quoted {
val elems: Seq[Expr[_]] = tup.asInstanceOf[Product].productIterator.toSeq.asInstanceOf[Seq[Expr[_]]]
ofTuple(elems).cast[Tuple.InverseMap[T, Expr]]
}
}

// TODO generalize for any function arity (see Expr.betaReduce)
def open[T1, R, X](f: Expr[T1 => R])(content: (Expr[R], [t] => Expr[t] => Expr[T1] => Expr[t]) => X)(given qctx: QuoteContext): X = {
import qctx.tasty.{given, _}
val (params, bodyExpr) = paramsAndBody(f)
content(bodyExpr, [t] => (e: Expr[t]) => (v: Expr[T1]) => bodyFn[t](e.unseal, params, List(v.unseal)).seal.asInstanceOf[Expr[t]])
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear how open/close is useful. I'd suggest removing them for now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cannot remove it without breaking some tests. I particular those that test the code infrastructure for what @biboudis is working on.


def open[T1, T2, R, X](f: Expr[(T1, T2) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit): X = {
import qctx.tasty.{given, _}
val (params, bodyExpr) = paramsAndBody(f)
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal)).seal.asInstanceOf[Expr[t]])
}

def open[T1, T2, T3, R, X](f: Expr[(T1, T2, T3) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2], Expr[T3]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit, DummyImplicit): X = {
import qctx.tasty.{given, _}
val (params, bodyExpr) = paramsAndBody(f)
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2], v3: Expr[T3]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal, v3.unseal)).seal.asInstanceOf[Expr[t]])
}

private def paramsAndBody[R](given qctx: QuoteContext)(f: Expr[Any]) = {
import qctx.tasty.{given, _}
val Block(List(DefDef("$anonfun", Nil, List(params), _, Some(body))), Closure(Ident("$anonfun"), None)) = f.unseal.etaExpand
(params, body.seal.asInstanceOf[Expr[R]])
}

private def bodyFn[t](given qctx: QuoteContext)(e: qctx.tasty.Term, params: List[qctx.tasty.ValDef], args: List[qctx.tasty.Term]): qctx.tasty.Term = {
import qctx.tasty.{given, _}
val map = params.map(_.symbol).zip(args).toMap
new TreeMap {
override def transformTerm(tree: Term)(given ctx: Context): Term =
super.transformTerm(tree) match
case tree: Ident => map.getOrElse(tree.symbol, tree)
case tree => tree
}.transformTerm(e)
}

}
}

package internal {
Expand Down
2 changes: 2 additions & 0 deletions library/src/scala/quoted/matching/Sym.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package matching
*/
class Sym[T <: AnyKind] private[scala](val name: String, private[Sym] val id: Object) { self =>

override def toString: String = s"Sym($name)@${id.hashCode}"

override def equals(obj: Any): Boolean = obj match {
case obj: Sym[_] => obj.id == id
case _ => false
Expand Down
9 changes: 9 additions & 0 deletions library/src/scala/tasty/reflect/CompilerInterface.scala
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,8 @@ trait CompilerInterface {
def Closure_apply(meth: Term, tpe: Option[Type])(given ctx: Context): Closure
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given ctx: Context): Closure

def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block

/** Tree representing an if/then/else `if (...) ... else ...` in the source code */
type If <: Term

Expand Down Expand Up @@ -810,6 +812,11 @@ trait CompilerInterface {
*/
def Type_widen(self: Type)(given ctx: Context): Type

/** Widen from TermRef to its underlying non-termref
* base type, while also skipping Expr types.
*/
def Type_widenTermRefExpr(self: Type)(given ctx: Context): Type

/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
* TypeVars until type is no longer alias type, annotated type, LazyRef,
* or instantiated type variable.
Expand Down Expand Up @@ -992,6 +999,8 @@ trait CompilerInterface {

def isInstanceOfMethodType(given ctx: Context): IsInstanceOf[MethodType]

def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType

def MethodType_isErased(self: MethodType): Boolean
def MethodType_isImplicit(self: MethodType): Boolean
def MethodType_paramNames(self: MethodType)(given ctx: Context): List[String]
Expand Down
4 changes: 4 additions & 0 deletions library/src/scala/tasty/reflect/TreeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,10 @@ trait TreeOps extends Core {

case _ => None
}

def apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
internal.Lambda_apply(tpe, rhsFn)

}

given (given Context): IsInstanceOf[If] = internal.isInstanceOfIf
Expand Down
17 changes: 17 additions & 0 deletions library/src/scala/tasty/reflect/TypeOrBoundsOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,22 @@ trait TypeOrBoundsOps extends Core {
/** Is this type a subtype of that type? */
def <:<(that: Type)(given ctx: Context): Boolean = internal.Type_isSubType(self)(that)

/** Widen from singleton type to its underlying non-singleton
* base type by applying one or more `underlying` dereferences,
* Also go from => T to T.
* Identity for all other types. Example:
*
* class Outer { class C ; val x: C }
* def o: Outer
* <o.x.type>.widen = o.C
*/
def widen(given ctx: Context): Type = internal.Type_widen(self)

/** Widen from TermRef to its underlying non-termref
* base type, while also skipping `=>T` types.
*/
def widenTermRefExpr(given ctx: Context): Type = internal.Type_widenTermRefExpr(self)

/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
* TypeVars until type is no longer alias type, annotated type, LazyRef,
* or instantiated type variable.
Expand Down Expand Up @@ -325,6 +339,9 @@ trait TypeOrBoundsOps extends Core {
def unapply(x: MethodType)(given ctx: Context): Option[MethodType] = Some(x)

object MethodType {
def apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
internal.MethodType_apply(paramNames)(paramInfosExp, resultTypeExp)

def unapply(x: MethodType)(given ctx: Context): Option[(List[String], List[Type], Type)] =
Some((x.paramNames, x.paramTypes, x.resType))
}
Expand Down
4 changes: 2 additions & 2 deletions tests/run-macros/quote-matcher-runtime.check
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ Pattern: {
val x: scala.Int = 45
x.+(scala.internal.Quoted.patternHole[scala.Int])
}
Result: Some(List(Expr(a)))
Result: None

Scrutinee: {
lazy val a: scala.Int = 45
Expand Down Expand Up @@ -622,7 +622,7 @@ Pattern: {
def a: scala.Int = scala.internal.Quoted.patternHole[scala.Int]
a.+(scala.internal.Quoted.patternHole[scala.Int])
}
Result: Some(List(Expr(a), Expr(a)))
Result: None

Scrutinee: {
lazy val a: scala.Int = a
Expand Down
Loading