Skip to content

Commit f8144ef

Browse files
committed
Use polyfunction for quote hole contents
This avoids leaking references to types defined within the quotes when the contents of the holes are extracted from the quote.
1 parent d4a8600 commit f8144ef

File tree

5 files changed

+121
-16
lines changed

5 files changed

+121
-16
lines changed

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class ExpandSAMs extends MiniPhase:
4848
tpt.tpe match {
4949
case NoType =>
5050
tree // it's a plain function
51-
case tpe if defn.isContextFunctionType(tpe) =>
51+
case tpe if defn.isFunctionOrPolyType(tpe) =>
5252
tree
5353
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
5454
val tpe1 = checkRefinements(tpe, fn)

compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -319,17 +319,28 @@ object PickleQuotes {
319319
defn.QuotedExprClass.typeRef.appliedTo(defn.AnyType)),
320320
args =>
321321
val cases = termSplices.map { case (splice, idx) =>
322-
val defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) = splice.tpe: @unchecked
322+
val (typeParamCount, argTypes, quotesType) = splice.tpe match
323+
case defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) => (0, argTypes, quotesType)
324+
case RefinedType(polyFun, nme.apply, pt @ PolyType(tparams, _)) if polyFun.typeSymbol.derivesFrom(defn.PolyFunctionClass) =>
325+
pt.instantiate(pt.paramInfos.map(_.hi)) match
326+
case MethodTpe(_, argTypes, defn.FunctionOf(quotesType :: _, _, _)) =>
327+
(tparams.size, argTypes, quotesType)
328+
323329
val rhs = {
324330
val spliceArgs = argTypes.zipWithIndex.map { (argType, i) =>
325331
args(1).select(nme.apply).appliedTo(Literal(Constant(i))).asInstance(argType)
326332
}
327333
val Block(List(ddef: DefDef), _) = splice: @unchecked
328-
// TODO: beta reduce inner closure? Or wait until BetaReduce phase?
329-
BetaReduce(
330-
splice
331-
.select(nme.apply).appliedToArgs(spliceArgs))
332-
.select(nme.apply).appliedTo(args(2).asInstance(quotesType))
334+
335+
val typeArgs = ddef.symbol.info match
336+
case pt: PolyType => pt.paramInfos
337+
case _ => Nil
338+
339+
val sel1 = splice.changeOwner(ddef.symbol.owner, ctx.owner).select(nme.apply)
340+
val appTpe = if typeParamCount == 0 then sel1 else sel1.appliedToTypes(List.fill(typeParamCount)(defn.AnyType))
341+
val app1 = appTpe.appliedToArgs(spliceArgs)
342+
val sel2 = app1.select(nme.apply)
343+
sel2.appliedTo(args(2).asInstance(quotesType))
333344
}
334345
CaseDef(Literal(Constant(idx)), EmptyTree, rhs)
335346
}
@@ -338,18 +349,31 @@ object PickleQuotes {
338349
case _ => Match(args(0).annotated(New(ref(defn.UncheckedAnnot.typeRef))), cases)
339350
)
340351

352+
def dealiasSplicedTypes(tp: Type) = new TypeMap {
353+
def apply(tp: Type): Type = tp match
354+
case tp: TypeRef if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
355+
val TypeAlias(alias) = tp.info: @unchecked
356+
alias
357+
case tp1 => mapOver(tp)
358+
}.apply(tp)
359+
360+
val adaptedType =
361+
if isType then dealiasSplicedTypes(originalTp)
362+
else originalTp
363+
341364
val quoteClass = if isType then defn.QuotedTypeClass else defn.QuotedExprClass
342-
val quotedType = quoteClass.typeRef.appliedTo(originalTp)
365+
val quotedType = quoteClass.typeRef.appliedTo(adaptedType)
343366
val lambdaTpe = MethodType(defn.QuotesClass.typeRef :: Nil, quotedType)
344367
val unpickleMeth =
345368
if isType then defn.QuoteUnpickler_unpickleTypeV2
346369
else defn.QuoteUnpickler_unpickleExprV2
347370
val unpickleArgs =
348371
if isType then List(pickledQuoteStrings, types)
349372
else List(pickledQuoteStrings, types, termHoles)
373+
350374
quotes
351375
.asInstance(defn.QuoteUnpicklerClass.typeRef)
352-
.select(unpickleMeth).appliedToType(originalTp)
376+
.select(unpickleMeth).appliedToType(adaptedType)
353377
.appliedToArgs(unpickleArgs).withSpan(body.span)
354378
}
355379

compiler/src/dotty/tools/dotc/transform/Splicing.scala

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import dotty.tools.dotc.config.ScalaRelease.*
2424
import dotty.tools.dotc.staging.QuoteContext.*
2525
import dotty.tools.dotc.staging.StagingLevel.*
2626
import dotty.tools.dotc.staging.QuoteTypeTags
27+
import dotty.tools.dotc.staging.DirectTypeOf
2728

2829
import scala.annotation.constructorOnly
2930

@@ -132,7 +133,7 @@ class Splicing extends MacroTransform:
132133
case None =>
133134
val holeIdx = numHoles
134135
numHoles += 1
135-
val hole = tpd.Hole(false, holeIdx, Nil, ref(qual), TypeTree(tp))
136+
val hole = tpd.Hole(false, holeIdx, Nil, ref(qual), TypeTree(tp.dealias))
136137
typeHoles.put(qual.symbol, hole)
137138
hole
138139
cpy.TypeDef(tree)(rhs = hole)
@@ -154,7 +155,7 @@ class Splicing extends MacroTransform:
154155

155156
private def transformAnnotations(tree: DefTree)(using Context): Unit =
156157
tree.symbol.annotations = tree.symbol.annotations.mapconserve { annot =>
157-
val newAnnotTree = transform(annot.tree)(using ctx.withOwner(tree.symbol))
158+
val newAnnotTree = transform(annot.tree)
158159
if (annot.tree == newAnnotTree) annot
159160
else ConcreteAnnotation(newAnnotTree)
160161
}
@@ -198,10 +199,56 @@ class Splicing extends MacroTransform:
198199
val newTree = transform(tree)
199200
val (refs, bindings) = refBindingMap.values.toList.unzip
200201
val bindingsTypes = bindings.map(_.termRef.widenTermRefExpr)
201-
val methType = MethodType(bindingsTypes, newTree.tpe)
202+
val types = bindingsTypes.collect {
203+
case AppliedType(tycon, List(arg: TypeRef)) if tycon.derivesFrom(defn.QuotedTypeClass) => arg
204+
}
205+
val newTypeParams = types.map { tpe =>
206+
newSymbol(
207+
spliceOwner,
208+
(tpe.symbol.name.toString + "$tpe").toTypeName,
209+
Param,
210+
TypeBounds.empty
211+
)
212+
}
213+
val methType =
214+
if types.nonEmpty then
215+
PolyType(types.map(tp => (tp.symbol.name.toString + "$").toTypeName))(
216+
pt => types.map(_ => TypeBounds.empty),
217+
pt => {
218+
val tpParamMap = new TypeMap {
219+
private val mapping = types.map(_.typeSymbol).zip(pt.paramRefs).toMap
220+
def apply(tp: Type): Type = tp match
221+
case tp: TypeRef => mapping.getOrElse(tp.typeSymbol, tp)
222+
case tp => mapOver(tp)
223+
}
224+
MethodType(bindingsTypes.map(tpParamMap), tpParamMap(newTree.tpe))
225+
}
226+
)
227+
else MethodType(bindingsTypes, newTree.tpe)
202228
val meth = newSymbol(spliceOwner, nme.ANON_FUN, Synthetic | Method, methType)
203-
val ddef = DefDef(meth, List(bindings), newTree.tpe, newTree.changeOwner(ctx.owner, meth))
204-
val fnType = defn.FunctionType(bindings.size, isContextual = false).appliedTo(bindingsTypes :+ newTree.tpe)
229+
230+
def substituteTypes(tree: Tree): Tree =
231+
if types.nonEmpty then
232+
TreeTypeMap(
233+
typeMap = new TypeMap {
234+
def apply(tp: Type): Type = tp match
235+
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice => tp
236+
case tp: TypeRef if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) => tp
237+
case _: TypeRef =>
238+
val idx = types.indexWhere(_ =:= tp) // TODO performance
239+
if idx == -1 then mapOver(tp)
240+
else newTypeParams(idx).typeRef
241+
case _ => mapOver(tp)
242+
}
243+
).transform(tree)
244+
else tree
245+
val paramss =
246+
if types.nonEmpty then List(newTypeParams, bindings)
247+
else List(bindings)
248+
val ddef = substituteTypes(DefDef(meth, paramss, newTree.tpe, newTree.changeOwner(ctx.owner, meth)))
249+
val fnType =
250+
if types.isEmpty then defn.FunctionType(bindings.size, isContextual = false).appliedTo(bindingsTypes :+ newTree.tpe) // FIXME add type parameter?
251+
else RefinedType(defn.PolyFunctionType, nme.apply, methType)
205252
val closure = Block(ddef :: Nil, Closure(Nil, ref(meth), TypeTree(fnType)))
206253
tpd.Hole(true, holeIdx, refs, closure, TypeTree(tpe))
207254

@@ -255,6 +302,9 @@ class Splicing extends MacroTransform:
255302
if tree.symbol == defn.QuotedTypeModule_of && containsCapturedType(tpt.tpe) =>
256303
val newContent = capturedPartTypes(tpt)
257304
newContent match
305+
case DirectTypeOf.Healed(termRef) =>
306+
// Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x`
307+
tpd.ref(termRef).withSpan(tpt.span)
258308
case block: Block =>
259309
inContext(ctx.withSource(tree.source)) {
260310
Apply(TypeApply(typeof, List(newContent)), List(quotes)).withSpan(tree.span)
@@ -354,7 +404,7 @@ class Splicing extends MacroTransform:
354404
private def newQuotedTypeClassBinding(tpe: Type)(using Context) =
355405
newSymbol(
356406
spliceOwner,
357-
UniqueName.fresh(nme.Type).toTermName,
407+
UniqueName.fresh(tpe.typeSymbol.name.toTermName),
358408
Param,
359409
defn.QuotedTypeClass.typeRef.appliedTo(tpe),
360410
)

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -670,6 +670,9 @@ object TreeChecker {
670670
else assert(tpt.typeOpt =:= pt)
671671

672672
// Check that the types of the args conform to the types of the contents of the hole
673+
val typeArgsTypes = args.collect { case arg if arg.isType =>
674+
arg.typeOpt
675+
}
673676
val argQuotedTypes = args.map { arg =>
674677
if arg.isTerm then
675678
val tpe = arg.typeOpt.widenTermRefExpr match
@@ -687,7 +690,29 @@ object TreeChecker {
687690
val contextualResult =
688691
defn.FunctionOf(List(defn.QuotesClass.typeRef), expectedResultType, isContextual = true)
689692
val expectedContentType =
690-
defn.FunctionOf(argQuotedTypes, contextualResult)
693+
if typeArgsTypes.isEmpty then defn.FunctionOf(argQuotedTypes, contextualResult)
694+
else RefinedType(defn.PolyFunctionType, nme.apply, PolyType(typeArgsTypes.map(_ => TypeBounds.empty))(pt =>
695+
val tpParamMap = new TypeMap {
696+
private val mapping = typeArgsTypes.map(_.typeSymbol).zip(pt.paramRefs).toMap
697+
def apply(tp: Type): Type = tp match
698+
case tp: TypeRef => mapping.getOrElse(tp.typeSymbol, tp)
699+
case tp => mapOver(tp)
700+
}
701+
MethodType(
702+
args.zipWithIndex.map { case (arg, idx) =>
703+
if arg.isTerm then
704+
val tpe = arg.typeOpt.widenTermRefExpr match
705+
case _: MethodicType =>
706+
// Special erasure for captured function references
707+
// See `SpliceTransformer.transformCapturedApplication`
708+
defn.AnyType
709+
case tpe => tpe
710+
defn.QuotedExprClass.typeRef.appliedTo(tpParamMap(tpe))
711+
else defn.QuotedTypeClass.typeRef.appliedTo(tpParamMap(arg.typeOpt))
712+
},
713+
tpParamMap(contextualResult))
714+
)
715+
)
691716
assert(content.typeOpt =:= expectedContentType, i"unexpected content of hole\nexpected: ${expectedContentType}\nwas: ${content.typeOpt}")
692717

693718
tree1

tests/pos-macros/captured-type.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import scala.quoted.*
2+
3+
object Foo:
4+
def baz(using Quotes): Unit = '{
5+
def f[T](x: T): T = ${ identity('{ x: T }) }
6+
}

0 commit comments

Comments
 (0)