Skip to content

Commit a5c1ebb

Browse files
committed
Support simple higher order pattern splices
This fixes the quoted pattern matcher runtime to never return open code. * `case '{ val x: Int = 3 ; $body }` will only match if body does not contain a reference to `x`. Same for other kind of definitions in the pattern. * `case '{ val x: Int = 3; ($f: Int => Int)(x) }` will match any body of type `Int` but will wrap it in a lambda that contains `x` as an argument. * Introduce `Expr.open` that takes a expression of a lambda and explicitly opens it temporarily an provides a way to re-close any subexpression of its body (unsafe if not used properly).
1 parent 85dd873 commit a5c1ebb

File tree

14 files changed

+270
-42
lines changed

14 files changed

+270
-42
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
580580
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given Context): Closure =
581581
tpd.cpy.Closure(original)(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree))
582582

583+
def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
584+
tpd.Lambda(tpe, rhsFn)
585+
583586
type If = tpd.If
584587

585588
def isInstanceOfIf(given ctx: Context): IsInstanceOf[If] = new {
@@ -1141,17 +1144,10 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
11411144

11421145
def Type_isSubType(self: Type)(that: Type)(given Context): Boolean = self <:< that
11431146

1144-
/** Widen from singleton type to its underlying non-singleton
1145-
* base type by applying one or more `underlying` dereferences,
1146-
* Also go from => T to T.
1147-
* Identity for all other types. Example:
1148-
*
1149-
* class Outer { class C ; val x: C }
1150-
* def o: Outer
1151-
* <o.x.type>.widen = o.C
1152-
*/
11531147
def Type_widen(self: Type)(given Context): Type = self.widen
11541148

1149+
def Type_widenTermRefExpr(self: Type)(given Context): Type = self.widenTermRefExpr
1150+
11551151
def Type_dealias(self: Type)(given Context): Type = self.dealias
11561152

11571153
def Type_simplified(self: Type)(given Context): Type = self.simplified
@@ -1398,6 +1394,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
13981394
case _ => None
13991395
}
14001396

1397+
def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
1398+
Types.MethodType(paramNames.map(_.toTermName))(paramInfosExp, resultTypeExp)
1399+
14011400
def MethodType_isErased(self: MethodType): Boolean = self.isErasedMethod
14021401
def MethodType_isImplicit(self: MethodType): Boolean = self.isImplicitMethod
14031402
def MethodType_paramNames(self: MethodType)(given Context): List[String] = self.paramNames.map(_.toString)

library/src-non-bootstrapped/scala/tasty/reflect/TreeUtils.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,50 @@ trait TreeUtils
77
with SymbolOps
88
with TreeOps { self: Reflection =>
99

10+
abstract class TreeAccumulator[X] {
11+
def foldTree(x: X, tree: Tree)(given ctx: Context): X
12+
def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X =
13+
throw new Exception("non-bootstraped-library")
14+
def foldOverTree(x: X, tree: Tree)(given ctx: Context): X =
15+
throw new Exception("non-bootstraped-library")
16+
}
17+
18+
abstract class TreeTraverser extends TreeAccumulator[Unit] {
19+
def traverseTree(tree: Tree)(given ctx: Context): Unit =
20+
throw new Exception("non-bootstraped-library")
21+
def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit =
22+
throw new Exception("non-bootstraped-library")
23+
protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit =
24+
throw new Exception("non-bootstraped-library")
25+
}
26+
27+
abstract class TreeMap { self =>
28+
def transformTree(tree: Tree)(given ctx: Context): Tree =
29+
throw new Exception("non-bootstraped-library")
30+
def transformStatement(tree: Statement)(given ctx: Context): Statement =
31+
throw new Exception("non-bootstraped-library")
32+
def transformTerm(tree: Term)(given ctx: Context): Term =
33+
throw new Exception("non-bootstraped-library")
34+
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree =
35+
throw new Exception("non-bootstraped-library")
36+
def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef =
37+
throw new Exception("non-bootstraped-library")
38+
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef =
39+
throw new Exception("non-bootstraped-library")
40+
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
41+
throw new Exception("non-bootstraped-library")
42+
def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] =
43+
throw new Exception("non-bootstraped-library")
44+
def transformTerms(trees: List[Term])(given ctx: Context): List[Term] =
45+
throw new Exception("non-bootstraped-library")
46+
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
47+
throw new Exception("non-bootstraped-library")
48+
def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] =
49+
throw new Exception("non-bootstraped-library")
50+
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
51+
throw new Exception("non-bootstraped-library")
52+
def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] =
53+
throw new Exception("non-bootstraped-library")
54+
}
55+
1056
}

library/src/scala/internal/quoted/Matcher.scala

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@ private[quoted] object Matcher {
1010
class QuoteMatcher[QCtx <: QuoteContext & Singleton](given val qctx: QCtx) {
1111
// TODO improve performance
1212

13+
// TODO use flag from qctx.tasty.rootContext. Maybe -debug or add -debug-macros
1314
private final val debug = false
1415

1516
import qctx.tasty.{_, given}
1617
import Matching._
1718

18-
private type Env = Set[(Symbol, Symbol)]
19+
private type Env = Map[Symbol, Symbol]
1920

2021
inline private def withEnv[T](env: Env)(body: => (given Env) => T): T = body(given env)
2122

2223
class SymBinding(val sym: Symbol, val fromAbove: Boolean)
2324

2425
def termMatch(scrutineeTerm: Term, patternTerm: Term, hasTypeSplices: Boolean): Option[Tuple] = {
25-
implicit val env: Env = Set.empty
26+
implicit val env: Env = Map.empty
2627
if (hasTypeSplices) {
2728
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
2829
val matchings = scrutineeTerm.underlyingArgument =?= patternTerm.underlyingArgument
@@ -42,7 +43,7 @@ private[quoted] object Matcher {
4243

4344
// TODO factor out common logic with `termMatch`
4445
def typeTreeMatch(scrutineeTypeTree: TypeTree, patternTypeTree: TypeTree, hasTypeSplices: Boolean): Option[Tuple] = {
45-
implicit val env: Env = Set.empty
46+
implicit val env: Env = Map.empty
4647
if (hasTypeSplices) {
4748
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
4849
val matchings = scrutineeTypeTree =?= patternTypeTree
@@ -138,11 +139,28 @@ private[quoted] object Matcher {
138139
matched(scrutinee.seal)
139140

140141
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
141-
case (scrutinee: Term, TypeApply(patternHole, tpt :: Nil))
142+
case (ClosedTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
142143
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
143144
scrutinee.tpe <:< tpt.tpe =>
144145
matched(scrutinee.seal)
145146

147+
// Matches an open term and wraps it into a lambda that provides the free variables
148+
case (scrutinee, pattern @ Apply(Select(TypeApply(Ident("patternHole"), List(Inferred())), "apply"), args0 @ IdentArgs(args))) =>
149+
def bodyFn(lambdaArgs: List[Tree]): Tree = {
150+
val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap
151+
new TreeMap {
152+
override def transformTerm(tree: Term)(given ctx: Context): Term =
153+
tree match
154+
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
155+
case tree => super.transformTerm(tree)
156+
}.transformTree(scrutinee)
157+
}
158+
val names = args.map(_.name)
159+
val argTypes = args0.map(x => x.tpe.widenTermRefExpr)
160+
val resType = pattern.tpe
161+
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn)
162+
matched(res.seal)
163+
146164
//
147165
// Match two equivalent trees
148166
//
@@ -156,7 +174,7 @@ private[quoted] object Matcher {
156174
case (scrutinee, Typed(expr2, _)) =>
157175
scrutinee =?= expr2
158176

159-
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].apply((scrutinee.symbol, pattern.symbol)) =>
177+
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) =>
160178
matched
161179

162180
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
@@ -165,18 +183,24 @@ private[quoted] object Matcher {
165183
case (_: Ref, _: Ref) if scrutinee.symbol == pattern.symbol =>
166184
matched
167185

168-
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
186+
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
169187
fn1 =?= fn2 && args1 =?= args2
170188

171-
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
189+
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
172190
fn1 =?= fn2 && args1 =?= args2
173191

174192
case (Block(stats1, expr1), Block(binding :: stats2, expr2)) if isTypeBinding(binding) =>
175193
qctx.tasty.internal.Context_GADT_addToConstraint(summon[Context])(binding.symbol :: Nil)
176194
matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block(stats1, expr1) =?= Block(stats2, expr2)
177195

178196
case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) =>
179-
withEnv(summon[Env] + (stat1.symbol -> stat2.symbol)) {
197+
val newEnv = (stat1, stat2) match {
198+
case (stat1: Definition, stat2: Definition) =>
199+
summon[Env] + (stat1.symbol -> stat2.symbol)
200+
case _ =>
201+
summon[Env]
202+
}
203+
withEnv(newEnv) {
180204
stat1 =?= stat2 && Block(stats1, expr1) =?= Block(stats2, expr2)
181205
}
182206

@@ -268,7 +292,7 @@ private[quoted] object Matcher {
268292
|
269293
|${pattern.showExtractors}
270294
|
271-
|
295+
|with environment: ${summon[Env]}
272296
|
273297
|
274298
|""".stripMargin)
@@ -277,6 +301,31 @@ private[quoted] object Matcher {
277301
}
278302
end treeOps
279303

304+
private object ClosedTerm {
305+
def unapply(term: Term)(given Context, Env): Option[term.type] =
306+
if freeVars(term).isEmpty then Some(term) else None
307+
308+
def freeVars(tree: Tree)(given qctx: Context, env: Env): Set[Symbol] =
309+
val accumulator = new TreeAccumulator[Set[Symbol]] {
310+
def foldTree(x: Set[Symbol], tree: Tree)(given ctx: Context): Set[Symbol] =
311+
tree match
312+
case tree: Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree)
313+
case _ => foldOverTree(x, tree)
314+
}
315+
accumulator.foldTree(Set.empty, tree)
316+
}
317+
318+
private object IdentArgs {
319+
def unapply(args: List[Term])(given Context): Option[List[Ident]] =
320+
args.foldRight(Option(List.empty[Ident])) {
321+
case (id: Ident, Some(acc)) => Some(id :: acc)
322+
case (Block(List(DefDef("$anonfun", Nil, List(params), Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
323+
if params.zip(args).forall(_.symbol == _.symbol) =>
324+
Some(id :: acc)
325+
case _ => None
326+
}
327+
}
328+
280329
private def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(given Context, Env): Matching = {
281330
(scrutinee, pattern) match {
282331
case (Some(x), Some(y)) => x =?= y
@@ -344,7 +393,7 @@ private[quoted] object Matcher {
344393
|
345394
|${pattern.showExtractors}
346395
|
347-
|
396+
|with environment: ${summon[Env]}
348397
|
349398
|
350399
|""".stripMargin)

library/src/scala/quoted/Expr.scala

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,44 @@ package quoted {
204204
val elems: Seq[Expr[_]] = tup.asInstanceOf[Product].productIterator.toSeq.asInstanceOf[Seq[Expr[_]]]
205205
ofTuple(elems).cast[Tuple.InverseMap[T, Expr]]
206206
}
207-
}
208207

208+
// TODO generalize for any function arity (see Expr.betaReduce)
209+
def open[T1, R, X](f: Expr[T1 => R])(content: (Expr[R], [t] => Expr[t] => Expr[T1] => Expr[t]) => X)(given qctx: QuoteContext): X = {
210+
import qctx.tasty.{given, _}
211+
val (params, bodyExpr) = paramsAndBody(f)
212+
content(bodyExpr, [t] => (e: Expr[t]) => (v: Expr[T1]) => bodyFn[t](e.unseal, params, List(v.unseal)).seal.asInstanceOf[Expr[t]])
213+
}
214+
215+
def open[T1, T2, R, X](f: Expr[(T1, T2) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit): X = {
216+
import qctx.tasty.{given, _}
217+
val (params, bodyExpr) = paramsAndBody(f)
218+
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal)).seal.asInstanceOf[Expr[t]])
219+
}
220+
221+
def open[T1, T2, T3, R, X](f: Expr[(T1, T2) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2], Expr[T3]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit, DummyImplicit): X = {
222+
import qctx.tasty.{given, _}
223+
val (params, bodyExpr) = paramsAndBody(f)
224+
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2], v3: Expr[T3]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal, v3.unseal)).seal.asInstanceOf[Expr[t]])
225+
}
226+
227+
private def paramsAndBody[R](given qctx: QuoteContext)(f: Expr[Any]) = {
228+
import qctx.tasty.{given, _}
229+
val Block(List(DefDef("$anonfun", Nil, List(params), _, Some(body))), Closure(Ident("$anonfun"), None)) = f.unseal.etaExpand
230+
(params, body.seal.asInstanceOf[Expr[R]])
231+
}
232+
233+
private def bodyFn[t](given qctx: QuoteContext)(e: qctx.tasty.Term, params: List[qctx.tasty.ValDef], args: List[qctx.tasty.Term]): qctx.tasty.Term = {
234+
import qctx.tasty.{given, _}
235+
val map = params.map(_.symbol).zip(args).toMap
236+
new TreeMap {
237+
override def transformTerm(tree: Term)(given ctx: Context): Term =
238+
super.transformTerm(tree) match
239+
case tree: Ident => map.getOrElse(tree.symbol, tree)
240+
case tree => tree
241+
}.transformTerm(e)
242+
}
243+
244+
}
209245
}
210246

211247
package internal {

library/src/scala/quoted/matching/Sym.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package matching
88
*/
99
class Sym[T <: AnyKind] private[scala](val name: String, private[Sym] val id: Object) { self =>
1010

11+
override def toString: String = s"Sym($name)@${id.hashCode}"
12+
1113
override def equals(obj: Any): Boolean = obj match {
1214
case obj: Sym[_] => obj.id == id
1315
case _ => false

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ trait CompilerInterface {
443443
def Closure_apply(meth: Term, tpe: Option[Type])(given ctx: Context): Closure
444444
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given ctx: Context): Closure
445445

446+
def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block
447+
446448
/** Tree representing an if/then/else `if (...) ... else ...` in the source code */
447449
type If <: Term
448450

@@ -810,6 +812,11 @@ trait CompilerInterface {
810812
*/
811813
def Type_widen(self: Type)(given ctx: Context): Type
812814

815+
/** Widen from TermRef to its underlying non-termref
816+
* base type, while also skipping Expr types.
817+
*/
818+
def Type_widenTermRefExpr(self: Type)(given ctx: Context): Type
819+
813820
/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
814821
* TypeVars until type is no longer alias type, annotated type, LazyRef,
815822
* or instantiated type variable.
@@ -992,6 +999,8 @@ trait CompilerInterface {
992999

9931000
def isInstanceOfMethodType(given ctx: Context): IsInstanceOf[MethodType]
9941001

1002+
def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType
1003+
9951004
def MethodType_isErased(self: MethodType): Boolean
9961005
def MethodType_isImplicit(self: MethodType): Boolean
9971006
def MethodType_paramNames(self: MethodType)(given ctx: Context): List[String]

library/src/scala/tasty/reflect/TreeOps.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,10 @@ trait TreeOps extends Core {
615615

616616
case _ => None
617617
}
618+
619+
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
620+
internal.Lambda_apply(tpe, rhsFn)
621+
618622
}
619623

620624
given (given Context): IsInstanceOf[If] = internal.isInstanceOfIf

library/src/scala/tasty/reflect/TypeOrBoundsOps.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,22 @@ trait TypeOrBoundsOps extends Core {
1717
/** Is this type a subtype of that type? */
1818
def <:<(that: Type)(given ctx: Context): Boolean = internal.Type_isSubType(self)(that)
1919

20+
/** Widen from singleton type to its underlying non-singleton
21+
* base type by applying one or more `underlying` dereferences,
22+
* Also go from => T to T.
23+
* Identity for all other types. Example:
24+
*
25+
* class Outer { class C ; val x: C }
26+
* def o: Outer
27+
* <o.x.type>.widen = o.C
28+
*/
2029
def widen(given ctx: Context): Type = internal.Type_widen(self)
2130

31+
/** Widen from TermRef to its underlying non-termref
32+
* base type, while also skipping Expr types.
33+
*/
34+
def widenTermRefExpr(given ctx: Context): Type = internal.Type_widenTermRefExpr(self)
35+
2236
/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
2337
* TypeVars until type is no longer alias type, annotated type, LazyRef,
2438
* or instantiated type variable.
@@ -325,6 +339,9 @@ trait TypeOrBoundsOps extends Core {
325339
def unapply(x: MethodType)(given ctx: Context): Option[MethodType] = Some(x)
326340

327341
object MethodType {
342+
def apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
343+
internal.MethodType_apply(paramNames)(paramInfosExp, resultTypeExp)
344+
328345
def unapply(x: MethodType)(given ctx: Context): Option[(List[String], List[Type], Type)] =
329346
Some((x.paramNames, x.paramTypes, x.resType))
330347
}

tests/run-macros/quote-matcher-runtime.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ Pattern: {
332332
val x: scala.Int = 45
333333
x.+(scala.internal.Quoted.patternHole[scala.Int])
334334
}
335-
Result: Some(List(Expr(a)))
335+
Result: None
336336

337337
Scrutinee: {
338338
lazy val a: scala.Int = 45
@@ -622,7 +622,7 @@ Pattern: {
622622
def a: scala.Int = scala.internal.Quoted.patternHole[scala.Int]
623623
a.+(scala.internal.Quoted.patternHole[scala.Int])
624624
}
625-
Result: Some(List(Expr(a), Expr(a)))
625+
Result: None
626626

627627
Scrutinee: {
628628
lazy val a: scala.Int = a

0 commit comments

Comments
 (0)