Skip to content

Commit 61e9372

Browse files
committed
Fix #1490: type test of union types via type alias
1 parent 5032f71 commit 61e9372

File tree

4 files changed

+44
-32
lines changed

4 files changed

+44
-32
lines changed

src/dotty/tools/dotc/core/TypeErasure.scala

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,10 @@ object TypeErasure {
112112
/** The standard erasure of a Scala type. Value classes are erased as normal classes.
113113
*
114114
* @param tp The type to erase.
115-
*/
116-
def erasure(tp: Type)(implicit ctx: Context): Type =
117-
erasureFn(isJava = false, semiEraseVCs = false, isConstructor = false, wildcardOK = false)(tp)(erasureCtx)
115+
* @param eraseOrType Whether erase OrType to its lowest uppper bound
116+
*/
117+
def erasure(tp: Type, eraseOrType: Boolean = true)(implicit ctx: Context): Type =
118+
erasureFn(isJava = false, semiEraseVCs = false, isConstructor = false, wildcardOK = false)(tp, eraseOrType)(erasureCtx)
118119

119120
/** The value class erasure of a Scala type, where value classes are semi-erased to
120121
* ErasedValueType (they will be fully erased in [[ElimErasedValueType]]).
@@ -337,34 +338,34 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean
337338
* - For NoType or NoPrefix, the type itself.
338339
* - For any other type, exception.
339340
*/
340-
private def apply(tp: Type)(implicit ctx: Context): Type = tp match {
341+
private def apply(tp: Type, eraseOrType: Boolean = true)(implicit ctx: Context): Type = tp match {
341342
case _: ErasedValueType =>
342343
tp
343344
case tp: TypeRef =>
344345
val sym = tp.symbol
345-
if (!sym.isClass) this(tp.info)
346+
if (!sym.isClass) this(tp.info, eraseOrType)
346347
else if (semiEraseVCs && isDerivedValueClass(sym)) eraseDerivedValueClassRef(tp)
347348
else if (sym == defn.ArrayClass) apply(tp.appliedTo(TypeBounds.empty)) // i966 shows that we can hit a raw Array type.
348349
else eraseNormalClassRef(tp)
349350
case tp: RefinedType =>
350351
val parent = tp.parent
351352
if (parent isRef defn.ArrayClass) eraseArray(tp)
352-
else this(parent)
353+
else this(parent, eraseOrType)
353354
case _: TermRef | _: ThisType =>
354355
this(tp.widen)
355356
case SuperType(thistpe, supertpe) =>
356-
SuperType(this(thistpe), this(supertpe))
357+
SuperType(this(thistpe, eraseOrType), this(supertpe, eraseOrType))
357358
case ExprType(rt) =>
358359
defn.FunctionClass(0).typeRef
359360
case tp: TypeProxy =>
360-
this(tp.underlying)
361+
this(tp.underlying, eraseOrType)
361362
case AndType(tp1, tp2) =>
362-
erasedGlb(this(tp1), this(tp2), isJava)
363+
erasedGlb(this(tp1, eraseOrType), this(tp2, eraseOrType), isJava)
363364
case OrType(tp1, tp2) =>
364-
ctx.typeComparer.orType(this(tp1), this(tp2), erased = true)
365+
ctx.typeComparer.orType(this(tp1, eraseOrType), this(tp2, eraseOrType), erased = eraseOrType)
365366
case tp: MethodType =>
366367
def paramErasure(tpToErase: Type) =
367-
erasureFn(tp.isJava, semiEraseVCs, isConstructor, wildcardOK)(tpToErase)
368+
erasureFn(tp.isJava, semiEraseVCs, isConstructor, wildcardOK)(tpToErase, eraseOrType)
368369
val formals = tp.paramTypes.mapConserve(paramErasure)
369370
eraseResult(tp.resultType) match {
370371
case rt: MethodType =>
@@ -373,14 +374,14 @@ class TypeErasure(isJava: Boolean, semiEraseVCs: Boolean, isConstructor: Boolean
373374
tp.derivedMethodType(tp.paramNames, formals, rt)
374375
}
375376
case tp: PolyType =>
376-
this(tp.resultType) match {
377+
this(tp.resultType, eraseOrType) match {
377378
case rt: MethodType => rt
378379
case rt => MethodType(Nil, Nil, rt)
379380
}
380381
case tp @ ClassInfo(pre, cls, classParents, decls, _) =>
381382
if (cls is Package) tp
382383
else {
383-
def eraseTypeRef(p: TypeRef) = this(p).asInstanceOf[TypeRef]
384+
def eraseTypeRef(p: TypeRef) = this(p, eraseOrType).asInstanceOf[TypeRef]
384385
val parents: List[TypeRef] =
385386
if ((cls eq defn.ObjectClass) || cls.isPrimitiveValueClass) Nil
386387
else classParents.mapConserve(eraseTypeRef) match {

src/dotty/tools/dotc/transform/TypeTestsCasts.scala

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,19 @@ trait TypeTestsCasts {
5151
case _: SingletonType =>
5252
val cmpOp = if (argType derivesFrom defn.AnyValClass) defn.Any_equals else defn.Object_eq
5353
expr.select(cmpOp).appliedTo(singleton(argType))
54+
case OrType(tp1, tp2) =>
55+
evalOnce(qual) { fun =>
56+
val erased1 = transformIsInstanceOf(fun, tp1)
57+
val erased2 = transformIsInstanceOf(fun, tp2)
58+
erased1 match {
59+
case Literal(Constant(true)) => erased1
60+
case _ =>
61+
erased2 match {
62+
case Literal(Constant(true)) => erased2
63+
case _ => erased1 or erased2
64+
}
65+
}
66+
}
5467
case AndType(tp1, tp2) =>
5568
evalOnce(expr) { fun =>
5669
val erased1 = transformIsInstanceOf(fun, tp1)
@@ -93,26 +106,8 @@ trait TypeTestsCasts {
93106
derivedTree(qual, defn.Any_asInstanceOf, argType)
94107
}
95108

96-
/** Transform isInstanceOf OrType
97-
*
98-
* expr.isInstanceOf[A | B] ~~> expr.isInstanceOf[A] | expr.isInstanceOf[B]
99-
*
100-
* The transform happens before erasure of `argType`, thus cannot be merged
101-
* with `transformIsInstanceOf`, which depends on erased type of `argType`.
102-
*/
103-
def transformOrTypeTest(qual: Tree, argType: Type): Tree = argType match {
104-
case OrType(tp1, tp2) =>
105-
evalOnce(qual) { fun =>
106-
transformOrTypeTest(fun, tp1)
107-
.select(nme.OR)
108-
.appliedTo(transformOrTypeTest(fun, tp2))
109-
}
110-
case _ =>
111-
transformIsInstanceOf(qual, erasure(argType))
112-
}
113-
114109
if (sym eq defn.Any_isInstanceOf)
115-
transformOrTypeTest(qual, tree.args.head.tpe)
110+
transformIsInstanceOf(qual, erasure(tree.args.head.tpe, false))
116111
else if (sym eq defn.Any_asInstanceOf)
117112
transformAsInstanceOf(erasure(tree.args.head.tpe))
118113
else tree

tests/run/i1490.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
true
2+
true
3+
false

tests/run/i1490.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
class Base {
2+
type T = Int | Boolean
3+
def test(x: Object) = x.isInstanceOf[T]
4+
}
5+
6+
object Test {
7+
def main(args: Array[String]) = {
8+
val b = new Base
9+
println(b.test(Int.box(3)))
10+
println(b.test(Boolean.box(false)))
11+
println(b.test(Double.box(3.4)))
12+
}
13+
}

0 commit comments

Comments
 (0)