Skip to content

Commit 10e9ef1

Browse files
committed
Specialize methods defined inside of other methods
Specialize inner methods. Also adds some tests of specialization.
1 parent c95aff2 commit 10e9ef1

12 files changed

+194
-98
lines changed

src/dotty/tools/dotc/transform/PreSpecializer.scala

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTran
1414
import scala.collection.mutable
1515

1616
/**
17-
* This phase runs before {what phase ?}, so as to retrieve all `@specialized`
18-
* anotations before they are thrown away, and stores them through a `PhaseCache`
19-
* for the `TypeSpecializer` phase.
17+
* This phase retrieves all `@specialized` anotations before they are thrown away,
18+
* and stores them for the `TypeSpecializer` phase.
2019
*/
2120
class PreSpecializer extends MiniPhaseTransform with InfoTransformer {
2221

@@ -30,6 +29,7 @@ class PreSpecializer extends MiniPhaseTransform with InfoTransformer {
3029

3130
def allowedToSpecialize(sym: Symbol): Boolean = {
3231
sym.name != nme.asInstanceOf_ &&
32+
sym.name != nme.isInstanceOf_ &&
3333
!(sym is Flags.JavaDefined) &&
3434
!sym.isConstructor//isPrimaryConstructor
3535
}
@@ -41,23 +41,17 @@ class PreSpecializer extends MiniPhaseTransform with InfoTransformer {
4141
val args = annot.arguments
4242
if (args.isEmpty) primitiveTypes
4343
else args.head match {
44-
case a@Typed(SeqLiteral(types), _) => types.map(t => nameToType(t.tpe))
45-
case a@Select(Ident(_), _) => {
46-
println(a)
47-
primitiveTypes
48-
}
49-
case _ => {
50-
println("Nonono")
51-
ctx.error("surprising match on specialized annotation"); Nil
52-
}
44+
case a@Typed(SeqLiteral(types), _) => types.map(t => nameToType(t.tpe)) // Matches the expected `@specialized(...)` annotations
45+
case a@Select(Ident(_), _) => primitiveTypes // Matches `Select(Ident(Specializable), Primitives)` which is used in several instances
46+
case _ => ctx.error("surprising match on specialized annotation"); Nil
5347
}
5448
case nil => Nil
5549
}
5650
} else Nil
5751
}
5852
val st = getSpecTypes(sym)
5953
if (st.nonEmpty) {
60-
specTypes.put(sym, st)
54+
specTypes.put(sym.owner, st)
6155
}
6256
tp
6357
}
@@ -90,9 +84,8 @@ class PreSpecializer extends MiniPhaseTransform with InfoTransformer {
9084
)
9185

9286
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = {
93-
specTypes.keys.foreach(
94-
sym => ctx.specializePhase.asInstanceOf[TypeSpecializer].registerSpecializationRequest(tree.symbol)(specTypes(sym))
95-
)
87+
val st = specTypes.getOrElse(tree.symbol, List())
88+
if (st.nonEmpty) ctx.specializePhase.asInstanceOf[TypeSpecializer].registerSpecializationRequest(tree.symbol)(st)
9689
tree
9790
}
9891
}

src/dotty/tools/dotc/transform/TypeSpecializer.scala

Lines changed: 77 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,30 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
2121
final val maxTparamsToSpecialize = 2
2222

2323
private final def specialisedTypeToSuffix(implicit ctx: Context) =
24-
Map(ctx.definitions.ByteType -> "$mcB$sp",
25-
ctx.definitions.BooleanType -> "$mcZ$sp",
26-
ctx.definitions.ShortType -> "$mcS$sp",
27-
ctx.definitions.IntType -> "$mcI$sp",
28-
ctx.definitions.LongType -> "$mcJ$sp",
29-
ctx.definitions.FloatType -> "$mcF$sp",
30-
ctx.definitions.DoubleType -> "$mcD$sp",
31-
ctx.definitions.CharType -> "$mcC$sp",
32-
ctx.definitions.UnitType -> "$mcV$sp")
24+
Map(defn.ByteType -> "$mcB$sp",
25+
defn.BooleanType -> "$mcZ$sp",
26+
defn.ShortType -> "$mcS$sp",
27+
defn.IntType -> "$mcI$sp",
28+
defn.LongType -> "$mcJ$sp",
29+
defn.FloatType -> "$mcF$sp",
30+
defn.DoubleType -> "$mcD$sp",
31+
defn.CharType -> "$mcC$sp",
32+
defn.UnitType -> "$mcV$sp")
3333

3434
private def primitiveTypes(implicit ctx: Context) =
35-
List(ctx.definitions.ByteType,
36-
ctx.definitions.BooleanType,
37-
ctx.definitions.ShortType,
38-
ctx.definitions.IntType,
39-
ctx.definitions.LongType,
40-
ctx.definitions.FloatType,
41-
ctx.definitions.DoubleType,
42-
ctx.definitions.CharType,
43-
ctx.definitions.UnitType
35+
List(defn.ByteType,
36+
defn.BooleanType,
37+
defn.ShortType,
38+
defn.IntType,
39+
defn.LongType,
40+
defn.FloatType,
41+
defn.DoubleType,
42+
defn.CharType,
43+
defn.UnitType
4444
)
4545

46+
private def defn(implicit ctx:Context) = ctx.definitions
47+
4648
private val specializationRequests: mutable.HashMap[Symbols.Symbol, List[Type]] = mutable.HashMap.empty
4749

4850
/**
@@ -51,15 +53,45 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
5153
*/
5254
private val newSymbolMap: mutable.HashMap[Symbol, mutable.HashMap[List[Type], Symbols.Symbol]] = mutable.HashMap.empty
5355

56+
def allowedToSpecialize(sym: Symbol, numOfTypes: Int)(implicit ctx: Context): Boolean = {
57+
numOfTypes <= maxTparamsToSpecialize &&
58+
numOfTypes > 0 &&
59+
sym.name != nme.asInstanceOf_ &&
60+
sym.name != nme.isInstanceOf_ &&
61+
!(sym is Flags.JavaDefined) &&
62+
!sym.isConstructor &&
63+
!sym.name.toString.contains("Function2")
64+
}
65+
66+
def getSpecTypes(sym: Symbol, poly: PolyType)(implicit ctx: Context): List[Type] = {
67+
val requested = specializationRequests.getOrElse(sym, List())
68+
if (requested.nonEmpty) requested.toList
69+
else {
70+
if(ctx.settings.Yspecialize.value == "all") primitiveTypes
71+
else Nil
72+
}.filter(tpe => poly.paramBounds.forall(_.contains(tpe)))
73+
}
74+
75+
def requestedSpecialization(decl: Symbol)(implicit ctx: Context): Boolean =
76+
specializationRequests.contains(decl) ||
77+
(ctx.settings.Yspecialize.value != "" && decl.name.contains(ctx.settings.Yspecialize.value)) ||
78+
ctx.settings.Yspecialize.value == "all"
79+
80+
def registerSpecializationRequest(method: Symbols.Symbol)(arguments: List[Type])(implicit ctx: Context) = {
81+
if(ctx.phaseId > this.treeTransformPhase.id)
82+
assert(ctx.phaseId <= this.treeTransformPhase.id)
83+
val prev = specializationRequests.getOrElse(method, List.empty)
84+
specializationRequests.put(method, (arguments ::: prev).toSet.toList)
85+
}
86+
5487
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type = {
5588
def generateSpecializations(remainingTParams: List[Name], specTypes: List[Type])
5689
(instantiations: List[Type], names: List[String], poly: PolyType, decl: Symbol)
5790
(implicit ctx: Context): List[Symbol] = {
5891
if (remainingTParams.nonEmpty) {
59-
val specializations = (for (tpe <- specTypes) yield {
92+
(for (tpe <- specTypes) yield {
6093
generateSpecializations(remainingTParams.tail, specTypes)(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
6194
}).flatten
62-
specializations
6395
}
6496
else {
6597
generateSpecializedSymbols(instantiations.reverse, names.reverse, poly, decl)
@@ -76,65 +108,39 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
76108
map.values.toList
77109
}
78110

79-
if((sym ne ctx.definitions.ScalaPredefModule.moduleClass) &&
111+
if((sym ne defn.ScalaPredefModule.moduleClass) &&
80112
!(sym is Flags.Package) &&
81-
!sym.isAnonymousClass &&
82-
!(sym.name == nme.asInstanceOf_)) {
113+
!sym.isAnonymousClass) {
83114
sym.info match {
84115
case classInfo: ClassInfo =>
85-
val newDecls = classInfo.decls.filterNot(_.isConstructor/*isPrimaryConstructor*/).flatMap(decl => {
86-
if(decl.name.toString.contains("foobar")) {
87-
println("hello")
88-
}
89-
if (shouldSpecialize(decl)) {
116+
val newDecls = classInfo.decls
117+
.filterNot(_.isConstructor)
118+
.filter(requestedSpecialization)
119+
.flatMap(decl => {
90120
decl.info.widen match {
91-
case poly: PolyType =>
92-
if (poly.paramNames.length <= maxTparamsToSpecialize && poly.paramNames.length > 0) {
93-
val specTypes = getSpecTypes(decl).filter(tpe => poly.paramBounds.forall(_.contains(tpe)))
94-
generateSpecializations(poly.paramNames, specTypes)(List.empty, List.empty, poly, decl)
95-
}
96-
else Nil
121+
case poly: PolyType if allowedToSpecialize(decl.symbol, poly.paramNames.length) =>
122+
generateSpecializations(poly.paramNames, getSpecTypes(decl, poly))(List.empty, List.empty, poly, decl)
97123
case nil => Nil
98124
}
99-
} else Nil
100125
})
101-
if (newDecls.nonEmpty) {
102126
val decls = classInfo.decls.cloneScope
103127
newDecls.foreach(decls.enter)
104128
classInfo.derivedClassInfo(decls = decls)
105-
}
129+
case poly: PolyType if !newSymbolMap.contains(sym) &&
130+
requestedSpecialization(sym) &&
131+
allowedToSpecialize(sym, poly.paramNames.length)=>
132+
generateSpecializations(poly.paramNames, getSpecTypes(sym, poly))(List.empty, List.empty, poly, sym)
106133
case nil =>
107134
}
108135
tp
109136
} else tp
110137
}
111138

112-
def getSpecTypes(sym: Symbol)(implicit ctx: Context): List[Type] = {
113-
val requested = specializationRequests.getOrElse(sym, List())
114-
if (requested.nonEmpty) requested.toList
115-
else {
116-
if(ctx.settings.Yspecialize.value == "all") primitiveTypes
117-
else Nil
118-
}
119-
}
120-
121-
def shouldSpecialize(decl: Symbol)(implicit ctx: Context): Boolean =
122-
specializationRequests.contains(decl) ||
123-
(ctx.settings.Yspecialize.value != "" && decl.name.contains(ctx.settings.Yspecialize.value)) ||
124-
ctx.settings.Yspecialize.value == "all"
125-
126-
def registerSpecializationRequest(method: Symbols.Symbol)(arguments: List[Type])(implicit ctx: Context) = {
127-
if(ctx.phaseId > this.treeTransformPhase.id)
128-
assert(ctx.phaseId <= this.treeTransformPhase.id)
129-
val prev = specializationRequests.getOrElse(method, List.empty)
130-
specializationRequests.put(method, arguments ::: prev)
131-
}
132-
133139
override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo): Tree = {
134140

135141
tree.tpe.widen match {
136142

137-
case poly: PolyType if !(tree.symbol.isConstructor//isPrimaryConstructor
143+
case poly: PolyType if !(tree.symbol.isConstructor
138144
|| (tree.symbol is Flags.Label))
139145
|| (tree.symbol.name == nme.asInstanceOf_) =>
140146
val origTParams = tree.tparams.map(_.symbol)
@@ -162,31 +168,35 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
162168
}
163169
} else Nil
164170
}
165-
val specializedMethods = specialize(tree.symbol)
166-
Thicket(tree :: specializedMethods)
171+
Thicket(tree :: specialize(tree.symbol))
167172
case _ => tree
168173
}
169174
}
170175

171176
override def transformTypeApply(tree: tpd.TypeApply)(implicit ctx: Context, info: TransformerInfo): Tree = {
172177

178+
173179
def allowedToSpecialize(sym: Symbol): Boolean = {
174180
sym.name != nme.asInstanceOf_ &&
175181
!(sym is Flags.JavaDefined) &&
176182
!sym.isConstructor//isPrimaryConstructor
177183
}
184+
178185
val TypeApply(fun,args) = tree
179-
if (newSymbolMap.contains(fun.symbol) && allowedToSpecialize(fun.symbol)) {
186+
if (newSymbolMap.contains(fun.symbol)){
180187
val newSymInfos = newSymbolMap(fun.symbol)
181188
val betterDefs = newSymInfos.filter(x => (x._1 zip args).forall{a =>
182189
val specializedType = a._1
183190
val argType = a._2
184191
argType.tpe <:< specializedType
185192
}).toList
186-
assert(betterDefs.length < 2) // TODO: How to select the best if there are several ?
193+
194+
if (betterDefs.length > 1) ctx.debuglog("Several specialized variants fit.")
195+
//assert(betterDefs.length < 2) // TODO: How to select the best if there are several ?
187196

188197
if (betterDefs.nonEmpty) {
189-
println(s"method $fun rewired to specialized variant with type (${betterDefs.head._1})")
198+
val best = betterDefs.head
199+
println(s"method ${fun.symbol.name} of ${fun.symbol.owner} rewired to specialized variant with type (${best._1})")
190200
val prefix = fun match {
191201
case Select(pre, name) =>
192202
pre
@@ -197,9 +207,10 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
197207
else EmptyTree
198208
}
199209
if (prefix ne EmptyTree)
200-
prefix.select(betterDefs.head._2)
201-
else ref(betterDefs.head._2)
210+
prefix.select(best._2)
211+
else ref(best._2)
202212
} else tree
203213
} else tree
204214
}
205215
}
216+

test/dotc/tests.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class tests extends CompilerTest {
3939
val negDir = testsDir + "neg/"
4040
val runDir = testsDir + "run/"
4141
val newDir = testsDir + "new/"
42+
val specialDir = posDir + "specialization/"
4243
val miniMethodDir = testsDir + "method_minibox/"
4344
val miniMoreDir = testsDir + "more_minibox/"
4445

@@ -101,8 +102,7 @@ class tests extends CompilerTest {
101102
@Test def pos_SI7638 = compileFile(posDir, "SI-7638")
102103
@Test def pos_SI7638a = compileFile(posDir, "SI-7638a")
103104

104-
105-
@Test def new_all = compileFiles(newDir, twice)
105+
//@Test def new_all = compileFiles(newDir, twice)
106106

107107
@Test def neg_blockescapes() = compileFile(negDir, "blockescapesNeg", xerrors = 1)
108108
@Test def neg_typedapply() = compileFile(negDir, "typedapply", xerrors = 4)
@@ -240,20 +240,21 @@ class tests extends CompilerTest {
240240
val javaDir = "./tests/pos/java-interop/"
241241
@Test def java_all = compileFiles(javaDir, twice)
242242

243-
//@Test def pos_specialization = compileFile(posDir, "specialization")//, specialise)
244-
245-
//@Test def dotc_compilercommand = compileFile(dotcDir + "config/", "CompilerCommand")
246-
247-
//@Test def test = compileFile(posDir, "t247", List("-Xprint:all"))
248-
@Test def mini_method = compileFiles(miniMethodDir)//, List("-Xprint:all"))
249-
@Test def mini_more = compileFiles(miniMoreDir)//, List("-Xprint:all"))
243+
//@Test def specialization = compileFile(specialDir, "specialization")//, specialise)
244+
//@Test def mutual_spec = compileFile(specialDir, "mutual_specialization")
245+
//@Test def return_spec = compileFile(specialDir, "return_specialization")
246+
//@Test def nothing_spec = compileFile(specialDir, "nothing_specialization")
247+
//@Test def method_in_class_spec = compileFile(specialDir, "method_in_class_specialization")
248+
//@Test def method_in_method_spec = compileFile(specialDir, "method_in_method_specialization", List("-Xprint:all"))
249+
@Test def type_check_spec = compileFile(specialDir, "type_check_specialization")
250+
//@Test def bounds_spec = compileFile(specialDir, "bounds_specialization", List("-Xprint:all"))
251+
//@Test def multi_spec = compileFile(specialDir, "multi_specialization", List("-Xprint:all"))
252+
//@Test def pos_spec_all = compileFiles(specialDir)
253+
254+
//@Test def mini_method = compileFiles(miniMethodDir)//, List("-Xprint:all"))
255+
//@Test def mini_more = compileFiles(miniMoreDir)//, List("-Xprint:all"))
250256
//@Test def pos_all = compileFiles(posDir)//, List("-Xprint:all"))
251257

252-
@Test def pos_mutual_spec = compileFile(posDir, "mutual_specialization", List("-Xprint:all"))
253-
//@Test def pos_mutual_spec = compileFile(posDir, "mutual_specialization")
254-
//@Test def pos_spec = compileFile(posDir, "specialization")
255-
*/
256-
@Test def pos_return_spec = compileFile(posDir, "return_specialization")
257-
// @Test def pos_si7638 = compileFile(posDir, "SI-7638", List("-Xprint:all"))
258-
258+
//@Test def pos_si7638 = compileFile(posDir, "SI-7638", List("-Xprint:all"))
259+
//@Test def test = compileFile(posDir, "t247", List("-Xprint:all"))
259260
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
object bounds_specialization {
2+
/*class Foo[@specialized K] {
3+
def bar[@specialized U](u: U) {
4+
def dough[@specialized V](v: V) {
5+
println("innerMethod")
6+
}
7+
dough(1.toShort)
8+
dough('c')
9+
}
10+
bar(2.toShort)
11+
bar('d')
12+
}
13+
*/
14+
def kung[@specialized(Int, Double) T <: AnyRef](t: T): T = {
15+
t
16+
}
17+
18+
def fu[@specialized(Int, Double) T >: Nothing](t: T): T = {
19+
t
20+
}
21+
}

0 commit comments

Comments
 (0)