Skip to content

Properly encode splice hole using PolyFunction #17072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ExpandSAMs extends MiniPhase:
tpt.tpe match {
case NoType =>
tree // it's a plain function
case tpe if defn.isContextFunctionType(tpe) =>
case tpe if defn.isFunctionOrPolyType(tpe) =>
tree
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
val tpe1 = checkRefinements(tpe, fn)
Expand Down
39 changes: 31 additions & 8 deletions compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -319,17 +319,27 @@ object PickleQuotes {
defn.QuotedExprClass.typeRef.appliedTo(defn.AnyType)),
args =>
val cases = termSplices.map { case (splice, idx) =>
val defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) = splice.tpe: @unchecked
val (typeParamCount, argTypes, quotesType) = splice.tpe match
case defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) => (0, argTypes, quotesType)
case RefinedType(polyFun, nme.apply, pt @ PolyType(tparams, _)) if polyFun.typeSymbol.derivesFrom(defn.PolyFunctionClass) =>
pt.instantiate(pt.paramInfos.map(_.hi)) match
case MethodTpe(_, argTypes, defn.FunctionOf(quotesType :: _, _, _)) =>
(tparams.size, argTypes, quotesType)

val rhs = {
val spliceArgs = argTypes.zipWithIndex.map { (argType, i) =>
args(1).select(nme.apply).appliedTo(Literal(Constant(i))).asInstance(argType)
}
val Block(List(ddef: DefDef), _) = splice: @unchecked
// TODO: beta reduce inner closure? Or wait until BetaReduce phase?
BetaReduce(
splice
.select(nme.apply).appliedToArgs(spliceArgs))
.select(nme.apply).appliedTo(args(2).asInstance(quotesType))
val quotes = args(2).asInstance(quotesType)
val dummyTargs = List.fill(typeParamCount)(defn.AnyType)
// Generate: <splice>.apply[<dummyTargs>*](<spliceArgs>*).apply(<quotes>)
splice.changeOwner(ddef.symbol.owner, ctx.owner)
.select(nme.apply)
.appliedToTypes(dummyTargs)
.appliedToArgs(spliceArgs)
.select(nme.apply)
.appliedTo(quotes)
}
CaseDef(Literal(Constant(idx)), EmptyTree, rhs)
}
Expand All @@ -338,18 +348,31 @@ object PickleQuotes {
case _ => Match(args(0).annotated(New(ref(defn.UncheckedAnnot.typeRef))), cases)
)

def dealiasSplicedTypes(tp: Type) = new TypeMap {
def apply(tp: Type): Type = tp match
case tp: TypeRef if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
val TypeAlias(alias) = tp.info: @unchecked
alias
case tp1 => mapOver(tp)
}.apply(tp)

val adaptedType =
if isType then dealiasSplicedTypes(originalTp)
else originalTp

val quoteClass = if isType then defn.QuotedTypeClass else defn.QuotedExprClass
val quotedType = quoteClass.typeRef.appliedTo(originalTp)
val quotedType = quoteClass.typeRef.appliedTo(adaptedType)
val lambdaTpe = MethodType(defn.QuotesClass.typeRef :: Nil, quotedType)
val unpickleMeth =
if isType then defn.QuoteUnpickler_unpickleTypeV2
else defn.QuoteUnpickler_unpickleExprV2
val unpickleArgs =
if isType then List(pickledQuoteStrings, types)
else List(pickledQuoteStrings, types, termHoles)

quotes
.asInstance(defn.QuoteUnpicklerClass.typeRef)
.select(unpickleMeth).appliedToType(originalTp)
.select(unpickleMeth).appliedToType(adaptedType)
.appliedToArgs(unpickleArgs).withSpan(body.span)
}

Expand Down
84 changes: 68 additions & 16 deletions compiler/src/dotty/tools/dotc/transform/Splicing.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import dotty.tools.dotc.config.ScalaRelease.*
import dotty.tools.dotc.staging.QuoteContext.*
import dotty.tools.dotc.staging.StagingLevel.*
import dotty.tools.dotc.staging.QuoteTypeTags
import dotty.tools.dotc.staging.DirectTypeOf

import scala.annotation.constructorOnly

Expand All @@ -38,18 +39,19 @@ object Splicing:
*
* After this phase we have the invariant where all splices have the following shape
* ```
* {{{ <holeIdx> | <holeType> | <captures>* | (<capturedTerms>*) => <spliceContent> }}}
* {{{ <holeIdx> | <holeType> | <captures>* | [<capturedTypes>*] => (<capturedTerms>*) => <spliceContent> }}}
* ```
* where `<spliceContent>` does not contain any free references to quoted definitions and `<captures>*`
* contains the quotes with references to all cross-quote references. There are some special rules
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the description a couple of lines above
(<capturedTerms>*) =>
should probably be
[<capturedTypes>*] => (<capturedTerms>*) =>

* for references in the LHS of assignments and cross-quote method references.
*
* In the following code example `x1` and `x2` are cross-quote references.
* In the following code example `x1`, `x2` and `U` are cross-quote references.
* ```
* '{ ...
* val x1: T1 = ???
* val x2: T2 = ???
* ${ (q: Quotes) ?=> f('{ g(x1, x2) }) }: T3
* type U
* val x1: T = ???
* val x2: U = ???
* ${ (q: Quotes) ?=> f('{ g[U](x1, x2) }) }: T3
* }
* ```
*
Expand All @@ -60,15 +62,15 @@ object Splicing:
* '{ ...
* val x1: T1 = ???
* val x2: T2 = ???
* {{{ 0 | T3 | x1, x2 |
* (x1$: Expr[T1], x2$: Expr[T2]) => // body of this lambda does not contain references to x1 or x2
* (q: Quotes) ?=> f('{ g(${x1$}, ${x2$}) })
* {{{ 0 | T3 | U, x1, x2 |
* [U$1] => (U$2: Type[U$1], x1$: Expr[T], x2$: Expr[U$1]) => // body of this lambda does not contain references to U, x1 or x2
* (q: Quotes) ?=> f('{ @SplicedType type U$3 = [[[ 0 | U$2 | | U$1 ]]]; g[U$3](${x1$}, ${x2$}) })
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the [[[ ... ]]] syntax used here documented somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. It is the syntax defined in the RefinedPrinter. I plan to update this syntax in #17144.

*
* }}}
* }
* ```
*
* and then performs the same transformation on `'{ g(${x1$}, ${x2$}) }`.
* and then performs the same transformation on `'{ @SplicedType type U$3 = [[[ 0 | U$2 | | U$1 ]]]; g[U$3](${x1$}, ${x2$}) }`.
*
*/
class Splicing extends MacroTransform:
Expand Down Expand Up @@ -132,7 +134,7 @@ class Splicing extends MacroTransform:
case None =>
val holeIdx = numHoles
numHoles += 1
val hole = tpd.Hole(false, holeIdx, Nil, ref(qual), TypeTree(tp))
val hole = tpd.Hole(false, holeIdx, Nil, ref(qual), TypeTree(tp.dealias))
typeHoles.put(qual.symbol, hole)
hole
cpy.TypeDef(tree)(rhs = hole)
Expand All @@ -154,7 +156,7 @@ class Splicing extends MacroTransform:

private def transformAnnotations(tree: DefTree)(using Context): Unit =
tree.symbol.annotations = tree.symbol.annotations.mapconserve { annot =>
val newAnnotTree = transform(annot.tree)(using ctx.withOwner(tree.symbol))
val newAnnotTree = transform(annot.tree)
if (annot.tree == newAnnotTree) annot
else ConcreteAnnotation(newAnnotTree)
}
Expand Down Expand Up @@ -184,7 +186,7 @@ class Splicing extends MacroTransform:
* ```
* is transformed into
* ```scala
* {{{ <holeIdx++> | T2 | x, X | (x$1: Expr[T1], X$1: Type[X]) => (using Quotes) ?=> {... ${x$1} ... X$1.Underlying ...} }}}
* {{{ <holeIdx++> | T2 | x, X | [X$2] => (x$1: Expr[T1], X$1: Type[X$2]) => (using Quotes) ?=> {... ${x$1} ... X$1.Underlying ...} }}}
* ```
*/
private class SpliceTransformer(spliceOwner: Symbol, isCaptured: Symbol => Boolean) extends Transformer:
Expand All @@ -198,10 +200,57 @@ class Splicing extends MacroTransform:
val newTree = transform(tree)
val (refs, bindings) = refBindingMap.values.toList.unzip
val bindingsTypes = bindings.map(_.termRef.widenTermRefExpr)
val methType = MethodType(bindingsTypes, newTree.tpe)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code example in the doc comment above SpliceTransformer also needs to be updated to use a polymorphic function.

val capturedTypes = bindingsTypes.collect {
case AppliedType(tycon, List(arg: TypeRef)) if tycon.derivesFrom(defn.QuotedTypeClass) => arg
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a dealias might be needed in case the user defines an alias such as type TInt = Type[Int]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These types are created in the SpliceTransformer. They should always have this shape.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if they have the same shape, they might not be referentially equal (they are only referentially equal if they are cached by Uniques.scala, but not all types are cacheable).

}
val newTypeParams = capturedTypes.map { tpe =>
newSymbol(
spliceOwner,
UniqueName.fresh(tpe.symbol.name.toTypeName),
Param,
TypeBounds.empty
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the type bounds of the type params match the type bounds of the original type for the encoding to always typecheck?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that I am missing something. I will add some test cases and update these bounds.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Speaking of bounds, if types can refer to terms in their bounds (val x = ...; type U <: x.T; ${ x.foo: U }) then the encoding where types always appear before terms wouldn't work if we need to also store these bounds.

)
}
val methType =
if capturedTypes.nonEmpty then
PolyType(capturedTypes.map(tp => UniqueName.fresh(tp.symbol.name.toTypeName)))(
pt => capturedTypes.map(_ => TypeBounds.empty),
pt => {
val tpParamMap = new TypeMap {
private val mapping = capturedTypes.map(_.typeSymbol).zip(pt.paramRefs).toMap
def apply(tp: Type): Type = tp match
case tp: TypeRef => mapping.getOrElse(tp.typeSymbol, tp)
case tp => mapOver(tp)
}
MethodType(bindingsTypes.map(tpParamMap), tpParamMap(newTree.tpe))
}
)
else MethodType(bindingsTypes, newTree.tpe)
val meth = newSymbol(spliceOwner, nme.ANON_FUN, Synthetic | Method, methType)
val ddef = DefDef(meth, List(bindings), newTree.tpe, newTree.changeOwner(ctx.owner, meth))
val fnType = defn.FunctionType(bindings.size, isContextual = false).appliedTo(bindingsTypes :+ newTree.tpe)

def substituteTypes(tree: Tree): Tree =
if capturedTypes.nonEmpty then
val typeIndex = capturedTypes.zipWithIndex.toMap
TreeTypeMap(
typeMap = new TypeMap {
def apply(tp: Type): Type = tp match
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice => tp
case tp: TypeRef if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) => tp
case tp: TypeRef =>
typeIndex.get(tp) match
case Some(idx) => newTypeParams(idx).typeRef
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having a Map where the keys are Types might be a little bit fragile since we could have two equivalent Types which are not referentially equal, is this not a problem here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should only contain TypeRefs (and TermRefs after some other fix). This worked in the tests I have. As far as I can see, both of those are cached and should have referential equality.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the prefix of a TypeRef could be an uncacheable type, for example if it contains a LazyRef

case None => mapOver(tp)
case _ => mapOver(tp)
}
).transform(tree)
else tree
val paramss =
if capturedTypes.nonEmpty then List(newTypeParams, bindings)
else List(bindings)
val ddef = substituteTypes(DefDef(meth, paramss, newTree.tpe, newTree.changeOwner(ctx.owner, meth)))
val fnType =
if capturedTypes.isEmpty then defn.FunctionType(bindings.size, isContextual = false).appliedTo(bindingsTypes :+ newTree.tpe)
else RefinedType(defn.PolyFunctionType, nme.apply, methType)
val closure = Block(ddef :: Nil, Closure(Nil, ref(meth), TypeTree(fnType)))
tpd.Hole(true, holeIdx, refs, closure, TypeTree(tpe))

Expand Down Expand Up @@ -255,6 +304,9 @@ class Splicing extends MacroTransform:
if tree.symbol == defn.QuotedTypeModule_of && containsCapturedType(tpt.tpe) =>
val newContent = capturedPartTypes(tpt)
newContent match
case DirectTypeOf.Healed(termRef) =>
// Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x`
tpd.ref(termRef).withSpan(tpt.span)
case block: Block =>
inContext(ctx.withSource(tree.source)) {
Apply(TypeApply(typeof, List(newContent)), List(quotes)).withSpan(tree.span)
Expand Down Expand Up @@ -354,7 +406,7 @@ class Splicing extends MacroTransform:
private def newQuotedTypeClassBinding(tpe: Type)(using Context) =
newSymbol(
spliceOwner,
UniqueName.fresh(nme.Type).toTermName,
UniqueName.fresh(tpe.typeSymbol.name.toTermName),
Param,
defn.QuotedTypeClass.typeRef.appliedTo(tpe),
)
Expand Down
66 changes: 46 additions & 20 deletions compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -669,26 +669,52 @@ object TreeChecker {
if isTermHole then assert(tpt.typeOpt <:< pt)
else assert(tpt.typeOpt =:= pt)

// Check that the types of the args conform to the types of the contents of the hole
val argQuotedTypes = args.map { arg =>
if arg.isTerm then
val tpe = arg.typeOpt.widenTermRefExpr match
case _: MethodicType =>
// Special erasure for captured function references
// See `SpliceTransformer.transformCapturedApplication`
defn.AnyType
case tpe => tpe
defn.QuotedExprClass.typeRef.appliedTo(tpe)
else defn.QuotedTypeClass.typeRef.appliedTo(arg.typeOpt.widenTermRefExpr)
}
val expectedResultType =
if isTermHole then defn.QuotedExprClass.typeRef.appliedTo(tpt.typeOpt)
else defn.QuotedTypeClass.typeRef.appliedTo(tpt.typeOpt)
val contextualResult =
defn.FunctionOf(List(defn.QuotesClass.typeRef), expectedResultType, isContextual = true)
val expectedContentType =
defn.FunctionOf(argQuotedTypes, contextualResult)
assert(content.typeOpt =:= expectedContentType, i"unexpected content of hole\nexpected: ${expectedContentType}\nwas: ${content.typeOpt}")
if content != EmptyTree then
// Check that the types of the args conform to the types of the contents of the hole
val typeArgsTypes = args.collect { case arg if arg.isType =>
arg.typeOpt
}
val argQuotedTypes = args.map { arg =>
if arg.isTerm then
val tpe = arg.typeOpt.widenTermRefExpr match
case _: MethodicType =>
// Special erasure for captured function references
// See `SpliceTransformer.transformCapturedApplication`
defn.AnyType
case tpe => tpe
defn.QuotedExprClass.typeRef.appliedTo(tpe)
else defn.QuotedTypeClass.typeRef.appliedTo(arg.typeOpt.widenTermRefExpr)
}
val expectedResultType =
if isTermHole then defn.QuotedExprClass.typeRef.appliedTo(tpt.typeOpt)
else defn.QuotedTypeClass.typeRef.appliedTo(tpt.typeOpt)
val contextualResult =
defn.FunctionOf(List(defn.QuotesClass.typeRef), expectedResultType, isContextual = true)
val expectedContentType =
if typeArgsTypes.isEmpty then defn.FunctionOf(argQuotedTypes, contextualResult)
else RefinedType(defn.PolyFunctionType, nme.apply, PolyType(typeArgsTypes.map(_ => TypeBounds.empty))(pt =>
val tpParamMap = new TypeMap {
private val mapping = typeArgsTypes.map(_.typeSymbol).zip(pt.paramRefs).toMap
def apply(tp: Type): Type = tp match
case tp: TypeRef => mapping.getOrElse(tp.typeSymbol, tp)
case tp => mapOver(tp)
}
MethodType(
args.zipWithIndex.map { case (arg, idx) =>
if arg.isTerm then
val tpe = arg.typeOpt.widenTermRefExpr match
case _: MethodicType =>
// Special erasure for captured function references
// See `SpliceTransformer.transformCapturedApplication`
defn.AnyType
case tpe => tpe
defn.QuotedExprClass.typeRef.appliedTo(tpParamMap(tpe))
else defn.QuotedTypeClass.typeRef.appliedTo(tpParamMap(arg.typeOpt))
},
tpParamMap(contextualResult))
)
)
assert(content.typeOpt =:= expectedContentType, i"unexpected content of hole\nexpected: ${expectedContentType}\nwas: ${content.typeOpt}")

tree1
}
Expand Down
6 changes: 6 additions & 0 deletions tests/pos-macros/captured-type.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import scala.quoted.*

object Foo:
def baz(using Quotes): Unit = '{
def f[T](x: T): T = ${ identity('{ x: T }) }
}