diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 7bb942e9ab6a..ac0305deac42 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2404,7 +2404,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling * 1. Single inheritance of classes * 2. Final classes cannot be extended * 3. ConstantTypes with distinct values are non intersecting - * 4. There is no value of type Nothing + * 4. TermRefs with distinct values are non intersecting + * 5. There is no value of type Nothing * * Note on soundness: the correctness of match types relies on on the * property that in all possible contexts, the same match type expression @@ -2412,6 +2413,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling */ def provablyDisjoint(tp1: Type, tp2: Type)(using Context): Boolean = { // println(s"provablyDisjoint(${tp1.show}, ${tp2.show})") + + def isEnumValueOrModule(ref: TermRef): Boolean = + val sym = ref.termSymbol + sym.isAllOf(EnumCase, butNot=JavaDefined) || sym.is(Module) + /** Can we enumerate all instantiations of this type? */ def isClosedSum(tp: Symbol): Boolean = tp.is(Sealed) && tp.isOneOf(AbstractOrTrait) && !tp.hasAnonymousChild @@ -2517,6 +2523,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling provablyDisjoint(gadtBounds(tp1.symbol).hi, tp2) || provablyDisjoint(tp1.superType, tp2) case (_, tp2: NamedType) if gadtBounds(tp2.symbol) != null => provablyDisjoint(tp1, gadtBounds(tp2.symbol).hi) || provablyDisjoint(tp1, tp2.superType) + case (tp1: TermRef, tp2: TermRef) if isEnumValueOrModule(tp1) && isEnumValueOrModule(tp2) => + tp1.termSymbol != tp2.termSymbol case (tp1: TypeProxy, tp2: TypeProxy) => provablyDisjoint(matchTypeSuperType(tp1), tp2) || provablyDisjoint(tp1, matchTypeSuperType(tp2)) case (tp1: TypeProxy, _) => diff --git a/compiler/test/dotc/pos-from-tasty.blacklist b/compiler/test/dotc/pos-from-tasty.blacklist index 51fb4980f30c..b0ea2e35caec 100644 --- a/compiler/test/dotc/pos-from-tasty.blacklist +++ b/compiler/test/dotc/pos-from-tasty.blacklist @@ -8,3 +8,6 @@ t802.scala # missing position rbtree.scala + +# transitive reduction of match types +i10511.scala diff --git a/docs/docs/reference/new-types/match-types.md b/docs/docs/reference/new-types/match-types.md index 72ee3f87384c..b849b9b7e237 100644 --- a/docs/docs/reference/new-types/match-types.md +++ b/docs/docs/reference/new-types/match-types.md @@ -48,7 +48,7 @@ Recursive match type definitions can also be given an upper bound, like this: ```scala type Concat[Xs <: Tuple, +Ys <: Tuple] <: Tuple = Xs match - case Unit => Ys + case EmptyTuple => Ys case x *: xs => x *: Concat[xs, Ys] ``` @@ -126,6 +126,7 @@ Disjointness proofs rely on the following properties of Scala types: 1. Single inheritance of classes 2. Final classes cannot be extended 3. Constant types with distinct values are nonintersecting +4. Singleton paths to distinct values are nonintersecting, such as `object` definitions or singleton enum cases. Type parameters in patterns are minimally instantiated when computing `S <: Pi`. An instantiation `Is` is _minimal_ for `Xs` if all type variables in `Xs` that @@ -240,4 +241,3 @@ main differences here are: whereas match types also work for type parameters and abstract types. - Match types support direct recursion. - Conditional types distribute through union types. - diff --git a/tests/neg/i10511.scala b/tests/neg/i10511.scala new file mode 100644 index 000000000000..4ceaae4141e4 --- /dev/null +++ b/tests/neg/i10511.scala @@ -0,0 +1,18 @@ +enum Bool { + case True + case False +} + +import Bool._ + +type Not[B <: Bool] = B match { + case True.type => False.type + case False.type => True.type +} + +def not[B <: Bool & Singleton](b: B): Not[B] = b match { + case b: False.type => True // error + case b: True.type => False // error +} + +val f: Not[False.type] = False // error: Found: (Bool.False : Bool) Required: (Bool.True : Bool) diff --git a/tests/pos/i10511.scala b/tests/pos/i10511.scala new file mode 100644 index 000000000000..386899d047a7 --- /dev/null +++ b/tests/pos/i10511.scala @@ -0,0 +1,22 @@ +enum Bool { + case True + case False +} + +import Bool._ + +type Not[B <: Bool] = B match { + case True.type => False.type + case False.type => True.type +} + +val t: True.type = True +val f: False.type = False + +val g: Not[False.type] = t + +val t1: Not[f.type] = t // transitivity +val f1: Not[t.type] = f // transitivity + +val t2: Not[f1.type] = t1 // transitivity x2 +val f2: Not[t1.type] = f1 // transitivity x2 diff --git a/tests/run/i10511.scala b/tests/run/i10511.scala new file mode 100644 index 000000000000..f1548ea7482a --- /dev/null +++ b/tests/run/i10511.scala @@ -0,0 +1,31 @@ +enum Bool { + case True + case False + + // just to make sure we are using reference equality + override def equals(a: Any) = false + +} + +import Bool._ + +type Not[B <: Bool] = B match { + case True.type => False.type + case False.type => True.type +} + +def not[B <: Bool & Singleton](b: B): Not[B] = b match { + case b: True.type => False + case b: False.type => True +} + +@main def Test = + + val t: True.type = True + val f: False.type = False + + val t1: Not[False.type] = t + val f1: Not[True.type] = f + + assert(not(True).asInstanceOf[AnyRef] eq False) + assert(not(False).asInstanceOf[AnyRef] eq True)