@@ -20,17 +20,6 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
20
20
21
21
final val maxTparamsToSpecialize = 2
22
22
23
- private final def nameToSpecialisedType (implicit ctx : Context ) =
24
- Map (" Byte" -> ctx.definitions.ByteType ,
25
- " Boolean" -> ctx.definitions.BooleanType ,
26
- " Short" -> ctx.definitions.ShortType ,
27
- " Int" -> ctx.definitions.IntType ,
28
- " Long" -> ctx.definitions.LongType ,
29
- " Float" -> ctx.definitions.FloatType ,
30
- " Double" -> ctx.definitions.DoubleType ,
31
- " Char" -> ctx.definitions.CharType ,
32
- " Unit" -> ctx.definitions.UnitType )
33
-
34
23
private final def specialisedTypeToSuffix (implicit ctx : Context ) =
35
24
Map (ctx.definitions.ByteType -> " $mcB$sp" ,
36
25
ctx.definitions.BooleanType -> " $mcZ$sp" ,
@@ -54,7 +43,7 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
54
43
ctx.definitions.UnitType
55
44
)
56
45
57
- private val specializationRequests : mutable.HashMap [Symbols .Symbol , List [List [ Type ] ]] = mutable.HashMap .empty
46
+ private val specializationRequests : mutable.HashMap [Symbols .Symbol , List [Type ]] = mutable.HashMap .empty
58
47
59
48
/**
60
49
* A map that links symbols to their specialized variants.
@@ -63,14 +52,12 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
63
52
private val newSymbolMap : mutable.HashMap [Symbol , mutable.HashMap [List [Type ], Symbols .Symbol ]] = mutable.HashMap .empty
64
53
65
54
override def transformInfo (tp : Type , sym : Symbol )(implicit ctx : Context ): Type = {
66
-
67
- def generateSpecializations (remainingTParams : List [Name ], remainingBounds : List [TypeBounds ], specTypes : List [Type ])
55
+ def generateSpecializations (remainingTParams : List [Name ], specTypes : List [Type ])
68
56
(instantiations : List [Type ], names : List [String ], poly : PolyType , decl : Symbol )
69
57
(implicit ctx : Context ): List [Symbol ] = {
70
58
if (remainingTParams.nonEmpty) {
71
- val bounds = remainingBounds.head
72
59
val specializations = (for (tpe <- specTypes) yield {
73
- generateSpecializations(remainingTParams.tail, remainingBounds.tail, specTypes)(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
60
+ generateSpecializations(remainingTParams.tail, specTypes)(tpe :: instantiations, specialisedTypeToSuffix(ctx)(tpe) :: names, poly, decl)
74
61
}).flatten
75
62
specializations
76
63
}
@@ -96,12 +83,15 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
96
83
sym.info match {
97
84
case classInfo : ClassInfo =>
98
85
val newDecls = classInfo.decls.filterNot(_.isConstructor/* isPrimaryConstructor*/ ).flatMap(decl => {
86
+ if (decl.name.toString.contains(" foobar" )) {
87
+ println(" hello" )
88
+ }
99
89
if (shouldSpecialize(decl)) {
100
90
decl.info.widen match {
101
91
case poly : PolyType =>
102
92
if (poly.paramNames.length <= maxTparamsToSpecialize && poly.paramNames.length > 0 ) {
103
- val specTypes = getSpecTypes(sym )
104
- generateSpecializations(poly.paramNames, poly.paramBounds, specTypes)(List .empty, List .empty, poly, decl)
93
+ val specTypes = getSpecTypes(decl).filter(tpe => poly.paramBounds.forall(_.contains(tpe)) )
94
+ generateSpecializations(poly.paramNames, specTypes)(List .empty, List .empty, poly, decl)
105
95
}
106
96
else Nil
107
97
case nil => Nil
@@ -120,16 +110,11 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
120
110
}
121
111
122
112
def getSpecTypes (sym : Symbol )(implicit ctx : Context ): List [Type ] = {
123
- sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil ) match {
124
- case annot : Annotation =>
125
- annot.arguments match {
126
- case List (SeqLiteral (types)) =>
127
- types.map(tpeTree => nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf [TermRef ].name.toString()))
128
- case List () => primitiveTypes
129
- }
130
- case nil =>
131
- if (ctx.settings.Yspecialize .value == " all" ) primitiveTypes
132
- else Nil
113
+ val requested = specializationRequests.getOrElse(sym, List ())
114
+ if (requested.nonEmpty) requested.toList
115
+ else {
116
+ if (ctx.settings.Yspecialize .value == " all" ) primitiveTypes
117
+ else Nil
133
118
}
134
119
}
135
120
@@ -142,35 +127,8 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
142
127
if (ctx.phaseId > this .treeTransformPhase.id)
143
128
assert(ctx.phaseId <= this .treeTransformPhase.id)
144
129
val prev = specializationRequests.getOrElse(method, List .empty)
145
- specializationRequests.put(method, arguments :: prev)
130
+ specializationRequests.put(method, arguments ::: prev)
146
131
}
147
- /*
148
- def specializeForAll(sym: Symbols.Symbol)(implicit ctx: Context): List[Type] = {
149
- registerSpecializationRequest(sym)(primitiveTypes)
150
- println(s"Specializing $sym for all primitive types")
151
- specializationRequests.getOrElse(sym, Nil).flatten
152
- }
153
-
154
- def specializeForSome(sym: Symbols.Symbol)(annotationArgs: List[Type])(implicit ctx: Context): List[Type] = {
155
- registerSpecializationRequest(sym)(annotationArgs)
156
- println(s"specializationRequests : $specializationRequests")
157
- specializationRequests.getOrElse(sym, Nil).flatten
158
- }
159
-
160
- def specializeFor(sym: Symbols.Symbol)(implicit ctx: Context): List[Type] = {
161
- sym.denot.getAnnotation(ctx.definitions.specializedAnnot).getOrElse(Nil) match {
162
- case annot: Annotation =>
163
- annot.arguments match {
164
- case List(SeqLiteral(types)) =>
165
- specializeForSome(sym)(types.map(tpeTree =>
166
- nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString()))) // Not sure how to match TermRefs rather than type names
167
- case List() => specializeForAll(sym)
168
- }
169
- case nil =>
170
- if(ctx.settings.Yspecialize.value == "all") specializeForAll(sym)
171
- else Nil
172
- }
173
- }*/
174
132
175
133
override def transformDefDef (tree : DefDef )(implicit ctx : Context , info : TransformerInfo ): Tree = {
176
134
@@ -228,7 +186,7 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
228
186
assert(betterDefs.length < 2 ) // TODO: How to select the best if there are several ?
229
187
230
188
if (betterDefs.nonEmpty) {
231
- println(s " method $fun rewired to specialozed variant with type ( ${betterDefs.head._1}) " )
189
+ println(s " method $fun rewired to specialized variant with type ( ${betterDefs.head._1}) " )
232
190
val prefix = fun match {
233
191
case Select (pre, name) =>
234
192
pre
0 commit comments