Skip to content

Commit c95aff2

Browse files
committed
Add PreSpecializer phase
Creates a new `PreSpecializer` phase that detects `@specialized` annotations and registers them. This is necessary because annotations are lost further down the pipeline (and before TypeSpecializer runs.)
1 parent 62e8131 commit c95aff2

File tree

8 files changed

+146
-62
lines changed

8 files changed

+146
-62
lines changed

src/dotty/tools/dotc/Compiler.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,13 @@ class Compiler {
3838
def phases: List[List[Phase]] =
3939
List(
4040
List(new FrontEnd),
41-
List(new PostTyper),
42-
List(new Pickler),
43-
List(new FirstTransform),
44-
List(new RefChecks,
41+
List(new InstChecks),
42+
List(new FirstTransform,
43+
new SyntheticMethods),
44+
List(new SuperAccessors),
45+
List(new Pickler), // Pickler needs to come last in a group since it should not pickle trees generated later
46+
List(new PreSpecializer,
47+
new RefChecks,
4548
new ElimRepeated,
4649
new NormalizeFlags,
4750
new ExtensionMethods,

src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class ScalaSettings extends Settings.SettingGroup {
152152
val YnoDeepSubtypes = BooleanSetting("-Yno-deep-subtypes", "throw an exception on deep subtyping call stacks.")
153153
val YprintSyms = BooleanSetting("-Yprint-syms", "when printing trees print info in symbols instead of corresponding info in trees.")
154154
val YtestPickler = BooleanSetting("-Ytest-pickler", "self-test for pickling functionality; should be used with -Ystop-after:pickler")
155-
val Yspecialize = StringSetting("-Yspecialize","all","Specialize all methods.", "all") // TODO remove default value
155+
val Yspecialize = StringSetting("-Yspecialize","","Specialize all methods.", "")
156156

157157
def stop = YstopAfter
158158

src/dotty/tools/dotc/core/Phases.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ object Phases {
240240
private val explicitOuterCache = new PhaseCache(classOf[ExplicitOuter])
241241
private val gettersCache = new PhaseCache(classOf[Getters])
242242
private val genBCodeCache = new PhaseCache(classOf[GenBCode])
243+
private val specializeCache = new PhaseCache(classOf[TypeSpecializer])
243244

244245
def typerPhase = typerCache.phase
245246
def refchecksPhase = refChecksCache.phase
@@ -251,6 +252,7 @@ object Phases {
251252
def explicitOuterPhase = explicitOuterCache.phase
252253
def gettersPhase = gettersCache.phase
253254
def genBCodePhase = genBCodeCache.phase
255+
def specializePhase = specializeCache.phase
254256

255257
def isAfterTyper(phase: Phase): Boolean = phase.id > typerPhase.id
256258
}
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
package dotty.tools.dotc.transform
2+
3+
import dotty.tools.dotc.ast.Trees.{Select, Ident, SeqLiteral, Typed}
4+
import dotty.tools.dotc.ast.tpd
5+
import dotty.tools.dotc.core.Annotations.Annotation
6+
import dotty.tools.dotc.core.Contexts.Context
7+
import dotty.tools.dotc.core.DenotTransformers.InfoTransformer
8+
import dotty.tools.dotc.core.StdNames._
9+
import dotty.tools.dotc.core.{Flags, Definitions, Symbols}
10+
import dotty.tools.dotc.core.Symbols.Symbol
11+
import dotty.tools.dotc.core.Types.{TermRef, TypeRef, OrType, Type}
12+
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTransform}
13+
14+
import scala.collection.mutable
15+
16+
/**
17+
* This phase runs before {what phase ?}, so as to retrieve all `@specialized`
18+
* anotations before they are thrown away, and stores them through a `PhaseCache`
19+
* for the `TypeSpecializer` phase.
20+
*/
21+
class PreSpecializer extends MiniPhaseTransform with InfoTransformer {
22+
23+
override def phaseName: String = "prespecialize"
24+
25+
private val specTypes: mutable.HashMap[Symbols.Symbol, List[Type]] = mutable.HashMap.empty
26+
27+
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type = {
28+
29+
def getSpecTypes(sym: Symbol)(implicit ctx: Context): List[Type] = {
30+
31+
def allowedToSpecialize(sym: Symbol): Boolean = {
32+
sym.name != nme.asInstanceOf_ &&
33+
!(sym is Flags.JavaDefined) &&
34+
!sym.isConstructor//isPrimaryConstructor
35+
}
36+
37+
if (allowedToSpecialize(sym)) {
38+
val annotation = sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil)
39+
annotation match {
40+
case annot: Annotation =>
41+
val args = annot.arguments
42+
if (args.isEmpty) primitiveTypes
43+
else args.head match {
44+
case a@Typed(SeqLiteral(types), _) => types.map(t => nameToType(t.tpe))
45+
case a@Select(Ident(_), _) => {
46+
println(a)
47+
primitiveTypes
48+
}
49+
case _ => {
50+
println("Nonono")
51+
ctx.error("surprising match on specialized annotation"); Nil
52+
}
53+
}
54+
case nil => Nil
55+
}
56+
} else Nil
57+
}
58+
val st = getSpecTypes(sym)
59+
if (st.nonEmpty) {
60+
specTypes.put(sym, st)
61+
}
62+
tp
63+
}
64+
65+
private final def nameToType(name: Type)(implicit ctx: Context) =
66+
name.asInstanceOf[TermRef].name.toString match {
67+
case s if s.startsWith("Int") => defn.IntType
68+
case s if s.startsWith("Boolean") => defn.BooleanType
69+
case s if s.startsWith("Byte") => defn.ByteType
70+
case s if s.startsWith("Long") => defn.LongType
71+
case s if s.startsWith("Short") => defn.ShortType
72+
case s if s.startsWith("Float") => defn.FloatType
73+
case s if s.startsWith("Unit") => defn.UnitType
74+
case s if s.startsWith("Double") => defn.DoubleType
75+
case s if s.startsWith("Char") => defn.CharType
76+
}
77+
78+
def defn(implicit ctx: Context): Definitions = ctx.definitions
79+
80+
private def primitiveTypes(implicit ctx: Context) =
81+
List(ctx.definitions.ByteType,
82+
ctx.definitions.BooleanType,
83+
ctx.definitions.ShortType,
84+
ctx.definitions.IntType,
85+
ctx.definitions.LongType,
86+
ctx.definitions.FloatType,
87+
ctx.definitions.DoubleType,
88+
ctx.definitions.CharType,
89+
ctx.definitions.UnitType
90+
)
91+
92+
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
93+
specTypes.keys.foreach(
94+
sym => ctx.specializePhase.asInstanceOf[TypeSpecializer].registerSpecializationRequest(tree.symbol)(specTypes(sym))
95+
)
96+
tree
97+
}
98+
}

src/dotty/tools/dotc/transform/TypeSpecializer.scala

Lines changed: 15 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,6 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
2020

2121
final val maxTparamsToSpecialize = 2
2222

23-
private final def nameToSpecialisedType(implicit ctx: Context) =
24-
Map("Byte" -> ctx.definitions.ByteType,
25-
"Boolean" -> ctx.definitions.BooleanType,
26-
"Short" -> ctx.definitions.ShortType,
27-
"Int" -> ctx.definitions.IntType,
28-
"Long" -> ctx.definitions.LongType,
29-
"Float" -> ctx.definitions.FloatType,
30-
"Double" -> ctx.definitions.DoubleType,
31-
"Char" -> ctx.definitions.CharType,
32-
"Unit" -> ctx.definitions.UnitType)
33-
3423
private final def specialisedTypeToSuffix(implicit ctx: Context) =
3524
Map(ctx.definitions.ByteType -> "$mcB$sp",
3625
ctx.definitions.BooleanType -> "$mcZ$sp",
@@ -54,7 +43,7 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
5443
ctx.definitions.UnitType
5544
)
5645

57-
private val specializationRequests: mutable.HashMap[Symbols.Symbol, List[List[Type]]] = mutable.HashMap.empty
46+
private val specializationRequests: mutable.HashMap[Symbols.Symbol, List[Type]] = mutable.HashMap.empty
5847

5948
/**
6049
* A map that links symbols to their specialized variants.
@@ -63,14 +52,12 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
6352
private val newSymbolMap: mutable.HashMap[Symbol, mutable.HashMap[List[Type], Symbols.Symbol]] = mutable.HashMap.empty
6453

6554
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type = {
66-
67-
def generateSpecializations(remainingTParams: List[Name], remainingBounds: List[TypeBounds], specTypes: List[Type])
55+
def generateSpecializations(remainingTParams: List[Name], specTypes: List[Type])
6856
(instantiations: List[Type], names: List[String], poly: PolyType, decl: Symbol)
6957
(implicit ctx: Context): List[Symbol] = {
7058
if (remainingTParams.nonEmpty) {
71-
val bounds = remainingBounds.head
7259
val specializations = (for (tpe <- specTypes) yield {
73-
generateSpecializations(remainingTParams.tail, remainingBounds.tail, specTypes)(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
60+
generateSpecializations(remainingTParams.tail, specTypes)(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
7461
}).flatten
7562
specializations
7663
}
@@ -96,12 +83,15 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
9683
sym.info match {
9784
case classInfo: ClassInfo =>
9885
val newDecls = classInfo.decls.filterNot(_.isConstructor/*isPrimaryConstructor*/).flatMap(decl => {
86+
if(decl.name.toString.contains("foobar")) {
87+
println("hello")
88+
}
9989
if (shouldSpecialize(decl)) {
10090
decl.info.widen match {
10191
case poly: PolyType =>
10292
if (poly.paramNames.length <= maxTparamsToSpecialize && poly.paramNames.length > 0) {
103-
val specTypes = getSpecTypes(sym)
104-
generateSpecializations(poly.paramNames, poly.paramBounds, specTypes)(List.empty, List.empty, poly, decl)
93+
val specTypes = getSpecTypes(decl).filter(tpe => poly.paramBounds.forall(_.contains(tpe)))
94+
generateSpecializations(poly.paramNames, specTypes)(List.empty, List.empty, poly, decl)
10595
}
10696
else Nil
10797
case nil => Nil
@@ -120,16 +110,11 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
120110
}
121111

122112
def getSpecTypes(sym: Symbol)(implicit ctx: Context): List[Type] = {
123-
sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil) match {
124-
case annot: Annotation =>
125-
annot.arguments match {
126-
case List(SeqLiteral(types)) =>
127-
types.map(tpeTree => nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString()))
128-
case List() => primitiveTypes
129-
}
130-
case nil =>
131-
if(ctx.settings.Yspecialize.value == "all") primitiveTypes
132-
else Nil
113+
val requested = specializationRequests.getOrElse(sym, List())
114+
if (requested.nonEmpty) requested.toList
115+
else {
116+
if(ctx.settings.Yspecialize.value == "all") primitiveTypes
117+
else Nil
133118
}
134119
}
135120

@@ -142,35 +127,8 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
142127
if(ctx.phaseId > this.treeTransformPhase.id)
143128
assert(ctx.phaseId <= this.treeTransformPhase.id)
144129
val prev = specializationRequests.getOrElse(method, List.empty)
145-
specializationRequests.put(method, arguments :: prev)
130+
specializationRequests.put(method, arguments ::: prev)
146131
}
147-
/*
148-
def specializeForAll(sym: Symbols.Symbol)(implicit ctx: Context): List[Type] = {
149-
registerSpecializationRequest(sym)(primitiveTypes)
150-
println(s"Specializing $sym for all primitive types")
151-
specializationRequests.getOrElse(sym, Nil).flatten
152-
}
153-
154-
def specializeForSome(sym: Symbols.Symbol)(annotationArgs: List[Type])(implicit ctx: Context): List[Type] = {
155-
registerSpecializationRequest(sym)(annotationArgs)
156-
println(s"specializationRequests : $specializationRequests")
157-
specializationRequests.getOrElse(sym, Nil).flatten
158-
}
159-
160-
def specializeFor(sym: Symbols.Symbol)(implicit ctx: Context): List[Type] = {
161-
sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil) match {
162-
case annot: Annotation =>
163-
annot.arguments match {
164-
case List(SeqLiteral(types)) =>
165-
specializeForSome(sym)(types.map(tpeTree =>
166-
nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString()))) // Not sure how to match TermRefs rather than type names
167-
case List() => specializeForAll(sym)
168-
}
169-
case nil =>
170-
if(ctx.settings.Yspecialize.value == "all") specializeForAll(sym)
171-
else Nil
172-
}
173-
}*/
174132

175133
override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo): Tree = {
176134

@@ -228,7 +186,7 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
228186
assert(betterDefs.length < 2) // TODO: How to select the best if there are several ?
229187

230188
if (betterDefs.nonEmpty) {
231-
println(s"method $fun rewired to specialozed variant with type (${betterDefs.head._1})")
189+
println(s"method $fun rewired to specialized variant with type (${betterDefs.head._1})")
232190
val prefix = fun match {
233191
case Select(pre, name) =>
234192
pre

test/dotc/tests.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,4 +249,11 @@ class tests extends CompilerTest {
249249
@Test def mini_more = compileFiles(miniMoreDir)//, List("-Xprint:all"))
250250
//@Test def pos_all = compileFiles(posDir)//, List("-Xprint:all"))
251251

252+
@Test def pos_mutual_spec = compileFile(posDir, "mutual_specialization", List("-Xprint:all"))
253+
//@Test def pos_mutual_spec = compileFile(posDir, "mutual_specialization")
254+
//@Test def pos_spec = compileFile(posDir, "specialization")
255+
*/
256+
@Test def pos_return_spec = compileFile(posDir, "return_specialization")
257+
// @Test def pos_si7638 = compileFile(posDir, "SI-7638", List("-Xprint:all"))
258+
252259
}

tests/pos/mutual_specialization.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object mutual_specialization {
2+
class A[T] {
3+
def foo[T](b: B[T], n: Int): Unit = if (n > 0) b.bar(this, n-1)
4+
}
5+
class B[T] {
6+
def bar[T](a: A[T], n: Int): Unit = if (n > 0) a.foo(this, n-1)
7+
}
8+
def foobar[@specialized(Int, Float, Double) T](n: T): Unit = new A[T].foo(new B[T], 5)
9+
foobar(5)
10+
}

tests/pos/return_specialization.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
object return_specialization {
2+
def qwa[@specialized T](a: (String, String) => T, b: T): T = {
3+
if(a ne this) return a("1", "2")
4+
else b
5+
}
6+
}

0 commit comments

Comments
 (0)