Skip to content

Commit ab1c1c5

Browse files
AlexSikiaAlexSikia
AlexSikia
authored andcommitted
Adapt instance of TreeTypeMap to map trees recursively
Solves an issue involving return values not being specialized. Test `return_specialization` illustrates this.
1 parent 10e9ef1 commit ab1c1c5

File tree

6 files changed

+92
-30
lines changed

6 files changed

+92
-30
lines changed

src/dotty/tools/dotc/transform/TypeSpecializer.scala

Lines changed: 68 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,23 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
101101
(implicit ctx: Context): List[Symbol] = {
102102
val newSym =
103103
ctx.newSymbol(decl.owner, (decl.name + names.mkString).toTermName,
104-
decl.flags | Flags.Synthetic, poly.instantiate(instantiations.toList))
104+
decl.flags | Flags.Synthetic, poly.instantiate(instantiations.toList))
105+
106+
/* The following generated symbols which kept type bounds. It served, as illustrated by the `this_specialization`
107+
* test, as a way of keeping type bounds when instantiating a `this` referring to a generic class. However,
108+
* because type bounds are not transitive, this did not work out and we introduced casts instead.
109+
*
110+
* ctx.newSymbol(decl.owner, (decl.name + names.mkString).toTermName,
111+
* decl.flags | Flags.Synthetic,
112+
* poly.derivedPolyType(poly.paramNames,
113+
* (poly.paramBounds zip instantiations).map
114+
* {case (bounds, instantiation) =>
115+
* TypeBounds(bounds.lo, AndType(bounds.hi, instantiation))},
116+
* poly.instantiate(indices, instantiations)
117+
* )
118+
* )
119+
*/
120+
105121
val map = newSymbolMap.getOrElse(decl, mutable.HashMap.empty)
106122
map.put(instantiations, newSym)
107123
newSymbolMap.put(decl, map)
@@ -153,48 +169,85 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
153169
val instantiations = declSpecs.keys.toArray
154170
var index = -1
155171
println(s"specializing ${tree.symbol} for $origTParams")
156-
newSyms.map { newSym =>
172+
newSyms.map { newSym =>
157173
index += 1
158174
polyDefDef(newSym.asTerm, { tparams => vparams => {
159-
assert(tparams.isEmpty)
175+
val tmap: (Tree => Tree) = _ match {
176+
case Return(t, from) if from.symbol == tree.symbol => Return(t, ref(newSym))
177+
case t: TypeApply => transformTypeApply(t)
178+
case t: Apply => transformApply(t)
179+
case t => t
180+
}
181+
160182
new TreeTypeMap(
183+
treeMap = tmap,
161184
typeMap = _
162-
.substDealias(origTParams, instantiations(index))
163-
.subst(origVParams, vparams.flatten.map(_.tpe)),
185+
.substDealias(origTParams, instantiations(index))
186+
.subst(origVParams, vparams.flatten.map(_.tpe))
187+
,
164188
oldOwners = tree.symbol :: Nil,
165189
newOwners = newSym :: Nil
166190
).transform(tree.rhs)
167191
}})
168192
}
169193
} else Nil
170194
}
171-
Thicket(tree :: specialize(tree.symbol))
195+
val specialized_trees = specialize(tree.symbol)
196+
Thicket(tree :: specialized_trees)
172197
case _ => tree
173198
}
174199
}
175200

176201
override def transformTypeApply(tree: tpd.TypeApply)(implicit ctx: Context, info: TransformerInfo): Tree = {
202+
val TypeApply(fun, _) = tree
203+
if (fun.tpe.isParameterless) rewireTree(tree)
204+
tree
205+
}
177206

207+
override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo): Tree = {
208+
val Apply(fun, args) = tree
209+
fun match {
210+
case fun: TypeApply => {
211+
println(
212+
s"""
213+
|args -> ${args}
178214

179-
def allowedToSpecialize(sym: Symbol): Boolean = {
180-
sym.name != nme.asInstanceOf_ &&
181-
!(sym is Flags.JavaDefined) &&
182-
!sym.isConstructor//isPrimaryConstructor
215+
|f.fun -> ${fun.fun.tree}
216+
""".stripMargin)
217+
218+
val newFun = rewireTree(fun)
219+
if (fun ne newFun) {
220+
val b = (args zip newFun.tpe.firstParamTypes)
221+
val a = b.map{
222+
case (arg, tpe) =>
223+
arg.ensureConforms(tpe)
224+
}
225+
Apply(newFun,a)
226+
/* zip (instantiations zip paramTypes)).map{
227+
case (argType, (specType, castType)) => argType.ensureConforms(specType)})*/
228+
} else tree
229+
}
230+
case _ => tree
183231
}
232+
}
184233

234+
def rewireTree(tree: Tree)(implicit ctx: Context): Tree = {
235+
assert(tree.isInstanceOf[TypeApply])
185236
val TypeApply(fun,args) = tree
186237
if (newSymbolMap.contains(fun.symbol)){
187238
val newSymInfos = newSymbolMap(fun.symbol)
188239
val betterDefs = newSymInfos.filter(x => (x._1 zip args).forall{a =>
189-
val specializedType = a._1
190-
val argType = a._2
240+
val specializedType = a._1
241+
val argType = a._2
191242
argType.tpe <:< specializedType
192243
}).toList
193244

194-
if (betterDefs.length > 1) ctx.debuglog("Several specialized variants fit.")
195-
//assert(betterDefs.length < 2) // TODO: How to select the best if there are several ?
245+
if (betterDefs.length > 1) {
246+
ctx.debuglog("Several specialized variants fit.")
247+
tree
248+
}
196249

197-
if (betterDefs.nonEmpty) {
250+
else if (betterDefs.nonEmpty) {
198251
val best = betterDefs.head
199252
println(s"method ${fun.symbol.name} of ${fun.symbol.owner} rewired to specialized variant with type (${best._1})")
200253
val prefix = fun match {
@@ -213,4 +266,3 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
213266
} else tree
214267
}
215268
}
216-

test/dotc/tests.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,16 +240,18 @@ class tests extends CompilerTest {
240240
val javaDir = "./tests/pos/java-interop/"
241241
@Test def java_all = compileFiles(javaDir, twice)
242242

243-
//@Test def specialization = compileFile(specialDir, "specialization")//, specialise)
244-
//@Test def mutual_spec = compileFile(specialDir, "mutual_specialization")
243+
*/
244+
//@Test def specialization = compileFile(specialDir, "specialization")
245+
//@Test def mutual_spec = compileFile(specialDir, "mutual_specialization", List("-Xprint:all"))
245246
//@Test def return_spec = compileFile(specialDir, "return_specialization")
246-
//@Test def nothing_spec = compileFile(specialDir, "nothing_specialization")
247-
//@Test def method_in_class_spec = compileFile(specialDir, "method_in_class_specialization")
248-
//@Test def method_in_method_spec = compileFile(specialDir, "method_in_method_specialization", List("-Xprint:all"))
249-
@Test def type_check_spec = compileFile(specialDir, "type_check_specialization")
250-
//@Test def bounds_spec = compileFile(specialDir, "bounds_specialization", List("-Xprint:all"))
251-
//@Test def multi_spec = compileFile(specialDir, "multi_specialization", List("-Xprint:all"))
252-
//@Test def pos_spec_all = compileFiles(specialDir)
247+
// @Test def nothing_spec = compileFile(specialDir, "nothing_specialization")
248+
// @Test def method_in_class_spec = compileFile(specialDir, "method_in_class_specialization")
249+
// @Test def method_in_method_spec = compileFile(specialDir, "method_in_method_specialization")
250+
// @Test def pos_type_check = compileFile(specialDir, "type_test")
251+
// @Test def bounds_spec = compileFile(specialDir, "bounds_specialization")
252+
// @Test def multi_spec = compileFile(specialDir, "multi_specialization")
253+
// @Test def pos_spec_all = compileFiles(specialDir)
254+
@Test def pos_this_specialization = compileFile(specialDir, "this_specialization", List("-Xprint:specialize"))
253255

254256
//@Test def mini_method = compileFiles(miniMethodDir)//, List("-Xprint:all"))
255257
//@Test def mini_more = compileFiles(miniMoreDir)//, List("-Xprint:all"))
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
object return_specialization {
2-
def qwa[@specialized T](a: (String, String) => T, b: T): T = {
3-
if(a ne this) return a("1", "2")
2+
def qwa[@specialized(Int) T](a: (T, T) => T, b: T): T = {
3+
if(a ne this) return a(b, b)
44
else b
55
}
66
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
sealed abstract class Foo[@specialized +A] {
2+
def bop[@specialized B >: A]: Foo[B] = new Bar[B](this)
3+
//def bip[@specialized C >: A, @specialized D >: A]: Foo[D] = new Cho[D, C](new Bar[C](this))
4+
}
5+
6+
case class Bar[@specialized a](tl: Foo[a]) extends Foo[a]
7+
8+
//case class Cho[@specialized c, @specialized d](tl: Bar[d]) extends Foo[c]

tests/pos/specialization/type_check_specialization.scala

Lines changed: 0 additions & 3 deletions
This file was deleted.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
object type_test {
2+
def typeTest(i: Char): Unit = i.isInstanceOf[Int]
3+
}

0 commit comments

Comments
 (0)