Skip to content

Commit 849598f

Browse files
committed
Extend pattern type constraining to closed hierarchies
1 parent f313d16 commit 849598f

File tree

4 files changed

+42
-24
lines changed

4 files changed

+42
-24
lines changed

compiler/src/dotty/tools/dotc/core/Flags.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,8 @@ object Flags {
552552
val JavaOrPrivateOrSynthetic: FlagSet = Artifact | JavaDefined | Private | Synthetic
553553
val PrivateOrSynthetic: FlagSet = Artifact | Private | Synthetic
554554
val EnumCase: FlagSet = Case | Enum
555+
val CaseOrFinalOrSealed: FlagSet = Case | Final | Sealed
556+
val CaseOrSealed: FlagSet = Case | Sealed
555557
val CovariantLocal: FlagSet = Covariant | Local // A covariant type parameter
556558
val ContravariantLocal: FlagSet = Contravariant | Local // A contravariant type parameter
557559
val EffectivelyErased = ConstructorProxy | Erased

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -225,29 +225,19 @@ trait PatternTypeConstrainer { self: TypeComparer =>
225225
*
226226
* It'd be unsound for us to say that `t <: T`, even though that follows from `D[t] <: C[T]`.
227227
* Note, however, that if `D` was a final class, we *could* rely on that relationship.
228-
* To support typical case classes, we also assume that this relationship holds for them and their parent traits.
229-
* This is enforced by checking that classes inheriting from case classes do not extend the parent traits of those
230-
* case classes without also appropriately extending the relevant case class
231-
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
228+
* Case classes and sealed traits (and sealed classes) are supported,
229+
* by assuming that this relationship holds for them and their parent traits.
230+
* This is enforced by checking no subclass of them mixes in any parent trait with a different type argument.
231+
* (see `RefChecks#checkVariantInheritanceProblems`).
232232
*/
233233
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
234234
def refinementIsInvariant(tp: Type): Boolean = tp match {
235235
case tp: SingletonType => true
236-
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
236+
case tp: ClassInfo => tp.cls.isOneOf(CaseOrFinalOrSealed)
237237
case tp: TypeProxy => refinementIsInvariant(tp.superType)
238238
case _ => false
239239
}
240240

241-
def widenVariantParams(tp: Type) = tp match {
242-
case tp @ AppliedType(tycon, args) =>
243-
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
244-
if (tparam.paramVarianceSign != 0) TypeBounds.empty else arg
245-
)
246-
tp.derivedAppliedType(tycon, args1)
247-
case tp =>
248-
tp
249-
}
250-
251241
val patternCls = patternTp.classSymbol
252242
val scrutineeCls = scrutineeTp.classSymbol
253243

compiler/src/dotty/tools/dotc/typer/RefChecks.scala

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -802,16 +802,21 @@ object RefChecks {
802802
}
803803
}
804804

805-
/** Check that inheriting a case class does not constitute a variant refinement
805+
/** Check that inheriting a case class or a sealed trait (or sealed class) does not constitute a variant refinement
806806
* of a base type of the case class. It is because of this restriction that we
807-
* can assume invariant refinement for case classes in `constrainPatternType`.
807+
* can assume invariant refinement for these classes in `constrainSimplePatternType`.
808808
*/
809-
def checkCaseClassInheritanceInvariant() =
810-
for (caseCls <- clazz.info.baseClasses.tail.find(_.is(Case)))
811-
for (baseCls <- caseCls.info.baseClasses.tail)
812-
if (baseCls.typeParams.exists(_.paramVarianceSign != 0))
813-
for (problem <- variantInheritanceProblems(baseCls, caseCls, "non-variant", "case "))
814-
report.errorOrMigrationWarning(problem, clazz.srcPos, from = `3.0`)
809+
def checkVariantInheritanceProblems() =
810+
for
811+
middle <- clazz.info.baseClasses.tail
812+
if middle.isOneOf(CaseOrSealed)
813+
baseCls <- middle.info.baseClasses.tail
814+
if baseCls.typeParams.exists(_.paramVarianceSign != 0)
815+
middleStr = if middle.is(Case) then "case " else if middle.is(Sealed) then "sealed " else ""
816+
problem <- variantInheritanceProblems(baseCls, middle, "non-variant", middleStr)
817+
do
818+
report.errorOrMigrationWarning(problem, clazz.srcPos, from = `3.0`)
819+
815820
checkNoAbstractMembers()
816821
if (abstractErrors.isEmpty)
817822
checkNoAbstractDecls(clazz)
@@ -820,7 +825,7 @@ object RefChecks {
820825
report.error(abstractErrorMessage, clazz.srcPos)
821826

822827
checkMemberTypesOK()
823-
checkCaseClassInheritanceInvariant()
828+
checkVariantInheritanceProblems()
824829
}
825830

826831
if (!clazz.is(Trait)) {

tests/pos/i4790.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
class Test:
2+
def foo(as: Seq[Int]) =
3+
val List(_, bs: _*) = as: @unchecked
4+
val cs: Seq[Int] = bs
5+
6+
class Test2:
7+
def foo(as: SSeq[Int]) =
8+
val LList(_, tail) = as: @unchecked
9+
val cs: SSeq[Int] = tail
10+
11+
trait SSeq[+A]
12+
sealed trait LList[+A] extends SSeq[A]
13+
final case class CCons[+A](head: A, tail: LList[A]) extends LList[A]
14+
case object NNil extends LList[Nothing]
15+
object LList:
16+
def unapply[A](xs: LList[A]): Extractor[A] = Extractor[A](xs)
17+
final class Extractor[A](private val xs: LList[A]) extends AnyVal:
18+
def get: this.type = this
19+
def isEmpty: Boolean = xs.isInstanceOf[CCons[?]]
20+
def _1: A = xs.asInstanceOf[CCons[A]].head
21+
def _2: SSeq[A] = xs.asInstanceOf[CCons[A]].tail

0 commit comments

Comments
 (0)