Skip to content

Commit e682550

Browse files
committed
Add specialized methods dispatch
1 parent aaf07bd commit e682550

File tree

4 files changed

+137
-86
lines changed

4 files changed

+137
-86
lines changed

src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ class ScalaSettings extends Settings.SettingGroup {
152152
val YnoDeepSubtypes = BooleanSetting("-Yno-deep-subtypes", "throw an exception on deep subtyping call stacks.")
153153
val YprintSyms = BooleanSetting("-Yprint-syms", "when printing trees print info in symbols instead of corresponding info in trees.")
154154
val YtestPickler = BooleanSetting("-Ytest-pickler", "self-test for pickling functionality; should be used with -Ystop-after:pickler")
155-
val Yspecialize = StringSetting("-Yspecialize","","Specialize all methods.","") // How should the second and fourth paramerters be initialised ?
155+
val Yspecialize = StringSetting("-Yspecialize","all","Specialize all methods.", "all") // TODO remove default value
156+
156157
def stop = YstopAfter
157158

158159
/** Area-specific debug output.
Lines changed: 117 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
package dotty.tools.dotc.transform
22

3-
import dotty.tools.dotc.ast.TreeTypeMap
4-
import dotty.tools.dotc.ast.Trees.SeqLiteral
3+
import dotty.tools.dotc.ast.{tpd, TreeTypeMap}
4+
import dotty.tools.dotc.ast.Trees.{TypeApply, SeqLiteral}
55
import dotty.tools.dotc.ast.tpd._
66
import dotty.tools.dotc.core.Annotations.Annotation
77
import dotty.tools.dotc.core.Contexts.Context
88
import dotty.tools.dotc.core.Decorators.StringDecorator
99
import dotty.tools.dotc.core.DenotTransformers.InfoTransformer
10-
import dotty.tools.dotc.core.Names.TermName
10+
import dotty.tools.dotc.core.Names.Name
1111
import dotty.tools.dotc.core.Symbols.Symbol
1212
import dotty.tools.dotc.core.{Symbols, Flags}
1313
import dotty.tools.dotc.core.Types._
@@ -42,51 +42,82 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
4242
ctx.definitions.CharType -> "$mcC$sp",
4343
ctx.definitions.UnitType -> "$mcV$sp")
4444

45+
private def primitiveTypes(implicit ctx: Context) =
46+
List(ctx.definitions.ByteType,
47+
ctx.definitions.BooleanType,
48+
ctx.definitions.ShortType,
49+
ctx.definitions.IntType,
50+
ctx.definitions.LongType,
51+
ctx.definitions.FloatType,
52+
ctx.definitions.DoubleType,
53+
ctx.definitions.CharType,
54+
ctx.definitions.UnitType
55+
)
56+
4557
private val specializationRequests: mutable.HashMap[Symbols.Symbol, List[List[Type]]] = mutable.HashMap.empty
4658

47-
private val newSymbolMap: mutable.HashMap[TermName, (List[Symbols.TermSymbol], List[Type])] = mutable.HashMap.empty // Why does the typechecker require TermSymbol ?
59+
private val newSymbolMap: mutable.HashMap[Symbol, List[mutable.HashMap[List[Type], Symbols.Symbol]]] = mutable.HashMap.empty
4860

4961
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type = {
62+
def generateSpecializations(remainingTParams: List[Name], remainingBounds: List[TypeBounds])
63+
(instantiations: List[Type], names: List[String], poly: PolyType, decl: Symbol)
64+
(implicit ctx: Context): List[Symbol] = {
65+
if (remainingTParams.nonEmpty) {
66+
val bounds = remainingBounds.head
67+
val specTypes = primitiveTypes.filter{ tpe => bounds.contains(tpe)}
68+
val specializations = (for (tpe <- specTypes) yield {
69+
generateSpecializations(remainingTParams.tail, remainingBounds.tail)(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
70+
}).flatten
71+
specializations
72+
}
73+
else {
74+
generateSpecializedSymbols(instantiations.reverse, names.reverse, poly, decl)
75+
}
76+
}
77+
def generateSpecializedSymbols(instantiations: List[Type], names: List[String], poly: PolyType, decl: Symbol)
78+
(implicit ctx: Context): List[Symbol] = {
79+
val newSym =
80+
ctx.newSymbol(decl.owner, (decl.name + names.mkString).toTermName,
81+
decl.flags | Flags.Synthetic, poly.instantiate(instantiations.toList)) // Who should the owner be ? decl.owner ? sym ? sym.owner ? ctx.owner ?
82+
// TODO I think there might be a bug in the assertion at dotty.tools.dotc.transform.TreeChecker$Checker.dotty$tools$dotc$transform$TreeChecker$Checker$$checkOwner(TreeChecker.scala:244)
83+
// Shouldn't the owner remain the original one ? In this instance, the assertion always expects the owner to be `class specialization` (the test I run), even for methods that aren't
84+
//defined by the test itself, such as `instanceOf` (to which my implementation gives owner `class Any`).
85+
val prevMaps = newSymbolMap.getOrElse(decl, List()).reverse
86+
val newMap: mutable.HashMap[List[Type], Symbols.Symbol] = mutable.HashMap(instantiations -> newSym)
87+
newSymbolMap.put(decl, (newMap :: prevMaps.reverse).reverse)
88+
(newSym :: prevMaps.flatMap(_.values).reverse).reverse // All those reverse are probably useless
89+
}
5090

51-
tp.widen match {
52-
case poly: PolyType if !(sym.isPrimaryConstructor
53-
|| (sym is Flags.Label)) =>
54-
55-
def generateSpecializations(remainingTParams: List[Type], remainingBounds: List[TypeBounds])
56-
(instantiations: List[Type], names: List[String])(implicit ctx: Context): Unit = {
57-
if (remainingTParams.nonEmpty) {
58-
val typeToSpecialize = remainingTParams.head
59-
val bounds = remainingBounds.head
60-
val a = shouldSpecializeFor(typeToSpecialize.typeSymbol) // TODO returns Nil because no annotations are found - elucidate
61-
a.flatten
62-
.filter { tpe =>
63-
bounds.contains(tpe)
64-
}.foreach({ tpe =>
65-
val nme = specialisedTypeToSuffix(ctx)(tpe)
66-
generateSpecializations(remainingTParams.tail, remainingBounds.tail)(tpe :: instantiations, nme :: names)
67-
})
68-
}
69-
else {
70-
generateSpecializedSymbols(instantiations.reverse, names.reverse)
91+
if((sym ne ctx.definitions.ScalaPredefModule.moduleClass) && !(sym is Flags.Package) && !sym.isAnonymousClass) {
92+
sym.info match {
93+
case classInfo: ClassInfo =>
94+
val newDecls = classInfo.decls.flatMap(decl => {
95+
if (shouldSpecialize(decl)) {
96+
decl.info.widen match {
97+
case poly: PolyType =>
98+
if (poly.paramNames.length <= maxTparamsToSpecialize && poly.paramNames.length > 0)
99+
generateSpecializations(poly.paramNames, poly.paramBounds)(List.empty, List.empty, poly, decl)
100+
else Nil
101+
case nil => Nil
102+
}
103+
} else Nil
104+
})
105+
if (newDecls.nonEmpty) {
106+
val decls = classInfo.decls.cloneScope
107+
newDecls.foreach(decls.enter)
108+
classInfo.derivedClassInfo(decls = decls)
71109
}
72-
}
73-
74-
def generateSpecializedSymbols(instantiations : List[Type], names: List[String])(implicit ctx: Context): Unit = {
75-
val newSym = ctx.newSymbol(sym.owner, (sym.name + names.mkString).toTermName, sym.flags | Flags.Synthetic, poly.instantiate(instantiations.toList))
76-
ctx.enter(newSym) // TODO check frozen flag ?
77-
val prev = newSymbolMap.getOrElse(sym.name.toTermName, (Nil, Nil))
78-
val newSyms = newSym :: prev._1
79-
newSymbolMap.put(sym.name.toTermName, (newSyms, instantiations)) // Could `.put(...)` bring up (mutability) issues ?
80-
}
81-
val origTParams = poly.resType.paramTypess.flatten // Is this really what is needed ?
82-
val bounds = poly.paramBounds
83-
generateSpecializations(origTParams, bounds)(List.empty, List.empty)
84-
tp
85-
case _ =>
86-
tp
87-
}
110+
case nil =>
111+
}
112+
tp
113+
} else tp
88114
}
89115

116+
def shouldSpecialize(decl: Symbol)(implicit ctx: Context): Boolean =
117+
specializationRequests.contains(decl) ||
118+
(ctx.settings.Yspecialize.value != "" && decl.name.contains(ctx.settings.Yspecialize.value)) ||
119+
ctx.settings.Yspecialize.value == "all"
120+
90121
def registerSpecializationRequest(method: Symbols.Symbol)(arguments: List[Type])(implicit ctx: Context) = {
91122
if(ctx.phaseId > this.treeTransformPhase.id)
92123
assert(ctx.phaseId <= this.treeTransformPhase.id)
@@ -95,7 +126,7 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
95126
}
96127

97128
def specializeForAll(sym: Symbols.Symbol)(implicit ctx: Context): List[List[Type]] = {
98-
registerSpecializationRequest(sym)(specialisedTypeToSuffix.keys.toList)
129+
registerSpecializationRequest(sym)(primitiveTypes)
99130
println("Specializing for all primitive types")
100131
specializationRequests.getOrElse(sym, Nil)
101132
}
@@ -106,63 +137,68 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
106137
specializationRequests.getOrElse(sym, Nil)
107138
}
108139

109-
def shouldSpecializeFor(sym: Symbols.Symbol)(implicit ctx: Context): List[List[Type]] = {
110-
sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil) match {
111-
case annot: Annotation =>
112-
annot.arguments match {
113-
case List(SeqLiteral(types)) =>
114-
specializeForSome(sym)(types.map(tpeTree =>
115-
nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString())))
116-
case List() => specializeForAll(sym)
117-
}
118-
case nil =>
119-
if(ctx.settings.Yspecialize.value == "all") specializeForAll(sym)
120-
else Nil
121-
}
140+
def specializeFor(sym: Symbols.Symbol)(implicit ctx: Context): List[List[Type]] = {
141+
sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil) match {
142+
case annot: Annotation =>
143+
annot.arguments match {
144+
case List(SeqLiteral(types)) =>
145+
specializeForSome(sym)(types.map(tpeTree => //tpeTree.tpe.widen))
146+
nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString()))) // Not sure how to match TermRefs rather than types. comment on line above was an attempt.
147+
case List() => specializeForAll(sym)
148+
}
149+
case nil =>
150+
if(ctx.settings.Yspecialize.value == "all") {println("Yspecialize set to all"); specializeForAll(sym) }
151+
else Nil
152+
}
122153
}
123154

124155
override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo): Tree = {
125156

126157
tree.tpe.widen match {
127158

128159
case poly: PolyType if !(tree.symbol.isPrimaryConstructor
129-
|| (tree.symbol is Flags.Label)) =>
160+
|| (tree.symbol is Flags.Label)) =>
130161
val origTParams = tree.tparams.map(_.symbol)
131162
val origVParams = tree.vparamss.flatten.map(_.symbol)
132163
println(s"specializing ${tree.symbol} for Tparams: $origTParams")
133164

134-
def specialize(instantiations: List[Type]): List[Tree] = {
135-
newSymbolMap(tree.name) match {
136-
case newSyms: (List[Symbol], List[Type]) =>
137-
newSyms._1.map{newSym =>
138-
polyDefDef(newSym, { tparams => vparams => {
139-
assert(tparams.isEmpty)
140-
new TreeTypeMap(
141-
typeMap = _
142-
.substDealias(origTParams, instantiations.toList)
143-
.subst(origVParams, vparams.flatten.map(_.tpe)),
144-
oldOwners = tree.symbol :: Nil,
145-
newOwners = newSym :: Nil
146-
).transform(tree.rhs)
147-
}
148-
})}
149-
case nil =>
150-
List()
165+
def specialize(decl : Symbol): List[Tree] = {
166+
val declSpecs = newSymbolMap(decl)
167+
val newSyms = declSpecs.map(_.values).flatten
168+
/*for (newSym <- newSyms) {
169+
println(newSym)
170+
}*/
171+
val instantiations = declSpecs.flatMap(_.keys).flatten
172+
newSyms.map{newSym =>
173+
polyDefDef(newSym.asTerm, { tparams => vparams => {
174+
assert(tparams.isEmpty)
175+
//println(newSym + " ; " + origVParams + " ; " + vparams + " ; " + vparams.flatten + " ; " + vparams.flatten.map(_.tpe))
176+
new TreeTypeMap( //TODO Figure out what is happening with newSym. Why do some symbols have unmatching vparams and origVParams ?
177+
typeMap = _
178+
.substDealias(origTParams, instantiations)
179+
.subst(origVParams, vparams.flatten.map(_.tpe)),
180+
oldOwners = tree.symbol :: Nil,
181+
newOwners = newSym :: Nil
182+
).transform(tree.rhs)
183+
}})
151184
}
152185
}
153-
154-
val specializedMethods: List[Tree] = (for (inst <- newSymbolMap.keys) yield specialize(newSymbolMap(inst)._2)).flatten.toList
186+
//specializeFor(tree.symbol) -> necessary ? This registers specialization requests, but do they still make sense at this point ? Symbols have already been generated
187+
val specializedMethods = newSymbolMap.keys.map(specialize).flatten.toList
155188
Thicket(tree :: specializedMethods)
156-
157189
case _ => tree
158190
}
159191
}
160192

161-
def transformTypeOfTree(tree: Tree): Tree = {
193+
override def transformTypeApply(tree: tpd.TypeApply)(implicit ctx: Context, info: TransformerInfo): Tree = {
194+
val TypeApply(fun,args) = tree
195+
val newSymInfo = newSymbolMap(fun.symbol).flatten.toMap
196+
val specializationType: List[Type] = args.map(_.tpe.asInstanceOf[TypeVar].instanceOpt)
197+
val t = fun.symbol.info.decls
198+
if (t.nonEmpty) {
199+
t.cloneScope.lookupEntry(args.head.symbol.name)
200+
val newSym = newSymInfo(specializationType)
201+
}
162202
tree
163203
}
164-
165-
override def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = transformTypeOfTree(tree)
166-
override def transformSelect(tree: Select)(implicit ctx: Context, info: TransformerInfo): Tree = transformTypeOfTree(tree)
167-
168-
}
204+
}

test/dotc/tests.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class tests extends CompilerTest {
156156
157157
@ Test def dotc_core_pickling = compileDir(dotcDir + "tools/dotc/core/pickling", failedOther)(allowDeepSubtypes) // Cannot emit primitive conversion from V to Z
158158
159-
@ Test def dotc_transform = compileDir(dotcDir + "tools/dotc/transform", failedbyName)
159+
//@ Test def dotc_transform = compileDir(dotcDir + "tools/dotc/transform", failedbyName)
160160
161161
@ Test def dotc_parsing = compileDir(dotcDir + "tools/dotc/parsing", failedOther)
162162
// Expected primitive types I - Ljava/lang/Object
@@ -197,10 +197,12 @@ class tests extends CompilerTest {
197197
198198
val javaDir = "./tests/pos/java-interop/"
199199
@Test def java_all = compileFiles(javaDir)
200-
*/
200+
201201
@Test def pos_specialization = compileFile(posDir, "specialization")
202202
203203
//@Test def dotc_compilercommand = compileFile(dotcDir + "tools/dotc/config/", "CompilerCommand")
204+
//@ Test def dotc_transform = compileDir(dotcDir + "tools/dotc/transform", failedbyName)
205+
204206
205207
@Test def dotc_parsing = compileDir(dotcDir, "parsing") // twice omitted to make tests run faster
206208
@@ -234,6 +236,9 @@ class tests extends CompilerTest {
234236
235237
val javaDir = "./tests/pos/java-interop/"
236238
@Test def java_all = compileFiles(javaDir, twice)
239+
240+
@Test def pos_specialization = compileFile(posDir, "specialization")
241+
237242
//@Test def dotc_compilercommand = compileFile(dotcDir + "config/", "CompilerCommand")
238243
239244
}

tests/pos/specialization.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
class specialization {
2-
def printer[@specialized(Int, Long) T, U](a: T, b:U) = {
2+
def printer1[@specialized(Int, Long) T](a: T) = {
3+
println(a.toString)
4+
}
5+
6+
def printer2[@specialized(Int, Long) T, U](a: T, b: U) = {
37
println(a.toString + b.toString)
48
}
9+
def print(a: Int) = {
10+
printer1(a)
11+
println(" ---- ")
12+
printer2(a,a)
13+
}
14+
print(9)
515
}
6-

0 commit comments

Comments
 (0)