Skip to content

Commit 7f7471b

Browse files
committed
Handle singleton type unions when comparing types
1 parent 5d8ea09 commit 7f7471b

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -545,11 +545,18 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
545545
}
546546
compareTypeLambda
547547
case OrType(tp21, tp22) =>
548-
val tp1a = tp1.widenDealiasKeepRefiningAnnots
549-
if (tp1a ne tp1)
550-
// Follow the alias; this might avoid truncating the search space in the either below
551-
// Note that it's safe to widen here because singleton types cannot be part of `|`.
552-
return recur(tp1a, tp2)
548+
def isSingletonlessTypeUnion(tp: Type): Boolean = tp match {
549+
case OrType(a, b) => isSingletonlessTypeUnion(a) && isSingletonlessTypeUnion(b)
550+
case _ => !tp.isSingleton
551+
}
552+
if (isSingletonlessTypeUnion(tp21) && isSingletonlessTypeUnion(tp22)) {
553+
val tp1a = tp1.widenDealiasKeepRefiningAnnots
554+
if (tp1a ne tp1)
555+
// Follow the alias; this might avoid truncating the search space in the either below
556+
// Note that it's safe to widen here because we ensured that
557+
// singleton types are not top-level members of tp2
558+
return recur(tp1a, tp2)
559+
}
553560

554561
// Rewrite T1 <: (T211 & T212) | T22 to T1 <: (T211 | T22) and T1 <: (T212 | T22)
555562
// and analogously for T1 <: T21 | (T221 & T222)

tests/neg/gadt-union-subtyping.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
object `gadt-union-subtyping` {
2+
enum SUB[A, +B] { case Refl[T]() extends SUB[T, T] }
3+
4+
def foo[T](t: T, sub1: 5 SUB T, sub2: 6 SUB T): T = {
5+
(sub1, sub2) match {
6+
case (SUB.Refl(), SUB.Refl()) =>
7+
val a: 5 = t // error
8+
val _t: T = (5 : 5)
9+
???
10+
}
11+
}
12+
13+
def bar[T](t: T, sub1: 5 SUB T, sub2: 6 SUB T, sub3: String SUB T): T = {
14+
(sub1, sub2, sub3) match {
15+
case (SUB.Refl(), SUB.Refl(), SUB.Refl()) =>
16+
val a: 5 = t // error
17+
val b: 6 = t // error
18+
val c: String = t // error
19+
val _t: T = (5 : 5)
20+
???
21+
}
22+
}
23+
24+
def baz[T](t: T, sub1: String SUB T, sub2: 5 SUB T, sub3: 6 SUB T): T = {
25+
(sub1, sub2, sub3) match {
26+
case (SUB.Refl(), SUB.Refl(), SUB.Refl()) =>
27+
val a: 5 = t // error
28+
val b: 6 = t // error
29+
val c: String = t // error
30+
val _t: T = (5 : 5)
31+
???
32+
}
33+
}
34+
}

0 commit comments

Comments
 (0)