1
1
package dotty .tools .dotc .transform
2
2
3
3
import dotty .tools .dotc .ast .{tpd , TreeTypeMap }
4
- import dotty .tools .dotc .ast .Trees .{TypeApply , SeqLiteral }
5
- import dotty .tools .dotc .ast .tpd ._
4
+ import dotty .tools .dotc .ast .Trees ._
6
5
import dotty .tools .dotc .core .Annotations .Annotation
7
6
import dotty .tools .dotc .core .Contexts .Context
8
7
import dotty .tools .dotc .core .Decorators .StringDecorator
@@ -13,9 +12,10 @@ import dotty.tools.dotc.core.{Symbols, Flags}
13
12
import dotty .tools .dotc .core .Types ._
14
13
import dotty .tools .dotc .transform .TreeTransforms .{TransformerInfo , MiniPhaseTransform }
15
14
import scala .collection .mutable
15
+ import dotty .tools .dotc .core .StdNames .nme
16
16
17
17
class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
18
-
18
+ import tpd . _
19
19
override def phaseName = " specialize"
20
20
21
21
final val maxTparamsToSpecialize = 2
@@ -56,17 +56,21 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
56
56
57
57
private val specializationRequests : mutable.HashMap [Symbols .Symbol , List [List [Type ]]] = mutable.HashMap .empty
58
58
59
- private val newSymbolMap : mutable.HashMap [Symbol , List [mutable.HashMap [List [Type ], Symbols .Symbol ]]] = mutable.HashMap .empty
59
+ /**
60
+ * A map that links symbols to their specialized variants.
61
+ * Each symbol maps to another as map, from the list of specialization types to the specialized symbol.
62
+ */
63
+ private val newSymbolMap : mutable.HashMap [Symbol , mutable.HashMap [List [Type ], Symbols .Symbol ]] = mutable.HashMap .empty
60
64
61
65
override def transformInfo (tp : Type , sym : Symbol )(implicit ctx : Context ): Type = {
62
- def generateSpecializations (remainingTParams : List [Name ], remainingBounds : List [TypeBounds ])
66
+
67
+ def generateSpecializations (remainingTParams : List [Name ], remainingBounds : List [TypeBounds ], specTypes : List [Type ])
63
68
(instantiations : List [Type ], names : List [String ], poly : PolyType , decl : Symbol )
64
69
(implicit ctx : Context ): List [Symbol ] = {
65
70
if (remainingTParams.nonEmpty) {
66
71
val bounds = remainingBounds.head
67
- val specTypes = primitiveTypes.filter{ tpe => bounds.contains(tpe)}
68
72
val specializations = (for (tpe <- specTypes) yield {
69
- generateSpecializations(remainingTParams.tail, remainingBounds.tail)(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
73
+ generateSpecializations(remainingTParams.tail, remainingBounds.tail, specTypes )(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
70
74
}).flatten
71
75
specializations
72
76
}
@@ -78,25 +82,27 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
78
82
(implicit ctx : Context ): List [Symbol ] = {
79
83
val newSym =
80
84
ctx.newSymbol(decl.owner, (decl.name + names.mkString).toTermName,
81
- decl.flags | Flags .Synthetic , poly.instantiate(instantiations.toList)) // Who should the owner be ? decl.owner ? sym ? sym.owner ? ctx.owner ?
82
- // TODO I think there might be a bug in the assertion at dotty.tools.dotc.transform.TreeChecker$Checker.dotty$tools$dotc$transform$TreeChecker$Checker$$checkOwner(TreeChecker.scala:244)
83
- // Shouldn't the owner remain the original one ? In this instance, the assertion always expects the owner to be `class specialization` (the test I run), even for methods that aren't
84
- // defined by the test itself, such as `instanceOf` (to which my implementation gives owner `class Any`).
85
- val prevMaps = newSymbolMap.getOrElse(decl, List ()).reverse
86
- val newMap : mutable.HashMap [List [Type ], Symbols .Symbol ] = mutable.HashMap (instantiations -> newSym)
87
- newSymbolMap.put(decl, (newMap :: prevMaps.reverse).reverse)
88
- (newSym :: prevMaps.flatMap(_.values).reverse).reverse // All those reverse are probably useless
85
+ decl.flags | Flags .Synthetic , poly.instantiate(instantiations.toList))
86
+ val map = newSymbolMap.getOrElse(decl, mutable.HashMap .empty)
87
+ map.put(instantiations, newSym)
88
+ newSymbolMap.put(decl, map)
89
+ map.values.toList
89
90
}
90
91
91
- if ((sym ne ctx.definitions.ScalaPredefModule .moduleClass) && ! (sym is Flags .Package ) && ! sym.isAnonymousClass) {
92
+ if ((sym ne ctx.definitions.ScalaPredefModule .moduleClass) &&
93
+ ! (sym is Flags .Package ) &&
94
+ ! sym.isAnonymousClass &&
95
+ ! (sym.name == nme.asInstanceOf_)) {
92
96
sym.info match {
93
97
case classInfo : ClassInfo =>
94
- val newDecls = classInfo.decls.flatMap(decl => {
98
+ val newDecls = classInfo.decls.filterNot(_.isConstructor /* isPrimaryConstructor */ ). flatMap(decl => {
95
99
if (shouldSpecialize(decl)) {
96
100
decl.info.widen match {
97
101
case poly : PolyType =>
98
- if (poly.paramNames.length <= maxTparamsToSpecialize && poly.paramNames.length > 0 )
99
- generateSpecializations(poly.paramNames, poly.paramBounds)(List .empty, List .empty, poly, decl)
102
+ if (poly.paramNames.length <= maxTparamsToSpecialize && poly.paramNames.length > 0 ) {
103
+ val specTypes = getSpecTypes(sym)
104
+ generateSpecializations(poly.paramNames, poly.paramBounds, specTypes)(List .empty, List .empty, poly, decl)
105
+ }
100
106
else Nil
101
107
case nil => Nil
102
108
}
@@ -113,6 +119,20 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
113
119
} else tp
114
120
}
115
121
122
+ def getSpecTypes (sym : Symbol )(implicit ctx : Context ): List [Type ] = {
123
+ sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil ) match {
124
+ case annot : Annotation =>
125
+ annot.arguments match {
126
+ case List (SeqLiteral (types)) =>
127
+ types.map(tpeTree => nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf [TermRef ].name.toString()))
128
+ case List () => primitiveTypes
129
+ }
130
+ case nil =>
131
+ if (ctx.settings.Yspecialize .value == " all" ) primitiveTypes
132
+ else Nil
133
+ }
134
+ }
135
+
116
136
def shouldSpecialize (decl : Symbol )(implicit ctx : Context ): Boolean =
117
137
specializationRequests.contains(decl) ||
118
138
(ctx.settings.Yspecialize .value != " " && decl.name.contains(ctx.settings.Yspecialize .value)) ||
@@ -124,81 +144,104 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
124
144
val prev = specializationRequests.getOrElse(method, List .empty)
125
145
specializationRequests.put(method, arguments :: prev)
126
146
}
127
-
128
- def specializeForAll (sym : Symbols .Symbol )(implicit ctx : Context ): List [List [ Type ] ] = {
147
+ /*
148
+ def specializeForAll(sym: Symbols.Symbol)(implicit ctx: Context): List[Type] = {
129
149
registerSpecializationRequest(sym)(primitiveTypes)
130
- println(" Specializing for all primitive types" )
131
- specializationRequests.getOrElse(sym, Nil )
150
+ println(s "Specializing $sym for all primitive types")
151
+ specializationRequests.getOrElse(sym, Nil).flatten
132
152
}
133
153
134
- def specializeForSome (sym : Symbols .Symbol )(annotationArgs : List [Type ])(implicit ctx : Context ): List [List [ Type ] ] = {
154
+ def specializeForSome(sym: Symbols.Symbol)(annotationArgs: List[Type])(implicit ctx: Context): List[Type] = {
135
155
registerSpecializationRequest(sym)(annotationArgs)
136
156
println(s"specializationRequests : $specializationRequests")
137
- specializationRequests.getOrElse(sym, Nil )
157
+ specializationRequests.getOrElse(sym, Nil).flatten
138
158
}
139
159
140
- def specializeFor (sym : Symbols .Symbol )(implicit ctx : Context ): List [List [ Type ] ] = {
160
+ def specializeFor(sym: Symbols.Symbol)(implicit ctx: Context): List[Type] = {
141
161
sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil) match {
142
162
case annot: Annotation =>
143
163
annot.arguments match {
144
164
case List(SeqLiteral(types)) =>
145
- specializeForSome(sym)(types.map(tpeTree => // tpeTree.tpe.widen))
146
- nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf [TermRef ].name.toString()))) // Not sure how to match TermRefs rather than types. comment on line above was an attempt.
165
+ specializeForSome(sym)(types.map(tpeTree =>
166
+ nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString()))) // Not sure how to match TermRefs rather than type names
147
167
case List() => specializeForAll(sym)
148
168
}
149
169
case nil =>
150
- if (ctx.settings.Yspecialize .value == " all" ) {println( " Yspecialize set to all " ); specializeForAll(sym) }
170
+ if(ctx.settings.Yspecialize.value == "all") specializeForAll(sym)
151
171
else Nil
152
172
}
153
- }
173
+ }*/
154
174
155
175
override def transformDefDef (tree : DefDef )(implicit ctx : Context , info : TransformerInfo ): Tree = {
156
176
157
177
tree.tpe.widen match {
158
178
159
- case poly : PolyType if ! (tree.symbol.isPrimaryConstructor
160
- || (tree.symbol is Flags .Label )) =>
179
+ case poly : PolyType if ! (tree.symbol.isConstructor// isPrimaryConstructor
180
+ || (tree.symbol is Flags .Label ))
181
+ || (tree.symbol.name == nme.asInstanceOf_) =>
161
182
val origTParams = tree.tparams.map(_.symbol)
162
183
val origVParams = tree.vparamss.flatten.map(_.symbol)
163
- println(s " specializing ${tree.symbol} for Tparams: $origTParams" )
164
184
165
185
def specialize (decl : Symbol ): List [Tree ] = {
166
- val declSpecs = newSymbolMap(decl)
167
- val newSyms = declSpecs.map(_.values).flatten
168
- /* for (newSym <- newSyms) {
169
- println(newSym)
170
- }*/
171
- val instantiations = declSpecs.flatMap(_.keys).flatten
172
- newSyms.map{newSym =>
186
+ if (newSymbolMap.contains(decl)) {
187
+ val declSpecs = newSymbolMap(decl)
188
+ val newSyms = declSpecs.values.toList
189
+ val instantiations = declSpecs.keys.toArray
190
+ var index = - 1
191
+ println(s " specializing ${tree.symbol} for $origTParams" )
192
+ newSyms.map { newSym =>
193
+ index += 1
173
194
polyDefDef(newSym.asTerm, { tparams => vparams => {
174
195
assert(tparams.isEmpty)
175
- // println(newSym + " ; " + origVParams + " ; " + vparams + " ; " + vparams.flatten + " ; " + vparams.flatten.map(_.tpe))
176
- new TreeTypeMap ( // TODO Figure out what is happening with newSym. Why do some symbols have unmatching vparams and origVParams ?
196
+ new TreeTypeMap (
177
197
typeMap = _
178
- .substDealias(origTParams, instantiations)
198
+ .substDealias(origTParams, instantiations(index) )
179
199
.subst(origVParams, vparams.flatten.map(_.tpe)),
180
200
oldOwners = tree.symbol :: Nil ,
181
201
newOwners = newSym :: Nil
182
202
).transform(tree.rhs)
183
203
}})
184
204
}
205
+ } else Nil
185
206
}
186
- // specializeFor(tree.symbol) -> necessary ? This registers specialization requests, but do they still make sense at this point ? Symbols have already been generated
187
- val specializedMethods = newSymbolMap.keys.map(specialize).flatten.toList
207
+ val specializedMethods = specialize(tree.symbol)
188
208
Thicket (tree :: specializedMethods)
189
209
case _ => tree
190
210
}
191
211
}
192
212
193
213
override def transformTypeApply (tree : tpd.TypeApply )(implicit ctx : Context , info : TransformerInfo ): Tree = {
194
- val TypeApply (fun,args) = tree
195
- val newSymInfo = newSymbolMap(fun.symbol).flatten.toMap
196
- val specializationType : List [Type ] = args.map(_.tpe.asInstanceOf [TypeVar ].instanceOpt)
197
- val t = fun.symbol.info.decls
198
- if (t.nonEmpty) {
199
- t.cloneScope.lookupEntry(args.head.symbol.name)
200
- val newSym = newSymInfo(specializationType)
214
+
215
+ def allowedToSpecialize (sym : Symbol ): Boolean = {
216
+ sym.name != nme.asInstanceOf_ &&
217
+ ! (sym is Flags .JavaDefined ) &&
218
+ ! sym.isConstructor// isPrimaryConstructor
201
219
}
202
- tree
220
+ val TypeApply (fun,args) = tree
221
+ if (newSymbolMap.contains(fun.symbol) && allowedToSpecialize(fun.symbol)) {
222
+ val newSymInfos = newSymbolMap(fun.symbol)
223
+ val betterDefs = newSymInfos.filter(x => (x._1 zip args).forall{a =>
224
+ val specializedType = a._1
225
+ val argType = a._2
226
+ argType.tpe <:< specializedType
227
+ }).toList
228
+ assert(betterDefs.length < 2 ) // TODO: How to select the best if there are several ?
229
+
230
+ if (betterDefs.nonEmpty) {
231
+ println(s " method $fun rewired to specialozed variant with type ( ${betterDefs.head._1}) " )
232
+ val prefix = fun match {
233
+ case Select (pre, name) =>
234
+ pre
235
+ case t @ Ident (_) if t.tpe.isInstanceOf [TermRef ] =>
236
+ val tp = t.tpe.asInstanceOf [TermRef ]
237
+ if (tp.prefix ne NoPrefix )
238
+ ref(tp.prefix.termSymbol)
239
+ else EmptyTree
240
+ }
241
+ if (prefix ne EmptyTree )
242
+ prefix.select(betterDefs.head._2)
243
+ else ref(betterDefs.head._2)
244
+ } else tree
245
+ } else tree
203
246
}
204
247
}
0 commit comments