Skip to content

Commit 01dc942

Browse files
Merge pull request #7591 from dotty-staging/quoted-pattern-open-holes
Support simple higher order pattern splices
2 parents 19e7a28 + 0db41ae commit 01dc942

File tree

17 files changed

+318
-42
lines changed

17 files changed

+318
-42
lines changed

compiler/src/dotty/tools/dotc/tastyreflect/ReflectionCompilerInterface.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
580580
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given Context): Closure =
581581
tpd.cpy.Closure(original)(Nil, meth, tpe.map(tpd.TypeTree(_)).getOrElse(tpd.EmptyTree))
582582

583+
def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
584+
tpd.Lambda(tpe, rhsFn)
585+
583586
type If = tpd.If
584587

585588
def isInstanceOfIf(given ctx: Context): IsInstanceOf[If] = new {
@@ -1141,17 +1144,10 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
11411144

11421145
def Type_isSubType(self: Type)(that: Type)(given Context): Boolean = self <:< that
11431146

1144-
/** Widen from singleton type to its underlying non-singleton
1145-
* base type by applying one or more `underlying` dereferences,
1146-
* Also go from => T to T.
1147-
* Identity for all other types. Example:
1148-
*
1149-
* class Outer { class C ; val x: C }
1150-
* def o: Outer
1151-
* <o.x.type>.widen = o.C
1152-
*/
11531147
def Type_widen(self: Type)(given Context): Type = self.widen
11541148

1149+
def Type_widenTermRefExpr(self: Type)(given Context): Type = self.widenTermRefExpr
1150+
11551151
def Type_dealias(self: Type)(given Context): Type = self.dealias
11561152

11571153
def Type_simplified(self: Type)(given Context): Type = self.simplified
@@ -1398,6 +1394,9 @@ class ReflectionCompilerInterface(val rootContext: core.Contexts.Context) extend
13981394
case _ => None
13991395
}
14001396

1397+
def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
1398+
Types.MethodType(paramNames.map(_.toTermName))(paramInfosExp, resultTypeExp)
1399+
14011400
def MethodType_isErased(self: MethodType): Boolean = self.isErasedMethod
14021401
def MethodType_isImplicit(self: MethodType): Boolean = self.isImplicitMethod
14031402
def MethodType_paramNames(self: MethodType)(given Context): List[String] = self.paramNames.map(_.toString)

library/src-non-bootstrapped/scala/tasty/reflect/TreeUtils.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,50 @@ trait TreeUtils
77
with SymbolOps
88
with TreeOps { self: Reflection =>
99

10+
abstract class TreeAccumulator[X] {
11+
def foldTree(x: X, tree: Tree)(given ctx: Context): X
12+
def foldTrees(x: X, trees: Iterable[Tree])(given ctx: Context): X =
13+
throw new Exception("non-bootstraped-library")
14+
def foldOverTree(x: X, tree: Tree)(given ctx: Context): X =
15+
throw new Exception("non-bootstraped-library")
16+
}
17+
18+
abstract class TreeTraverser extends TreeAccumulator[Unit] {
19+
def traverseTree(tree: Tree)(given ctx: Context): Unit =
20+
throw new Exception("non-bootstraped-library")
21+
def foldTree(x: Unit, tree: Tree)(given ctx: Context): Unit =
22+
throw new Exception("non-bootstraped-library")
23+
protected def traverseTreeChildren(tree: Tree)(given ctx: Context): Unit =
24+
throw new Exception("non-bootstraped-library")
25+
}
26+
27+
abstract class TreeMap { self =>
28+
def transformTree(tree: Tree)(given ctx: Context): Tree =
29+
throw new Exception("non-bootstraped-library")
30+
def transformStatement(tree: Statement)(given ctx: Context): Statement =
31+
throw new Exception("non-bootstraped-library")
32+
def transformTerm(tree: Term)(given ctx: Context): Term =
33+
throw new Exception("non-bootstraped-library")
34+
def transformTypeTree(tree: TypeTree)(given ctx: Context): TypeTree =
35+
throw new Exception("non-bootstraped-library")
36+
def transformCaseDef(tree: CaseDef)(given ctx: Context): CaseDef =
37+
throw new Exception("non-bootstraped-library")
38+
def transformTypeCaseDef(tree: TypeCaseDef)(given ctx: Context): TypeCaseDef =
39+
throw new Exception("non-bootstraped-library")
40+
def transformStats(trees: List[Statement])(given ctx: Context): List[Statement] =
41+
throw new Exception("non-bootstraped-library")
42+
def transformTrees(trees: List[Tree])(given ctx: Context): List[Tree] =
43+
throw new Exception("non-bootstraped-library")
44+
def transformTerms(trees: List[Term])(given ctx: Context): List[Term] =
45+
throw new Exception("non-bootstraped-library")
46+
def transformTypeTrees(trees: List[TypeTree])(given ctx: Context): List[TypeTree] =
47+
throw new Exception("non-bootstraped-library")
48+
def transformCaseDefs(trees: List[CaseDef])(given ctx: Context): List[CaseDef] =
49+
throw new Exception("non-bootstraped-library")
50+
def transformTypeCaseDefs(trees: List[TypeCaseDef])(given ctx: Context): List[TypeCaseDef] =
51+
throw new Exception("non-bootstraped-library")
52+
def transformSubTrees[Tr <: Tree](trees: List[Tr])(given ctx: Context): List[Tr] =
53+
throw new Exception("non-bootstraped-library")
54+
}
55+
1056
}

library/src/scala/internal/quoted/Matcher.scala

Lines changed: 69 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,27 @@ private[quoted] object Matcher {
1010
class QuoteMatcher[QCtx <: QuoteContext & Singleton](given val qctx: QCtx) {
1111
// TODO improve performance
1212

13+
// TODO use flag from qctx.tasty.rootContext. Maybe -debug or add -debug-macros
1314
private final val debug = false
1415

1516
import qctx.tasty.{_, given}
1617
import Matching._
1718

18-
private type Env = Set[(Symbol, Symbol)]
19+
/** A map relating equivalent symbols from the scrutinee and the pattern
20+
* For example in
21+
* ```
22+
* '{val a = 4; a * a} match case '{ val x = 4; x * x }
23+
* ```
24+
* when matching `a * a` with `x * x` the enviroment will contain `Map(a -> x)`.
25+
*/
26+
private type Env = Map[Symbol, Symbol]
1927

2028
inline private def withEnv[T](env: Env)(body: => (given Env) => T): T = body(given env)
2129

2230
class SymBinding(val sym: Symbol, val fromAbove: Boolean)
2331

2432
def termMatch(scrutineeTerm: Term, patternTerm: Term, hasTypeSplices: Boolean): Option[Tuple] = {
25-
implicit val env: Env = Set.empty
33+
implicit val env: Env = Map.empty
2634
if (hasTypeSplices) {
2735
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
2836
val matchings = scrutineeTerm.underlyingArgument =?= patternTerm.underlyingArgument
@@ -42,7 +50,7 @@ private[quoted] object Matcher {
4250

4351
// TODO factor out common logic with `termMatch`
4452
def typeTreeMatch(scrutineeTypeTree: TypeTree, patternTypeTree: TypeTree, hasTypeSplices: Boolean): Option[Tuple] = {
45-
implicit val env: Env = Set.empty
53+
implicit val env: Env = Map.empty
4654
if (hasTypeSplices) {
4755
implicit val ctx: Context = internal.Context_GADT_setFreshGADTBounds(rootContext)
4856
val matchings = scrutineeTypeTree =?= patternTypeTree
@@ -138,11 +146,29 @@ private[quoted] object Matcher {
138146
matched(scrutinee.seal)
139147

140148
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
141-
case (scrutinee: Term, TypeApply(patternHole, tpt :: Nil))
149+
case (ClosedPatternTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
142150
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
143151
scrutinee.tpe <:< tpt.tpe =>
144152
matched(scrutinee.seal)
145153

154+
// Matches an open term and wraps it into a lambda that provides the free variables
155+
case (scrutinee, pattern @ Apply(Select(TypeApply(patternHole, List(Inferred())), "apply"), args0 @ IdentArgs(args)))
156+
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole =>
157+
def bodyFn(lambdaArgs: List[Tree]): Tree = {
158+
val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf[List[Term]]).toMap
159+
new TreeMap {
160+
override def transformTerm(tree: Term)(given ctx: Context): Term =
161+
tree match
162+
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
163+
case tree => super.transformTerm(tree)
164+
}.transformTree(scrutinee)
165+
}
166+
val names = args.map(_.name)
167+
val argTypes = args0.map(x => x.tpe.widenTermRefExpr)
168+
val resType = pattern.tpe
169+
val res = Lambda(MethodType(names)(_ => argTypes, _ => resType), bodyFn)
170+
matched(res.seal)
171+
146172
//
147173
// Match two equivalent trees
148174
//
@@ -156,7 +182,7 @@ private[quoted] object Matcher {
156182
case (scrutinee, Typed(expr2, _)) =>
157183
scrutinee =?= expr2
158184

159-
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].apply((scrutinee.symbol, pattern.symbol)) =>
185+
case (Ident(_), Ident(_)) if scrutinee.symbol == pattern.symbol || summon[Env].get(scrutinee.symbol).contains(pattern.symbol) =>
160186
matched
161187

162188
case (Select(qual1, _), Select(qual2, _)) if scrutinee.symbol == pattern.symbol =>
@@ -165,18 +191,24 @@ private[quoted] object Matcher {
165191
case (_: Ref, _: Ref) if scrutinee.symbol == pattern.symbol =>
166192
matched
167193

168-
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol =>
194+
case (Apply(fn1, args1), Apply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
169195
fn1 =?= fn2 && args1 =?= args2
170196

171-
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol =>
197+
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) if fn1.symbol == fn2.symbol || summon[Env].get(fn1.symbol).contains(fn2.symbol) =>
172198
fn1 =?= fn2 && args1 =?= args2
173199

174200
case (Block(stats1, expr1), Block(binding :: stats2, expr2)) if isTypeBinding(binding) =>
175201
qctx.tasty.internal.Context_GADT_addToConstraint(summon[Context])(binding.symbol :: Nil)
176202
matched(new SymBinding(binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block(stats1, expr1) =?= Block(stats2, expr2)
177203

178204
case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) =>
179-
withEnv(summon[Env] + (stat1.symbol -> stat2.symbol)) {
205+
val newEnv = (stat1, stat2) match {
206+
case (stat1: Definition, stat2: Definition) =>
207+
summon[Env] + (stat1.symbol -> stat2.symbol)
208+
case _ =>
209+
summon[Env]
210+
}
211+
withEnv(newEnv) {
180212
stat1 =?= stat2 && Block(stats1, expr1) =?= Block(stats2, expr2)
181213
}
182214

@@ -268,7 +300,7 @@ private[quoted] object Matcher {
268300
|
269301
|${pattern.showExtractors}
270302
|
271-
|
303+
|with environment: ${summon[Env]}
272304
|
273305
|
274306
|""".stripMargin)
@@ -277,6 +309,33 @@ private[quoted] object Matcher {
277309
}
278310
end treeOps
279311

312+
private object ClosedPatternTerm {
313+
/** Matches a term that does not contain free variables defined in the pattern (i.e. not defined in `Env`) */
314+
def unapply(term: Term)(given Context, Env): Option[term.type] =
315+
if freePatternVars(term).isEmpty then Some(term) else None
316+
317+
/** Return all free variables of the term defined in the pattern (i.e. defined in `Env`) */
318+
def freePatternVars(term: Term)(given qctx: Context, env: Env): Set[Symbol] =
319+
val accumulator = new TreeAccumulator[Set[Symbol]] {
320+
def foldTree(x: Set[Symbol], tree: Tree)(given ctx: Context): Set[Symbol] =
321+
tree match
322+
case tree: Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree)
323+
case _ => foldOverTree(x, tree)
324+
}
325+
accumulator.foldTree(Set.empty, term)
326+
}
327+
328+
private object IdentArgs {
329+
def unapply(args: List[Term])(given Context): Option[List[Ident]] =
330+
args.foldRight(Option(List.empty[Ident])) {
331+
case (id: Ident, Some(acc)) => Some(id :: acc)
332+
case (Block(List(DefDef("$anonfun", Nil, List(params), Inferred(), Some(Apply(id: Ident, args)))), Closure(Ident("$anonfun"), None)), Some(acc))
333+
if params.zip(args).forall(_.symbol == _.symbol) =>
334+
Some(id :: acc)
335+
case _ => None
336+
}
337+
}
338+
280339
private def treeOptMatches(scrutinee: Option[Tree], pattern: Option[Tree])(given Context, Env): Matching = {
281340
(scrutinee, pattern) match {
282341
case (Some(x), Some(y)) => x =?= y
@@ -344,7 +403,7 @@ private[quoted] object Matcher {
344403
|
345404
|${pattern.showExtractors}
346405
|
347-
|
406+
|with environment: ${summon[Env]}
348407
|
349408
|
350409
|""".stripMargin)

library/src/scala/quoted/Expr.scala

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,8 +204,44 @@ package quoted {
204204
val elems: Seq[Expr[_]] = tup.asInstanceOf[Product].productIterator.toSeq.asInstanceOf[Seq[Expr[_]]]
205205
ofTuple(elems).cast[Tuple.InverseMap[T, Expr]]
206206
}
207-
}
208207

208+
// TODO generalize for any function arity (see Expr.betaReduce)
209+
def open[T1, R, X](f: Expr[T1 => R])(content: (Expr[R], [t] => Expr[t] => Expr[T1] => Expr[t]) => X)(given qctx: QuoteContext): X = {
210+
import qctx.tasty.{given, _}
211+
val (params, bodyExpr) = paramsAndBody(f)
212+
content(bodyExpr, [t] => (e: Expr[t]) => (v: Expr[T1]) => bodyFn[t](e.unseal, params, List(v.unseal)).seal.asInstanceOf[Expr[t]])
213+
}
214+
215+
def open[T1, T2, R, X](f: Expr[(T1, T2) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit): X = {
216+
import qctx.tasty.{given, _}
217+
val (params, bodyExpr) = paramsAndBody(f)
218+
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal)).seal.asInstanceOf[Expr[t]])
219+
}
220+
221+
def open[T1, T2, T3, R, X](f: Expr[(T1, T2, T3) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2], Expr[T3]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit, DummyImplicit): X = {
222+
import qctx.tasty.{given, _}
223+
val (params, bodyExpr) = paramsAndBody(f)
224+
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2], v3: Expr[T3]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal, v3.unseal)).seal.asInstanceOf[Expr[t]])
225+
}
226+
227+
private def paramsAndBody[R](given qctx: QuoteContext)(f: Expr[Any]) = {
228+
import qctx.tasty.{given, _}
229+
val Block(List(DefDef("$anonfun", Nil, List(params), _, Some(body))), Closure(Ident("$anonfun"), None)) = f.unseal.etaExpand
230+
(params, body.seal.asInstanceOf[Expr[R]])
231+
}
232+
233+
private def bodyFn[t](given qctx: QuoteContext)(e: qctx.tasty.Term, params: List[qctx.tasty.ValDef], args: List[qctx.tasty.Term]): qctx.tasty.Term = {
234+
import qctx.tasty.{given, _}
235+
val map = params.map(_.symbol).zip(args).toMap
236+
new TreeMap {
237+
override def transformTerm(tree: Term)(given ctx: Context): Term =
238+
super.transformTerm(tree) match
239+
case tree: Ident => map.getOrElse(tree.symbol, tree)
240+
case tree => tree
241+
}.transformTerm(e)
242+
}
243+
244+
}
209245
}
210246

211247
package internal {

library/src/scala/quoted/matching/Sym.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ package matching
88
*/
99
class Sym[T <: AnyKind] private[scala](val name: String, private[Sym] val id: Object) { self =>
1010

11+
override def toString: String = s"Sym($name)@${id.hashCode}"
12+
1113
override def equals(obj: Any): Boolean = obj match {
1214
case obj: Sym[_] => obj.id == id
1315
case _ => false

library/src/scala/tasty/reflect/CompilerInterface.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ trait CompilerInterface {
443443
def Closure_apply(meth: Term, tpe: Option[Type])(given ctx: Context): Closure
444444
def Closure_copy(original: Tree)(meth: Tree, tpe: Option[Type])(given ctx: Context): Closure
445445

446+
def Lambda_apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block
447+
446448
/** Tree representing an if/then/else `if (...) ... else ...` in the source code */
447449
type If <: Term
448450

@@ -810,6 +812,11 @@ trait CompilerInterface {
810812
*/
811813
def Type_widen(self: Type)(given ctx: Context): Type
812814

815+
/** Widen from TermRef to its underlying non-termref
816+
* base type, while also skipping Expr types.
817+
*/
818+
def Type_widenTermRefExpr(self: Type)(given ctx: Context): Type
819+
813820
/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
814821
* TypeVars until type is no longer alias type, annotated type, LazyRef,
815822
* or instantiated type variable.
@@ -992,6 +999,8 @@ trait CompilerInterface {
992999

9931000
def isInstanceOfMethodType(given ctx: Context): IsInstanceOf[MethodType]
9941001

1002+
def MethodType_apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType
1003+
9951004
def MethodType_isErased(self: MethodType): Boolean
9961005
def MethodType_isImplicit(self: MethodType): Boolean
9971006
def MethodType_paramNames(self: MethodType)(given ctx: Context): List[String]

library/src/scala/tasty/reflect/TreeOps.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,10 @@ trait TreeOps extends Core {
615615

616616
case _ => None
617617
}
618+
619+
def apply(tpe: MethodType, rhsFn: List[Tree] => Tree)(implicit ctx: Context): Block =
620+
internal.Lambda_apply(tpe, rhsFn)
621+
618622
}
619623

620624
given (given Context): IsInstanceOf[If] = internal.isInstanceOfIf

library/src/scala/tasty/reflect/TypeOrBoundsOps.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,22 @@ trait TypeOrBoundsOps extends Core {
1717
/** Is this type a subtype of that type? */
1818
def <:<(that: Type)(given ctx: Context): Boolean = internal.Type_isSubType(self)(that)
1919

20+
/** Widen from singleton type to its underlying non-singleton
21+
* base type by applying one or more `underlying` dereferences,
22+
* Also go from => T to T.
23+
* Identity for all other types. Example:
24+
*
25+
* class Outer { class C ; val x: C }
26+
* def o: Outer
27+
* <o.x.type>.widen = o.C
28+
*/
2029
def widen(given ctx: Context): Type = internal.Type_widen(self)
2130

31+
/** Widen from TermRef to its underlying non-termref
32+
* base type, while also skipping `=>T` types.
33+
*/
34+
def widenTermRefExpr(given ctx: Context): Type = internal.Type_widenTermRefExpr(self)
35+
2236
/** Follow aliases and dereferences LazyRefs, annotated types and instantiated
2337
* TypeVars until type is no longer alias type, annotated type, LazyRef,
2438
* or instantiated type variable.
@@ -325,6 +339,9 @@ trait TypeOrBoundsOps extends Core {
325339
def unapply(x: MethodType)(given ctx: Context): Option[MethodType] = Some(x)
326340

327341
object MethodType {
342+
def apply(paramNames: List[String])(paramInfosExp: MethodType => List[Type], resultTypeExp: MethodType => Type): MethodType =
343+
internal.MethodType_apply(paramNames)(paramInfosExp, resultTypeExp)
344+
328345
def unapply(x: MethodType)(given ctx: Context): Option[(List[String], List[Type], Type)] =
329346
Some((x.paramNames, x.paramTypes, x.resType))
330347
}

tests/run-macros/quote-matcher-runtime.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ Pattern: {
332332
val x: scala.Int = 45
333333
x.+(scala.internal.Quoted.patternHole[scala.Int])
334334
}
335-
Result: Some(List(Expr(a)))
335+
Result: None
336336

337337
Scrutinee: {
338338
lazy val a: scala.Int = 45
@@ -622,7 +622,7 @@ Pattern: {
622622
def a: scala.Int = scala.internal.Quoted.patternHole[scala.Int]
623623
a.+(scala.internal.Quoted.patternHole[scala.Int])
624624
}
625-
Result: Some(List(Expr(a), Expr(a)))
625+
Result: None
626626

627627
Scrutinee: {
628628
lazy val a: scala.Int = a

0 commit comments

Comments
 (0)