Skip to content

Commit 9d07d52

Browse files
authored
Merge pull request #15134 from dwijnand/local-classes-are-uncheckable
Local classes are uncheckable (type tests)
2 parents 6a62bb7 + 3ca5b70 commit 9d07d52

File tree

5 files changed

+214
-15
lines changed

5 files changed

+214
-15
lines changed

compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ object TypeTestsCasts {
3131
import typer.Inferencing.maximizeType
3232
import typer.ProtoTypes.constrained
3333

34-
/** Whether `(x:X).isInstanceOf[P]` can be checked at runtime?
34+
/** Whether `(x: X).isInstanceOf[P]` can be checked at runtime?
3535
*
3636
* First do the following substitution:
3737
* (a) replace `T @unchecked` and pattern binder types (e.g., `_$1`) in P with WildcardType
@@ -48,7 +48,8 @@ object TypeTestsCasts {
4848
* (c) maximize `pre.F[Xs]` and check `pre.F[Xs] <:< P`
4949
* 6. if `P = T1 | T2` or `P = T1 & T2`, checkable(X, T1) && checkable(X, T2).
5050
* 7. if `P` is a refinement type, FALSE
51-
* 8. otherwise, TRUE
51+
* 8. if `P` is a local class which is not statically reachable from the scope where `X` is defined, FALSE
52+
* 9. otherwise, TRUE
5253
*/
5354
def checkable(X: Type, P: Type, span: Span)(using Context): Boolean = atPhase(Phases.refchecksPhase.next) {
5455
// Run just before ElimOpaque transform (which follows RefChecks)
@@ -152,10 +153,13 @@ object TypeTestsCasts {
152153
case AnnotatedType(t, _) => recur(X, t)
153154
case tp2: RefinedType => recur(X, tp2.parent) && TypeComparer.hasMatchingMember(tp2.refinedName, X, tp2)
154155
case tp2: RecType => recur(X, tp2.parent)
156+
case _
157+
if P.classSymbol.isLocal && foundClasses(X).exists(P.classSymbol.isInaccessibleChildOf) => // 8
158+
false
155159
case _ => true
156160
})
157161

158-
val res = recur(X.widen, replaceP(P))
162+
val res = X.widenTermRefExpr.hasAnnotation(defn.UncheckedAnnot) || recur(X.widen, replaceP(P))
159163

160164
debug.println(i"checking ${X.show} isInstanceOf ${P} = $res")
161165

@@ -174,15 +178,6 @@ object TypeTestsCasts {
174178
def derivedTree(expr1: Tree, sym: Symbol, tp: Type) =
175179
cpy.TypeApply(tree)(expr1.select(sym).withSpan(expr.span), List(TypeTree(tp)))
176180

177-
def effectiveClass(tp: Type): Symbol =
178-
if tp.isRef(defn.PairClass) then effectiveClass(erasure(tp))
179-
else if tp.isRef(defn.AnyValClass) then defn.AnyClass
180-
else tp.classSymbol
181-
182-
def foundClasses(tp: Type, acc: List[Symbol]): List[Symbol] = tp.dealias match
183-
case OrType(tp1, tp2) => foundClasses(tp2, foundClasses(tp1, acc))
184-
case _ => effectiveClass(tp) :: acc
185-
186181
def inMatch =
187182
tree.fun.symbol == defn.Any_typeTest || // new scheme
188183
expr.symbol.is(Case) // old scheme
@@ -251,7 +246,7 @@ object TypeTestsCasts {
251246
if expr.tpe.isBottomType then
252247
report.warning(TypeTestAlwaysDiverges(expr.tpe, testType), tree.srcPos)
253248
val nestedCtx = ctx.fresh.setNewTyperState()
254-
val foundClsSyms = foundClasses(expr.tpe.widen, Nil)
249+
val foundClsSyms = foundClasses(expr.tpe.widen)
255250
val sensical = checkSensical(foundClsSyms)(using nestedCtx)
256251
if (!sensical) {
257252
nestedCtx.typerState.commit()
@@ -272,7 +267,7 @@ object TypeTestsCasts {
272267
def transformAsInstanceOf(testType: Type): Tree = {
273268
def testCls = effectiveClass(testType.widen)
274269
def foundClsSymPrimitive = {
275-
val foundClsSyms = foundClasses(expr.tpe.widen, Nil)
270+
val foundClsSyms = foundClasses(expr.tpe.widen)
276271
foundClsSyms.size == 1 && foundClsSyms.head.isPrimitiveValueClass
277272
}
278273
if (erasure(expr.tpe) <:< testType)
@@ -372,4 +367,16 @@ object TypeTestsCasts {
372367
}
373368
interceptWith(expr)
374369
}
370+
371+
private def effectiveClass(tp: Type)(using Context): Symbol =
372+
if tp.isRef(defn.PairClass) then effectiveClass(erasure(tp))
373+
else if tp.isRef(defn.AnyValClass) then defn.AnyClass
374+
else tp.classSymbol
375+
376+
private[transform] def foundClasses(tp: Type)(using Context): List[Symbol] =
377+
def go(tp: Type, acc: List[Type])(using Context): List[Type] = tp.dealias match
378+
case OrType(tp1, tp2) => go(tp2, go(tp1, acc))
379+
case AndType(tp1, tp2) => (for t1 <- go(tp1, Nil); t2 <- go(tp2, Nil); yield AndType(t1, t2)) ::: acc
380+
case _ => tp :: acc
381+
go(tp, Nil).map(effectiveClass)
375382
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package dotty.tools
2+
package dotc
3+
package transform
4+
5+
import core.*
6+
import Contexts.*, Decorators.*, Denotations.*, SymDenotations.*, Symbols.*, Types.*
7+
import Annotations.*
8+
9+
import org.junit.Test
10+
import org.junit.Assert.*
11+
12+
class TypeTestsCastsTest extends DottyTest:
13+
val defn = ctx.definitions; import defn.*
14+
15+
@Test def orL = checkFound(List(StringType, LongType), OrType(LongType, StringType, false))
16+
@Test def orR = checkFound(List(LongType, StringType), OrType(StringType, LongType, false))
17+
18+
@Test def annot = checkFound(List(StringType, LongType), AnnotatedType(OrType(LongType, StringType, false), Annotation(defn.UncheckedAnnot)))
19+
20+
@Test def andL = checkFound(List(StringType), AndType(StringType, AnyType))
21+
@Test def andR = checkFound(List(StringType), AndType(AnyType, StringType))
22+
@Test def andX = checkFound(List(NoType), AndType(StringType, BooleanType))
23+
24+
// (A | B) & C => {(A & B), (A & C)}
25+
// A & (B | C) => {(A & B), (A & C)}
26+
// (A | B) & (C | D) => {(A & C), (A & D), (B & C), (B & D)}
27+
@Test def orInAndL = checkFound(List(StringType, LongType), AndType(OrType(LongType, StringType, false), AnyType))
28+
@Test def orInAndR = checkFound(List(StringType, LongType), AndType(AnyType, OrType(LongType, StringType, false)))
29+
@Test def orInAndZ =
30+
// (Throwable | Exception) & (RuntimeException | Any) =
31+
// Throwable & RuntimeException = RuntimeException
32+
// Throwable & Any = Throwable
33+
// Exception & RuntimeException = RuntimeException
34+
// Exception & Any = Exception
35+
val ExceptionType = defn.ExceptionClass.typeRef
36+
val RuntimeExceptionType = defn.RuntimeExceptionClass.typeRef
37+
val tp = AndType(OrType(ThrowableType, ExceptionType, false), OrType(RuntimeExceptionType, AnyType, false))
38+
val exp = List(ExceptionType, RuntimeExceptionType, ThrowableType, RuntimeExceptionType)
39+
checkFound(exp, tp)
40+
41+
def checkFound(found: List[Type], tp: Type) =
42+
val expected = found.map(_.classSymbol)
43+
val obtained = TypeTestsCasts.foundClasses(tp)
44+
assertEquals(expected, obtained)
45+
end TypeTestsCastsTest

tests/neg/i4812.check

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
-- Error: tests/neg/i4812.scala:8:11 -----------------------------------------------------------------------------------
2+
8 | case prev: A => // error: the type test for A cannot be checked at runtime
3+
| ^
4+
| the type test for A cannot be checked at runtime
5+
-- Error: tests/neg/i4812.scala:18:11 ----------------------------------------------------------------------------------
6+
18 | case prev: A => // error: the type test for A cannot be checked at runtime
7+
| ^
8+
| the type test for A cannot be checked at runtime
9+
-- Error: tests/neg/i4812.scala:28:11 ----------------------------------------------------------------------------------
10+
28 | case prev: A => // error: the type test for A cannot be checked at runtime
11+
| ^
12+
| the type test for A cannot be checked at runtime
13+
-- Error: tests/neg/i4812.scala:38:11 ----------------------------------------------------------------------------------
14+
38 | case prev: A => // error: the type test for A cannot be checked at runtime
15+
| ^
16+
| the type test for A cannot be checked at runtime
17+
-- Error: tests/neg/i4812.scala:50:13 ----------------------------------------------------------------------------------
18+
50 | case prev: A => // error: the type test for A cannot be checked at runtime
19+
| ^
20+
| the type test for A cannot be checked at runtime
21+
-- Error: tests/neg/i4812.scala:60:11 ----------------------------------------------------------------------------------
22+
60 | case prev: A => // error: the type test for A cannot be checked at runtime
23+
| ^
24+
| the type test for A cannot be checked at runtime
25+
-- Error: tests/neg/i4812.scala:96:11 ----------------------------------------------------------------------------------
26+
96 | case x: B => // error: the type test for B cannot be checked at runtime
27+
| ^
28+
| the type test for B cannot be checked at runtime

tests/neg/i4812.scala

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// scalac: -Werror
2+
object Test:
3+
var prev: Any = _
4+
5+
def test[T](x: T): T =
6+
class A(val elem: (T, Boolean))
7+
prev match
8+
case prev: A => // error: the type test for A cannot be checked at runtime
9+
prev.elem._1
10+
case _ =>
11+
prev = new A((x, true))
12+
x
13+
14+
def test2[T](x: T): T =
15+
abstract class Parent(_elem: T) { def elem: T = _elem }
16+
class A extends Parent(x)
17+
prev match
18+
case prev: A => // error: the type test for A cannot be checked at runtime
19+
prev.elem
20+
case _ =>
21+
prev = new A
22+
x
23+
24+
def test3[T](x: T): T =
25+
class Holder(val elem: T)
26+
class A(val holder: Holder)
27+
prev match
28+
case prev: A => // error: the type test for A cannot be checked at runtime
29+
prev.holder.elem
30+
case _ =>
31+
prev = new A(new Holder(x))
32+
x
33+
34+
def test4[T](x: T): T =
35+
class Holder(val elem: (Int, (Unit, (T, Boolean))))
36+
class A { var holder: Holder = null }
37+
prev match
38+
case prev: A => // error: the type test for A cannot be checked at runtime
39+
prev.holder.elem._2._2._1
40+
case _ =>
41+
val a = new A
42+
a.holder = new Holder((42, ((), (x, true))))
43+
prev = a
44+
x
45+
46+
class Foo[U]:
47+
def test5(x: U): U =
48+
class A(val elem: U)
49+
prev match
50+
case prev: A => // error: the type test for A cannot be checked at runtime
51+
prev.elem
52+
case _ =>
53+
prev = new A(x)
54+
x
55+
56+
def test6[T](x: T): T =
57+
class A { var b: B = null }
58+
class B { var a: A = null; var elem: T = _ }
59+
prev match
60+
case prev: A => // error: the type test for A cannot be checked at runtime
61+
prev.b.elem
62+
case _ =>
63+
val a = new A
64+
val b = new B
65+
b.elem = x
66+
a.b = b
67+
prev = a
68+
x
69+
70+
def test7[T](x: T): T =
71+
class A(val elem: T)
72+
prev match
73+
case prev: A @unchecked => prev.elem
74+
case _ => prev = new A(x); x
75+
76+
def test8[T](x: T): T =
77+
class A(val elem: T)
78+
val p = prev
79+
(p: @unchecked) match
80+
case prev: A => prev.elem
81+
case _ => prev = new A(x); x
82+
83+
def test9 =
84+
trait A
85+
class B extends A
86+
val x: A = new B
87+
x match
88+
case x: B => x
89+
90+
sealed class A
91+
var prevA: A = _
92+
def test10: A =
93+
val methodCallId = System.nanoTime()
94+
class B(val id: Long) extends A
95+
prevA match
96+
case x: B => // error: the type test for B cannot be checked at runtime
97+
x.ensuring(x.id == methodCallId, s"Method call id $methodCallId != ${x.id}")
98+
case _ =>
99+
val x = new B(methodCallId)
100+
prevA = x
101+
x
102+
103+
def test11 =
104+
trait A
105+
trait B
106+
class C extends A with B
107+
val x: A = new C
108+
x match
109+
case x: B => x
110+
111+
def test12 =
112+
class Foo
113+
class Bar
114+
val x: Foo | Bar = new Foo
115+
x.isInstanceOf[Foo]
116+
117+
def main(args: Array[String]): Unit =
118+
test(1)
119+
val x: String = test("") // was: ClassCastException: java.lang.Integer cannot be cast to java.lang.String

0 commit comments

Comments
 (0)