Skip to content

Commit 9b2571a

Browse files
committed
Use polyfunction for quote hole contents
This avoids leaking references to types defined within the quotes when the contents of the holes are extracted from the quote.
1 parent f1bc0fb commit 9b2571a

File tree

17 files changed

+304
-55
lines changed

17 files changed

+304
-55
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class Compiler {
8989
new ExplicitOuter, // Add accessors to outer classes from nested ones.
9090
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
9191
new StringInterpolatorOpt, // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats
92-
new DropBreaks) :: // Optimize local Break throws by rewriting them
92+
new DropBreaks) :: // Optimize local Break throws by rewriting them
9393
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
9494
new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_`
9595
new InlinePatterns, // Remove placeholders of inlined patterns
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package dotty.tools.dotc
2+
package staging
3+
4+
import dotty.tools.dotc.ast.{tpd, untpd}
5+
import dotty.tools.dotc.core.Annotations._
6+
import dotty.tools.dotc.core.Contexts._
7+
import dotty.tools.dotc.core.Decorators._
8+
import dotty.tools.dotc.core.Flags._
9+
import dotty.tools.dotc.core.NameKinds._
10+
import dotty.tools.dotc.core.StdNames._
11+
import dotty.tools.dotc.core.Symbols._
12+
import dotty.tools.dotc.core.Types._
13+
import dotty.tools.dotc.staging.QuoteContext.*
14+
import dotty.tools.dotc.staging.StagingLevel.*
15+
import dotty.tools.dotc.staging.QuoteTypeTags.*
16+
import dotty.tools.dotc.util.Property
17+
import dotty.tools.dotc.util.Spans._
18+
import dotty.tools.dotc.util.SrcPos
19+
20+
object HealedDirectQuotedTypeRef:
21+
import tpd.*
22+
23+
/** Matches `quoted.Type.of[x.Underlying](quotes)` and extracts the TermRef to `x` */
24+
def unapply(body: Tree)(using Context): Option[TermRef] =
25+
body match
26+
case Block(List(tdef: TypeDef), tpt: TypeTree) =>
27+
tpt.tpe match
28+
case tpe: TypeRef if tpe.typeSymbol == tdef.symbol =>
29+
tdef.rhs.tpe.hiBound match
30+
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice => Some(x)
31+
case _ => None
32+
case _ => None
33+
case _ => None
34+

compiler/src/dotty/tools/dotc/staging/PCPCheckAndHeal.scala

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ import dotty.tools.dotc.util.SrcPos
4444
*
4545
*/
4646
class PCPCheckAndHeal extends TreeMapWithStages {
47-
import tpd._
47+
import tpd.*
4848

4949
private val InAnnotation = Property.Key[Unit]()
5050

@@ -107,30 +107,37 @@ class PCPCheckAndHeal extends TreeMapWithStages {
107107
val stripAnnotsDeep: TypeMap = new TypeMap:
108108
def apply(tp: Type): Type = mapOver(tp.stripAnnots)
109109

110-
val contextWithQuote =
111-
if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext)
112-
else quoteContext
113-
val body1 = transform(body)(using contextWithQuote)
114-
val body2 =
110+
def transformBody() =
111+
val contextWithQuote =
112+
if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext)
113+
else quoteContext
114+
val transformedBody = transform(body)(using contextWithQuote)
115115
taggedTypes.getTypeTags match
116-
case Nil => body1
117-
case tags => tpd.Block(tags, body1).withSpan(body.span)
116+
case Nil => transformedBody
117+
case tags => tpd.Block(tags, transformedBody).withSpan(body.span)
118118

119119
if body.isTerm then
120+
val transformedBody = transformBody()
120121
// `quoted.runtime.Expr.quote[T](<body>)` --> `quoted.runtime.Expr.quote[T2](<body2>)`
121122
val TypeApply(fun, targs) = quote.fun: @unchecked
122123
val targs2 = targs.map(targ => TypeTree(healType(quote.fun.srcPos)(stripAnnotsDeep(targ.tpe))))
123-
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, targs2), body2 :: Nil)
124+
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, targs2), transformedBody :: Nil)
124125
else
125-
val quotes = quote.args.mapConserve(transform)
126126
body.tpe match
127127
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice =>
128128
// Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x`
129-
ref(x)
129+
ref(x).withSpan(quote.span)
130130
case _ =>
131-
// `quoted.Type.of[<body>](quotes)` --> `quoted.Type.of[<body2>](quotes)`
132-
val TypeApply(fun, _) = quote.fun: @unchecked
133-
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, body2 :: Nil), quotes)
131+
transformBody() match
132+
case HealedDirectQuotedTypeRef(termRef) =>
133+
// Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x`
134+
tpd.ref(termRef).withSpan(quote.span)
135+
case transformedBody =>
136+
val quotes = quote.args.mapConserve(transform)
137+
// `quoted.Type.of[<body>](quotes)` --> `quoted.Type.of[<body2>](quotes)`
138+
val TypeApply(fun, _) = quote.fun: @unchecked
139+
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, transformedBody :: Nil), quotes)
140+
134141
}
135142

136143
/** Transform splice
@@ -236,4 +243,5 @@ class PCPCheckAndHeal extends TreeMapWithStages {
236243
| - but the access is at level $level.$hint""", pos)
237244
tp
238245
}
246+
239247
}

compiler/src/dotty/tools/dotc/transform/BetaReduce.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ object BetaReduce:
117117
case ref @ TypeRef(NoPrefix, _) =>
118118
ref.symbol
119119
case _ =>
120-
val binding = TypeDef(newSymbol(ctx.owner, tparam.name, EmptyFlags, targ.tpe, coord = targ.span)).withSpan(targ.span)
120+
val binding = TypeDef(newSymbol(ctx.owner, tparam.name, EmptyFlags, TypeAlias(targ.tpe), coord = targ.span)).withSpan(targ.span)
121121
bindings += binding
122122
binding.symbol
123123

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class ExpandSAMs extends MiniPhase:
4848
tpt.tpe match {
4949
case NoType =>
5050
tree // it's a plain function
51-
case tpe if defn.isContextFunctionType(tpe) =>
51+
case tpe if defn.isFunctionOrPolyType(tpe) =>
5252
tree
5353
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
5454
val tpe1 = checkRefinements(tpe, fn)

compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -314,17 +314,28 @@ object PickleQuotes {
314314
defn.QuotedExprClass.typeRef.appliedTo(defn.AnyType)),
315315
args =>
316316
val cases = termSplices.map { case (splice, idx) =>
317-
val defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) = splice.tpe: @unchecked
317+
val (typeParamCount, argTypes, quotesType) = splice.tpe match
318+
case defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) => (0, argTypes, quotesType)
319+
case RefinedType(polyFun, nme.apply, pt @ PolyType(tparams, _)) if polyFun.typeSymbol.derivesFrom(defn.PolyFunctionClass) =>
320+
pt.instantiate(pt.paramInfos.map(_.hi)) match
321+
case MethodTpe(_, argTypes, defn.FunctionOf(quotesType :: _, _, _)) =>
322+
(tparams.size, argTypes, quotesType)
323+
318324
val rhs = {
319325
val spliceArgs = argTypes.zipWithIndex.map { (argType, i) =>
320326
args(1).select(nme.apply).appliedTo(Literal(Constant(i))).asInstance(argType)
321327
}
322328
val Block(List(ddef: DefDef), _) = splice: @unchecked
323-
// TODO: beta reduce inner closure? Or wait until BetaReduce phase?
324-
BetaReduce(
325-
splice
326-
.select(nme.apply).appliedToArgs(spliceArgs))
327-
.select(nme.apply).appliedTo(args(2).asInstance(quotesType))
329+
330+
val typeArgs = ddef.symbol.info match
331+
case pt: PolyType => pt.paramInfos
332+
case _ => Nil
333+
334+
val sel1 = splice.changeOwner(ddef.symbol.owner, ctx.owner).select(nme.apply)
335+
val appTpe = if typeParamCount == 0 then sel1 else sel1.appliedToTypes(List.fill(typeParamCount)(defn.AnyType))
336+
val app1 = appTpe.appliedToArgs(spliceArgs)
337+
val sel2 = app1.select(nme.apply)
338+
sel2.appliedTo(args(2).asInstance(quotesType))
328339
}
329340
CaseDef(Literal(Constant(idx)), EmptyTree, rhs)
330341
}
@@ -333,18 +344,31 @@ object PickleQuotes {
333344
case _ => Match(args(0).annotated(New(ref(defn.UncheckedAnnot.typeRef))), cases)
334345
)
335346

347+
def dealiasSplicedTypes(tp: Type) = new TypeMap {
348+
def apply(tp: Type): Type = tp match
349+
case tp: TypeRef if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
350+
val TypeAlias(alias) = tp.info: @unchecked
351+
alias
352+
case tp1 => mapOver(tp)
353+
}.apply(tp)
354+
355+
val adaptedType =
356+
if isType then dealiasSplicedTypes(originalTp)
357+
else originalTp
358+
336359
val quoteClass = if isType then defn.QuotedTypeClass else defn.QuotedExprClass
337-
val quotedType = quoteClass.typeRef.appliedTo(originalTp)
360+
val quotedType = quoteClass.typeRef.appliedTo(adaptedType)
338361
val lambdaTpe = MethodType(defn.QuotesClass.typeRef :: Nil, quotedType)
339362
val unpickleMeth =
340363
if isType then defn.QuoteUnpickler_unpickleTypeV2
341364
else defn.QuoteUnpickler_unpickleExprV2
342365
val unpickleArgs =
343366
if isType then List(pickledQuoteStrings, types)
344367
else List(pickledQuoteStrings, types, termHoles)
368+
345369
quotes
346370
.asInstance(defn.QuoteUnpicklerClass.typeRef)
347-
.select(unpickleMeth).appliedToType(originalTp)
371+
.select(unpickleMeth).appliedToType(adaptedType)
348372
.appliedToArgs(unpickleArgs).withSpan(body.span)
349373
}
350374

compiler/src/dotty/tools/dotc/transform/Splicing.scala

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import dotty.tools.dotc.staging.PCPCheckAndHeal
2525
import dotty.tools.dotc.staging.QuoteContext.*
2626
import dotty.tools.dotc.staging.StagingLevel.*
2727
import dotty.tools.dotc.staging.QuoteTypeTags
28+
import dotty.tools.dotc.staging.HealedDirectQuotedTypeRef
2829

2930
import scala.annotation.constructorOnly
3031

@@ -133,7 +134,7 @@ class Splicing extends MacroTransform:
133134
case None =>
134135
val holeIdx = numHoles
135136
numHoles += 1
136-
val hole = tpd.Hole(false, holeIdx, Nil, ref(qual), TypeTree(tp))
137+
val hole = tpd.Hole(false, holeIdx, Nil, ref(qual), TypeTree(tp.dealias))
137138
typeHoles.put(qual.symbol, hole)
138139
hole
139140
cpy.TypeDef(tree)(rhs = hole)
@@ -155,7 +156,7 @@ class Splicing extends MacroTransform:
155156

156157
private def transformAnnotations(tree: DefTree)(using Context): Unit =
157158
tree.symbol.annotations = tree.symbol.annotations.mapconserve { annot =>
158-
val newAnnotTree = transform(annot.tree)(using ctx.withOwner(tree.symbol))
159+
val newAnnotTree = transform(annot.tree)
159160
if (annot.tree == newAnnotTree) annot
160161
else ConcreteAnnotation(newAnnotTree)
161162
}
@@ -199,10 +200,56 @@ class Splicing extends MacroTransform:
199200
val newTree = transform(tree)
200201
val (refs, bindings) = refBindingMap.values.toList.unzip
201202
val bindingsTypes = bindings.map(_.termRef.widenTermRefExpr)
202-
val methType = MethodType(bindingsTypes, newTree.tpe)
203+
val types = bindingsTypes.collect {
204+
case AppliedType(tycon, List(arg: TypeRef)) if tycon.derivesFrom(defn.QuotedTypeClass) => arg
205+
}
206+
val newTypeParams = types.map { tpe =>
207+
newSymbol(
208+
spliceOwner,
209+
(tpe.symbol.name.toString + "$tpe").toTypeName,
210+
Param,
211+
TypeBounds.empty
212+
)
213+
}
214+
val methType =
215+
if types.nonEmpty then
216+
PolyType(types.map(tp => (tp.symbol.name.toString + "$").toTypeName))(
217+
pt => types.map(_ => TypeBounds.empty),
218+
pt => {
219+
val tpParamMap = new TypeMap {
220+
private val mapping = types.map(_.typeSymbol).zip(pt.paramRefs).toMap
221+
def apply(tp: Type): Type = tp match
222+
case tp: TypeRef => mapping.getOrElse(tp.typeSymbol, tp)
223+
case tp => mapOver(tp)
224+
}
225+
MethodType(bindingsTypes.map(tpParamMap), tpParamMap(newTree.tpe))
226+
}
227+
)
228+
else MethodType(bindingsTypes, newTree.tpe)
203229
val meth = newSymbol(spliceOwner, nme.ANON_FUN, Synthetic | Method, methType)
204-
val ddef = DefDef(meth, List(bindings), newTree.tpe, newTree.changeOwner(ctx.owner, meth))
205-
val fnType = defn.FunctionType(bindings.size, isContextual = false).appliedTo(bindingsTypes :+ newTree.tpe)
230+
231+
def substituteTypes(tree: Tree): Tree =
232+
if types.nonEmpty then
233+
TreeTypeMap(
234+
typeMap = new TypeMap {
235+
def apply(tp: Type): Type = tp match
236+
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice => tp
237+
case tp: TypeRef if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) => tp
238+
case _: TypeRef =>
239+
val idx = types.indexWhere(_ =:= tp) // TODO performance
240+
if idx == -1 then mapOver(tp)
241+
else newTypeParams(idx).typeRef
242+
case _ => mapOver(tp)
243+
}
244+
).transform(tree)
245+
else tree
246+
val paramss =
247+
if types.nonEmpty then List(newTypeParams, bindings)
248+
else List(bindings)
249+
val ddef = substituteTypes(DefDef(meth, paramss, newTree.tpe, newTree.changeOwner(ctx.owner, meth)))
250+
val fnType =
251+
if types.isEmpty then defn.FunctionType(bindings.size, isContextual = false).appliedTo(bindingsTypes :+ newTree.tpe) // FIXME add type parameter?
252+
else RefinedType(defn.PolyFunctionType, nme.apply, methType)
206253
val closure = Block(ddef :: Nil, Closure(Nil, ref(meth), TypeTree(fnType)))
207254
tpd.Hole(true, holeIdx, refs, closure, TypeTree(tpe))
208255

@@ -248,6 +295,9 @@ class Splicing extends MacroTransform:
248295
if tree.symbol == defn.QuotedTypeModule_of && containsCapturedType(tpt.tpe) =>
249296
val newContent = capturedPartTypes(tpt)
250297
newContent match
298+
case HealedDirectQuotedTypeRef(termRef) =>
299+
// Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x`
300+
tpd.ref(termRef).withSpan(tpt.span)
251301
case block: Block =>
252302
inContext(ctx.withSource(tree.source)) {
253303
Apply(TypeApply(typeof, List(newContent)), List(quotes)).withSpan(tree.span)
@@ -347,7 +397,7 @@ class Splicing extends MacroTransform:
347397
private def newQuotedTypeClassBinding(tpe: Type)(using Context) =
348398
newSymbol(
349399
spliceOwner,
350-
UniqueName.fresh(nme.Type).toTermName,
400+
UniqueName.fresh(tpe.typeSymbol.name.toTermName),
351401
Param,
352402
defn.QuotedTypeClass.typeRef.appliedTo(tpe),
353403
)

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,11 @@ object TreeChecker {
533533
i"owner chain = ${tree.symbol.ownersIterator.toList}%, %, ctxOwners = ${ctx.outersIterator.map(_.owner).toList}%, %")
534534
}
535535

536+
override def typedTypeDef(tdef: untpd.TypeDef, sym: Symbol)(using Context): Tree = {
537+
assert(sym.info.isInstanceOf[ClassInfo] || sym.info.isInstanceOf[TypeBounds], i"wrong type, expect a template or type bounds for ${sym.fullName}, but found: ${sym.info}")
538+
super.typedTypeDef(tdef, sym)
539+
}
540+
536541
override def typedClassDef(cdef: untpd.TypeDef, cls: ClassSymbol)(using Context): Tree = {
537542
val TypeDef(_, impl @ Template(constr, _, _, _)) = cdef: @unchecked
538543
assert(cdef.symbol == cls)
@@ -665,6 +670,9 @@ object TreeChecker {
665670
else assert(tpt.typeOpt =:= pt)
666671

667672
// Check that the types of the args conform to the types of the contents of the hole
673+
val typeArgsTypes = args.collect { case arg if arg.isType =>
674+
arg.typeOpt
675+
}
668676
val argQuotedTypes = args.map { arg =>
669677
if arg.isTerm then
670678
val tpe = arg.typeOpt.widenTermRefExpr match
@@ -682,7 +690,28 @@ object TreeChecker {
682690
val contextualResult =
683691
defn.FunctionOf(List(defn.QuotesClass.typeRef), expectedResultType, isContextual = true)
684692
val expectedContentType =
685-
defn.FunctionOf(argQuotedTypes, contextualResult)
693+
if typeArgsTypes.isEmpty then defn.FunctionOf(argQuotedTypes, contextualResult)
694+
else RefinedType(defn.PolyFunctionType, nme.apply, PolyType(typeArgsTypes.map(_ => TypeBounds.empty))(pt =>
695+
val tpParamMap = new TypeMap {
696+
private val mapping = typeArgsTypes.map(_.typeSymbol).zip(pt.paramRefs).toMap
697+
def apply(tp: Type): Type = tp match
698+
case tp: TypeRef => mapping.getOrElse(tp.typeSymbol, tp)
699+
case tp => mapOver(tp)
700+
}
701+
MethodType(
702+
args.zipWithIndex.map { case (arg, idx) =>
703+
704+
if arg.isTerm then
705+
val tpe = arg.typeOpt.widenTermRefExpr match
706+
case _: MethodicType =>
707+
// Special erasure for captured function references
708+
// See `SpliceTransformer.transformCapturedApplication`
709+
defn.AnyType
710+
case tpe => tpe
711+
defn.QuotedExprClass.typeRef.appliedTo(tpParamMap(tpe))
712+
else defn.QuotedTypeClass.typeRef.appliedTo(tpParamMap(arg.typeOpt))
713+
},
714+
tpParamMap(contextualResult))))
686715
assert(content.typeOpt =:= expectedContentType, i"expected content of the hole to be ${expectedContentType} but got ${content.typeOpt}")
687716

688717
tree1

compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,18 +1345,22 @@ object SourceCode {
13451345
}
13461346

13471347
private def printBoundsTree(bounds: TypeBoundsTree)(using elideThis: Option[Symbol]): this.type = {
1348-
bounds.low match {
1349-
case Inferred() =>
1350-
case low =>
1351-
this += " >: "
1352-
printTypeTree(low)
1353-
}
1354-
bounds.hi match {
1355-
case Inferred() => this
1356-
case hi =>
1357-
this += " <: "
1358-
printTypeTree(hi)
1359-
}
1348+
if bounds.low.tpe =:= bounds.hi.tpe then
1349+
this += " = "
1350+
printTypeTree(bounds.low)
1351+
else
1352+
bounds.low match {
1353+
case Inferred() =>
1354+
case low =>
1355+
this += " >: "
1356+
printTypeTree(low)
1357+
}
1358+
bounds.hi match {
1359+
case Inferred() => this
1360+
case hi =>
1361+
this += " <: "
1362+
printTypeTree(hi)
1363+
}
13601364
}
13611365

13621366
private def printBounds(bounds: TypeBounds)(using elideThis: Option[Symbol]): this.type = {

0 commit comments

Comments
 (0)