diff --git a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala index 0bfc444e0997..b9abcb9cffdc 100644 --- a/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala +++ b/compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala @@ -48,7 +48,7 @@ class ExpandSAMs extends MiniPhase: tpt.tpe match { case NoType => tree // it's a plain function - case tpe if defn.isContextFunctionType(tpe) => + case tpe if defn.isFunctionOrPolyType(tpe) => tree case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) => val tpe1 = checkRefinements(tpe, fn) diff --git a/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala b/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala index 62174c806f09..722c685d4b96 100644 --- a/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala +++ b/compiler/src/dotty/tools/dotc/transform/PickleQuotes.scala @@ -319,17 +319,27 @@ object PickleQuotes { defn.QuotedExprClass.typeRef.appliedTo(defn.AnyType)), args => val cases = termSplices.map { case (splice, idx) => - val defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) = splice.tpe: @unchecked + val (typeParamCount, argTypes, quotesType) = splice.tpe match + case defn.FunctionOf(argTypes, defn.FunctionOf(quotesType :: _, _, _), _) => (0, argTypes, quotesType) + case RefinedType(polyFun, nme.apply, pt @ PolyType(tparams, _)) if polyFun.typeSymbol.derivesFrom(defn.PolyFunctionClass) => + pt.instantiate(pt.paramInfos.map(_.hi)) match + case MethodTpe(_, argTypes, defn.FunctionOf(quotesType :: _, _, _)) => + (tparams.size, argTypes, quotesType) + val rhs = { val spliceArgs = argTypes.zipWithIndex.map { (argType, i) => args(1).select(nme.apply).appliedTo(Literal(Constant(i))).asInstance(argType) } val Block(List(ddef: DefDef), _) = splice: @unchecked - // TODO: beta reduce inner closure? Or wait until BetaReduce phase? - BetaReduce( - splice - .select(nme.apply).appliedToArgs(spliceArgs)) - .select(nme.apply).appliedTo(args(2).asInstance(quotesType)) + val quotes = args(2).asInstance(quotesType) + val dummyTargs = List.fill(typeParamCount)(defn.AnyType) + // Generate: .apply[*](*).apply() + splice.changeOwner(ddef.symbol.owner, ctx.owner) + .select(nme.apply) + .appliedToTypes(dummyTargs) + .appliedToArgs(spliceArgs) + .select(nme.apply) + .appliedTo(quotes) } CaseDef(Literal(Constant(idx)), EmptyTree, rhs) } @@ -338,8 +348,20 @@ object PickleQuotes { case _ => Match(args(0).annotated(New(ref(defn.UncheckedAnnot.typeRef))), cases) ) + def dealiasSplicedTypes(tp: Type) = new TypeMap { + def apply(tp: Type): Type = tp match + case tp: TypeRef if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) => + val TypeAlias(alias) = tp.info: @unchecked + alias + case tp1 => mapOver(tp) + }.apply(tp) + + val adaptedType = + if isType then dealiasSplicedTypes(originalTp) + else originalTp + val quoteClass = if isType then defn.QuotedTypeClass else defn.QuotedExprClass - val quotedType = quoteClass.typeRef.appliedTo(originalTp) + val quotedType = quoteClass.typeRef.appliedTo(adaptedType) val lambdaTpe = MethodType(defn.QuotesClass.typeRef :: Nil, quotedType) val unpickleMeth = if isType then defn.QuoteUnpickler_unpickleTypeV2 @@ -347,9 +369,10 @@ object PickleQuotes { val unpickleArgs = if isType then List(pickledQuoteStrings, types) else List(pickledQuoteStrings, types, termHoles) + quotes .asInstance(defn.QuoteUnpicklerClass.typeRef) - .select(unpickleMeth).appliedToType(originalTp) + .select(unpickleMeth).appliedToType(adaptedType) .appliedToArgs(unpickleArgs).withSpan(body.span) } diff --git a/compiler/src/dotty/tools/dotc/transform/Splicing.scala b/compiler/src/dotty/tools/dotc/transform/Splicing.scala index bb82fba32a7c..6f3a3828e01d 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicing.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicing.scala @@ -24,6 +24,7 @@ import dotty.tools.dotc.config.ScalaRelease.* import dotty.tools.dotc.staging.QuoteContext.* import dotty.tools.dotc.staging.StagingLevel.* import dotty.tools.dotc.staging.QuoteTypeTags +import dotty.tools.dotc.staging.DirectTypeOf import scala.annotation.constructorOnly @@ -38,18 +39,19 @@ object Splicing: * * After this phase we have the invariant where all splices have the following shape * ``` - * {{{ | | * | (*) => }}} + * {{{ | | * | [*] => (*) => }}} * ``` * where `` does not contain any free references to quoted definitions and `*` * contains the quotes with references to all cross-quote references. There are some special rules * for references in the LHS of assignments and cross-quote method references. * - * In the following code example `x1` and `x2` are cross-quote references. + * In the following code example `x1`, `x2` and `U` are cross-quote references. * ``` * '{ ... - * val x1: T1 = ??? - * val x2: T2 = ??? - * ${ (q: Quotes) ?=> f('{ g(x1, x2) }) }: T3 + * type U + * val x1: T = ??? + * val x2: U = ??? + * ${ (q: Quotes) ?=> f('{ g[U](x1, x2) }) }: T3 * } * ``` * @@ -60,15 +62,15 @@ object Splicing: * '{ ... * val x1: T1 = ??? * val x2: T2 = ??? - * {{{ 0 | T3 | x1, x2 | - * (x1$: Expr[T1], x2$: Expr[T2]) => // body of this lambda does not contain references to x1 or x2 - * (q: Quotes) ?=> f('{ g(${x1$}, ${x2$}) }) + * {{{ 0 | T3 | U, x1, x2 | + * [U$1] => (U$2: Type[U$1], x1$: Expr[T], x2$: Expr[U$1]) => // body of this lambda does not contain references to U, x1 or x2 + * (q: Quotes) ?=> f('{ @SplicedType type U$3 = [[[ 0 | U$2 | | U$1 ]]]; g[U$3](${x1$}, ${x2$}) }) * * }}} * } * ``` * - * and then performs the same transformation on `'{ g(${x1$}, ${x2$}) }`. + * and then performs the same transformation on `'{ @SplicedType type U$3 = [[[ 0 | U$2 | | U$1 ]]]; g[U$3](${x1$}, ${x2$}) }`. * */ class Splicing extends MacroTransform: @@ -132,7 +134,7 @@ class Splicing extends MacroTransform: case None => val holeIdx = numHoles numHoles += 1 - val hole = tpd.Hole(false, holeIdx, Nil, ref(qual), TypeTree(tp)) + val hole = tpd.Hole(false, holeIdx, Nil, ref(qual), TypeTree(tp.dealias)) typeHoles.put(qual.symbol, hole) hole cpy.TypeDef(tree)(rhs = hole) @@ -154,7 +156,7 @@ class Splicing extends MacroTransform: private def transformAnnotations(tree: DefTree)(using Context): Unit = tree.symbol.annotations = tree.symbol.annotations.mapconserve { annot => - val newAnnotTree = transform(annot.tree)(using ctx.withOwner(tree.symbol)) + val newAnnotTree = transform(annot.tree) if (annot.tree == newAnnotTree) annot else ConcreteAnnotation(newAnnotTree) } @@ -184,7 +186,7 @@ class Splicing extends MacroTransform: * ``` * is transformed into * ```scala - * {{{ | T2 | x, X | (x$1: Expr[T1], X$1: Type[X]) => (using Quotes) ?=> {... ${x$1} ... X$1.Underlying ...} }}} + * {{{ | T2 | x, X | [X$2] => (x$1: Expr[T1], X$1: Type[X$2]) => (using Quotes) ?=> {... ${x$1} ... X$1.Underlying ...} }}} * ``` */ private class SpliceTransformer(spliceOwner: Symbol, isCaptured: Symbol => Boolean) extends Transformer: @@ -198,10 +200,57 @@ class Splicing extends MacroTransform: val newTree = transform(tree) val (refs, bindings) = refBindingMap.values.toList.unzip val bindingsTypes = bindings.map(_.termRef.widenTermRefExpr) - val methType = MethodType(bindingsTypes, newTree.tpe) + val capturedTypes = bindingsTypes.collect { + case AppliedType(tycon, List(arg: TypeRef)) if tycon.derivesFrom(defn.QuotedTypeClass) => arg + } + val newTypeParams = capturedTypes.map { tpe => + newSymbol( + spliceOwner, + UniqueName.fresh(tpe.symbol.name.toTypeName), + Param, + TypeBounds.empty + ) + } + val methType = + if capturedTypes.nonEmpty then + PolyType(capturedTypes.map(tp => UniqueName.fresh(tp.symbol.name.toTypeName)))( + pt => capturedTypes.map(_ => TypeBounds.empty), + pt => { + val tpParamMap = new TypeMap { + private val mapping = capturedTypes.map(_.typeSymbol).zip(pt.paramRefs).toMap + def apply(tp: Type): Type = tp match + case tp: TypeRef => mapping.getOrElse(tp.typeSymbol, tp) + case tp => mapOver(tp) + } + MethodType(bindingsTypes.map(tpParamMap), tpParamMap(newTree.tpe)) + } + ) + else MethodType(bindingsTypes, newTree.tpe) val meth = newSymbol(spliceOwner, nme.ANON_FUN, Synthetic | Method, methType) - val ddef = DefDef(meth, List(bindings), newTree.tpe, newTree.changeOwner(ctx.owner, meth)) - val fnType = defn.FunctionType(bindings.size, isContextual = false).appliedTo(bindingsTypes :+ newTree.tpe) + + def substituteTypes(tree: Tree): Tree = + if capturedTypes.nonEmpty then + val typeIndex = capturedTypes.zipWithIndex.toMap + TreeTypeMap( + typeMap = new TypeMap { + def apply(tp: Type): Type = tp match + case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice => tp + case tp: TypeRef if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) => tp + case tp: TypeRef => + typeIndex.get(tp) match + case Some(idx) => newTypeParams(idx).typeRef + case None => mapOver(tp) + case _ => mapOver(tp) + } + ).transform(tree) + else tree + val paramss = + if capturedTypes.nonEmpty then List(newTypeParams, bindings) + else List(bindings) + val ddef = substituteTypes(DefDef(meth, paramss, newTree.tpe, newTree.changeOwner(ctx.owner, meth))) + val fnType = + if capturedTypes.isEmpty then defn.FunctionType(bindings.size, isContextual = false).appliedTo(bindingsTypes :+ newTree.tpe) + else RefinedType(defn.PolyFunctionType, nme.apply, methType) val closure = Block(ddef :: Nil, Closure(Nil, ref(meth), TypeTree(fnType))) tpd.Hole(true, holeIdx, refs, closure, TypeTree(tpe)) @@ -255,6 +304,9 @@ class Splicing extends MacroTransform: if tree.symbol == defn.QuotedTypeModule_of && containsCapturedType(tpt.tpe) => val newContent = capturedPartTypes(tpt) newContent match + case DirectTypeOf.Healed(termRef) => + // Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x` + tpd.ref(termRef).withSpan(tpt.span) case block: Block => inContext(ctx.withSource(tree.source)) { Apply(TypeApply(typeof, List(newContent)), List(quotes)).withSpan(tree.span) @@ -354,7 +406,7 @@ class Splicing extends MacroTransform: private def newQuotedTypeClassBinding(tpe: Type)(using Context) = newSymbol( spliceOwner, - UniqueName.fresh(nme.Type).toTermName, + UniqueName.fresh(tpe.typeSymbol.name.toTermName), Param, defn.QuotedTypeClass.typeRef.appliedTo(tpe), ) diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index f9240d6091c4..79ca6e6df2d8 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -669,26 +669,52 @@ object TreeChecker { if isTermHole then assert(tpt.typeOpt <:< pt) else assert(tpt.typeOpt =:= pt) - // Check that the types of the args conform to the types of the contents of the hole - val argQuotedTypes = args.map { arg => - if arg.isTerm then - val tpe = arg.typeOpt.widenTermRefExpr match - case _: MethodicType => - // Special erasure for captured function references - // See `SpliceTransformer.transformCapturedApplication` - defn.AnyType - case tpe => tpe - defn.QuotedExprClass.typeRef.appliedTo(tpe) - else defn.QuotedTypeClass.typeRef.appliedTo(arg.typeOpt.widenTermRefExpr) - } - val expectedResultType = - if isTermHole then defn.QuotedExprClass.typeRef.appliedTo(tpt.typeOpt) - else defn.QuotedTypeClass.typeRef.appliedTo(tpt.typeOpt) - val contextualResult = - defn.FunctionOf(List(defn.QuotesClass.typeRef), expectedResultType, isContextual = true) - val expectedContentType = - defn.FunctionOf(argQuotedTypes, contextualResult) - assert(content.typeOpt =:= expectedContentType, i"unexpected content of hole\nexpected: ${expectedContentType}\nwas: ${content.typeOpt}") + if content != EmptyTree then + // Check that the types of the args conform to the types of the contents of the hole + val typeArgsTypes = args.collect { case arg if arg.isType => + arg.typeOpt + } + val argQuotedTypes = args.map { arg => + if arg.isTerm then + val tpe = arg.typeOpt.widenTermRefExpr match + case _: MethodicType => + // Special erasure for captured function references + // See `SpliceTransformer.transformCapturedApplication` + defn.AnyType + case tpe => tpe + defn.QuotedExprClass.typeRef.appliedTo(tpe) + else defn.QuotedTypeClass.typeRef.appliedTo(arg.typeOpt.widenTermRefExpr) + } + val expectedResultType = + if isTermHole then defn.QuotedExprClass.typeRef.appliedTo(tpt.typeOpt) + else defn.QuotedTypeClass.typeRef.appliedTo(tpt.typeOpt) + val contextualResult = + defn.FunctionOf(List(defn.QuotesClass.typeRef), expectedResultType, isContextual = true) + val expectedContentType = + if typeArgsTypes.isEmpty then defn.FunctionOf(argQuotedTypes, contextualResult) + else RefinedType(defn.PolyFunctionType, nme.apply, PolyType(typeArgsTypes.map(_ => TypeBounds.empty))(pt => + val tpParamMap = new TypeMap { + private val mapping = typeArgsTypes.map(_.typeSymbol).zip(pt.paramRefs).toMap + def apply(tp: Type): Type = tp match + case tp: TypeRef => mapping.getOrElse(tp.typeSymbol, tp) + case tp => mapOver(tp) + } + MethodType( + args.zipWithIndex.map { case (arg, idx) => + if arg.isTerm then + val tpe = arg.typeOpt.widenTermRefExpr match + case _: MethodicType => + // Special erasure for captured function references + // See `SpliceTransformer.transformCapturedApplication` + defn.AnyType + case tpe => tpe + defn.QuotedExprClass.typeRef.appliedTo(tpParamMap(tpe)) + else defn.QuotedTypeClass.typeRef.appliedTo(tpParamMap(arg.typeOpt)) + }, + tpParamMap(contextualResult)) + ) + ) + assert(content.typeOpt =:= expectedContentType, i"unexpected content of hole\nexpected: ${expectedContentType}\nwas: ${content.typeOpt}") tree1 } diff --git a/tests/pos-macros/captured-type.scala b/tests/pos-macros/captured-type.scala new file mode 100644 index 000000000000..54e59aec7a54 --- /dev/null +++ b/tests/pos-macros/captured-type.scala @@ -0,0 +1,6 @@ +import scala.quoted.* + +object Foo: + def baz(using Quotes): Unit = '{ + def f[T](x: T): T = ${ identity('{ x: T }) } + }