Skip to content

Commit aaf07bd

Browse files
committed
Add InfoTransformer Trait to TypeSpecializer
Moves the new Symbols generation to the `transformInfo` method call, and builds a map of those new symbols. `transfomDefDef` later uses them to actually generate the specialized methods.
1 parent 2578036 commit aaf07bd

File tree

3 files changed

+107
-59
lines changed

3 files changed

+107
-59
lines changed

src/dotty/tools/dotc/transform/TypeSpecializer.scala

Lines changed: 98 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -5,28 +5,22 @@ import dotty.tools.dotc.ast.Trees.SeqLiteral
55
import dotty.tools.dotc.ast.tpd._
66
import dotty.tools.dotc.core.Annotations.Annotation
77
import dotty.tools.dotc.core.Contexts.Context
8+
import dotty.tools.dotc.core.Decorators.StringDecorator
9+
import dotty.tools.dotc.core.DenotTransformers.InfoTransformer
10+
import dotty.tools.dotc.core.Names.TermName
11+
import dotty.tools.dotc.core.Symbols.Symbol
812
import dotty.tools.dotc.core.{Symbols, Flags}
913
import dotty.tools.dotc.core.Types._
1014
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTransform}
11-
import dotty.tools.dotc.core.Decorators._
1215
import scala.collection.mutable
1316

14-
class TypeSpecializer extends MiniPhaseTransform {
17+
class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
1518

1619
override def phaseName = "specialize"
1720

1821
final val maxTparamsToSpecialize = 2
1922

20-
private val specializationRequests: mutable.HashMap[Symbols.Symbol, List[List[Type]]] = mutable.HashMap.empty
21-
22-
def registerSpecializationRequest(method: Symbols.Symbol)(arguments: List[Type])(implicit ctx: Context) = {
23-
if(ctx.phaseId > this.treeTransformPhase.id)
24-
assert(ctx.phaseId <= this.treeTransformPhase.id)
25-
val prev = specializationRequests.getOrElse(method, List.empty)
26-
specializationRequests.put(method, arguments :: prev)
27-
}
28-
29-
private final def name2SpecialisedType(implicit ctx: Context) =
23+
private final def nameToSpecialisedType(implicit ctx: Context) =
3024
Map("Byte" -> ctx.definitions.ByteType,
3125
"Boolean" -> ctx.definitions.BooleanType,
3226
"Short" -> ctx.definitions.ShortType,
@@ -37,19 +31,71 @@ class TypeSpecializer extends MiniPhaseTransform {
3731
"Char" -> ctx.definitions.CharType,
3832
"Unit" -> ctx.definitions.UnitType)
3933

40-
private final def specialisedType2Suffix(implicit ctx: Context) =
34+
private final def specialisedTypeToSuffix(implicit ctx: Context) =
4135
Map(ctx.definitions.ByteType -> "$mcB$sp",
42-
ctx.definitions.BooleanType -> "$mcZ$sp",
43-
ctx.definitions.ShortType -> "$mcS$sp",
44-
ctx.definitions.IntType -> "$mcI$sp",
45-
ctx.definitions.LongType -> "$mcJ$sp",
46-
ctx.definitions.FloatType -> "$mcF$sp",
47-
ctx.definitions.DoubleType -> "$mcD$sp",
48-
ctx.definitions.CharType -> "$mcC$sp",
49-
ctx.definitions.UnitType -> "$mcV$sp")
36+
ctx.definitions.BooleanType -> "$mcZ$sp",
37+
ctx.definitions.ShortType -> "$mcS$sp",
38+
ctx.definitions.IntType -> "$mcI$sp",
39+
ctx.definitions.LongType -> "$mcJ$sp",
40+
ctx.definitions.FloatType -> "$mcF$sp",
41+
ctx.definitions.DoubleType -> "$mcD$sp",
42+
ctx.definitions.CharType -> "$mcC$sp",
43+
ctx.definitions.UnitType -> "$mcV$sp")
44+
45+
private val specializationRequests: mutable.HashMap[Symbols.Symbol, List[List[Type]]] = mutable.HashMap.empty
46+
47+
private val newSymbolMap: mutable.HashMap[TermName, (List[Symbols.TermSymbol], List[Type])] = mutable.HashMap.empty // Why does the typechecker require TermSymbol ?
48+
49+
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type = {
50+
51+
tp.widen match {
52+
case poly: PolyType if !(sym.isPrimaryConstructor
53+
|| (sym is Flags.Label)) =>
54+
55+
def generateSpecializations(remainingTParams: List[Type], remainingBounds: List[TypeBounds])
56+
(instantiations: List[Type], names: List[String])(implicit ctx: Context): Unit = {
57+
if (remainingTParams.nonEmpty) {
58+
val typeToSpecialize = remainingTParams.head
59+
val bounds = remainingBounds.head
60+
val a = shouldSpecializeFor(typeToSpecialize.typeSymbol) // TODO returns Nil because no annotations are found - elucidate
61+
a.flatten
62+
.filter { tpe =>
63+
bounds.contains(tpe)
64+
}.foreach({ tpe =>
65+
val nme = specialisedTypeToSuffix(ctx)(tpe)
66+
generateSpecializations(remainingTParams.tail, remainingBounds.tail)(tpe :: instantiations, nme :: names)
67+
})
68+
}
69+
else {
70+
generateSpecializedSymbols(instantiations.reverse, names.reverse)
71+
}
72+
}
73+
74+
def generateSpecializedSymbols(instantiations : List[Type], names: List[String])(implicit ctx: Context): Unit = {
75+
val newSym = ctx.newSymbol(sym.owner, (sym.name + names.mkString).toTermName, sym.flags | Flags.Synthetic, poly.instantiate(instantiations.toList))
76+
ctx.enter(newSym) // TODO check frozen flag ?
77+
val prev = newSymbolMap.getOrElse(sym.name.toTermName, (Nil, Nil))
78+
val newSyms = newSym :: prev._1
79+
newSymbolMap.put(sym.name.toTermName, (newSyms, instantiations)) // Could `.put(...)` bring up (mutability) issues ?
80+
}
81+
val origTParams = poly.resType.paramTypess.flatten // Is this really what is needed ?
82+
val bounds = poly.paramBounds
83+
generateSpecializations(origTParams, bounds)(List.empty, List.empty)
84+
tp
85+
case _ =>
86+
tp
87+
}
88+
}
89+
90+
def registerSpecializationRequest(method: Symbols.Symbol)(arguments: List[Type])(implicit ctx: Context) = {
91+
if(ctx.phaseId > this.treeTransformPhase.id)
92+
assert(ctx.phaseId <= this.treeTransformPhase.id)
93+
val prev = specializationRequests.getOrElse(method, List.empty)
94+
specializationRequests.put(method, arguments :: prev)
95+
}
5096

5197
def specializeForAll(sym: Symbols.Symbol)(implicit ctx: Context): List[List[Type]] = {
52-
registerSpecializationRequest(sym)(specialisedType2Suffix.keys.toList)
98+
registerSpecializationRequest(sym)(specialisedTypeToSuffix.keys.toList)
5399
println("Specializing for all primitive types")
54100
specializationRequests.getOrElse(sym, Nil)
55101
}
@@ -66,7 +112,7 @@ class TypeSpecializer extends MiniPhaseTransform {
66112
annot.arguments match {
67113
case List(SeqLiteral(types)) =>
68114
specializeForSome(sym)(types.map(tpeTree =>
69-
name2SpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString())))
115+
nameToSpecialisedType(ctx)(tpeTree.tpe.asInstanceOf[TermRef].name.toString())))
70116
case List() => specializeForAll(sym)
71117
}
72118
case nil =>
@@ -80,46 +126,43 @@ class TypeSpecializer extends MiniPhaseTransform {
80126
tree.tpe.widen match {
81127

82128
case poly: PolyType if !(tree.symbol.isPrimaryConstructor
83-
|| (tree.symbol is Flags.Label)) => {
129+
|| (tree.symbol is Flags.Label)) =>
84130
val origTParams = tree.tparams.map(_.symbol)
85131
val origVParams = tree.vparamss.flatten.map(_.symbol)
86132
println(s"specializing ${tree.symbol} for Tparams: $origTParams")
87133

88-
def specialize(instatiations: List[Type], names: List[String]): Tree = {
89-
val newSym = ctx.newSymbol(tree.symbol.owner, (tree.name + names.mkString).toTermName, tree.symbol.flags | Flags.Synthetic, poly.instantiate(instatiations.toList))
90-
polyDefDef(newSym, { tparams => vparams => {
91-
assert(tparams.isEmpty)
92-
new TreeTypeMap(
93-
typeMap = _
94-
.substDealias(origTParams, instatiations.toList)
95-
.subst(origVParams, vparams.flatten.map(_.tpe)),
96-
oldOwners = tree.symbol :: Nil,
97-
newOwners = newSym :: Nil
98-
).transform(tree.rhs)
134+
def specialize(instantiations: List[Type]): List[Tree] = {
135+
newSymbolMap(tree.name) match {
136+
case newSyms: (List[Symbol], List[Type]) =>
137+
newSyms._1.map{newSym =>
138+
polyDefDef(newSym, { tparams => vparams => {
139+
assert(tparams.isEmpty)
140+
new TreeTypeMap(
141+
typeMap = _
142+
.substDealias(origTParams, instantiations.toList)
143+
.subst(origVParams, vparams.flatten.map(_.tpe)),
144+
oldOwners = tree.symbol :: Nil,
145+
newOwners = newSym :: Nil
146+
).transform(tree.rhs)
147+
}
148+
})}
149+
case nil =>
150+
List()
99151
}
100-
})
101152
}
102153

103-
def generateSpecializations(remainingTParams: List[TypeDef], remainingBounds: List[TypeBounds])
104-
(instatiations: List[Type],
105-
names: List[String]): Iterable[Tree] = {
106-
if (remainingTParams.nonEmpty) {
107-
val typeToSpecialize = remainingTParams.head
108-
val bounds = remainingBounds.head
109-
shouldSpecializeFor(typeToSpecialize.symbol)
110-
.flatten
111-
.filter{ tpe =>
112-
bounds.contains(tpe)
113-
}.flatMap { tpe =>
114-
val nme = specialisedType2Suffix(ctx)(tpe)
115-
generateSpecializations(remainingTParams.tail, remainingBounds.tail)(tpe :: instatiations, nme :: names)
116-
}
117-
} else
118-
List(specialize(instatiations.reverse, names.reverse))
119-
}
120-
Thicket(tree :: generateSpecializations(tree.tparams, poly.paramBounds)(List.empty, List.empty).toList)
121-
}
154+
val specializedMethods: List[Tree] = (for (inst <- newSymbolMap.keys) yield specialize(newSymbolMap(inst)._2)).flatten.toList
155+
Thicket(tree :: specializedMethods)
156+
122157
case _ => tree
123158
}
124159
}
160+
161+
def transformTypeOfTree(tree: Tree): Tree = {
162+
tree
163+
}
164+
165+
override def transformIdent(tree: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = transformTypeOfTree(tree)
166+
override def transformSelect(tree: Select)(implicit ctx: Context, info: TransformerInfo): Tree = transformTypeOfTree(tree)
167+
125168
}

test/dotc/tests.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,16 @@ class tests extends CompilerTest {
8686
@Test def pos_t2613 = compileFile(posSpecialDir, "t2613")(allowDeepSubtypes)
8787
@Test def pos_packageObj = compileFile(posDir, "i0239")
8888
@Test def pos_anonClassSubtyping = compileFile(posDir, "anonClassSubtyping")
89+
8990
@Test def pos_specialization = compileFile(posDir, "specialization")
9091
9192
@Test def pos_all = compileFiles(posDir) // twice omitted to make tests run faster
9293
9394
@Test def pos_specialization = compileFile(posDir, "specialization")
9495
95-
// contains buggy tests
9696
@Test def pos_all = compileFiles(posDir, failedOther)
9797
98-
*/@Test def pos_SI7638 = compileFile(posDir, "SI-7638")/*
98+
@Test def pos_SI7638 = compileFile(posDir, "SI-7638")
9999
@Test def pos_SI7638a = compileFile(posDir, "SI-7638a")
100100
101101
@@ -195,7 +195,11 @@ class tests extends CompilerTest {
195195
196196
@Test def dotc_transform = compileDir(dotcDir, "transform")// twice omitted to make tests run faster
197197
198+
val javaDir = "./tests/pos/java-interop/"
199+
@Test def java_all = compileFiles(javaDir)
198200
*/
201+
@Test def pos_specialization = compileFile(posDir, "specialization")
202+
199203
//@Test def dotc_compilercommand = compileFile(dotcDir + "tools/dotc/config/", "CompilerCommand")
200204

201205
@Test def dotc_parsing = compileDir(dotcDir, "parsing") // twice omitted to make tests run faster
@@ -231,4 +235,5 @@ class tests extends CompilerTest {
231235
val javaDir = "./tests/pos/java-interop/"
232236
@Test def java_all = compileFiles(javaDir, twice)
233237
//@Test def dotc_compilercommand = compileFile(dotcDir + "config/", "CompilerCommand")
238+
234239
}

tests/pos/specialization.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
class specialization {
2-
def printer[@specialized(Int, Long) T](a: T) = {
3-
println(a)
2+
def printer[@specialized(Int, Long) T, U](a: T, b:U) = {
3+
println(a.toString + b.toString)
44
}
55
}
66

0 commit comments

Comments
 (0)