Skip to content

Commit ef6102a

Browse files
committed
Test TypeTestsCasts.foundClasses & handle nesting in AndTypes
1 parent 7f51a86 commit ef6102a

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ object TypeTestsCasts {
154154
case tp2: RefinedType => recur(X, tp2.parent) && TypeComparer.hasMatchingMember(tp2.refinedName, X, tp2)
155155
case tp2: RecType => recur(X, tp2.parent)
156156
case _
157-
if P.classSymbol.isLocal && foundClasses(X, Nil).exists(P.classSymbol.isInaccessibleChildOf) => // 8
157+
if P.classSymbol.isLocal && foundClasses(X).exists(P.classSymbol.isInaccessibleChildOf) => // 8
158158
false
159159
case _ => true
160160
})
@@ -246,7 +246,7 @@ object TypeTestsCasts {
246246
if expr.tpe.isBottomType then
247247
report.warning(TypeTestAlwaysDiverges(expr.tpe, testType), tree.srcPos)
248248
val nestedCtx = ctx.fresh.setNewTyperState()
249-
val foundClsSyms = foundClasses(expr.tpe.widen, Nil)
249+
val foundClsSyms = foundClasses(expr.tpe.widen)
250250
val sensical = checkSensical(foundClsSyms)(using nestedCtx)
251251
if (!sensical) {
252252
nestedCtx.typerState.commit()
@@ -267,7 +267,7 @@ object TypeTestsCasts {
267267
def transformAsInstanceOf(testType: Type): Tree = {
268268
def testCls = effectiveClass(testType.widen)
269269
def foundClsSymPrimitive = {
270-
val foundClsSyms = foundClasses(expr.tpe.widen, Nil)
270+
val foundClsSyms = foundClasses(expr.tpe.widen)
271271
foundClsSyms.size == 1 && foundClsSyms.head.isPrimitiveValueClass
272272
}
273273
if (erasure(expr.tpe) <:< testType)
@@ -373,7 +373,10 @@ object TypeTestsCasts {
373373
else if tp.isRef(defn.AnyValClass) then defn.AnyClass
374374
else tp.classSymbol
375375

376-
private def foundClasses(tp: Type, acc: List[Symbol])(using Context): List[Symbol] = tp.dealias match
377-
case OrType(tp1, tp2) => foundClasses(tp2, foundClasses(tp1, acc))
378-
case _ => effectiveClass(tp) :: acc
376+
private[transform] def foundClasses(tp: Type)(using Context): List[Symbol] =
377+
def go(tp: Type, acc: List[Type])(using Context): List[Type] = tp.dealias match
378+
case OrType(tp1, tp2) => go(tp2, go(tp1, acc))
379+
case AndType(tp1, tp2) => (for t1 <- go(tp1, Nil); t2 <- go(tp2, Nil); yield AndType(t1, t2)) ::: acc
380+
case _ => tp :: acc
381+
go(tp, Nil).map(effectiveClass)
379382
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package dotty.tools
2+
package dotc
3+
package transform
4+
5+
import core.*
6+
import Contexts.*, Decorators.*, Denotations.*, SymDenotations.*, Symbols.*, Types.*
7+
import Annotations.*
8+
9+
import org.junit.Test
10+
import org.junit.Assert.*
11+
12+
class TypeTestsCastsTest extends DottyTest:
13+
val defn = ctx.definitions; import defn.*
14+
15+
@Test def orL = checkFound(List(StringType, LongType), OrType(LongType, StringType, false))
16+
@Test def orR = checkFound(List(LongType, StringType), OrType(StringType, LongType, false))
17+
18+
@Test def annot = checkFound(List(StringType, LongType), AnnotatedType(OrType(LongType, StringType, false), Annotation(defn.UncheckedAnnot)))
19+
20+
@Test def andL = checkFound(List(StringType), AndType(StringType, AnyType))
21+
@Test def andR = checkFound(List(StringType), AndType(AnyType, StringType))
22+
@Test def andX = checkFound(List(NoType), AndType(StringType, BooleanType))
23+
24+
// (A | B) & C => {(A & B), (A & C)}
25+
// A & (B | C) => {(A & B), (A & C)}
26+
// (A | B) & (C | D) => {(A & C), (A & D), (B & C), (B & D)}
27+
@Test def orInAndL = checkFound(List(StringType, LongType), AndType(OrType(LongType, StringType, false), AnyType))
28+
@Test def orInAndR = checkFound(List(StringType, LongType), AndType(AnyType, OrType(LongType, StringType, false)))
29+
@Test def orInAndZ =
30+
// (Throwable | Exception) & (RuntimeException | Any) =
31+
// Throwable & RuntimeException = RuntimeException
32+
// Throwable & Any = Throwable
33+
// Exception & RuntimeException = RuntimeException
34+
// Exception & Any = Exception
35+
val ExceptionType = defn.ExceptionClass.typeRef
36+
val RuntimeExceptionType = defn.RuntimeExceptionClass.typeRef
37+
val tp = AndType(OrType(ThrowableType, ExceptionType, false), OrType(RuntimeExceptionType, AnyType, false))
38+
val exp = List(ExceptionType, RuntimeExceptionType, ThrowableType, RuntimeExceptionType)
39+
checkFound(exp, tp)
40+
41+
def checkFound(found: List[Type], tp: Type) =
42+
val expected = found.map(_.classSymbol)
43+
val obtained = TypeTestsCasts.foundClasses(tp)
44+
assertEquals(expected, obtained)
45+
end TypeTestsCastsTest

0 commit comments

Comments
 (0)