Skip to content

Commit 9e430a4

Browse files
committed
Extend pattern type constraining to closed hierarchies
1 parent f313d16 commit 9e430a4

File tree

4 files changed

+49
-20
lines changed

4 files changed

+49
-20
lines changed

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -231,23 +231,20 @@ trait PatternTypeConstrainer { self: TypeComparer =>
231231
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
232232
*/
233233
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
234+
def refinementIsInvariantCls(cls: Symbol): Boolean =
235+
cls.isOneOf(Final | Case)
236+
|| cls.is(Sealed) && cls.children.forall { c =>
237+
c == cls // means the `cls` has (at least one) anonymous child - which is effectively final, so that's ok
238+
|| refinementIsInvariantCls(c)
239+
}
240+
234241
def refinementIsInvariant(tp: Type): Boolean = tp match {
235242
case tp: SingletonType => true
236-
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
243+
case tp: ClassInfo => refinementIsInvariantCls(tp.cls)
237244
case tp: TypeProxy => refinementIsInvariant(tp.superType)
238245
case _ => false
239246
}
240247

241-
def widenVariantParams(tp: Type) = tp match {
242-
case tp @ AppliedType(tycon, args) =>
243-
val args1 = args.zipWithConserve(tycon.typeParams)((arg, tparam) =>
244-
if (tparam.paramVarianceSign != 0) TypeBounds.empty else arg
245-
)
246-
tp.derivedAppliedType(tycon, args1)
247-
case tp =>
248-
tp
249-
}
250-
251248
val patternCls = patternTp.classSymbol
252249
val scrutineeCls = scrutineeTp.classSymbol
253250

@@ -258,12 +255,10 @@ trait PatternTypeConstrainer { self: TypeComparer =>
258255
val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
259256
val tp = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
260257

261-
val assumeInvariantRefinement =
262-
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
263-
264258
trace(i"constraining simple pattern type $tp >:< $pt", gadts, (res: Boolean) => i"$res gadt = ${ctx.gadt}") {
265259
(tp, pt) match {
266260
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
261+
val assumeInvariantRefinement = migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
267262
val saved = state.nn.constraint
268263
val result =
269264
ctx.gadtState.rollbackGadtUnless {

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,16 @@ class Namer { typer: Typer =>
892892
val sym = denot.symbol
893893

894894
def register(child: Symbol, parentCls: ClassSymbol) = {
895-
if (parentCls.is(Sealed))
895+
// The class `*:` (aka "Pair") is weird:
896+
// it's patched as the parent of the 22 tuple classes
897+
// but it's also sealed so it should know its children.
898+
// But what can eventually happens is it's queried
899+
// and then later another child is loaded, which blows up:
900+
// "children of class *: were already queried before class Tuple10 was discovered."
901+
// I think it's sealed only so it's not inherited by users.
902+
// So let's just say it has no known sealed children (only subclasses).
903+
// So skip adding any of its subclasses as its children.
904+
if parentCls.is(Sealed) && parentCls != defn.PairClass then
896905
if ((child.isInaccessibleChildOf(parentCls) || child.isAnonymousClass) && !sym.hasAnonymousChild)
897906
addChild(parentCls, parentCls)
898907
else if (!parentCls.is(ChildrenQueried))

compiler/src/dotty/tools/dotc/typer/RefChecks.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -807,11 +807,15 @@ object RefChecks {
807807
* can assume invariant refinement for case classes in `constrainPatternType`.
808808
*/
809809
def checkCaseClassInheritanceInvariant() =
810-
for (caseCls <- clazz.info.baseClasses.tail.find(_.is(Case)))
811-
for (baseCls <- caseCls.info.baseClasses.tail)
812-
if (baseCls.typeParams.exists(_.paramVarianceSign != 0))
813-
for (problem <- variantInheritanceProblems(baseCls, caseCls, "non-variant", "case "))
814-
report.errorOrMigrationWarning(problem, clazz.srcPos, from = `3.0`)
810+
for
811+
middle <- clazz.info.baseClasses.tail
812+
if middle.isOneOf(Case | Sealed)
813+
baseCls <- middle.info.baseClasses.tail
814+
if baseCls.typeParams.exists(_.paramVarianceSign != 0)
815+
middleStr = if middle.is(Case) then "case " else if middle.is(Sealed) then "sealed " else ""
816+
problem <- variantInheritanceProblems(baseCls, middle, "non-variant", middleStr)
817+
do
818+
report.errorOrMigrationWarning(problem, clazz.srcPos, from = `3.0`)
815819
checkNoAbstractMembers()
816820
if (abstractErrors.isEmpty)
817821
checkNoAbstractDecls(clazz)

tests/pos/i4790.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
class Test:
2+
def foo(as: Seq[Int]) =
3+
val List(_, bs: _*) = as: @unchecked
4+
val cs: Seq[Int] = bs
5+
6+
class Test2:
7+
def foo(as: SSeq[Int]) =
8+
val LList(_, tail) = as: @unchecked
9+
val cs: SSeq[Int] = tail
10+
11+
trait SSeq[+A]
12+
sealed trait LList[+A] extends SSeq[A]
13+
final case class CCons[+A](head: A, tail: LList[A]) extends LList[A]
14+
case object NNil extends LList[Nothing]
15+
object LList:
16+
def unapply[A](xs: LList[A]): Extractor[A] = Extractor[A](xs)
17+
final class Extractor[A](private val xs: LList[A]) extends AnyVal:
18+
def get: this.type = this
19+
def isEmpty: Boolean = xs.isInstanceOf[CCons[?]]
20+
def _1: A = xs.asInstanceOf[CCons[A]].head
21+
def _2: SSeq[A] = xs.asInstanceOf[CCons[A]].tail

0 commit comments

Comments
 (0)