Skip to content

Commit 3da6139

Browse files
committed
Add support for pattern matching on definition identifiers
Introduces `Binding[T]` which can be used to match a check is an `Expr` is a reference to some other binding defined in scope. ```scala case '{ val $x: Int = ($body: Int) } => // where x: Binding[Int] case '{ (`$x`: Int) => ($body: Int) } => // where x: Binding[Int] case Binding(b) => // where b: Binding[Int] ```
1 parent 449008e commit 3da6139

File tree

17 files changed

+408
-26
lines changed

17 files changed

+408
-26
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ object desugar {
139139
* def x: Int = expr
140140
* def x_=($1: <TypeTree()>): Unit = ()
141141
*/
142-
def valDef(vdef: ValDef)(implicit ctx: Context): Tree = {
142+
def valDef(vdef0: ValDef)(implicit ctx: Context): Tree = {
143+
val vdef = transformQuotedPatternName(vdef0)
143144
val ValDef(name, tpt, rhs) = vdef
144145
val mods = vdef.mods
145146
val setterNeeded =
@@ -164,6 +165,14 @@ object desugar {
164165
else vdef
165166
}
166167

168+
def transformQuotedPatternName(vdef: ValDef)(implicit ctx: Context): ValDef = {
169+
if (ctx.mode.is(Mode.QuotedPattern) && vdef.name.startsWith("$")) {
170+
val name = vdef.name.toString.substring(1).toTermName
171+
val mods = vdef.mods.withAddedAnnotation(New(ref(defn.InternalQuoted_patternBindHoleAnnot.typeRef)).withSpan(vdef.span))
172+
cpy.ValDef(vdef)(name).withMods(mods)
173+
} else vdef
174+
}
175+
167176
def makeImplicitParameters(tpts: List[Tree], contextualFlag: FlagSet = EmptyFlags, forPrimaryConstructor: Boolean = false)(implicit ctx: Context): List[ValDef] =
168177
for (tpt <- tpts) yield {
169178
val paramFlags: FlagSet = if (forPrimaryConstructor) PrivateLocalParamAccessor else Param

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,7 @@ class Definitions {
722722
def InternalQuoted_typeQuote(implicit ctx: Context): Symbol = InternalQuoted_typeQuoteR.symbol
723723
lazy val InternalQuoted_patternHoleR: TermRef = InternalQuotedModule.requiredMethodRef("patternHole")
724724
def InternalQuoted_patternHole(implicit ctx: Context): Symbol = InternalQuoted_patternHoleR.symbol
725+
lazy val InternalQuoted_patternBindHoleAnnot: ClassSymbol = InternalQuotedModule.requiredClass("patternBindHole")
725726

726727
lazy val InternalQuotedMatcherModuleRef: TermRef = ctx.requiredModuleRef("scala.internal.quoted.Matcher")
727728
def InternalQuotedMatcherModule(implicit ctx: Context): Symbol = InternalQuotedMatcherModuleRef.symbol
@@ -741,6 +742,9 @@ class Definitions {
741742
lazy val QuotedTypeModuleRef: TermRef = ctx.requiredModuleRef("scala.quoted.Type")
742743
def QuotedTypeModule(implicit ctx: Context): Symbol = QuotedTypeModuleRef.symbol
743744

745+
lazy val QuotedMatchingBindingType: TypeRef = ctx.requiredClassRef("scala.quoted.matching.Binding")
746+
def QuotedMatchingBindingClass(implicit ctx: Context): ClassSymbol = QuotedMatchingBindingType.symbol.asClass
747+
744748
def Unpickler_unpickleExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleExpr")
745749
def Unpickler_liftedExpr: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.liftedExpr")
746750
def Unpickler_unpickleType: TermSymbol = ctx.requiredMethod("scala.runtime.quoted.Unpickler.unpickleType")

compiler/src/dotty/tools/dotc/tastyreflect/KernelImpl.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,6 +1786,7 @@ class KernelImpl(val rootContext: core.Contexts.Context, val rootPosition: util.
17861786
def Definitions_TupleClass(arity: Int): Symbol = defn.TupleType(arity).classSymbol.asClass
17871787

17881788
def Definitions_InternalQuoted_patternHole: Symbol = defn.InternalQuoted_patternHole
1789+
def Definitions_InternalQuoted_patternBindHoleAnnot: Symbol = defn.InternalQuoted_patternBindHoleAnnot
17891790

17901791
// Types
17911792

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1957,6 +1957,7 @@ class Typer extends Namer
19571957
}
19581958

19591959
def splitQuotePattern(quoted: Tree)(implicit ctx: Context): (Tree, List[Tree]) = {
1960+
val ctx0 = ctx
19601961
object splitter extends tpd.TreeMap {
19611962
val patBuf = new mutable.ListBuffer[Tree]
19621963
override def transform(tree: Tree)(implicit ctx: Context) = tree match {
@@ -1966,6 +1967,13 @@ class Typer extends Namer
19661967
case Splice(pat) =>
19671968
try patternHole(tree)
19681969
finally patBuf += pat
1970+
case vdef: ValDef =>
1971+
if (vdef.symbol.annotations.exists(_.symbol == defn.InternalQuoted_patternBindHoleAnnot)) {
1972+
val tpe = AppliedType(defn.QuotedMatchingBindingType, vdef.tpt.tpe :: Nil)
1973+
val sym = ctx0.newPatternBoundSymbol(vdef.name, tpe, vdef.span)
1974+
patBuf += Bind(sym, untpd.Ident(nme.WILDCARD).withType(tpe)).withSpan(vdef.span)
1975+
}
1976+
super.transform(tree)
19691977
case _ =>
19701978
super.transform(tree)
19711979
}

library/src-bootstrapped/scala/internal/Quoted.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package scala.internal
22

3+
import scala.annotation.Annotation
34
import scala.quoted._
45

56
object Quoted {
@@ -19,4 +20,8 @@ object Quoted {
1920
/** A splice in a quoted pattern is desugared by the compiler into a call to this method */
2021
def patternHole[T]: T =
2122
throw new Error("Internal error: this method call should have been replaced by the compiler")
23+
24+
/** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */
25+
class patternBindHole extends Annotation
26+
2227
}

library/src-bootstrapped/scala/internal/quoted/Matcher.scala

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package scala.internal.quoted
33
import scala.annotation.internal.sharable
44

55
import scala.quoted._
6+
import scala.quoted.matching.Binding
67
import scala.tasty._
78

89
object Matcher {
@@ -51,6 +52,18 @@ object Matcher {
5152
sFlags.is(Lazy) == pFlags.is(Lazy) && sFlags.is(Mutable) == pFlags.is(Mutable)
5253
}
5354

55+
def bindingMatch(sym: Symbol) =
56+
Some(Tuple1(new Binding(sym.name, sym)))
57+
58+
def hasBindingTypeAnnotation(tpt: TypeTree): Boolean = tpt match {
59+
case Annotated(tpt2, Apply(Select(New(TypeIdent("patternBindHole")), "<init>"), Nil)) => true
60+
case Annotated(tpt2, _) => hasBindingTypeAnnotation(tpt2)
61+
case _ => false
62+
}
63+
64+
def hasBindingAnnotation(sym: Symbol) =
65+
sym.annots.exists { case Apply(Select(New(TypeIdent("patternBindHole")),"<init>"),List()) => true; case _ => true }
66+
5467
def treesMatch(scrutinees: List[Tree], patterns: List[Tree]): Option[Tuple] =
5568
if (scrutinees.size != patterns.size) None
5669
else foldMatchings(scrutinees.zip(patterns).map(treeMatches): _*)
@@ -134,24 +147,30 @@ object Matcher {
134147
foldMatchings(treeMatches(tycon1, tycon2), treesMatch(args1, args2))
135148

136149
case (ValDef(_, tpt1, rhs1), ValDef(_, tpt2, rhs2)) if checkValFlags() =>
150+
val bindMatch =
151+
if (hasBindingAnnotation(pattern.symbol) || hasBindingTypeAnnotation(tpt2)) bindingMatch(scrutinee.symbol)
152+
else Some(())
137153
val returnTptMatch = treeMatches(tpt1, tpt2)
138154
val rhsEnv = env + (scrutinee.symbol -> pattern.symbol)
139155
val rhsMatchings = treeOptMatches(rhs1, rhs2)(rhsEnv)
140-
foldMatchings(returnTptMatch, rhsMatchings)
156+
foldMatchings(bindMatch, returnTptMatch, rhsMatchings)
141157

142158
case (DefDef(_, typeParams1, paramss1, tpt1, Some(rhs1)), DefDef(_, typeParams2, paramss2, tpt2, Some(rhs2))) =>
143159
val typeParmasMatch = treesMatch(typeParams1, typeParams2)
144160
val paramssMatch =
145161
if (paramss1.size != paramss2.size) None
146162
else foldMatchings(paramss1.zip(paramss2).map { (params1, params2) => treesMatch(params1, params2) }: _*)
163+
val bindMatch =
164+
if (hasBindingAnnotation(pattern.symbol)) bindingMatch(scrutinee.symbol)
165+
else Some(())
147166
val tptMatch = treeMatches(tpt1, tpt2)
148167
val rhsEnv =
149168
env + (scrutinee.symbol -> pattern.symbol) ++
150-
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
151-
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
169+
typeParams1.zip(typeParams2).map((tparam1, tparam2) => tparam1.symbol -> tparam2.symbol) ++
170+
paramss1.flatten.zip(paramss2.flatten).map((param1, param2) => param1.symbol -> param2.symbol)
152171
val rhsMatch = treeMatches(rhs1, rhs2)(rhsEnv)
153172

154-
foldMatchings(typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
173+
foldMatchings(bindMatch, typeParmasMatch, paramssMatch, tptMatch, rhsMatch)
155174

156175
case (Lambda(_, tpt1), Lambda(_, tpt2)) =>
157176
// TODO match tpt1 with tpt2?
@@ -172,6 +191,10 @@ object Matcher {
172191
val finalizerMatch = treeOptMatches(finalizer1, finalizer2)
173192
foldMatchings(bodyMacth, casesMatch, finalizerMatch)
174193

194+
// Ignore type annotations
195+
case (Annotated(tpt, _), _) => treeMatches(tpt, pattern)
196+
case (_, Annotated(tpt, _)) => treeMatches(scrutinee, tpt)
197+
175198
// No Match
176199
case _ =>
177200
if (debug)

library/src-non-bootstrapped/scala/internal/Quoted.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package scala.internal
22

3+
import scala.annotation.Annotation
34
import scala.quoted._
45

56
object Quoted {
@@ -16,4 +17,11 @@ object Quoted {
1617
def typeQuote[T/* <: AnyKind */]: Type[T] =
1718
throw new Error("Internal error: this method call should have been replaced by the compiler")
1819

20+
/** A splice in a quoted pattern is desugared by the compiler into a call to this method */
21+
def patternHole[T]: T =
22+
throw new Error("Internal error: this method call should have been replaced by the compiler")
23+
24+
/** A splice of a name in a quoted pattern is desugared by wrapping getting this annotation */
25+
class patternBindHole extends Annotation
26+
1927
}
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package scala.quoted
2+
package matching
3+
4+
import scala.tasty.Reflection // TODO do not depend on reflection directly
5+
6+
/** Binding of an Expr[T] used to know if some Expr[T] is a reference to the binding
7+
*
8+
* @param name string name of this binding
9+
* @param id unique id used for equality
10+
*/
11+
class Binding[-T] private[scala](val name: String, private[Binding] val id: Object) { self =>
12+
13+
override def equals(obj: Any): Boolean = obj match {
14+
case obj: Binding[_] => obj.id == id
15+
case _ => false
16+
}
17+
18+
override def hashCode(): Int = id.hashCode()
19+
20+
}
21+
22+
object Binding {
23+
24+
def unapply[T](expr: Expr[T])(implicit reflect: Reflection): Option[Binding[T]] = {
25+
import reflect._
26+
expr.unseal match {
27+
case IsIdent(ref) =>
28+
val sym = ref.symbol
29+
Some(new Binding[T](sym.name, sym))
30+
case _ => None
31+
}
32+
}
33+
34+
}

library/src/scala/tasty/reflect/Kernel.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1446,9 +1446,12 @@ trait Kernel {
14461446

14471447
def Definitions_TupleClass(arity: Int): Symbol
14481448

1449-
/** Symbol of scala.runtime.Quoted.patternHole */
1449+
/** Symbol of scala.internal.Quoted.patternHole */
14501450
def Definitions_InternalQuoted_patternHole: Symbol
14511451

1452+
/** Symbol of scala.internal.Quoted.patternBindHole */
1453+
def Definitions_InternalQuoted_patternBindHoleAnnot: Symbol
1454+
14521455
def Definitions_UnitType: Type
14531456
def Definitions_ByteType: Type
14541457
def Definitions_ShortType: Type

tests/pos/quotedPatterns.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,15 @@ object Test {
1111
case '{g($y, $z)} => '{$y * $z}
1212
case '{ ((a: Int) => 3)($y) } => y
1313
case '{ 1 + ($y: Int)} => y
14+
case '{ val a = 1 + ($y: Int); 3 } => y
1415
// currently gives an unreachable case warning
1516
// but only when used in conjunction with the others.
1617
// I believe this is because implicit arguments are not taken
1718
// into account when checking whether we have already seen an `unapply` before.
19+
case '{ val $y: Int = $z; 1 } => z
20+
case '{ ((`$y`: Int) => 1 + y + ($z: Int))(2) } => z
21+
// TODO support syntax
22+
// case '{ (($y: Int) => 1 + y + ($z: Int))(2) } => z
1823
case _ => '{1}
1924
}
2025
}

0 commit comments

Comments
 (0)