Skip to content

Commit 1053899

Browse files
committed
Dealias quoted types when staging
This change improves the code generation of the contents of splices. In the splicing phase, we have to find all types that are defined in the quote but used in the splice. When there are type aliases, we can end up with several `Type[T]` for the different aliases of `T`. By dealiasing during the staging phase (just before splicing phase) we make sure that the splicer phase will only generate one `Type[T]`. By dealiasing we also optimize some situations where a type from outside a quote is inserted in the quoted code and then used in one of its splice through an alias. In this situation we can use the outer `Type[T]` directly.
1 parent d36cd2d commit 1053899

11 files changed

+95
-41
lines changed

compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,29 @@ class CrossStageSafety extends TreeMapWithStages {
122122
val targs2 = targs.map(targ => TypeTree(healType(quote.fun.srcPos)(stripAnnotsDeep(targ.tpe))))
123123
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, targs2), body2 :: Nil)
124124
else
125-
val quotes = quote.args.mapConserve(transform)
126-
body.tpe match
127-
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice =>
128-
// Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x`
129-
ref(x)
125+
object DirectTypeOfRef:
126+
def unapply(body: Tree): Option[Tree] =
127+
body.tpe match
128+
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice =>
129+
// Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x`
130+
Some(ref(x).withSpan(quote.span))
131+
case _ =>
132+
body2 match
133+
case Block(List(tdef: TypeDef), tpt: TypeTree) =>
134+
tpt.tpe match
135+
case tpe: TypeRef if tpe.typeSymbol == tdef.symbol =>
136+
tdef.rhs.tpe.hiBound match
137+
case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice =>
138+
// Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x`
139+
Some(ref(x).withSpan(quote.span))
140+
case _ => None
141+
case _ => None
142+
case _ => None
143+
144+
body match
145+
case DirectTypeOfRef(ref) => ref
130146
case _ =>
147+
val quotes = quote.args.mapConserve(transform)
131148
// `quoted.Type.of[<body>](quotes)` --> `quoted.Type.of[<body2>](quotes)`
132149
val TypeApply(fun, _) = quote.fun: @unchecked
133150
cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, body2 :: Nil), quotes)

compiler/src/dotty/tools/dotc/staging/HealType.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
3232
def apply(tp: Type): Type =
3333
tp match
3434
case tp: TypeRef =>
35-
healTypeRef(tp)
35+
tp.underlying match
36+
case TypeAlias(alias)
37+
if !tp.symbol.isTypeSplice && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
38+
this.apply(alias)
39+
case _ =>
40+
healTypeRef(tp)
3641
case tp @ TermRef(NoPrefix, _) if !tp.symbol.isStatic && level > levelOf(tp.symbol) =>
3742
levelError(tp.symbol, tp, pos)
3843
case tp: AnnotatedType =>
@@ -46,11 +51,11 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
4651
checkNotWildcardSplice(tp)
4752
if level == 0 then tp else getQuoteTypeTags.getTagRef(prefix)
4853
case prefix: TermRef if !prefix.symbol.isStatic && level > levelOf(prefix.symbol) =>
49-
dealiasAndTryHeal(prefix.symbol, tp, pos)
54+
tryHeal(prefix.symbol, tp, pos)
5055
case NoPrefix if level > levelOf(tp.symbol) && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
51-
dealiasAndTryHeal(tp.symbol, tp, pos)
56+
tryHeal(tp.symbol, tp, pos)
5257
case prefix: ThisType if level > levelOf(prefix.cls) && !tp.symbol.isStatic =>
53-
dealiasAndTryHeal(tp.symbol, tp, pos)
58+
tryHeal(tp.symbol, tp, pos)
5459
case _ =>
5560
mapOver(tp)
5661

@@ -59,11 +64,6 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
5964
case (tb: TypeBounds) :: _ => report.error(em"Cannot splice $splice because it is a wildcard type", pos)
6065
case _ =>
6166

62-
private def dealiasAndTryHeal(sym: Symbol, tp: TypeRef, pos: SrcPos): Type =
63-
val tp1 = tp.dealias
64-
if tp1 != tp then apply(tp1)
65-
else tryHeal(tp.symbol, tp, pos)
66-
6767
/** Try to heal reference to type `T` used in a higher level than its definition.
6868
* Returns a reference to a type tag generated by `QuoteTypeTags` that contains a
6969
* reference to a type alias containing the equivalent of `${summon[quoted.Type[T]]}`.

tests/pos-macros/i8100b.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import scala.quoted.*
2+
3+
def f[T](using t: Type[T])(using Quotes) =
4+
'{
5+
// @SplicedType type t$1 = t.Underlying
6+
type T2 = T // type T2 = t$1
7+
${
8+
9+
val t0: T = ???
10+
val t1: T2 = ??? // val t1: T = ???
11+
val tp1 = Type.of[T] // val tp1 = t
12+
val tp2 = Type.of[T2] // val tp2 = t
13+
'{
14+
// @SplicedType type t$2 = t.Underlying
15+
val t3: T = ??? // val t3: t$2 = ???
16+
val t4: T2 = ??? // val t4: t$2 = ???
17+
}
18+
}
19+
}
20+
21+
def g(using Quotes) =
22+
'{
23+
type U
24+
type U2 = U
25+
${
26+
27+
val u1: U = ???
28+
val u2: U2 = ??? // val u2: U = ???
29+
30+
val tp1 = Type.of[U] // val tp1 = Type.of[U]
31+
val tp2 = Type.of[U2] // val tp2 = Type.of[U]
32+
'{
33+
val u3: U = ???
34+
val u4: U2 = ??? // val u4: U = ???
35+
}
36+
}
37+
}

tests/run-macros/i12392.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
scala.Option[scala.Predef.String] to scala.Option[scala.Int]
1+
scala.Option[java.lang.String] to scala.Option[scala.Int]
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
2-
type T = scala.Predef.String
2+
type T = java.lang.String
33
val x: java.lang.String = "foo"
4-
val z: T = x
4+
val z: java.lang.String = x
55

66
(x: java.lang.String)
77
}
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
((q: scala.quoted.Quotes) ?=> {
2-
val t: scala.quoted.Type[scala.Predef.String] = scala.quoted.Type.of[scala.Predef.String](q)
2+
val t: scala.quoted.Type[java.lang.String] = scala.quoted.Type.of[java.lang.String](q)
33

4-
(t: scala.quoted.Type[scala.Predef.String])
4+
(t: scala.quoted.Type[java.lang.String])
55
})
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
2-
type T[X] = scala.List[X]
2+
type T[X] = [A >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[A][X]
33
val x: java.lang.String = "foo"
4-
val z: T[scala.Predef.String] = scala.List.apply[java.lang.String](x)
4+
val z: [X >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[X][java.lang.String] = scala.List.apply[java.lang.String](x)
55

66
(x: java.lang.String)
77
}

tests/run-staging/quote-owners-2.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
def ff: scala.Int = {
33
val a: scala.collection.immutable.List[scala.Int] = {
44
type T = scala.collection.immutable.List[scala.Int]
5-
val b: T = scala.Nil.::[scala.Int](3)
5+
val b: scala.collection.immutable.List[scala.Int] = scala.Nil.::[scala.Int](3)
66

77
(b: scala.collection.immutable.List[scala.Int])
88
}

tests/run-staging/quote-unrolled-foreach.check

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
}
99
})
1010

11-
((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
11+
((arr: scala.Array[java.lang.String], f: scala.Function1[java.lang.String, scala.Unit]) => {
1212
val size: scala.Int = arr.length
1313
var i: scala.Int = 0
1414
while (i.<(size)) {
@@ -18,7 +18,7 @@
1818
}
1919
})
2020

21-
((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
21+
((arr: scala.Array[java.lang.String], f: scala.Function1[java.lang.String, scala.Unit]) => {
2222
val size: scala.Int = arr.length
2323
var i: scala.Int = 0
2424
while (i.<(size)) {
@@ -41,7 +41,7 @@
4141
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
4242
val size: scala.Int = arr.length
4343
var i: scala.Int = 0
44-
if (size.%(3).!=(0)) throw new scala.Exception("...") else ()
44+
if (size.%(3).!=(0)) throw new java.lang.Exception("...") else ()
4545
while (i.<(size)) {
4646
f.apply(arr.apply(i))
4747
f.apply(arr.apply(i.+(1)))
@@ -53,7 +53,7 @@
5353
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
5454
val size: scala.Int = arr.length
5555
var i: scala.Int = 0
56-
if (size.%(4).!=(0)) throw new scala.Exception("...") else ()
56+
if (size.%(4).!=(0)) throw new java.lang.Exception("...") else ()
5757
while (i.<(size)) {
5858
f.apply(arr.apply(i.+(0)))
5959
f.apply(arr.apply(i.+(1)))

tests/run-staging/shonan-hmm-simple.check

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ Complex(4,3)
66
10
77

88
((arr1: scala.Array[scala.Int], arr2: scala.Array[scala.Int]) => {
9-
if (arr1.length.!=(arr2.length)) throw new scala.Exception("...") else ()
9+
if (arr1.length.!=(arr2.length)) throw new java.lang.Exception("...") else ()
1010
var sum: scala.Int = 0
1111
var i: scala.Int = 0
1212
while (i.<(scala.Predef.intArrayOps(arr1).size)) {
@@ -22,13 +22,13 @@ Complex(4,3)
2222
10
2323

2424
((arr: scala.Array[scala.Int]) => {
25-
if (arr.length.!=(5)) throw new scala.Exception("...") else ()
25+
if (arr.length.!=(5)) throw new java.lang.Exception("...") else ()
2626
arr.apply(0).+(arr.apply(2)).+(arr.apply(4))
2727
})
2828
10
2929

3030
((arr: scala.Array[Complex[scala.Int]]) => {
31-
if (arr.length.!=(4)) throw new scala.Exception("...") else ()
31+
if (arr.length.!=(4)) throw new java.lang.Exception("...") else ()
3232
Complex.apply[scala.Int](0.-(arr.apply(0).im).+(0.-(arr.apply(2).im)).+(arr.apply(3).re.*(2)), arr.apply(0).re.+(arr.apply(2).re).+(arr.apply(3).im.*(2)))
3333
})
3434
Complex(4,3)

tests/run-staging/shonan-hmm.check

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ List(25, 30, 20, 43, 44)
3535

3636

3737
((vout: scala.Array[scala.Int], a: scala.Array[scala.Array[scala.Int]], v: scala.Array[scala.Int]) => {
38-
if (3.!=(vout.length)) throw new scala.IndexOutOfBoundsException("3") else ()
39-
if (2.!=(v.length)) throw new scala.IndexOutOfBoundsException("2") else ()
38+
if (3.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("3") else ()
39+
if (2.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("2") else ()
4040
vout.update(0, 0.+(v.apply(0).*(a.apply(0).apply(0))).+(v.apply(1).*(a.apply(0).apply(1))))
4141
vout.update(1, 0.+(v.apply(0).*(a.apply(1).apply(0))).+(v.apply(1).*(a.apply(1).apply(1))))
4242
vout.update(2, 0.+(v.apply(0).*(a.apply(2).apply(0))).+(v.apply(1).*(a.apply(2).apply(1))))
@@ -95,8 +95,8 @@ List(25, 30, 20, 43, 44)
9595
array
9696
}
9797
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
98-
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
99-
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
98+
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
99+
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
100100
vout.update(0, 0.+(v.apply(0).*(5)).+(v.apply(1).*(0)).+(v.apply(2).*(0)).+(v.apply(3).*(5)).+(v.apply(4).*(0)))
101101
vout.update(1, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(10)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
102102
vout.update(2, 0.+(v.apply(0).*(0)).+(v.apply(1).*(10)).+(v.apply(2).*(0)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
@@ -158,8 +158,8 @@ List(25, 30, 20, 43, 44)
158158
array
159159
}
160160
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
161-
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
162-
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
161+
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
162+
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
163163
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
164164
vout.update(1, v.apply(2).*(10))
165165
vout.update(2, v.apply(1).*(10))
@@ -221,8 +221,8 @@ List(25, 30, 20, 43, 44)
221221
array
222222
}
223223
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
224-
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
225-
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
224+
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
225+
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
226226
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
227227
vout.update(1, v.apply(2).*(10))
228228
vout.update(2, v.apply(1).*(10))
@@ -243,8 +243,8 @@ List(25, 30, 20, 43, 44)
243243

244244

245245
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
246-
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
247-
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
246+
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
247+
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
248248
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
249249
vout.update(1, v.apply(2).*(10))
250250
vout.update(2, v.apply(1).*(10))
@@ -282,8 +282,8 @@ List(25, 30, 20, 43, 44)
282282
array
283283
}
284284
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
285-
if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
286-
if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
285+
if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
286+
if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
287287
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
288288
vout.update(1, v.apply(2).*(10))
289289
vout.update(2, v.apply(1).*(10))

0 commit comments

Comments
 (0)