Skip to content

Commit 8410346

Browse files
committed
Support simple higher order pattern splices
This fixes the quoted pattern matcher runtime to never return open code. * `case '{ val x: Int = 3 ; $body }` will only match if body does not contain a reference to `x`. Same for other kind of definitions in the pattern. * `case '{ val x: Int = 3; ($f: Int => Int)(x) }` will match any body of type `Int` but will wrap it in a lambda that contains `x` as an argument. * Introduce `Expr.open` that takes a expression of a lambda and explicitly opens it temporarily an provides a way to re-close any subexpression of its body (unsafe if not used properly).
1 parent 3cc436a commit 8410346

File tree

14 files changed

+271
-42
lines changed

14 files changed

+271
-42
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
580580
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given Context): Closure =
581581
tpd.cpy.Closure(original)(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree))
582582

583+
def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
584+
tpd.Lambda(tpe, rhsFn)
585+
583586
type If = tpd.If
584587

585588
def isInstanceOfIf(given ctx: Context): IsInstanceOf[If] = new {
@@ -1141,17 +1144,10 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
11411144

11421145
def `Type_<:<`(self: Type)(that: Type)(given Context): Boolean = self <:< that
11431146

1144-
/** Widen from singleton type to its underlying non-singleton
1145-
* base type by applying one or more `underlying` dereferences,
1146-
* Also go from => T to T.
1147-
* Identity for all other types. Example:
1148-
*
1149-
* class Outer { class C ; val x: C }
1150-
* def o: Outer
1151-
* <o.x.type>.widen = o.C
1152-
*/
11531147
def Type_widen(self: Type)(given Context): Type = self.widen
11541148

1149+
def Type_widenTermRefExpr(self: Type)(given Context): Type = self.widenTermRefExpr
1150+
11551151
def Type_dealias(self: Type)(given Context): Type = self.dealias
11561152

11571153
def Type_simplified(self: Type)(given Context): Type = self.simplified
@@ -1398,6 +1394,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
13981394
case _ => None
13991395
}
14001396

1397+
def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
1398+
Types.MethodType(paramNames.map(_.toTermName))(paramInfosExp, resultTypeExp)
1399+
14011400
def MethodType_isErased(self: MethodType): Boolean = self.isErasedMethod
14021401
def MethodType_isImplicit(self: MethodType): Boolean = self.isImplicitMethod
14031402
def MethodType_paramNames(self: MethodType)(given Context): List[String] = self.paramNames.map(_.toString)

library/src-non-bootstrapped/scala/tasty/reflect/TreeUtils.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,50 @@ trait TreeUtils
77
with SymbolOps
88
with TreeOps { self: Reflection =>
99

10+
abstract class TreeAccumulator[X] {
11+
def foldTree(x: X, tree: Tree)(given ctx: Context): X
12+
def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X =
13+
throw new Exception("non-bootstraped-library")
14+
def foldOverTree(x: X, tree: Tree)(given ctx: Context): X =
15+
throw new Exception("non-bootstraped-library")
16+
}
17+
18+
abstract class TreeTraverser extends TreeAccumulator[Unit] {
19+
def traverseTree(tree: Tree)(given ctx: Context): Unit =
20+
throw new Exception("non-bootstraped-library")
21+
def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit =
22+
throw new Exception("non-bootstraped-library")
23+
protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit =
24+
throw new Exception("non-bootstraped-library")
25+
}
26+
27+
abstract class TreeMap { self =>
28+
def transformTree(tree: Tree)(given ctx: Context): Tree =
29+
throw new Exception("non-bootstraped-library")
30+
def transformStatement(tree: Statement)(given ctx: Context): Statement =
31+
throw new Exception("non-bootstraped-library")
32+
def transformTerm(tree: Term)(given ctx: Context): Term =
33+
throw new Exception("non-bootstraped-library")
34+
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree =
35+
throw new Exception("non-bootstraped-library")
36+
def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef =
37+
throw new Exception("non-bootstraped-library")
38+
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef =
39+
throw new Exception("non-bootstraped-library")
40+
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
41+
throw new Exception("non-bootstraped-library")
42+
def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] =
43+
throw new Exception("non-bootstraped-library")
44+
def transformTerms(trees: List[Term])(given ctx: Context): List[Term] =
45+
throw new Exception("non-bootstraped-library")
46+
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
47+
throw new Exception("non-bootstraped-library")
48+
def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] =
49+
throw new Exception("non-bootstraped-library")
50+
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
51+
throw new Exception("non-bootstraped-library")
52+
def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] =
53+
throw new Exception("non-bootstraped-library")
54+
}
55+
1056
}

library/src/scala/internal/quoted/Matcher.scala

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@ private[quoted] object Matcher {
1010
class QuoteMatcher[QCtx <: QuoteContext & Singleton](given val qctx: QCtx) {
1111
// TODO improve performance
1212

13+
// TODO use flag from qctx.tasty.rootContext. Maybe -debug or add -debug-macros
1314
private final val debug = false
1415

1516
import qctx.tasty.{_, given}
1617
import Matching._
1718

18-
private type Env = Set[(Symbol, Symbol)]
19+
private type Env = Map[Symbol, Symbol]
1920

2021
inline private def withEnv[T](env: Env)(body: => (given Env) => T): T = body(given env)
2122

2223
class SymBinding(val sym: Symbol, val fromAbove: Boolean)
2324

2425
def termMatch(scrutineeTerm: Term, patternTerm: Term, hasTypeSplices: Boolean): Option[Tuple] = {
25-
implicit val env: Env = Set.empty
26+
implicit val env: Env = Map.empty
2627
if (hasTypeSplices) {
2728
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
2829
val matchings = scrutineeTerm.underlyingArgument =?= patternTerm.underlyingArgument
@@ -42,7 +43,7 @@ private[quoted] object Matcher {
4243

4344
// TODO factor out common logic with `termMatch`
4445
def typeTreeMatch(scrutineeTypeTree: TypeTree, patternTypeTree: TypeTree, hasTypeSplices: Boolean): Option[Tuple] = {
45-
implicit val env: Env = Set.empty
46+
implicit val env: Env = Map.empty
4647
if (hasTypeSplices) {
4748
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
4849
val matchings = scrutineeTypeTree =?= patternTypeTree
@@ -138,11 +139,28 @@ private[quoted] object Matcher {
138139
matched(scrutinee.seal)
139140

140141
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
141-
case (scrutinee: Term, TypeApply(patternHole, tpt :: Nil))
142+
case (ClosedTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
142143
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
143144
scrutinee.tpe <:< tpt.tpe =>
144145
matched(scrutinee.seal)
145146

147+
// Matches an open term and wraps it into a lambda that provides the free variables
148+
case (scrutinee, pattern @ Apply(Select(TypeApply(Ident("patternHole"), List(Inferred())), "apply"), args0 @ IdentArgs(args))) =>
149+
def bodyFn(lambdaArgs: List[Tree]): Tree = {
150+
val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap
151+
new TreeMap {
152+
override def transformTerm(tree: Term)(given ctx: Context): Term =
153+
tree match
154+
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
155+
case tree => super.transformTerm(tree)
156+
}.transformTree(scrutinee)
157+
}
158+
val names = args.map(_.name)
159+
val argTypes = args0.map(x => x.tpe.widenTermRefExpr)
160+
val resType = pattern.tpe
161+
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn)
162+
matched(res.seal)
163+
146164
//
147165
// Match two equivalent trees
148166
//
@@ -156,7 +174,7 @@ private[quoted] object Matcher {
156174
case (scrutinee, Typed(expr2, _)) =>
157175
scrutinee =?= expr2
158176

159-
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].apply((scrutinee.symbol, pattern.symbol)) =>
177+
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) =>
160178
matched
161179

162180
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
@@ -165,18 +183,24 @@ private[quoted] object Matcher {
165183
case (_: Ref, _: Ref) if scrutinee.symbol == pattern.symbol =>
166184
matched
167185

168-
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
186+
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
169187
fn1 =?= fn2 && args1 =?= args2
170188

171-
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
189+
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
172190
fn1 =?= fn2 && args1 =?= args2
173191

174192
case (Block(stats1, expr1), Block(binding :: stats2, expr2)) if isTypeBinding(binding) =>
175193
qctx.tasty.internal.Context_GADT_addToConstraint(summon[Context])(binding.symbol :: Nil)
176194
matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block(stats1, expr1) =?= Block(stats2, expr2)
177195

178196
case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) =>
179-
withEnv(summon[Env] + (stat1.symbol -> stat2.symbol)) {
197+
val newEnv = (stat1, stat2) match {
198+
case (stat1: Definition, stat2: Definition) =>
199+
summon[Env] + (stat1.symbol -> stat2.symbol)
200+
case _ =>
201+
summon[Env]
202+
}
203+
withEnv(newEnv) {
180204
stat1 =?= stat2 && Block(stats1, expr1) =?= Block(stats2, expr2)
181205
}
182206

@@ -268,7 +292,7 @@ private[quoted] object Matcher {
268292
|
269293
|${pattern.showExtractors}
270294
|
271-
|
295+
|with environment: ${summon[Env]}
272296
|
273297
|
274298
|""".stripMargin)
@@ -277,6 +301,31 @@ private[quoted] object Matcher {
277301
}
278302
}
279303

304+
private object ClosedTerm {
305+
def unapply(term: Term)(given Context, Env): Option[term.type] =
306+
if freeVars(term).isEmpty then Some(term) else None
307+
308+
def freeVars(tree: Tree)(given qctx: Context, env: Env): Set[Symbol] =
309+
val accumulator = new TreeAccumulator[Set[Symbol]] {
310+
def foldTree(x: Set[Symbol], tree: Tree)(given ctx: Context): Set[Symbol] =
311+
tree match
312+
case tree: Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree)
313+
case _ => foldOverTree(x, tree)
314+
}
315+
accumulator.foldTree(Set.empty, tree)
316+
}
317+
318+
private object IdentArgs {
319+
def unapply(args: List[Term])(given Context): Option[List[Ident]] =
320+
args.foldRight(Option(List.empty[Ident])) {
321+
case (id: Ident, Some(acc)) => Some(id :: acc)
322+
case (Block(List(DefDef("$anonfun", Nil, List(params), Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
323+
if params.zip(args).forall(_.symbol == _.symbol) =>
324+
Some(id :: acc)
325+
case _ => None
326+
}
327+
}
328+
280329
private def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(given Context, Env): Matching = {
281330
(scrutinee, pattern) match {
282331
case (Some(x), Some(y)) => x =?= y
@@ -344,7 +393,7 @@ private[quoted] object Matcher {
344393
|
345394
|${pattern.showExtractors}
346395
|
347-
|
396+
|with environment: ${summon[Env]}
348397
|
349398
|
350399
|""".stripMargin)

library/src/scala/quoted/Expr.scala

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,44 @@ package quoted {
184184
val elems: Seq[Expr[_]] = tup.asInstanceOf[Product].productIterator.toSeq.asInstanceOf[Seq[Expr[_]]]
185185
ofTuple(elems).cast[Tuple.InverseMap[T, Expr]]
186186
}
187-
}
188187

188+
// TODO generalize for any function arity (see Expr.betaReduce)
189+
def open[T1, R, X](f: Expr[T1 => R])(content: (Expr[R], [t] => Expr[t] => Expr[T1] => Expr[t]) => X)(given qctx: QuoteContext): X = {
190+
import qctx.tasty.{given, _}
191+
val (params, bodyExpr) = paramsAndBody(f)
192+
content(bodyExpr, [t] => (e: Expr[t]) => (v: Expr[T1]) => bodyFn[t](e.unseal, params, List(v.unseal)).seal.asInstanceOf[Expr[t]])
193+
}
194+
195+
def open[T1, T2, R, X](f: Expr[(T1, T2) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit): X = {
196+
import qctx.tasty.{given, _}
197+
val (params, bodyExpr) = paramsAndBody(f)
198+
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal)).seal.asInstanceOf[Expr[t]])
199+
}
200+
201+
def open[T1, T2, T3, R, X](f: Expr[(T1, T2) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2], Expr[T3]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit, DummyImplicit): X = {
202+
import qctx.tasty.{given, _}
203+
val (params, bodyExpr) = paramsAndBody(f)
204+
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2], v3: Expr[T3]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal, v3.unseal)).seal.asInstanceOf[Expr[t]])
205+
}
206+
207+
private def paramsAndBody[R](given qctx: QuoteContext)(f: Expr[Any]) = {
208+
import qctx.tasty.{given, _}
209+
val Block(List(DefDef("$anonfun", Nil, List(params), _, Some(body))), Closure(Ident("$anonfun"), None)) = f.unseal.etaExpand
210+
(params, body.seal.asInstanceOf[Expr[R]])
211+
}
212+
213+
private def bodyFn[t](given qctx: QuoteContext)(e: qctx.tasty.Term, params: List[qctx.tasty.ValDef], args: List[qctx.tasty.Term]): qctx.tasty.Term = {
214+
import qctx.tasty.{given, _}
215+
val map = params.map(_.symbol).zip(args).toMap
216+
new TreeMap {
217+
override def transformTerm(tree: Term)(given ctx: Context): Term =
218+
super.transformTerm(tree) match
219+
case tree: Ident => map.getOrElse(tree.symbol, tree)
220+
case tree => tree
221+
}.transformTerm(e)
222+
}
223+
224+
}
189225
}
190226

191227
package internal {

library/src/scala/quoted/matching/Sym.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package matching
88
*/
99
class Sym[T <: AnyKind] private[scala](val name: String, private[Sym] val id: Object) { self =>
1010

11+
override def toString: String = s"Sym($name)@${id.hashCode}"
12+
1113
override def equals(obj: Any): Boolean = obj match {
1214
case obj: Sym[_] => obj.id == id
1315
case _ => false

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ trait CompilerInterface {
443443
def Closure_apply(meth: Term, tpe: Option[Type])(given ctx: Context): Closure
444444
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given ctx: Context): Closure
445445

446+
def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block
447+
446448
/** Tree representing an if/then/else `if (...) ... else ...` in the source code */
447449
type If <: Term
448450

@@ -805,6 +807,11 @@ trait CompilerInterface {
805807
*/
806808
def Type_widen(self: Type)(given ctx: Context): Type
807809

810+
/** Widen from TermRef to its underlying non-termref
811+
* base type, while also skipping Expr types.
812+
*/
813+
def Type_widenTermRefExpr(self: Type)(given ctx: Context): Type
814+
808815
/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
809816
* TypeVars until type is no longer alias type, annotated type, LazyRef,
810817
* or instantiated type variable.
@@ -987,6 +994,8 @@ trait CompilerInterface {
987994

988995
def isInstanceOfMethodType(given ctx: Context): IsInstanceOf[MethodType]
989996

997+
def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType
998+
990999
def MethodType_isErased(self: MethodType): Boolean
9911000
def MethodType_isImplicit(self: MethodType): Boolean
9921001
def MethodType_paramNames(self: MethodType)(given ctx: Context): List[String]

library/src/scala/tasty/reflect/TreeOps.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,10 @@ trait TreeOps extends Core {
615615

616616
case _ => None
617617
}
618+
619+
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
620+
internal.Lambda_apply(tpe, rhsFn)
621+
618622
}
619623

620624
given (given Context): IsInstanceOf[If] = internal.isInstanceOfIf

library/src/scala/tasty/reflect/TypeOrBoundsOps.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,23 @@ trait TypeOrBoundsOps extends Core {
1010
given TypeOps: (self: Type) {
1111
def =:=(that: Type)(given ctx: Context): Boolean = internal.`Type_=:=`(self)(that)
1212
def <:<(that: Type)(given ctx: Context): Boolean = internal.`Type_<:<`(self)(that)
13+
14+
/** Widen from singleton type to its underlying non-singleton
15+
* base type by applying one or more `underlying` dereferences,
16+
* Also go from => T to T.
17+
* Identity for all other types. Example:
18+
*
19+
* class Outer { class C ; val x: C }
20+
* def o: Outer
21+
* <o.x.type>.widen = o.C
22+
*/
1323
def widen(given ctx: Context): Type = internal.Type_widen(self)
1424

25+
/** Widen from TermRef to its underlying non-termref
26+
* base type, while also skipping Expr types.
27+
*/
28+
def widenTermRefExpr(given ctx: Context): Type = internal.Type_widenTermRefExpr(self)
29+
1530
/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
1631
* TypeVars until type is no longer alias type, annotated type, LazyRef,
1732
* or instantiated type variable.
@@ -318,6 +333,9 @@ trait TypeOrBoundsOps extends Core {
318333
def unapply(x: MethodType)(given ctx: Context): Option[MethodType] = Some(x)
319334

320335
object MethodType {
336+
def apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
337+
internal.MethodType_apply(paramNames)(paramInfosExp, resultTypeExp)
338+
321339
def unapply(x: MethodType)(given ctx: Context): Option[(List[String], List[Type], Type)] =
322340
Some((x.paramNames, x.paramTypes, x.resType))
323341
}

tests/run-macros/quote-matcher-runtime.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ Pattern: {
332332
val x: scala.Int = 45
333333
x.+(scala.internal.Quoted.patternHole[scala.Int])
334334
}
335-
Result: Some(List(Expr(a)))
335+
Result: None
336336

337337
Scrutinee: {
338338
lazy val a: scala.Int = 45
@@ -622,7 +622,7 @@ Pattern: {
622622
def a: scala.Int = scala.internal.Quoted.patternHole[scala.Int]
623623
a.+(scala.internal.Quoted.patternHole[scala.Int])
624624
}
625-
Result: Some(List(Expr(a), Expr(a)))
625+
Result: None
626626

627627
Scrutinee: {
628628
lazy val a: scala.Int = a

0 commit comments

Comments
 (0)