Skip to content

Commit a4dfdec

Browse files
AlexSikiaDarkDimius
authored andcommitted
Add casts, and debug implementation
An issue occurs when trying to specialize certain methods when relying on typer only - this is described by scala#592 , and occured in test `this_specialization`. # with '#' will be ignored, and an empty message aborts the commit.
1 parent 349b4e6 commit a4dfdec

File tree

5 files changed

+38
-40
lines changed

5 files changed

+38
-40
lines changed

src/dotty/tools/dotc/transform/TypeSpecializer.scala

Lines changed: 20 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package dotty.tools.dotc.transform
22

33
import dotty.tools.dotc.ast.{tpd, TreeTypeMap}
44
import dotty.tools.dotc.ast.Trees._
5-
import dotty.tools.dotc.core.Annotations.Annotation
65
import dotty.tools.dotc.core.Contexts.Context
76
import dotty.tools.dotc.core.Decorators.StringDecorator
87
import dotty.tools.dotc.core.DenotTransformers.InfoTransformer
@@ -13,6 +12,7 @@ import dotty.tools.dotc.core.Types._
1312
import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, MiniPhaseTransform}
1413
import scala.collection.mutable
1514
import dotty.tools.dotc.core.StdNames.nme
15+
import dotty.tools._
1616

1717
class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
1818
import tpd._
@@ -175,36 +175,34 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
175175
val tmap: (Tree => Tree) = _ match {
176176
case Return(t, from) if from.symbol == tree.symbol => Return(t, ref(newSym))
177177
case t: TypeApply => transformTypeApply(t)
178-
case t: Apply =>
179-
transformApply(t)
178+
case t: Apply => transformApply(t)
180179
case t => t
181180
}
182-
val tp = new TreeMap() {
183-
// needed to workaround https://github.com/lampepfl/dotty/issues/592
184-
override def transform(t: Tree)(implicit ctx: Context) = super.transform(t) match {
185-
case t @ Apply(fun, args) =>
186-
val newArgs = (args zip fun.tpe.firstParamTypes).map{case(t, tpe) => t.ensureConforms(tpe)}
187-
if (sameTypes(args, newArgs)) {
188-
t
189-
} else tpd.Apply(fun, newArgs)
190-
case t: ValDef =>
191-
cpy.ValDef(t)(rhs = t.rhs.ensureConforms(t.tpe.widen))
192-
case t: DefDef =>
193-
cpy.DefDef(t)(rhs = t.rhs.ensureConforms(t.tpe.finalResultType))
194-
case t => t
195-
}
196-
}
197181

198182
val typesReplaced = new TreeTypeMap(
199183
treeMap = tmap,
200184
typeMap = _
201185
.substDealias(origTParams, instantiations(index))
202-
.subst(origVParams, vparams.flatten.map(_.tpe))
203-
,
186+
.subst(origVParams, vparams.flatten.map(_.tpe)),
204187
oldOwners = tree.symbol :: Nil,
205188
newOwners = newSym :: Nil
206189
).transform(tree.rhs)
207190

191+
val tp = new TreeMap() {
192+
// needed to workaround https://github.com/lampepfl/dotty/issues/592
193+
override def transform(t: Tree)(implicit ctx: Context) = super.transform(t) match {
194+
case t @ Apply(fun, args) =>
195+
assert(sameLength(args, fun.tpe.widen.firstParamTypes))
196+
val newArgs = (args zip fun.tpe.widen.firstParamTypes).map{case(t, tpe) => t.ensureConforms(tpe)}
197+
if (sameTypes(args, newArgs)) {
198+
t
199+
} else tpd.Apply(fun, newArgs)
200+
case t: ValDef =>
201+
cpy.ValDef(t)(rhs = if(t.rhs.isEmpty) EmptyTree else t.rhs.ensureConforms(t.tpt.tpe))
202+
case t: DefDef =>
203+
cpy.DefDef(t)(rhs = if(t.rhs.isEmpty) EmptyTree else t.rhs.ensureConforms(t.tpt.tpe))
204+
case t => t
205+
}}
208206
val expectedTypeFixed = tp.transform(typesReplaced)
209207
expectedTypeFixed.ensureConforms(newSym.info.widen.finalResultType)
210208
}})
@@ -227,23 +225,13 @@ class TypeSpecializer extends MiniPhaseTransform with InfoTransformer {
227225
val Apply(fun, args) = tree
228226
fun match {
229227
case fun: TypeApply => {
230-
println(
231-
s"""
232-
|args -> ${args}
233-
234-
|f.fun -> ${fun.fun.tree}
235-
""".stripMargin)
236-
237228
val newFun = rewireTree(fun)
238229
if (fun ne newFun) {
239-
val b = (args zip newFun.tpe.firstParamTypes)
240-
val a = b.map{
230+
val as = (args zip newFun.tpe.widen.firstParamTypes).map{
241231
case (arg, tpe) =>
242232
arg.ensureConforms(tpe)
243233
}
244-
Apply(newFun,a)
245-
/* zip (instantiations zip paramTypes)).map{
246-
case (argType, (specType, castType)) => argType.ensureConforms(specType)})*/
234+
Apply(newFun,as)
247235
} else tree
248236
}
249237
case _ => tree

tests/pos/specialization/method_in_method_specialization.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,10 @@ object method_in_method_specialization {
1010

1111
outer(2)
1212
outer('d')
13+
14+
def outer2[@specialized(Int) O](o: O): Int = {
15+
def inner2[@specialized(Int) I] (i: I) = 1
16+
inner2(42)
17+
}
18+
outer2(1)
1319
}

tests/pos/specialization/multi_specialization.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,8 @@ object multi_specialization {
22
def one[@specialized T](n: T): T = n
33
def two[@specialized T, U](n: T, m: U): (T,U) = (n,m)
44
def three[@specialized T, U, V](n: T, m: U, o: V): (T,U,V) = (n,m,o)
5+
6+
one(1)
7+
two(1,2)
8+
two('a', null)
59
}
Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
object nothing_specialization {
2-
def ret_nothing[@specialized T] = {
2+
def ret_nothing[@specialized(Char) T] = {
33
//val a: List[T] = List[Nothing]()
4-
def apply[@specialized X](xs : X*) : List[X] = List(xs:_*)
5-
def apply6[@specialized X](xs : Nothing*) : List[Nothing] = List(xs: _*)
4+
def apply[@specialized(Char) X](xs : X*) : List[X] = List(xs:_*)
5+
def apply6[@specialized(Char) X](xs : Nothing*) : List[Nothing] = List(xs: _*)
6+
def apply2[@specialized(Long) U] = 1.asInstanceOf[U]
67
}
78
}
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
1-
class specialization {
1+
trait specialization {
22
def printer1[@specialized(Int, Long) T](a: T) = {
33
println(a.toString)
44
}
5-
65
def printer2[@specialized(Int, Long) T, U](a: T, b: U) = {
76
println(a.toString + b.toString)
87
}
9-
def print(a: Int) = {
10-
printer1(a)
8+
def print(i: Int) = {
9+
printer1(i)
1110
println(" ---- ")
12-
printer2(a,a)
11+
printer2(i,i)
1312
}
1413
print(9)
1514
}

0 commit comments

Comments
 (0)