Skip to content

Commit 62e8131

Browse files
committed
Add specialised method dispatching
Implementation of specialized methods dispatching in `transformTypeApply`. If several specialised variants fit (e.g. Nothing is a subtype of all primitive types), the compiler defaults to not specialising. Currently, all specialization is done through the Yspecialize:all setting, as some anotations (including `@specialized`) are lost earlier in the pipeline. Various bugs corrected in generation of specialized method symbols.
1 parent e682550 commit 62e8131

File tree

3 files changed

+109
-58
lines changed

3 files changed

+109
-58
lines changed

src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class ScalaSettings extends Settings.SettingGroup {
3333
val usejavacp = BooleanSetting("-usejavacp", "Utilize the java.class.path in classpath resolution.")
3434
val verbose = BooleanSetting("-verbose", "Output messages about what the compiler is doing.")
3535
val version = BooleanSetting("-version", "Print product version and exit.")
36-
val pageWidth = IntSetting("-pagewidth", "Set page width", 80)
36+
val pageWidth = IntSetting("-pagewidth", "Set page width", 160)
3737

3838
val jvmargs = PrefixSetting("-J<flag>", "-J", "Pass <flag> directly to the runtime system.")
3939
val defines = PrefixSetting("-Dproperty=value", "-D", "Pass -Dproperty=value directly to the runtime system.")
Lines changed: 96 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
package dotty.tools.dotc.transform
22

33
import dotty.tools.dotc.ast.{tpd, TreeTypeMap}
4-
import dotty.tools.dotc.ast.Trees.{TypeApply, SeqLiteral}
5-
import dotty.tools.dotc.ast.tpd._
4+
import dotty.tools.dotc.ast.Trees._
65
import dotty.tools.dotc.core.Annotations.Annotation
76
import dotty.tools.dotc.core.Contexts.Context
87
import dotty.tools.dotc.core.Decorators.StringDecorator
@@ -13,9 +12,10 @@ import dotty.tools.dotc.core.{Symbols, Flags}
1312
import dotty.tools.dotc.core.Types._
1413
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTransform}
1514
import scala.collection.mutable
15+
import dotty.tools.dotc.core.StdNames.nme
1616

1717
class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
18-
18+
import tpd._
1919
override def phaseName = "specialize"
2020

2121
final val maxTparamsToSpecialize = 2
@@ -56,17 +56,21 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
5656

5757
private val specializationRequests: mutable.HashMap[Symbols.Symbol, List[List[Type]]] = mutable.HashMap.empty
5858

59-
private val newSymbolMap: mutable.HashMap[Symbol, List[mutable.HashMap[List[Type], Symbols.Symbol]]] = mutable.HashMap.empty
59+
/**
60+
* A map that links symbols to their specialized variants.
61+
* Each symbol maps to another as map, from the list of specialization types to the specialized symbol.
62+
*/
63+
private val newSymbolMap: mutable.HashMap[Symbol, mutable.HashMap[List[Type], Symbols.Symbol]] = mutable.HashMap.empty
6064

6165
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type = {
62-
def generateSpecializations(remainingTParams: List[Name], remainingBounds: List[TypeBounds])
66+
67+
def generateSpecializations(remainingTParams: List[Name], remainingBounds: List[TypeBounds], specTypes: List[Type])
6368
(instantiations: List[Type], names: List[String], poly: PolyType, decl: Symbol)
6469
(implicit ctx: Context): List[Symbol] = {
6570
if (remainingTParams.nonEmpty) {
6671
val bounds = remainingBounds.head
67-
val specTypes = primitiveTypes.filter{ tpe => bounds.contains(tpe)}
6872
val specializations = (for (tpe <- specTypes) yield {
69-
generateSpecializations(remainingTParams.tail, remainingBounds.tail)(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
73+
generateSpecializations(remainingTParams.tail, remainingBounds.tail, specTypes)(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
7074
}).flatten
7175
specializations
7276
}
@@ -78,25 +82,27 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
7882
(implicit ctx: Context): List[Symbol] = {
7983
val newSym =
8084
ctx.newSymbol(decl.owner, (decl.name + names.mkString).toTermName,
81-
decl.flags | Flags.Synthetic, poly.instantiate(instantiations.toList)) // Who should the owner be ? decl.owner ? sym ? sym.owner ? ctx.owner ?
82-
// TODO I think there might be a bug in the assertion at dotty.tools.dotc.transform.TreeChecker$Checker.dotty$tools$dotc$transform$TreeChecker$Checker$$checkOwner(TreeChecker.scala:244)
83-
// Shouldn't the owner remain the original one ? In this instance, the assertion always expects the owner to be `class specialization` (the test I run), even for methods that aren't
84-
//defined by the test itself, such as `instanceOf` (to which my implementation gives owner `class Any`).
85-
val prevMaps = newSymbolMap.getOrElse(decl, List()).reverse
86-
val newMap: mutable.HashMap[List[Type], Symbols.Symbol] = mutable.HashMap(instantiations -> newSym)
87-
newSymbolMap.put(decl, (newMap :: prevMaps.reverse).reverse)
88-
(newSym :: prevMaps.flatMap(_.values).reverse).reverse // All those reverse are probably useless
85+
decl.flags | Flags.Synthetic, poly.instantiate(instantiations.toList))
86+
val map = newSymbolMap.getOrElse(decl, mutable.HashMap.empty)
87+
map.put(instantiations, newSym)
88+
newSymbolMap.put(decl, map)
89+
map.values.toList
8990
}
9091

91-
if((sym ne ctx.definitions.ScalaPredefModule.moduleClass) && !(sym is Flags.Package) && !sym.isAnonymousClass) {
92+
if((sym ne ctx.definitions.ScalaPredefModule.moduleClass) &&
93+
!(sym is Flags.Package) &&
94+
!sym.isAnonymousClass &&
95+
!(sym.name == nme.asInstanceOf_)) {
9296
sym.info match {
9397
case classInfo: ClassInfo =>
94-
val newDecls = classInfo.decls.flatMap(decl => {
98+
val newDecls = classInfo.decls.filterNot(_.isConstructor/*isPrimaryConstructor*/).flatMap(decl => {
9599
if (shouldSpecialize(decl)) {
96100
decl.info.widen match {
97101
case poly: PolyType =>
98-
if (poly.paramNames.length <= maxTparamsToSpecialize && poly.paramNames.length > 0)
99-
generateSpecializations(poly.paramNames, poly.paramBounds)(List.empty, List.empty, poly, decl)
102+
if (poly.paramNames.length <= maxTparamsToSpecialize && poly.paramNames.length > 0) {
103+
val specTypes = getSpecTypes(sym)
104+
generateSpecializations(poly.paramNames, poly.paramBounds, specTypes)(List.empty, List.empty, poly, decl)
105+
}
100106
else Nil
101107
case nil => Nil
102108
}
@@ -113,6 +119,20 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
113119
} else tp
114120
}
115121

122+
def getSpecTypes(sym: Symbol)(implicit ctx: Context): List[Type] = {
123+
sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil) match {
124+
case annot: Annotation =>
125+
annot.arguments match {
126+
case List(SeqLiteral(types)) =>
127+
types.map(tpeTree => nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString()))
128+
case List() => primitiveTypes
129+
}
130+
case nil =>
131+
if(ctx.settings.Yspecialize.value == "all") primitiveTypes
132+
else Nil
133+
}
134+
}
135+
116136
def shouldSpecialize(decl: Symbol)(implicit ctx: Context): Boolean =
117137
specializationRequests.contains(decl) ||
118138
(ctx.settings.Yspecialize.value != "" && decl.name.contains(ctx.settings.Yspecialize.value)) ||
@@ -124,81 +144,104 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
124144
val prev = specializationRequests.getOrElse(method, List.empty)
125145
specializationRequests.put(method, arguments :: prev)
126146
}
127-
128-
def specializeForAll(sym: Symbols.Symbol)(implicit ctx: Context): List[List[Type]] = {
147+
/*
148+
def specializeForAll(sym: Symbols.Symbol)(implicit ctx: Context): List[Type] = {
129149
registerSpecializationRequest(sym)(primitiveTypes)
130-
println("Specializing for all primitive types")
131-
specializationRequests.getOrElse(sym, Nil)
150+
println(s"Specializing $sym for all primitive types")
151+
specializationRequests.getOrElse(sym, Nil).flatten
132152
}
133153
134-
def specializeForSome(sym: Symbols.Symbol)(annotationArgs: List[Type])(implicit ctx: Context): List[List[Type]] = {
154+
def specializeForSome(sym: Symbols.Symbol)(annotationArgs: List[Type])(implicit ctx: Context): List[Type] = {
135155
registerSpecializationRequest(sym)(annotationArgs)
136156
println(s"specializationRequests : $specializationRequests")
137-
specializationRequests.getOrElse(sym, Nil)
157+
specializationRequests.getOrElse(sym, Nil).flatten
138158
}
139159
140-
def specializeFor(sym: Symbols.Symbol)(implicit ctx: Context): List[List[Type]] = {
160+
def specializeFor(sym: Symbols.Symbol)(implicit ctx: Context): List[Type] = {
141161
sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil) match {
142162
case annot: Annotation =>
143163
annot.arguments match {
144164
case List(SeqLiteral(types)) =>
145-
specializeForSome(sym)(types.map(tpeTree => //tpeTree.tpe.widen))
146-
nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString()))) // Not sure how to match TermRefs rather than types. comment on line above was an attempt.
165+
specializeForSome(sym)(types.map(tpeTree =>
166+
nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString()))) // Not sure how to match TermRefs rather than type names
147167
case List() => specializeForAll(sym)
148168
}
149169
case nil =>
150-
if(ctx.settings.Yspecialize.value == "all") {println("Yspecialize set to all"); specializeForAll(sym) }
170+
if(ctx.settings.Yspecialize.value == "all") specializeForAll(sym)
151171
else Nil
152172
}
153-
}
173+
}*/
154174

155175
override def transformDefDef(tree: DefDef)(implicit ctx: Context, info: TransformerInfo): Tree = {
156176

157177
tree.tpe.widen match {
158178

159-
case poly: PolyType if !(tree.symbol.isPrimaryConstructor
160-
|| (tree.symbol is Flags.Label)) =>
179+
case poly: PolyType if !(tree.symbol.isConstructor//isPrimaryConstructor
180+
|| (tree.symbol is Flags.Label))
181+
|| (tree.symbol.name == nme.asInstanceOf_) =>
161182
val origTParams = tree.tparams.map(_.symbol)
162183
val origVParams = tree.vparamss.flatten.map(_.symbol)
163-
println(s"specializing ${tree.symbol} for Tparams: $origTParams")
164184

165185
def specialize(decl : Symbol): List[Tree] = {
166-
val declSpecs = newSymbolMap(decl)
167-
val newSyms = declSpecs.map(_.values).flatten
168-
/*for (newSym <- newSyms) {
169-
println(newSym)
170-
}*/
171-
val instantiations = declSpecs.flatMap(_.keys).flatten
172-
newSyms.map{newSym =>
186+
if (newSymbolMap.contains(decl)) {
187+
val declSpecs = newSymbolMap(decl)
188+
val newSyms = declSpecs.values.toList
189+
val instantiations = declSpecs.keys.toArray
190+
var index = -1
191+
println(s"specializing ${tree.symbol} for $origTParams")
192+
newSyms.map { newSym =>
193+
index += 1
173194
polyDefDef(newSym.asTerm, { tparams => vparams => {
174195
assert(tparams.isEmpty)
175-
//println(newSym + " ; " + origVParams + " ; " + vparams + " ; " + vparams.flatten + " ; " + vparams.flatten.map(_.tpe))
176-
new TreeTypeMap( //TODO Figure out what is happening with newSym. Why do some symbols have unmatching vparams and origVParams ?
196+
new TreeTypeMap(
177197
typeMap = _
178-
.substDealias(origTParams, instantiations)
198+
.substDealias(origTParams, instantiations(index))
179199
.subst(origVParams, vparams.flatten.map(_.tpe)),
180200
oldOwners = tree.symbol :: Nil,
181201
newOwners = newSym :: Nil
182202
).transform(tree.rhs)
183203
}})
184204
}
205+
} else Nil
185206
}
186-
//specializeFor(tree.symbol) -> necessary ? This registers specialization requests, but do they still make sense at this point ? Symbols have already been generated
187-
val specializedMethods = newSymbolMap.keys.map(specialize).flatten.toList
207+
val specializedMethods = specialize(tree.symbol)
188208
Thicket(tree :: specializedMethods)
189209
case _ => tree
190210
}
191211
}
192212

193213
override def transformTypeApply(tree: tpd.TypeApply)(implicit ctx: Context, info: TransformerInfo): Tree = {
194-
val TypeApply(fun,args) = tree
195-
val newSymInfo = newSymbolMap(fun.symbol).flatten.toMap
196-
val specializationType: List[Type] = args.map(_.tpe.asInstanceOf[TypeVar].instanceOpt)
197-
val t = fun.symbol.info.decls
198-
if (t.nonEmpty) {
199-
t.cloneScope.lookupEntry(args.head.symbol.name)
200-
val newSym = newSymInfo(specializationType)
214+
215+
def allowedToSpecialize(sym: Symbol): Boolean = {
216+
sym.name != nme.asInstanceOf_ &&
217+
!(sym is Flags.JavaDefined) &&
218+
!sym.isConstructor//isPrimaryConstructor
201219
}
202-
tree
220+
val TypeApply(fun,args) = tree
221+
if (newSymbolMap.contains(fun.symbol) && allowedToSpecialize(fun.symbol)) {
222+
val newSymInfos = newSymbolMap(fun.symbol)
223+
val betterDefs = newSymInfos.filter(x => (x._1 zip args).forall{a =>
224+
val specializedType = a._1
225+
val argType = a._2
226+
argType.tpe <:< specializedType
227+
}).toList
228+
assert(betterDefs.length < 2) // TODO: How to select the best if there are several ?
229+
230+
if (betterDefs.nonEmpty) {
231+
println(s"method $fun rewired to specialozed variant with type (${betterDefs.head._1})")
232+
val prefix = fun match {
233+
case Select(pre, name) =>
234+
pre
235+
case t @ Ident(_) if t.tpe.isInstanceOf[TermRef] =>
236+
val tp = t.tpe.asInstanceOf[TermRef]
237+
if (tp.prefix ne NoPrefix)
238+
ref(tp.prefix.termSymbol)
239+
else EmptyTree
240+
}
241+
if (prefix ne EmptyTree)
242+
prefix.select(betterDefs.head._2)
243+
else ref(betterDefs.head._2)
244+
} else tree
245+
} else tree
203246
}
204247
}

test/dotc/tests.scala

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class tests extends CompilerTest {
1414
// "-Xprompt",
1515
// "-explaintypes",
1616
// "-Yshow-suppressed-errors",
17-
"-pagewidth", "160")
17+
)
1818

1919
val defaultOutputDir = "./out/"
2020

@@ -31,20 +31,23 @@ class tests extends CompilerTest {
3131
val allowDeepSubtypes = defaultOptions diff List("-Yno-deep-subtypes")
3232
val allowDoubleBindings = defaultOptions diff List("-Yno-double-bindings")
3333

34+
val specialise = List("-Yspecialize:all")
35+
3436
val testsDir = "./tests/"
3537
val posDir = testsDir + "pos/"
3638
val posSpecialDir = testsDir + "pos-special/"
3739
val negDir = testsDir + "neg/"
3840
val runDir = testsDir + "run/"
3941
val newDir = testsDir + "new/"
42+
val miniMethodDir = testsDir + "method_minibox/"
43+
val miniMoreDir = testsDir + "more_minibox/"
4044

4145
val sourceDir = "./src/"
4246
val dottyDir = sourceDir + "dotty/"
4347
val toolsDir = dottyDir + "tools/"
4448
val dotcDir = toolsDir + "dotc/"
4549
val coreDir = dotcDir + "core/"
4650

47-
/*
4851
@Test def pickle_pickleOK = compileDir(testsDir, "pickling", testPickling)
4952
// This directory doesn't exist anymore
5053
// @Test def pickle_pickling = compileDir(coreDir, "pickling", testPickling)
@@ -237,8 +240,13 @@ class tests extends CompilerTest {
237240
val javaDir = "./tests/pos/java-interop/"
238241
@Test def java_all = compileFiles(javaDir, twice)
239242

240-
@Test def pos_specialization = compileFile(posDir, "specialization")
243+
//@Test def pos_specialization = compileFile(posDir, "specialization")//, specialise)
241244

242245
//@Test def dotc_compilercommand = compileFile(dotcDir + "config/", "CompilerCommand")
243-
246+
247+
//@Test def test = compileFile(posDir, "t247", List("-Xprint:all"))
248+
@Test def mini_method = compileFiles(miniMethodDir)//, List("-Xprint:all"))
249+
@Test def mini_more = compileFiles(miniMoreDir)//, List("-Xprint:all"))
250+
//@Test def pos_all = compileFiles(posDir)//, List("-Xprint:all"))
251+
244252
}

0 commit comments

Comments
 (0)