diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index 4c33d6e3e933..3fd5a2b9f208 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -73,17 +73,17 @@ trait PatternTypeConstrainer { self: TypeComparer => * scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables, * in which case the subtyping relationship "heals" the type. */ - def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) { + def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) { def classesMayBeCompatible: Boolean = { import Flags._ - val patClassSym = pat.classSymbol - val scrutClassSym = scrut.classSymbol - !patClassSym.exists || !scrutClassSym.exists || { - if (patClassSym.is(Final)) patClassSym.derivesFrom(scrutClassSym) - else if (scrutClassSym.is(Final)) scrutClassSym.derivesFrom(patClassSym) - else if (!patClassSym.is(Flags.Trait) && !scrutClassSym.is(Flags.Trait)) - patClassSym.derivesFrom(scrutClassSym) || scrutClassSym.derivesFrom(patClassSym) + val patCls = pat.classSymbol + val scrCls = scrut.classSymbol + !patCls.exists || !scrCls.exists || { + if (patCls.is(Final)) patCls.derivesFrom(scrCls) + else if (scrCls.is(Final)) scrCls.derivesFrom(patCls) + else if (!patCls.is(Flags.Trait) && !scrCls.is(Flags.Trait)) + patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls) else true } } @@ -93,6 +93,14 @@ trait PatternTypeConstrainer { self: TypeComparer => case tp => tp } + def tryConstrainSimplePatternType(pat: Type, scrut: Type) = { + val patCls = pat.classSymbol + val scrCls = scrut.classSymbol + patCls.exists && scrCls.exists + && (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls)) + && constrainSimplePatternType(pat, scrut, forceInvariantRefinement) + } + def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) { // Fold a list of types into an AndType def buildAndType(xs: List[Type]): Type = { @@ -113,7 +121,7 @@ trait PatternTypeConstrainer { self: TypeComparer => val andType = buildAndType(parents) !andType.exists || constrainPatternType(pat, andType) case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass => - val patClassSym = pat.classSymbol + val patCls = pat.classSymbol // find all shared parents in the inheritance hierarchy between pat and scrut def allParentsSharedWithPat(tp: Type, tpClassSym: ClassSymbol): List[Symbol] = { var parents = tpClassSym.info.parents @@ -121,7 +129,7 @@ trait PatternTypeConstrainer { self: TypeComparer => parents = parents.tail parents flatMap { tp => val sym = tp.classSymbol.asClass - if patClassSym.derivesFrom(sym) then List(sym) + if patCls.derivesFrom(sym) then List(sym) else allParentsSharedWithPat(tp, sym) } } @@ -135,19 +143,31 @@ trait PatternTypeConstrainer { self: TypeComparer => case _ => NoType } if (upcasted.exists) - constrainSimplePatternType(pat, upcasted, widenParams) || constrainUpcasted(upcasted) + tryConstrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted) else true } } - scrut.dealias match { + def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match { + case tp: TermRef => + // we drop TermRefs that don't have a class symbol, as they can't + // meaningfully participate in GADT reasoning and just get in the way. + // Their info could, for an example, be an AndType. One example where + // this is important is an enum case that extends its parent and an + // additional trait - argument-less enum cases desugar to vals. + // See run/enum-Tree.scala. + if tp.classSymbol.exists then tp else tp.info + case tp => tp + } + + dealiasDropNonmoduleRefs(scrut) match { case OrType(scrut1, scrut2) => either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2)) case AndType(scrut1, scrut2) => constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2) case scrut: RefinedOrRecType => constrainPatternType(pat, stripRefinement(scrut)) - case scrut => pat.dealias match { + case scrut => dealiasDropNonmoduleRefs(pat) match { case OrType(pat1, pat2) => either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut)) case AndType(pat1, pat2) => @@ -155,22 +175,23 @@ trait PatternTypeConstrainer { self: TypeComparer => case pat: RefinedOrRecType => constrainPatternType(stripRefinement(pat), scrut) case pat => - constrainSimplePatternType(pat, scrut, widenParams) || classesMayBeCompatible && constrainUpcasted(scrut) + tryConstrainSimplePatternType(pat, scrut) + || classesMayBeCompatible && constrainUpcasted(scrut) } } } /** Constrain "simple" patterns (see `constrainPatternType`). * - * This function attempts to modify pattern and scrutinee type s.t. the pattern must be a subtype of the scrutinee, - * or otherwise it cannot possibly match. In order to do that, we: - * - * 1. Rely on `constrainPatternType` to break the actual scrutinee/pattern types into subcomponents - * 2. Widen type parameters of scrutinee type that are not invariantly refined (see below) by the pattern type. - * 3. Wrap the pattern type in a skolem to avoid overconstraining top-level abstract types in scrutinee type - * 4. Check that `WidenedScrutineeType <: NarrowedPatternType` + * This function expects to receive two types (scrutinee and pattern), both + * of which have class symbols, one of which is derived from another. If the + * type "being derived from" is an applied type, it will 1) "upcast" the + * deriving type to an applied type with the same constructor and 2) infer + * constraints for the applied types' arguments that follow from both + * types being inhabited by one value (the scrutinee). * - * Importantly, note that the pattern type may contain type variables. + * Importantly, note that the pattern type may contain type variables, which + * are used to infer type arguments to Unapply trees. * * ## Invariant refinement * Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if @@ -194,7 +215,7 @@ trait PatternTypeConstrainer { self: TypeComparer => * case classes without also appropriately extending the relevant case class * (see `RefChecks#checkCaseClassInheritanceInvariant`). */ - def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, widenParams: Boolean): Boolean = { + def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = { def refinementIsInvariant(tp: Type): Boolean = tp match { case tp: SingletonType => true case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case) @@ -212,13 +233,53 @@ trait PatternTypeConstrainer { self: TypeComparer => tp } - val widePt = - if migrateTo3 || refinementIsInvariant(patternTp) then scrutineeTp - else if widenParams then widenVariantParams(scrutineeTp) - else scrutineeTp - val narrowTp = SkolemType(patternTp) - trace(i"constraining simple pattern type $narrowTp <:< $widePt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { - isSubType(narrowTp, widePt) + val patternCls = patternTp.classSymbol + val scrutineeCls = scrutineeTp.classSymbol + + // NOTE: we already know that there is a derives-from relationship in either direction + val upcastPattern = + patternCls.derivesFrom(scrutineeCls) + + val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp + val tp = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp + + val assumeInvariantRefinement = + migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp) + + trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") { + (tp, pt) match { + case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) => + val saved = state.constraint + val savedGadt = ctx.gadt.fresh + val result = + tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) => + val variance = param.paramVarianceSign + if variance != 0 && !assumeInvariantRefinement then true + else if argS.isInstanceOf[TypeBounds] || argP.isInstanceOf[TypeBounds] then + // Passing TypeBounds to isSubType on LHS or RHS does the + // incorrect thing and infers unsound constraints, while simply + // returning true is sound. However, I believe that it should + // still be possible to extract useful constraints here. + // TODO extract GADT information out of wildcard type arguments + true + else { + var res = true + if variance < 1 then res &&= isSubType(argS, argP) + if variance > -1 then res &&= isSubType(argP, argS) + res + } + } + if !result then + constraint = saved + ctx.gadt.restore(savedGadt) + result + case _ => + // Give up if we don't get AppliedType, e.g. if we upcasted to Any. + // Note that this doesn't mean that patternTp, scrutineeTp cannot possibly + // be co-inhabited, just that we cannot extract information out of them directly + // and should upcast. + false + } } } } diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 47f36255de81..76b4c80eb061 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1275,14 +1275,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling else if tp1 eq tp2 then true else val saved = constraint + val savedGadt = ctx.gadt.fresh + inline def restore() = + state.constraint = saved + ctx.gadt.restore(savedGadt) val savedSuccessCount = successCount try recCount += 1 if recCount >= Config.LogPendingSubTypesThreshold then monitored = true val result = if monitored then monitoredIsSubType else firstTry recCount -= 1 - if !result then - state.constraint = saved + if !result then restore() else if recCount == 0 && needsGc then state.gc() needsGc = false @@ -1291,7 +1294,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling catch case NonFatal(ex) => if ex.isInstanceOf[AssertionError] then showGoal(tp1, tp2) recCount -= 1 - state.constraint = saved + restore() successCount = savedSuccessCount throw ex } @@ -2763,8 +2766,8 @@ object TypeComparer { def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type = comparing(_.dropTransparentTraits(tp, bound)) - def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true)(using Context): Boolean = - comparing(_.constrainPatternType(pat, scrut, widenParams)) + def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean = + comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement)) def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:")(using Context): String = comparing(_.explained(op, header)) diff --git a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala index 2679fbeaf94d..c0b1b83e9a9e 100644 --- a/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala +++ b/compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala @@ -98,8 +98,10 @@ object TypeTestsCasts { // // If we perform widening, we will get X = Nothing, and we don't have // Ident[X] <:< Ident[Int] any more. - TypeComparer.constrainPatternType(P1, X, widenParams = false) - debug.println(TypeComparer.explained(_.constrainPatternType(P1, X, widenParams = false))) + TypeComparer.constrainPatternType(P1, X, forceInvariantRefinement = true) + debug.println( + TypeComparer.explained(_.constrainPatternType(P1, X, forceInvariantRefinement = true)) + ) } // Maximization of the type means we try to cover all possible values diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index cda1daafd252..d3bc77070c8a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -3840,9 +3840,15 @@ class Typer extends Namer // approximate type params with bounds def approx = new ApproximatingTypeMap { + var alreadyExpanding: List[TypeRef] = Nil def apply(tp: Type) = tp.dealias match case tp: TypeRef if !tp.symbol.isClass => - expandBounds(tp.info.bounds) + if alreadyExpanding contains tp then tp else + val saved = alreadyExpanding + alreadyExpanding ::= tp + val res = expandBounds(tp.info.bounds) + alreadyExpanding = saved + res case _ => mapOver(tp) } diff --git a/tests/neg/gadt-contradictory-pattern.scala b/tests/neg/gadt-contradictory-pattern.scala new file mode 100644 index 000000000000..561c0c23d518 --- /dev/null +++ b/tests/neg/gadt-contradictory-pattern.scala @@ -0,0 +1,13 @@ +object Test { + sealed abstract class Foo[T] + case object Bar1 extends Foo[Int] + case object Bar2 extends Foo[String] + case object Bar3 extends Foo[AnyRef] + + def fail4[T <: AnyRef](xx: (Foo[T], Foo[T])) = xx match { + case (Bar1, Bar1) => () // error // error + case (Bar2, Bar3) => () + case (Bar3, _) => () + } + +} diff --git a/tests/neg/i11103.scala b/tests/neg/i11103.scala new file mode 100644 index 000000000000..6892f9ad30b2 --- /dev/null +++ b/tests/neg/i11103.scala @@ -0,0 +1,16 @@ +@main def test: Unit = { + class Foo + class Bar + + trait UpBnd[+A] + trait P extends UpBnd[Foo] + + def pmatch[A, T <: UpBnd[A]](s: T): A = s match { + case p: P => + new Foo // error + } + + class UpBndAndB extends UpBnd[Bar] with P + // ClassCastException: Foo cannot be cast to Bar + val x = pmatch(new UpBndAndB) +} diff --git a/tests/pos/i9740c.scala b/tests/neg/i9740c.scala similarity index 92% rename from tests/pos/i9740c.scala rename to tests/neg/i9740c.scala index 968355711e19..87881c9b20d7 100644 --- a/tests/pos/i9740c.scala +++ b/tests/neg/i9740c.scala @@ -11,6 +11,6 @@ class Foo { def bar[A <: Txn[A]](x: Exp[A]): Unit = x match case IntExp(x) => case StrExp(x) => - case UnitExp => + case UnitExp => // error case Obj(o) => } diff --git a/tests/pos/i9740b.scala b/tests/neg/i9740d.scala similarity index 89% rename from tests/pos/i9740b.scala rename to tests/neg/i9740d.scala index 412e8a95dc27..9f2490b697b6 100644 --- a/tests/pos/i9740b.scala +++ b/tests/neg/i9740d.scala @@ -7,5 +7,5 @@ class Foo[U <: Int, T <: U] { def bar[A <: T](x: Exp[A]): Unit = x match case IntExp(x) => case StrExp(x) => - case UnitExp => -} \ No newline at end of file + case UnitExp => // error +} diff --git a/tests/patmat/exhausting.check b/tests/patmat/exhausting.check index ff3536046ce5..cb1662883aa1 100644 --- a/tests/patmat/exhausting.check +++ b/tests/patmat/exhausting.check @@ -3,4 +3,4 @@ 32: Pattern Match Exhaustivity: List(_, _*) 39: Pattern Match Exhaustivity: Bar3 44: Pattern Match Exhaustivity: (Bar2, Bar2) -50: Pattern Match Exhaustivity: (Bar2, Bar2) +49: Pattern Match Exhaustivity: (Bar2, Bar2) diff --git a/tests/patmat/exhausting.scala b/tests/patmat/exhausting.scala index 640c9e88b100..9f17fae9def5 100644 --- a/tests/patmat/exhausting.scala +++ b/tests/patmat/exhausting.scala @@ -42,7 +42,6 @@ object Test { } // fails for: (Bar2, Bar2) def fail4[T <: AnyRef](xx: (Foo[T], Foo[T])) = xx match { - case (Bar1, Bar1) => () case (Bar2, Bar3) => () case (Bar3, _) => () } diff --git a/tests/pos/i12476.scala b/tests/pos/i12476.scala new file mode 100644 index 000000000000..1509f82cdc5f --- /dev/null +++ b/tests/pos/i12476.scala @@ -0,0 +1,10 @@ +object test { + def foo[A, B](m: B) = { + m match { + case _: A => + m match { + case _: B => // crash with -Yno-deep-subtypes + } + } + } +}