Skip to content

Commit b1d9162

Browse files
authored
Merge pull request #4045 from dotty-staging/fix-3324
Fix #3324: add `isInstanceOf` check
2 parents 62be430 + 2b00b7a commit b1d9162

21 files changed

+384
-7
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class Compiler {
8181
new CrossCastAnd, // Normalize selections involving intersection types.
8282
new Splitter) :: // Expand selections involving union types into conditionals
8383
List(new ErasedDecls, // Removes all erased defs and vals decls (except for parameters)
84+
new IsInstanceOfChecker, // check runtime realisability for `isInstanceOf`
8485
new VCInlineMethods, // Inlines calls to value class methods
8586
new SeqLiterals, // Express vararg arguments as arrays
8687
new InterceptedMethods, // Special handling of `==`, `|=`, `getClass` methods

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ object Trees {
805805
def unforced: AnyRef
806806
protected def force(x: AnyRef): Unit
807807
def forceIfLazy(implicit ctx: Context): T = unforced match {
808-
case lzy: Lazy[T] =>
808+
case lzy: Lazy[T @unchecked] =>
809809
val x = lzy.complete
810810
force(x)
811811
x
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import util.Positions._
5+
import MegaPhase.MiniPhase
6+
import core._
7+
import Contexts.Context, Types._, Decorators._, Symbols._, typer._, ast._, NameKinds._
8+
import TypeUtils._, Flags._
9+
import config.Printers.{ transforms => debug }
10+
11+
/** Check runtime realizability of type test, see the documentation for `Checkable`.
12+
*/
13+
class IsInstanceOfChecker extends MiniPhase {
14+
15+
import ast.tpd._
16+
17+
val phaseName = "isInstanceOfChecker"
18+
19+
override def transformTypeApply(tree: TypeApply)(implicit ctx: Context): Tree = {
20+
def ensureCheckable(qual: Tree, pt: Tree): Tree = {
21+
if (!Checkable.checkable(qual.tpe, pt.tpe, tree.pos))
22+
ctx.warning(
23+
s"the type test for ${pt.show} cannot be checked at runtime",
24+
tree.pos
25+
)
26+
27+
tree
28+
}
29+
30+
tree.fun match {
31+
case fn: Select
32+
if fn.symbol == defn.Any_typeTest || fn.symbol == defn.Any_isInstanceOf =>
33+
ensureCheckable(fn.qualifier, tree.args.head)
34+
case _ => tree
35+
}
36+
}
37+
}
38+
39+
object Checkable {
40+
import Inferencing._
41+
import ProtoTypes._
42+
43+
/** Whether `(x:X).isInstanceOf[P]` can be checked at runtime?
44+
*
45+
* First do the following substitution:
46+
* (a) replace `T @unchecked` and pattern binder types (e.g., `_$1`) in P with WildcardType
47+
* (b) replace pattern binder types (e.g., `_$1`) in X:
48+
* - variance = 1 : hiBound
49+
* - variance = -1 : loBound
50+
* - variance = 0 : OrType(Any, Nothing) // TODO: use original type param bounds
51+
*
52+
* Then check:
53+
*
54+
* 1. if `X <:< P`, TRUE
55+
* 2. if `P` is a singleton type, TRUE
56+
* 3. if `P` refers to an abstract type member or type parameter, FALSE
57+
* 4. if `P = Array[T]`, checkable(E, T) where `E` is the element type of `X`, defaults to `Any`.
58+
* 5. if `P` is `pre.F[Ts]` and `pre.F` refers to a class which is not `Array`:
59+
* (a) replace `Ts` with fresh type variables `Xs`
60+
* (b) constrain `Xs` with `pre.F[Xs] <:< X`
61+
* (c) instantiate Xs and check `pre.F[Xs] <:< P`
62+
* 6. if `P = T1 | T2` or `P = T1 & T2`, checkable(X, T1) && checkable(X, T2).
63+
* 7. if `P` is a refinement type, FALSE
64+
* 8. otherwise, TRUE
65+
*/
66+
def checkable(X: Type, P: Type, pos: Position)(implicit ctx: Context): Boolean = {
67+
def isAbstract(P: Type) = !P.dealias.typeSymbol.isClass
68+
def isPatternTypeSymbol(sym: Symbol) = !sym.isClass && sym.is(Case)
69+
70+
def replaceP(implicit ctx: Context) = new TypeMap {
71+
def apply(tp: Type) = tp match {
72+
case tref: TypeRef
73+
if isPatternTypeSymbol(tref.typeSymbol) => WildcardType
74+
case AnnotatedType(_, annot)
75+
if annot.symbol == defn.UncheckedAnnot => WildcardType
76+
case _ => mapOver(tp)
77+
}
78+
}
79+
80+
def replaceX(implicit ctx: Context) = new TypeMap {
81+
def apply(tp: Type) = tp match {
82+
case tref: TypeRef
83+
if isPatternTypeSymbol(tref.typeSymbol) =>
84+
if (variance == 1) tref.info.hiBound
85+
else if (variance == -1) tref.info.loBound
86+
else OrType(defn.AnyType, defn.NothingType)
87+
case _ => mapOver(tp)
88+
}
89+
}
90+
91+
def isClassDetermined(X: Type, P: AppliedType)(implicit ctx: Context) = {
92+
val AppliedType(tycon, _) = P
93+
val typeLambda = tycon.ensureLambdaSub.asInstanceOf[TypeLambda]
94+
val tvars = constrained(typeLambda, untpd.EmptyTree, alwaysAddTypeVars = true)._2.map(_.tpe)
95+
val P1 = tycon.appliedTo(tvars)
96+
97+
debug.println("P : " + P.show)
98+
debug.println("P1 : " + P1.show)
99+
debug.println("X : " + X.show)
100+
101+
P1 <:< X // may fail, ignore
102+
103+
val res = isFullyDefined(P1, ForceDegree.noBottom) && P1 <:< P
104+
debug.println("P1 : " + P1)
105+
debug.println("P1 <:< P = " + res)
106+
107+
res
108+
}
109+
110+
def recur(X: Type, P: Type): Boolean = (X <:< P) || (P match {
111+
case _: SingletonType => true
112+
case _: TypeProxy
113+
if isAbstract(P) => false
114+
case defn.ArrayOf(tpT) =>
115+
X match {
116+
case defn.ArrayOf(tpE) => recur(tpE, tpT)
117+
case _ => recur(defn.AnyType, tpT)
118+
}
119+
case tpe: AppliedType => isClassDetermined(X, tpe)(ctx.fresh.setNewTyperState())
120+
case AndType(tp1, tp2) => recur(X, tp1) && recur(X, tp2)
121+
case OrType(tp1, tp2) => recur(X, tp1) && recur(X, tp2)
122+
case AnnotatedType(t, _) => recur(X, t)
123+
case _: RefinedType => false
124+
case _ => true
125+
})
126+
127+
val res = recur(replaceX.apply(X.widen), replaceP.apply(P))
128+
129+
debug.println(i"checking ${X.show} isInstanceOf ${P} = $res")
130+
131+
res
132+
}
133+
}

compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import ast.Trees._
1111
import ast.untpd
1212
import Decorators._
1313
import NameOps._
14+
import Annotations.Annotation
1415
import ValueClasses.isDerivedValueClass
1516
import scala.collection.mutable.ListBuffer
1617
import scala.language.postfixOps
@@ -152,17 +153,20 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
152153
* def equals(that: Any): Boolean =
153154
* (this eq that) || {
154155
* that match {
155-
* case x$0 @ (_: C) => this.x == this$0.x && this.y == x$0.y
156+
* case x$0 @ (_: C @unchecked) => this.x == this$0.x && this.y == x$0.y
156157
* case _ => false
157158
* }
158159
* ```
159160
*
160161
* If `C` is a value class the initial `eq` test is omitted.
162+
*
163+
* `@unchecked` is needed for parametric case classes.
164+
*
161165
*/
162166
def equalsBody(that: Tree)(implicit ctx: Context): Tree = {
163167
val thatAsClazz = ctx.newSymbol(ctx.owner, nme.x_0, Synthetic, clazzType, coord = ctx.owner.pos) // x$0
164168
def wildcardAscription(tp: Type) = Typed(Underscore(tp), TypeTree(tp))
165-
val pattern = Bind(thatAsClazz, wildcardAscription(clazzType)) // x$0 @ (_: C)
169+
val pattern = Bind(thatAsClazz, wildcardAscription(AnnotatedType(clazzType, Annotation(defn.UncheckedAnnot)))) // x$0 @ (_: C @unchecked)
166170
val comparisons = accessors map { accessor =>
167171
This(clazz).select(accessor).equal(ref(thatAsClazz).select(accessor)) }
168172
val rhs = // this.x == this$0.x && this.y == x$0.y
@@ -250,10 +254,12 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
250254
* gets the `canEqual` method
251255
*
252256
* ```
253-
* def canEqual(that: Any) = that.isInstanceOf[C]
257+
* def canEqual(that: Any) = that.isInstanceOf[C @unchecked]
254258
* ```
259+
*
260+
* `@unchecked` is needed for parametric case classes.
255261
*/
256-
def canEqualBody(that: Tree): Tree = that.isInstance(clazzType)
262+
def canEqualBody(that: Tree): Tree = that.isInstance(AnnotatedType(clazzType, Annotation(defn.UncheckedAnnot)))
257263

258264
symbolsToSynthesize flatMap syntheticDefIfMissing
259265
}

compiler/test/dotty/tools/dotc/CompilationTests.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class CompilationTests extends ParallelTesting {
7878
) +
7979
compileFilesInDir("tests/pos-special/spec-t5545", defaultOptions) +
8080
compileFilesInDir("tests/pos-special/strawman-collections", defaultOptions) +
81+
compileFilesInDir("tests/pos-special/isInstanceOf", allowDeepSubtypes.and("-Xfatal-warnings")) +
8182
compileFile("scala2-library/src/library/scala/collection/immutable/IndexedSeq.scala", defaultOptions) +
8283
compileFile("scala2-library/src/library/scala/collection/parallel/mutable/ParSetLike.scala", defaultOptions) +
8384
compileList(
@@ -190,6 +191,7 @@ class CompilationTests extends ParallelTesting {
190191
compileFile("tests/neg-custom-args/noimports2.scala", defaultOptions.and("-Yno-imports")) +
191192
compileFile("tests/neg-custom-args/i3882.scala", allowDeepSubtypes) +
192193
compileFile("tests/neg-custom-args/i1754.scala", allowDeepSubtypes) +
194+
compileFilesInDir("tests/neg-custom-args/isInstanceOf", allowDeepSubtypes and "-Xfatal-warnings") +
193195
compileFile("tests/neg-custom-args/i3627.scala", allowDeepSubtypes)
194196
}.checkExpectedErrors()
195197

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class Test {
2+
def remove[S](a: S | Int, f: Int => S):S = a match {
3+
case a: S => a // error
4+
case a: Int => f(a)
5+
}
6+
7+
val t: Int | String = 5
8+
val t1 = remove[String](t, _.toString)
9+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
class C[T] {
2+
val x: Any = ???
3+
if (x.isInstanceOf[List[String]]) // error: unchecked
4+
if (x.isInstanceOf[T]) // error: unchecked
5+
x match {
6+
case x: List[String] => // error: unchecked
7+
case x: T => // error: unchecked
8+
}
9+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
trait C[T]
2+
class D[T]
3+
4+
class Test {
5+
def foo[T](x: C[T]) = x match {
6+
case _: D[T] => // error
7+
}
8+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class Test {
2+
trait A[+T]
3+
class B[T] extends A[T]
4+
class C[T] extends B[Any] with A[T]
5+
6+
def foo[T](c: C[T]): Unit = c match {
7+
case _: B[T] => // error
8+
}
9+
10+
def bar[T](b: B[T]): Unit = b match {
11+
case _: A[T] =>
12+
}
13+
14+
def quux[T](a: A[T]): Unit = a match {
15+
case _: B[T] => // should be an error!!
16+
}
17+
18+
quux(new C[Int])
19+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
object Test {
2+
trait Foo
3+
case class One[+T](fst: T)
4+
5+
def bad[T <: Foo](e: One[T])(x: T) = e match {
6+
case foo: One[a] =>
7+
x.isInstanceOf[a] // error
8+
val y: Any = ???
9+
y.isInstanceOf[a] // error
10+
}
11+
}
12+
13+
object Test2 {
14+
case class One[T](fst: T)
15+
16+
def bad[T](e: One[T])(x: T) = e match {
17+
case foo: One[a] =>
18+
x.isInstanceOf[a] // error
19+
val y: Any = ???
20+
y.isInstanceOf[a] // error
21+
}
22+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
sealed trait Exp[T]
2+
case class Fun[A, B](f: Exp[A => B]) extends Exp[A => B]
3+
4+
class Test {
5+
def eval[T](e: Exp[T]) = e match {
6+
case Fun(x: Fun[Int, Double]) => ??? // error
7+
case Fun(x: Exp[Int => String]) => ??? // error
8+
}
9+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class Foo {
2+
def foo(x: Any): Boolean =
3+
x.isInstanceOf[List[String]] // error
4+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
sealed trait A[T]
2+
class B[T] extends A[T]
3+
4+
class Test {
5+
def f(x: B[Int]) = x match { case _: A[Int] if true => }
6+
7+
def g(x: A[Int]) = x match { case _: B[Int] => }
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
class Test {
2+
val x: Any = ???
3+
4+
x match {
5+
case _: List[Int @unchecked] => 5
6+
case _: List[Int] @unchecked => 5
7+
}
8+
}
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
class C[T] {
2+
val x: T = ???
3+
x.isInstanceOf[T]
4+
5+
val y: Array[T] = ???
6+
7+
y match {
8+
case x: Array[T] =>
9+
}
10+
11+
type F[X]
12+
13+
val z: F[T] = ???
14+
z match {
15+
case x: F[T] =>
16+
}
17+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
object Test {
2+
trait Marker
3+
def foo[T](x: T) = x match {
4+
case _: (T & Marker) => // no warning
5+
case _: T with Marker => // scalac emits a warning
6+
case _ =>
7+
}
8+
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
object p {
2+
3+
// test parametric case classes, which synthesis `canEqual` and `equals`
4+
enum Result[+T, +E] {
5+
case OK [T](x: T) extends Result[T, Nothing]
6+
case Err[E](e: E) extends Result[Nothing, E]
7+
}
8+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import scala.reflect.ClassTag
2+
3+
object IsInstanceOfClassTag {
4+
def safeCast[T: ClassTag](x: Any): Option[T] = {
5+
x match {
6+
case x: T => Some(x)
7+
case _ => None
8+
}
9+
}
10+
}

0 commit comments

Comments
 (0)