Skip to content

Commit abafaea

Browse files
committed
Support simple higher order pattern splices
This fixes the quoted pattern matcher runtime to never return open code. * `case '{ val x: Int = 3 ; $body }` will only match if body does not contain a reference to `x`. Same for other kind of definitions in the pattern. * `case '{ val x: Int = 3; ($f: Int => Int)(x) }` will match any body of type `Int` but will wrap it in a lambda that contains `x` as an argument. * Introduce `Expr.open` that takes a expression of a lambda and explicitly opens it temporarily an provides a way to re-close any subexpression of its body (unsafe if not used properly).
1 parent c1e0e04 commit abafaea

File tree

14 files changed

+239
-38
lines changed

14 files changed

+239
-38
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
580580
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given Context): Closure =
581581
tpd.cpy.Closure(original)(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree))
582582

583+
def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
584+
tpd.Lambda(tpe, rhsFn)
585+
583586
type If = tpd.If
584587

585588
def isInstanceOfIf(given ctx: Context): IsInstanceOf[If] = new {
@@ -1141,17 +1144,10 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
11411144

11421145
def `Type_<:<`(self: Type)(that: Type)(given Context): Boolean = self <:< that
11431146

1144-
/** Widen from singleton type to its underlying non-singleton
1145-
* base type by applying one or more `underlying` dereferences,
1146-
* Also go from => T to T.
1147-
* Identity for all other types. Example:
1148-
*
1149-
* class Outer { class C ; val x: C }
1150-
* def o: Outer
1151-
* <o.x.type>.widen = o.C
1152-
*/
11531147
def Type_widen(self: Type)(given Context): Type = self.widen
11541148

1149+
def Type_widenTermRefExpr(self: Type)(given Context): Type = self.widenTermRefExpr
1150+
11551151
def Type_dealias(self: Type)(given Context): Type = self.dealias
11561152

11571153
def Type_simplified(self: Type)(given Context): Type = self.simplified
@@ -1398,6 +1394,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
13981394
case _ => None
13991395
}
14001396

1397+
def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
1398+
Types.MethodType(paramNames.map(_.toTermName))(paramInfosExp, resultTypeExp)
1399+
14011400
def MethodType_isErased(self: MethodType): Boolean = self.isErasedMethod
14021401
def MethodType_isImplicit(self: MethodType): Boolean = self.isImplicitMethod
14031402
def MethodType_paramNames(self: MethodType)(given Context): List[String] = self.paramNames.map(_.toString)

library/src-non-bootstrapped/scala/tasty/reflect/TreeUtils.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,50 @@ trait TreeUtils
77
with SymbolOps
88
with TreeOps { self: Reflection =>
99

10+
abstract class TreeAccumulator[X] {
11+
def foldTree(x: X, tree: Tree)(given ctx: Context): X
12+
def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X =
13+
throw new Exception("non-bootstraped-library")
14+
def foldOverTree(x: X, tree: Tree)(given ctx: Context): X =
15+
throw new Exception("non-bootstraped-library")
16+
}
17+
18+
abstract class TreeTraverser extends TreeAccumulator[Unit] {
19+
def traverseTree(tree: Tree)(given ctx: Context): Unit =
20+
throw new Exception("non-bootstraped-library")
21+
def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit =
22+
throw new Exception("non-bootstraped-library")
23+
protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit =
24+
throw new Exception("non-bootstraped-library")
25+
}
26+
27+
abstract class TreeMap { self =>
28+
def transformTree(tree: Tree)(given ctx: Context): Tree =
29+
throw new Exception("non-bootstraped-library")
30+
def transformStatement(tree: Statement)(given ctx: Context): Statement =
31+
throw new Exception("non-bootstraped-library")
32+
def transformTerm(tree: Term)(given ctx: Context): Term =
33+
throw new Exception("non-bootstraped-library")
34+
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree =
35+
throw new Exception("non-bootstraped-library")
36+
def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef =
37+
throw new Exception("non-bootstraped-library")
38+
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef =
39+
throw new Exception("non-bootstraped-library")
40+
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
41+
throw new Exception("non-bootstraped-library")
42+
def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] =
43+
throw new Exception("non-bootstraped-library")
44+
def transformTerms(trees: List[Term])(given ctx: Context): List[Term] =
45+
throw new Exception("non-bootstraped-library")
46+
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
47+
throw new Exception("non-bootstraped-library")
48+
def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] =
49+
throw new Exception("non-bootstraped-library")
50+
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
51+
throw new Exception("non-bootstraped-library")
52+
def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] =
53+
throw new Exception("non-bootstraped-library")
54+
}
55+
1056
}

library/src/scala/internal/quoted/Matcher.scala

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,20 @@ private[quoted] object Matcher {
1010
class QuoteMatcher[QCtx <: QuoteContext & Singleton](given val qctx: QCtx) {
1111
// TODO improve performance
1212

13+
// TODO use flag from qctx.tasty.rootContext. Maybe -debug or add -debug-macros
1314
private final val debug = false
1415

1516
import qctx.tasty.{_, given}
1617
import Matching._
1718

18-
private type Env = Set[(Symbol, Symbol)]
19+
private type Env = Map[Symbol, Symbol]
1920

2021
inline private def withEnv[T](env: Env)(body: => (given Env) => T): T = body(given env)
2122

2223
class SymBinding(val sym: Symbol, val fromAbove: Boolean)
2324

2425
def termMatch(scrutineeTerm: Term, patternTerm: Term, hasTypeSplices: Boolean): Option[Tuple] = {
25-
implicit val env: Env = Set.empty
26+
implicit val env: Env = Map.empty
2627
if (hasTypeSplices) {
2728
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
2829
val matchings = scrutineeTerm.underlyingArgument =?= patternTerm.underlyingArgument
@@ -42,7 +43,7 @@ private[quoted] object Matcher {
4243

4344
// TODO factor out common logic with `termMatch`
4445
def typeTreeMatch(scrutineeTypeTree: TypeTree, patternTypeTree: TypeTree, hasTypeSplices: Boolean): Option[Tuple] = {
45-
implicit val env: Env = Set.empty
46+
implicit val env: Env = Map.empty
4647
if (hasTypeSplices) {
4748
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
4849
val matchings = scrutineeTypeTree =?= patternTypeTree
@@ -138,11 +139,28 @@ private[quoted] object Matcher {
138139
matched(scrutinee.seal)
139140

140141
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
141-
case (scrutinee: Term, TypeApply(patternHole, tpt :: Nil))
142+
case (ClosedTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
142143
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
143144
scrutinee.tpe <:< tpt.tpe =>
144145
matched(scrutinee.seal)
145146

147+
// Matches an open term and wraps it into a lambda that provides the free variables
148+
case (scrutinee, pattern @ Apply(Select(TypeApply(Ident("patternHole"), List(Inferred())), "apply"), IdentArgs(args))) =>
149+
def bodyFn(lambdaArgs: List[Tree]): Tree = {
150+
val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap
151+
new TreeMap {
152+
override def transformTerm(tree: Term)(given ctx: Context): Term =
153+
super.transformTerm(tree) match
154+
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
155+
case tree => tree
156+
}.transformTree(scrutinee)
157+
}
158+
val names = args.map(_.name)
159+
val argTypes = args.map(_.tpe.widenTermRefExpr)
160+
val resType = pattern.tpe
161+
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn)
162+
matched(res.seal)
163+
146164
//
147165
// Match two equivalent trees
148166
//
@@ -156,7 +174,7 @@ private[quoted] object Matcher {
156174
case (scrutinee, Typed(expr2, _)) =>
157175
scrutinee =?= expr2
158176

159-
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].apply((scrutinee.symbol, pattern.symbol)) =>
177+
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) =>
160178
matched
161179

162180
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
@@ -176,7 +194,13 @@ private[quoted] object Matcher {
176194
matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block(stats1, expr1) =?= Block(stats2, expr2)
177195

178196
case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) =>
179-
withEnv(summon[Env] + (stat1.symbol -> stat2.symbol)) {
197+
val newEnv = (stat1, stat2) match {
198+
case (stat1: Definition, stat2: Definition) =>
199+
summon[Env] + (stat1.symbol -> stat2.symbol)
200+
case _ =>
201+
summon[Env]
202+
}
203+
withEnv(newEnv) {
180204
stat1 =?= stat2 && Block(stats1, expr1) =?= Block(stats2, expr2)
181205
}
182206

@@ -277,6 +301,28 @@ private[quoted] object Matcher {
277301
}
278302
}
279303

304+
private object ClosedTerm {
305+
def unapply(term: Term)(given Context, Env): Option[term.type] =
306+
if freeVars(term).isEmpty then Some(term) else None
307+
308+
def freeVars(tree: Tree)(given qctx: Context, env: Env): Set[Symbol] =
309+
val accumulator = new TreeAccumulator[Set[Symbol]] {
310+
def foldTree(x: Set[Symbol], tree: Tree)(given ctx: Context): Set[Symbol] =
311+
tree match
312+
case tree: Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree)
313+
case _ => foldOverTree(x, tree)
314+
}
315+
accumulator.foldTree(Set.empty, tree)
316+
}
317+
318+
private object IdentArgs {
319+
def unapply(args: List[Term])(given Context): Option[List[Ident]] =
320+
args.foldLeft(Option(List.empty[Ident])) {
321+
case (Some(acc), id: Ident) => Some(id :: acc)
322+
case _ => None
323+
}
324+
}
325+
280326
private def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(given Context, Env): Matching = {
281327
(scrutinee, pattern) match {
282328
case (Some(x), Some(y)) => x =?= y

library/src/scala/quoted/Expr.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,25 @@ package quoted {
173173
val elems: Seq[Expr[_]] = tup.asInstanceOf[Product].productIterator.toSeq.asInstanceOf[Seq[Expr[_]]]
174174
ofTuple(elems).cast[Tuple.InverseMap[T, Expr]]
175175
}
176-
}
177176

177+
// TODO generalize for any function arity (see Expr.betaReduce)
178+
def open[T, U, X](f: Expr[T => U])(content: (Expr[U], [t] => Expr[t] => Expr[T] => Expr[t]) => X)(given qctx: QuoteContext): X = {
179+
import qctx.tasty.{given, _}
180+
f.unseal.etaExpand match
181+
case Block(List(DefDef("$anonfun", Nil, List(List(param)), _, Some(body))), Closure(Ident("$anonfun"), None)) =>
182+
val bodyExpr = body.seal.asInstanceOf[Expr[U]]
183+
def bodyFn[V](e: Expr[V])(v: Expr[T]): Expr[V] = {
184+
new TreeMap {
185+
override def transformTerm(tree: Term)(given ctx: Context): Term =
186+
super.transformTerm(tree) match
187+
case tree: Ident if tree.symbol == param.symbol => v.unseal
188+
case tree => tree
189+
}.transformTerm(e.unseal).seal.asInstanceOf[Expr[V]]
190+
}
191+
content(bodyExpr, [t] => (e: Expr[t]) => (v: Expr[T]) => bodyFn[t](e)(v))
192+
}
193+
194+
}
178195
}
179196

180197
package internal {

library/src/scala/quoted/matching/Sym.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package matching
88
*/
99
class Sym[T <: AnyKind] private[scala](val name: String, private[Sym] val id: Object) { self =>
1010

11+
override def toString: String = s"Sym($name)@${id.hashCode}"
12+
1113
override def equals(obj: Any): Boolean = obj match {
1214
case obj: Sym[_] => obj.id == id
1315
case _ => false

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ trait CompilerInterface {
443443
def Closure_apply(meth: Term, tpe: Option[Type])(given ctx: Context): Closure
444444
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given ctx: Context): Closure
445445

446+
def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block
447+
446448
/** Tree representing an if/then/else `if (...) ... else ...` in the source code */
447449
type If <: Term
448450

@@ -805,6 +807,11 @@ trait CompilerInterface {
805807
*/
806808
def Type_widen(self: Type)(given ctx: Context): Type
807809

810+
/** Widen from TermRef to its underlying non-termref
811+
* base type, while also skipping Expr types.
812+
*/
813+
def Type_widenTermRefExpr(self: Type)(given ctx: Context): Type
814+
808815
/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
809816
* TypeVars until type is no longer alias type, annotated type, LazyRef,
810817
* or instantiated type variable.
@@ -987,6 +994,8 @@ trait CompilerInterface {
987994

988995
def isInstanceOfMethodType(given ctx: Context): IsInstanceOf[MethodType]
989996

997+
def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType
998+
990999
def MethodType_isErased(self: MethodType): Boolean
9911000
def MethodType_isImplicit(self: MethodType): Boolean
9921001
def MethodType_paramNames(self: MethodType)(given ctx: Context): List[String]

library/src/scala/tasty/reflect/TreeOps.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,10 @@ trait TreeOps extends Core {
615615

616616
case _ => None
617617
}
618+
619+
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
620+
internal.Lambda_apply(tpe, rhsFn)
621+
618622
}
619623

620624
given (given Context): IsInstanceOf[If] = internal.isInstanceOfIf

library/src/scala/tasty/reflect/TypeOrBoundsOps.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,23 @@ trait TypeOrBoundsOps extends Core {
1010
given TypeOps: (self: Type) {
1111
def =:=(that: Type)(given ctx: Context): Boolean = internal.`Type_=:=`(self)(that)
1212
def <:<(that: Type)(given ctx: Context): Boolean = internal.`Type_<:<`(self)(that)
13+
14+
/** Widen from singleton type to its underlying non-singleton
15+
* base type by applying one or more `underlying` dereferences,
16+
* Also go from => T to T.
17+
* Identity for all other types. Example:
18+
*
19+
* class Outer { class C ; val x: C }
20+
* def o: Outer
21+
* <o.x.type>.widen = o.C
22+
*/
1323
def widen(given ctx: Context): Type = internal.Type_widen(self)
1424

25+
/** Widen from TermRef to its underlying non-termref
26+
* base type, while also skipping Expr types.
27+
*/
28+
def widenTermRefExpr(given ctx: Context): Type = internal.Type_widenTermRefExpr(self)
29+
1530
/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
1631
* TypeVars until type is no longer alias type, annotated type, LazyRef,
1732
* or instantiated type variable.
@@ -318,6 +333,9 @@ trait TypeOrBoundsOps extends Core {
318333
def unapply(x: MethodType)(given ctx: Context): Option[MethodType] = Some(x)
319334

320335
object MethodType {
336+
def apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
337+
internal.MethodType_apply(paramNames)(paramInfosExp, resultTypeExp)
338+
321339
def unapply(x: MethodType)(given ctx: Context): Option[(List[String], List[Type], Type)] =
322340
Some((x.paramNames, x.paramTypes, x.resType))
323341
}

tests/run-macros/quote-matcher-runtime.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ Pattern: {
332332
val x: scala.Int = 45
333333
x.+(scala.internal.Quoted.patternHole[scala.Int])
334334
}
335-
Result: Some(List(Expr(a)))
335+
Result: None
336336

337337
Scrutinee: {
338338
lazy val a: scala.Int = 45
@@ -622,7 +622,7 @@ Pattern: {
622622
def a: scala.Int = scala.internal.Quoted.patternHole[scala.Int]
623623
a.+(scala.internal.Quoted.patternHole[scala.Int])
624624
}
625-
Result: Some(List(Expr(a), Expr(a)))
625+
Result: None
626626

627627
Scrutinee: {
628628
lazy val a: scala.Int = a

tests/run-macros/quote-matcher-symantics-2/quoted_1.scala

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ object Macros {
1111

1212
private def impl[T: Type](sym: Symantics[T], a: Expr[DSL])(given qctx: QuoteContext): Expr[T] = {
1313

14-
def lift(e: Expr[DSL])(implicit env: Map[Sym[DSL], Expr[T]]): Expr[T] = e match {
14+
def lift(e: Expr[DSL])(implicit env: Map[Int, Expr[T]]): Expr[T] = e match {
1515

1616
case '{ LitDSL(${Const(c)}) } => sym.value(c)
1717

@@ -21,23 +21,32 @@ object Macros {
2121

2222
case '{ ($f: DSL => DSL)($x: DSL) } => sym.app(liftFun(f), lift(x))
2323

24-
case '{ val $x: DSL = $value; $body: DSL } => lift(body)(env + (x -> lift(value)))
24+
case '{ val x: DSL = $value; ($bodyFn: DSL => DSL)(x) } =>
25+
Expr.open(bodyFn) { (body1, close) =>
26+
val (i, nEnvVar) = freshEnvVar()
27+
lift(close(body1)(nEnvVar))(env + (i -> lift(value)))
28+
}
2529

26-
case Sym(b) if env.contains(b) => env(b)
30+
case '{ envVar(${Const(i)}) } => env(i)
31+
// case Sym(b) if env.contains(b) => env(b)
2732

2833
case _ =>
2934
import qctx.tasty.{_, given}
30-
error("Expected explicit DSL", e.unseal.pos)
35+
error("Expected explicit DSL " + e.show, e.unseal.pos)
3136
???
3237
}
3338

34-
def liftFun(e: Expr[DSL => DSL])(implicit env: Map[Sym[DSL], Expr[T]]): Expr[T => T] = e match {
35-
case '{ ($x: DSL) => ($body: DSL) } =>
36-
sym.lam((y: Expr[T]) => lift(body)(env + (x -> y)))
37-
39+
def liftFun(e: Expr[DSL => DSL])(implicit env: Map[Int, Expr[T]]): Expr[T => T] = e match {
40+
case '{ (x: DSL) => ($bodyFn: DSL => DSL)(x) } =>
41+
sym.lam((y: Expr[T]) =>
42+
Expr.open(bodyFn) { (body1, close) =>
43+
val (i, nEnvVar) = freshEnvVar()
44+
lift(close(body1)(nEnvVar))(env + (i -> y))
45+
}
46+
)
3847
case _ =>
3948
import qctx.tasty.{_, given}
40-
error("Expected explicit DSL => DSL", e.unseal.pos)
49+
error("Expected explicit DSL => DSL " + e.show, e.unseal.pos)
4150
???
4251
}
4352

@@ -46,6 +55,13 @@ object Macros {
4655

4756
}
4857

58+
def freshEnvVar()(given QuoteContext): (Int, Expr[DSL]) = {
59+
v += 1
60+
(v, '{envVar(${Expr(v)})})
61+
}
62+
var v = 0
63+
def envVar(i: Int): DSL = ???
64+
4965
//
5066
// DSL in which the user write the code
5167
//

0 commit comments

Comments
 (0)