Skip to content

Dealias quoted types when staging #17059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,30 +107,37 @@ class CrossStageSafety extends TreeMapWithStages {
val stripAnnotsDeep: TypeMap = new TypeMap:
def apply(tp: Type): Type = mapOver(tp.stripAnnots)

val contextWithQuote =
if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext)
else quoteContext
val body1 = transform(body)(using contextWithQuote)
val body2 =
def transformBody() =
val contextWithQuote =
if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext)
else quoteContext
val transformedBody = transform(body)(using contextWithQuote)
taggedTypes.getTypeTags match
case Nil => body1
case tags => tpd.Block(tags, body1).withSpan(body.span)
case Nil => transformedBody
case tags => tpd.Block(tags, transformedBody).withSpan(body.span)

if body.isTerm then
val transformedBody = transformBody()
// `quoted.runtime.Expr.quote[T](<body>)` --> `quoted.runtime.Expr.quote[T2](<body2>)`
val TypeApply(fun, targs) = quote.fun: @unchecked
val targs2 = targs.map(targ => TypeTree(healType(quote.fun.srcPos)(stripAnnotsDeep(targ.tpe))))
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, targs2), body2 :: Nil)
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, targs2), transformedBody :: Nil)
else
val quotes = quote.args.mapConserve(transform)
body.tpe match
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice =>
case DirectTypeOf(termRef) =>
// Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x`
ref(x)
ref(termRef).withSpan(quote.span)
case _ =>
// `quoted.Type.of[<body>](quotes)` --> `quoted.Type.of[<body2>](quotes)`
val TypeApply(fun, _) = quote.fun: @unchecked
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, body2 :: Nil), quotes)
transformBody() match
case DirectTypeOf.Healed(termRef) =>
// Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x`
ref(termRef).withSpan(quote.span)
case transformedBody =>
val quotes = quote.args.mapConserve(transform)
// `quoted.Type.of[<body>](quotes)` --> `quoted.Type.of[<body2>](quotes)`
val TypeApply(fun, _) = quote.fun: @unchecked
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, transformedBody :: Nil), quotes)

}

/** Transform splice
Expand Down
25 changes: 25 additions & 0 deletions compiler/src/dotty/tools/dotc/staging/DirectTypeOf.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package dotty.tools.dotc.staging

import dotty.tools.dotc.ast.{tpd, untpd}
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.Types._

object DirectTypeOf:
import tpd.*

/** Matches `x.Underlying` and extracts the TermRef to `x` */
def unapply(tpe: Type)(using Context): Option[TermRef] = tpe match
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice => Some(x)
case _ => None

object Healed:
/** Matches `{ @SplicedType type T = x.Underlying; T }` and extracts the TermRef to `x` */
def unapply(body: Tree)(using Context): Option[TermRef] =
body match
case Block(List(tdef: TypeDef), tpt: TypeTree) =>
tpt.tpe match
case tpe: TypeRef if tpe.typeSymbol == tdef.symbol =>
DirectTypeOf.unapply(tdef.rhs.tpe.hiBound)
case _ => None
case _ => None
18 changes: 9 additions & 9 deletions compiler/src/dotty/tools/dotc/staging/HealType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
def apply(tp: Type): Type =
tp match
case tp: TypeRef =>
healTypeRef(tp)
tp.underlying match
case TypeAlias(alias)
if !tp.symbol.isTypeSplice && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
this.apply(alias)
case _ =>
healTypeRef(tp)
case tp @ TermRef(NoPrefix, _) if !tp.symbol.isStatic && level > levelOf(tp.symbol) =>
levelError(tp.symbol, tp, pos)
case tp: AnnotatedType =>
Expand All @@ -46,11 +51,11 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
checkNotWildcardSplice(tp)
if level == 0 then tp else getQuoteTypeTags.getTagRef(prefix)
case prefix: TermRef if !prefix.symbol.isStatic && level > levelOf(prefix.symbol) =>
dealiasAndTryHeal(prefix.symbol, tp, pos)
tryHeal(prefix.symbol, tp, pos)
case NoPrefix if level > levelOf(tp.symbol) && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
dealiasAndTryHeal(tp.symbol, tp, pos)
tryHeal(tp.symbol, tp, pos)
case prefix: ThisType if level > levelOf(prefix.cls) && !tp.symbol.isStatic =>
dealiasAndTryHeal(tp.symbol, tp, pos)
tryHeal(tp.symbol, tp, pos)
case _ =>
mapOver(tp)

Expand All @@ -59,11 +64,6 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
case (tb: TypeBounds) :: _ => report.error(em"Cannot splice $splice because it is a wildcard type", pos)
case _ =>

private def dealiasAndTryHeal(sym: Symbol, tp: TypeRef, pos: SrcPos): Type =
val tp1 = tp.dealias
if tp1 != tp then apply(tp1)
else tryHeal(tp.symbol, tp, pos)

/** Try to heal reference to type `T` used in a higher level than its definition.
* Returns a reference to a type tag generated by `QuoteTypeTags` that contains a
* reference to a type alias containing the equivalent of `${summon[quoted.Type[T]]}`.
Expand Down
37 changes: 37 additions & 0 deletions tests/pos-macros/i8100b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import scala.quoted.*

def f[T](using t: Type[T])(using Quotes) =
'{
// @SplicedType type t$1 = t.Underlying
type T2 = T // type T2 = t$1
${

val t0: T = ???
val t1: T2 = ??? // val t1: T = ???
val tp1 = Type.of[T] // val tp1 = t
val tp2 = Type.of[T2] // val tp2 = t
'{
// @SplicedType type t$2 = t.Underlying
val t3: T = ??? // val t3: t$2 = ???
val t4: T2 = ??? // val t4: t$2 = ???
}
}
}

def g(using Quotes) =
'{
type U
type U2 = U
${

val u1: U = ???
val u2: U2 = ??? // val u2: U = ???

val tp1 = Type.of[U] // val tp1 = Type.of[U]
val tp2 = Type.of[U2] // val tp2 = Type.of[U]
'{
val u3: U = ???
val u4: U2 = ??? // val u4: U = ???
}
}
}
2 changes: 1 addition & 1 deletion tests/run-macros/i12392.check
Original file line number Diff line number Diff line change
@@ -1 +1 @@
scala.Option[scala.Predef.String] to scala.Option[scala.Int]
scala.Option[java.lang.String] to scala.Option[scala.Int]
4 changes: 2 additions & 2 deletions tests/run-staging/quote-nested-3.check
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
type T = scala.Predef.String
type T = java.lang.String
val x: java.lang.String = "foo"
val z: T = x
val z: java.lang.String = x

(x: java.lang.String)
}
4 changes: 2 additions & 2 deletions tests/run-staging/quote-nested-4.check
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
((q: scala.quoted.Quotes) ?=> {
val t: scala.quoted.Type[scala.Predef.String] = scala.quoted.Type.of[scala.Predef.String](q)
val t: scala.quoted.Type[java.lang.String] = scala.quoted.Type.of[java.lang.String](q)

(t: scala.quoted.Type[scala.Predef.String])
(t: scala.quoted.Type[java.lang.String])
})
4 changes: 2 additions & 2 deletions tests/run-staging/quote-nested-6.check
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
type T[X] = scala.List[X]
type T[X] = [A >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[A][X]
val x: java.lang.String = "foo"
val z: T[scala.Predef.String] = scala.List.apply[java.lang.String](x)
val z: [X >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[X][java.lang.String] = scala.List.apply[java.lang.String](x)

(x: java.lang.String)
}
2 changes: 1 addition & 1 deletion tests/run-staging/quote-owners-2.check
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
def ff: scala.Int = {
val a: scala.collection.immutable.List[scala.Int] = {
type T = scala.collection.immutable.List[scala.Int]
val b: T = scala.Nil.::[scala.Int](3)
val b: scala.collection.immutable.List[scala.Int] = scala.Nil.::[scala.Int](3)

(b: scala.collection.immutable.List[scala.Int])
}
Expand Down
8 changes: 4 additions & 4 deletions tests/run-staging/quote-unrolled-foreach.check
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
}
})

((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
((arr: scala.Array[java.lang.String], f: scala.Function1[java.lang.String, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
while (i.<(size)) {
Expand All @@ -18,7 +18,7 @@
}
})

((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
((arr: scala.Array[java.lang.String], f: scala.Function1[java.lang.String, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
while (i.<(size)) {
Expand All @@ -41,7 +41,7 @@
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
if (size.%(3).!=(0)) throw new scala.Exception("...") else ()
if (size.%(3).!=(0)) throw new java.lang.Exception("...") else ()
while (i.<(size)) {
f.apply(arr.apply(i))
f.apply(arr.apply(i.+(1)))
Expand All @@ -53,7 +53,7 @@
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
if (size.%(4).!=(0)) throw new scala.Exception("...") else ()
if (size.%(4).!=(0)) throw new java.lang.Exception("...") else ()
while (i.<(size)) {
f.apply(arr.apply(i.+(0)))
f.apply(arr.apply(i.+(1)))
Expand Down
6 changes: 3 additions & 3 deletions tests/run-staging/shonan-hmm-simple.check
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ Complex(4,3)
10

((arr1: scala.Array[scala.Int], arr2: scala.Array[scala.Int]) => {
if (arr1.length.!=(arr2.length)) throw new scala.Exception("...") else ()
if (arr1.length.!=(arr2.length)) throw new java.lang.Exception("...") else ()
var sum: scala.Int = 0
var i: scala.Int = 0
while (i.<(scala.Predef.intArrayOps(arr1).size)) {
Expand All @@ -22,13 +22,13 @@ Complex(4,3)
10

((arr: scala.Array[scala.Int]) => {
if (arr.length.!=(5)) throw new scala.Exception("...") else ()
if (arr.length.!=(5)) throw new java.lang.Exception("...") else ()
arr.apply(0).+(arr.apply(2)).+(arr.apply(4))
})
10

((arr: scala.Array[Complex[scala.Int]]) => {
if (arr.length.!=(4)) throw new scala.Exception("...") else ()
if (arr.length.!=(4)) throw new java.lang.Exception("...") else ()
Complex.apply[scala.Int](0.-(arr.apply(0).im).+(0.-(arr.apply(2).im)).+(arr.apply(3).re.*(2)), arr.apply(0).re.+(arr.apply(2).re).+(arr.apply(3).im.*(2)))
})
Complex(4,3)
Expand Down
24 changes: 12 additions & 12 deletions tests/run-staging/shonan-hmm.check
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ List(25, 30, 20, 43, 44)


((vout: scala.Array[scala.Int], a: scala.Array[scala.Array[scala.Int]], v: scala.Array[scala.Int]) => {
if (3.!=(vout.length)) throw new scala.IndexOutOfBoundsException("3") else ()
if (2.!=(v.length)) throw new scala.IndexOutOfBoundsException("2") else ()
if (3.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("3") else ()
if (2.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("2") else ()
vout.update(0, 0.+(v.apply(0).*(a.apply(0).apply(0))).+(v.apply(1).*(a.apply(0).apply(1))))
vout.update(1, 0.+(v.apply(0).*(a.apply(1).apply(0))).+(v.apply(1).*(a.apply(1).apply(1))))
vout.update(2, 0.+(v.apply(0).*(a.apply(2).apply(0))).+(v.apply(1).*(a.apply(2).apply(1))))
Expand Down Expand Up @@ -95,8 +95,8 @@ List(25, 30, 20, 43, 44)
array
}
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, 0.+(v.apply(0).*(5)).+(v.apply(1).*(0)).+(v.apply(2).*(0)).+(v.apply(3).*(5)).+(v.apply(4).*(0)))
vout.update(1, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(10)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
vout.update(2, 0.+(v.apply(0).*(0)).+(v.apply(1).*(10)).+(v.apply(2).*(0)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
Expand Down Expand Up @@ -158,8 +158,8 @@ List(25, 30, 20, 43, 44)
array
}
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
vout.update(1, v.apply(2).*(10))
vout.update(2, v.apply(1).*(10))
Expand Down Expand Up @@ -221,8 +221,8 @@ List(25, 30, 20, 43, 44)
array
}
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
vout.update(1, v.apply(2).*(10))
vout.update(2, v.apply(1).*(10))
Expand All @@ -243,8 +243,8 @@ List(25, 30, 20, 43, 44)


((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
vout.update(1, v.apply(2).*(10))
vout.update(2, v.apply(1).*(10))
Expand Down Expand Up @@ -282,8 +282,8 @@ List(25, 30, 20, 43, 44)
array
}
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
vout.update(1, v.apply(2).*(10))
vout.update(2, v.apply(1).*(10))
Expand Down