Skip to content

Add quote pattern bindings in patterns #6274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,8 @@ object desugar {
* def x: Int = expr
* def x_=($1: <TypeTree()>): Unit = ()
*/
def valDef(vdef: ValDef)(implicit ctx: Context): Tree = {
val ValDef(name, tpt, rhs) = vdef
def valDef(vdef0: ValDef)(implicit ctx: Context): Tree = {
val vdef @ ValDef(name, tpt, rhs) = transformQuotedPatternName(vdef0)
val mods = vdef.mods
val setterNeeded =
(mods is Mutable) && ctx.owner.isClass && (!(mods is PrivateLocal) || (ctx.owner is Trait))
Expand Down Expand Up @@ -197,8 +197,8 @@ object desugar {
* ==>
* inline def f(x: Boolean): Any = (if (x) 1 else ""): Any
*/
private def defDef(meth: DefDef, isPrimaryConstructor: Boolean = false)(implicit ctx: Context): Tree = {
val DefDef(_, tparams, vparamss, tpt, rhs) = meth
private def defDef(meth0: DefDef, isPrimaryConstructor: Boolean = false)(implicit ctx: Context): Tree = {
val meth @ DefDef(_, tparams, vparamss, tpt, rhs) = transformQuotedPatternName(meth0)
val methName = normalizeName(meth, tpt).asTermName
val mods = meth.mods
val epbuf = new ListBuffer[ValDef]
Expand Down Expand Up @@ -272,6 +272,32 @@ object desugar {
}
}

/** Transforms a definition with a name starting with a `$` in a quoted pattern into a `quoted.binding.Binding` splice.
*
* The desugaring consists in renaming the the definition and adding the `@patternBindHole` annotation. This
* annotation is used during typing to perform the full transformation.
*
* A definition
* ```scala
* case '{ def $a(...) = ... a() ...; ... a() ... }
* ```
* into
* ```scala
* case '{ @patternBindHole def a(...) = ... a() ...; ... a() ... }
* ```
*/
def transformQuotedPatternName(tree: ValOrDefDef)(implicit ctx: Context): ValOrDefDef = {
if (ctx.mode.is(Mode.QuotedPattern) && !tree.isBackquoted && tree.name != nme.ANON_FUN && tree.name.startsWith("$")) {
val name = tree.name.toString.substring(1).toTermName
val newTree: ValOrDefDef = tree match {
case tree: ValDef => cpy.ValDef(tree)(name)
case tree: DefDef => cpy.DefDef(tree)(name)
}
val mods = tree.mods.withAddedAnnotation(New(ref(defn.InternalQuoted_patternBindHoleAnnot.typeRef)).withSpan(tree.span))
newTree.withMods(mods)
} else tree
}

// Add all evidence parameters in `params` as implicit parameters to `meth` */
private def addEvidenceParams(meth: DefDef, params: List[ValDef])(implicit ctx: Context): DefDef =
params match {
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,14 @@ object Trees {

/** A ValDef or DefDef tree */
abstract class ValOrDefDef[-T >: Untyped](implicit @constructorOnly src: SourceFile) extends MemberDef[T] with WithLazyField[Tree[T]] {
type ThisTree[-T >: Untyped] <: ValOrDefDef[T]
def name: TermName
def tpt: Tree[T]
def unforcedRhs: LazyTree = unforced
def rhs(implicit ctx: Context): Tree[T] = forceIfLazy

/** Is this a `BackquotedValDef` or `BackquotedDefDef` ? */
def isBackquoted: Boolean = false
}

// ----------- Tree case classes ------------------------------------
Expand Down Expand Up @@ -706,6 +710,12 @@ object Trees {
protected def force(x: AnyRef): Unit = preRhs = x
}

class BackquotedValDef[-T >: Untyped] private[ast] (name: TermName, tpt: Tree[T], preRhs: LazyTree)(implicit @constructorOnly src: SourceFile)
extends ValDef[T](name, tpt, preRhs) {
override def isBackquoted: Boolean = true
override def productPrefix: String = "BackquotedValDef"
}

/** mods def name[tparams](vparams_1)...(vparams_n): tpt = rhs */
case class DefDef[-T >: Untyped] private[ast] (name: TermName, tparams: List[TypeDef[T]],
vparamss: List[List[ValDef[T]]], tpt: Tree[T], private var preRhs: LazyTree)(implicit @constructorOnly src: SourceFile)
Expand All @@ -716,6 +726,13 @@ object Trees {
protected def force(x: AnyRef): Unit = preRhs = x
}

class BackquotedDefDef[-T >: Untyped] private[ast] (name: TermName, tparams: List[TypeDef[T]],
vparamss: List[List[ValDef[T]]], tpt: Tree[T], preRhs: LazyTree)(implicit @constructorOnly src: SourceFile)
extends DefDef[T](name, tparams, vparamss, tpt, preRhs) {
override def isBackquoted: Boolean = true
override def productPrefix: String = "BackquotedDefDef"
}

/** mods class name template or
* mods trait name template or
* mods type name = rhs or
Expand Down Expand Up @@ -932,7 +949,9 @@ object Trees {
type Alternative = Trees.Alternative[T]
type UnApply = Trees.UnApply[T]
type ValDef = Trees.ValDef[T]
type BackquotedValDef = Trees.BackquotedValDef[T]
type DefDef = Trees.DefDef[T]
type BackquotedDefDef = Trees.BackquotedDefDef[T]
type TypeDef = Trees.TypeDef[T]
type Template = Trees.Template[T]
type Import = Trees.Import[T]
Expand Down Expand Up @@ -1125,10 +1144,16 @@ object Trees {
case _ => finalize(tree, untpd.UnApply(fun, implicits, patterns)(sourceFile(tree)))
}
def ValDef(tree: Tree)(name: TermName, tpt: Tree, rhs: LazyTree)(implicit ctx: Context): ValDef = tree match {
case tree: BackquotedValDef =>
if ((name == tree.name) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs)) tree
else finalize(tree, untpd.BackquotedValDef(name, tpt, rhs)(sourceFile(tree)))
case tree: ValDef if (name == tree.name) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs) => tree
case _ => finalize(tree, untpd.ValDef(name, tpt, rhs)(sourceFile(tree)))
}
def DefDef(tree: Tree)(name: TermName, tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree, rhs: LazyTree)(implicit ctx: Context): DefDef = tree match {
case tree: BackquotedDefDef =>
if ((name == tree.name) && (tparams eq tree.tparams) && (vparamss eq tree.vparamss) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs)) tree
else finalize(tree, untpd.BackquotedDefDef(name, tparams, vparamss, tpt, rhs)(sourceFile(tree)))
case tree: DefDef if (name == tree.name) && (tparams eq tree.tparams) && (vparamss eq tree.vparamss) && (tpt eq tree.tpt) && (rhs eq tree.unforcedRhs) => tree
case _ => finalize(tree, untpd.DefDef(name, tparams, vparamss, tpt, rhs)(sourceFile(tree)))
}
Expand Down
10 changes: 8 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def Alternative(trees: List[Tree])(implicit src: SourceFile): Alternative = new Alternative(trees)
def UnApply(fun: Tree, implicits: List[Tree], patterns: List[Tree])(implicit src: SourceFile): UnApply = new UnApply(fun, implicits, patterns)
def ValDef(name: TermName, tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): ValDef = new ValDef(name, tpt, rhs)
def BackquotedValDef(name: TermName, tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): ValDef = new BackquotedValDef(name, tpt, rhs)
def DefDef(name: TermName, tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): DefDef = new DefDef(name, tparams, vparamss, tpt, rhs)
def BackquotedDefDef(name: TermName, tparams: List[TypeDef], vparamss: List[List[ValDef]], tpt: Tree, rhs: LazyTree)(implicit src: SourceFile): DefDef = new BackquotedDefDef(name, tparams, vparamss, tpt, rhs)
def TypeDef(name: TypeName, rhs: Tree)(implicit src: SourceFile): TypeDef = new TypeDef(name, rhs)
def Template(constr: DefDef, parents: List[Tree], derived: List[Tree], self: ValDef, body: LazyTreeList)(implicit src: SourceFile): Template =
if (derived.isEmpty) new Template(constr, parents, self, body)
Expand Down Expand Up @@ -406,8 +408,12 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def makeAndType(left: Tree, right: Tree)(implicit ctx: Context): AppliedTypeTree =
AppliedTypeTree(ref(defn.andType.typeRef), left :: right :: Nil)

def makeParameter(pname: TermName, tpe: Tree, mods: Modifiers = EmptyModifiers)(implicit ctx: Context): ValDef =
ValDef(pname, tpe, EmptyTree).withMods(mods | Param)
def makeParameter(pname: TermName, tpe: Tree, mods: Modifiers = EmptyModifiers, isBackquoted: Boolean = false)(implicit ctx: Context): ValDef = {
val vdef =
if (isBackquoted) BackquotedValDef(pname, tpe, EmptyTree)
else ValDef(pname, tpe, EmptyTree)
vdef.withMods(mods | Param)
}

def makeSyntheticParameter(n: Int = 1, tpt: Tree = null, flags: FlagSet = EmptyFlags)(implicit ctx: Context): ValDef =
ValDef(nme.syntheticParamName(n), if (tpt == null) TypeTree() else tpt, EmptyTree)
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,11 @@ class Definitions {
def InternalQuoted_typeQuote(implicit ctx: Context): Symbol = InternalQuoted_typeQuoteR.symbol
lazy val InternalQuoted_patternHoleR: TermRef = InternalQuotedModule.requiredMethodRef("patternHole")
def InternalQuoted_patternHole(implicit ctx: Context): Symbol = InternalQuoted_patternHoleR.symbol
lazy val InternalQuoted_patternBindHoleAnnot: ClassSymbol = InternalQuotedModule.requiredClass("patternBindHole")
lazy val InternalQuoted_patternMatchBindHoleModuleR: TermRef = InternalQuotedModule.requiredValueRef("patternMatchBindHole".toTermName)
def InternalQuoted_patternMatchBindHoleModule: Symbol = InternalQuoted_patternMatchBindHoleModuleR.symbol
lazy val InternalQuoted_patternMatchBindHole_unapplyR: TermRef = InternalQuoted_patternMatchBindHoleModule.requiredMethodRef("unapply")
def InternalQuoted_patternMatchBindHole_unapply(implicit ctx: Context): Symbol = InternalQuoted_patternMatchBindHole_unapplyR.symbol

lazy val InternalQuotedMatcherModuleRef: TermRef = ctx.requiredModuleRef("scala.internal.quoted.Matcher")
def InternalQuotedMatcherModule(implicit ctx: Context): Symbol = InternalQuotedMatcherModuleRef.symbol
Expand All @@ -741,6 +746,9 @@ class Definitions {
lazy val QuotedTypeModuleRef: TermRef = ctx.requiredModuleRef("scala.quoted.Type")
def QuotedTypeModule(implicit ctx: Context): Symbol = QuotedTypeModuleRef.symbol

lazy val QuotedMatchingBindingType: TypeRef = ctx.requiredClassRef("scala.quoted.matching.Bind")
def QuotedMatchingBindingClass(implicit ctx: Context): ClassSymbol = QuotedMatchingBindingType.symbol.asClass

def Unpickler_unpickleExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleExpr")
def Unpickler_liftedExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.liftedExpr")
def Unpickler_unpickleType: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleType")
Expand Down
22 changes: 14 additions & 8 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -406,10 +406,12 @@ object Parsers {
/** Convert tree to formal parameter
*/
def convertToParam(tree: Tree, expected: String = "formal parameter"): ValDef = tree match {
case Ident(name) =>
makeParameter(name.asTermName, TypeTree()).withSpan(tree.span)
case Typed(Ident(name), tpt) =>
makeParameter(name.asTermName, tpt).withSpan(tree.span)
case id @ Ident(name) =>
makeParameter(name.asTermName, TypeTree(), isBackquoted = id.isBackquoted).withSpan(tree.span)
case Typed(id @ Ident(name), tpt) =>
makeParameter(name.asTermName, tpt, isBackquoted = id.isBackquoted).withSpan(tree.span)
case Typed(Splice(Ident(name)), tpt) =>
makeParameter(("$" + name).toTermName, tpt).withSpan(tree.span)
case _ =>
syntaxError(s"not a legal $expected", tree.span)
makeParameter(nme.ERROR, tree)
Expand Down Expand Up @@ -2370,7 +2372,9 @@ object Parsers {
}
} else EmptyTree
lhs match {
case (id @ Ident(name: TermName)) :: Nil => {
case (id: BackquotedIdent) :: Nil if id.name.isTermName =>
finalizeDef(BackquotedValDef(id.name.asTermName, tpt, rhs), mods, start)
case Ident(name: TermName) :: Nil => {
finalizeDef(ValDef(name, tpt, rhs), mods, start)
} case _ =>
PatDef(mods, lhs, tpt, rhs)
Expand Down Expand Up @@ -2414,10 +2418,10 @@ object Parsers {
else
(Nil, Method)
val mods1 = addFlag(mods, flags)
val name = ident()
val ident = termIdent()
val tparams = typeParamClauseOpt(ParamOwner.Def)
val vparamss = paramClauses() match {
case rparams :: rparamss if leadingParamss.nonEmpty && !isLeftAssoc(name) =>
case rparams :: rparamss if leadingParamss.nonEmpty && !isLeftAssoc(ident.name) =>
rparams :: leadingParamss ::: rparamss
case rparamss =>
leadingParamss ::: rparamss
Expand Down Expand Up @@ -2447,7 +2451,9 @@ object Parsers {
accept(EQUALS)
expr()
}
finalizeDef(DefDef(name, tparams, vparamss, tpt, rhs), mods1, start)

if (ident.isBackquoted) finalizeDef(BackquotedDefDef(ident.name.asTermName, tparams, vparamss, tpt, rhs), mods1, start)
else finalizeDef(DefDef(ident.name.asTermName, tparams, vparamss, tpt, rhs), mods1, start)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,7 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
def Definitions_TupleClass(arity: Int): Symbol = defn.TupleType(arity).classSymbol.asClass

def Definitions_InternalQuoted_patternHole: Symbol = defn.InternalQuoted_patternHole
def Definitions_InternalQuoted_patternBindHoleAnnot: Symbol = defn.InternalQuoted_patternBindHoleAnnot

// Types

Expand Down
12 changes: 11 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,17 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
tree
}

def typedUnApply(tree: untpd.Apply, selType: Type)(implicit ctx: Context): Tree = track("typedUnApply") {
def typedUnApply(tree0: untpd.Apply, selType: Type)(implicit ctx: Context): Tree = track("typedUnApply") {
val tree =
if (ctx.mode.is(Mode.QuotedPattern)) { // TODO move to desugar
val Apply(qual0, args0) = tree0
val args1 = args0 map {
case arg: untpd.Ident if arg.name.startsWith("$") =>
untpd.Apply(untpd.ref(defn.InternalQuoted_patternMatchBindHoleModuleR), untpd.Ident(arg.name.toString.substring(1).toTermName) :: Nil)
case arg => arg
}
untpd.cpy.Apply(tree0)(qual0, args1)
} else tree0
val Apply(qual, args) = tree

def notAnExtractor(tree: Tree) =
Expand Down
31 changes: 31 additions & 0 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,10 @@ class Typer extends Namer
}

def typedBind(tree: untpd.Bind, pt: Type)(implicit ctx: Context): Tree = track("typedBind") {
if (ctx.mode.is(Mode.QuotedPattern) && tree.name.startsWith("$")) {
val bind1 = untpd.cpy.Bind(tree)(tree.name.toString.substring(1).toTermName, tree.body)
return typed(untpd.Apply(untpd.ref(defn.InternalQuoted_patternMatchBindHoleModuleR), bind1 :: Nil).withSpan(tree.span), pt)
}
val pt1 = fullyDefinedType(pt, "pattern variable", tree.span)
val body1 = typed(tree.body, pt1)
body1 match {
Expand Down Expand Up @@ -1959,6 +1963,14 @@ class Typer extends Namer
}

def splitQuotePattern(quoted: Tree)(implicit ctx: Context): (Tree, List[Tree]) = {
val ctx0 = ctx

def bindExpr(name: Name, tpe: Type, span: Span): Tree = {
val exprTpe = AppliedType(defn.QuotedMatchingBindingType, tpe :: Nil)
val sym = ctx0.newPatternBoundSymbol(name, exprTpe, span)
Bind(sym, untpd.Ident(nme.WILDCARD).withType(exprTpe)).withSpan(span)
}

object splitter extends tpd.TreeMap {
val patBuf = new mutable.ListBuffer[Tree]
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
Expand All @@ -1973,6 +1985,25 @@ class Typer extends Namer
val pat1 = if (patType eq patType1) pat else pat.withType(patType1)
patBuf += pat1
}
case ddef: ValOrDefDef =>
if (ddef.symbol.annotations.exists(_.symbol == defn.InternalQuoted_patternBindHoleAnnot)) {
val tpe = ddef.symbol.info match {
case t: ExprType => t.resType
case t: MethodType => t.toFunctionType()
case t: PolyType =>
HKTypeLambda(t.paramNames)(
x => t.paramInfos.mapConserve(_.subst(t, x).asInstanceOf[TypeBounds]),
x => t.resType.subst(t, x).toFunctionType())
case t => t
}
val exprTpe = AppliedType(defn.QuotedMatchingBindingType, tpe :: Nil)
val sym = ctx0.newPatternBoundSymbol(ddef.name, exprTpe, ddef.span)
patBuf += Bind(sym, untpd.Ident(nme.WILDCARD).withType(exprTpe)).withSpan(ddef.span)
}
super.transform(tree)
case tree @ UnApply(_, _, (bind: Bind) :: Nil) if tree.fun.symbol == defn.InternalQuoted_patternMatchBindHole_unapply =>
patBuf += bindExpr(bind.name, bind.tpe.widen, bind.span)
cpy.UnApply(tree)(patterns = untpd.Ident(nme.WILDCARD).withType(bind.tpe.widen) :: Nil)
case _ =>
super.transform(tree)
}
Expand Down
11 changes: 11 additions & 0 deletions library/src-bootstrapped/scala/internal/Quoted.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package scala.internal

import scala.annotation.Annotation
import scala.quoted._

object Quoted {
Expand All @@ -19,4 +20,14 @@ object Quoted {
/** A splice in a quoted pattern is desugared by the compiler into a call to this method */
def patternHole[T]: T =
throw new Error("Internal error: this method call should have been replaced by the compiler")

/** A splice of a name in a quoted pattern is desugared by adding this annotation */
class patternBindHole extends Annotation

/** A splice of a name in a quoted pattern in pattern position is desugared by wrapping it in this extractor */
object patternMatchBindHole {
def unapply(x: Any): Some[x.type] =
throw new Error("Internal error: this method call should have been replaced by the compiler")
}

}
Loading