Skip to content

Deskolemize PatternTypeConstrainer #12506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 7, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 82 additions & 22 deletions compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ trait PatternTypeConstrainer { self: TypeComparer =>
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
* in which case the subtyping relationship "heals" the type.
*/
def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {

def classesMayBeCompatible: Boolean = {
import Flags._
val patClassSym = pat.classSymbol
val scrutClassSym = scrut.classSymbol
!patClassSym.exists || !scrutClassSym.exists || {
if (patClassSym.is(Final)) patClassSym.derivesFrom(scrutClassSym)
else if (scrutClassSym.is(Final)) scrutClassSym.derivesFrom(patClassSym)
else if (!patClassSym.is(Flags.Trait) && !scrutClassSym.is(Flags.Trait))
patClassSym.derivesFrom(scrutClassSym) || scrutClassSym.derivesFrom(patClassSym)
val patCls = pat.classSymbol
val scrCls = scrut.classSymbol
!patCls.exists || !scrCls.exists || {
if (patCls.is(Final)) patCls.derivesFrom(scrCls)
else if (scrCls.is(Final)) scrCls.derivesFrom(patCls)
else if (!patCls.is(Flags.Trait) && !scrCls.is(Flags.Trait))
patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls)
else true
}
}
Expand All @@ -93,6 +93,14 @@ trait PatternTypeConstrainer { self: TypeComparer =>
case tp => tp
}

def tryConstrainSimplePatternType(pat: Type, scrut: Type) = {
val patCls = pat.classSymbol
val scrCls = scrut.classSymbol
patCls.exists && scrCls.exists
&& (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls))
&& constrainSimplePatternType(pat, scrut, forceInvariantRefinement)
}

def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) {
// Fold a list of types into an AndType
def buildAndType(xs: List[Type]): Type = {
Expand All @@ -113,15 +121,15 @@ trait PatternTypeConstrainer { self: TypeComparer =>
val andType = buildAndType(parents)
!andType.exists || constrainPatternType(pat, andType)
case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass =>
val patClassSym = pat.classSymbol
val patCls = pat.classSymbol
// find all shared parents in the inheritance hierarchy between pat and scrut
def allParentsSharedWithPat(tp: Type, tpClassSym: ClassSymbol): List[Symbol] = {
var parents = tpClassSym.info.parents
if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then
parents = parents.tail
parents flatMap { tp =>
val sym = tp.classSymbol.asClass
if patClassSym.derivesFrom(sym) then List(sym)
if patCls.derivesFrom(sym) then List(sym)
else allParentsSharedWithPat(tp, sym)
}
}
Expand All @@ -135,27 +143,39 @@ trait PatternTypeConstrainer { self: TypeComparer =>
case _ => NoType
}
if (upcasted.exists)
constrainSimplePatternType(pat, upcasted, widenParams) || constrainUpcasted(upcasted)
tryConstrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted)
else true
}
}

scrut.dealias match {
def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match {
case tp: TermRef =>
// we drop TermRefs that don't have a class symbol, as they can't
// meaningfully participate in GADT reasoning and just get in the way.
// Their info could, for an example, be an AndType. One example where
// this is important is an enum case that extends its parent and an
// additional trait - argument-less enum cases desugar to vals.
if tp.classSymbol.exists then tp else tp.info
case tp => tp
}

dealiasDropNonmoduleRefs(scrut) match {
case OrType(scrut1, scrut2) =>
either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
case AndType(scrut1, scrut2) =>
constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
case scrut: RefinedOrRecType =>
constrainPatternType(pat, stripRefinement(scrut))
case scrut => pat.dealias match {
case scrut => dealiasDropNonmoduleRefs(pat) match {
case OrType(pat1, pat2) =>
either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
case AndType(pat1, pat2) =>
constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
case pat: RefinedOrRecType =>
constrainPatternType(stripRefinement(pat), scrut)
case pat =>
constrainSimplePatternType(pat, scrut, widenParams) || classesMayBeCompatible && constrainUpcasted(scrut)
tryConstrainSimplePatternType(pat, scrut)
|| classesMayBeCompatible && constrainUpcasted(scrut)
}
}
}
Expand Down Expand Up @@ -194,7 +214,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
* case classes without also appropriately extending the relevant case class
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
*/
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, widenParams: Boolean): Boolean = {
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
def refinementIsInvariant(tp: Type): Boolean = tp match {
case tp: SingletonType => true
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
Expand All @@ -212,13 +232,53 @@ trait PatternTypeConstrainer { self: TypeComparer =>
tp
}

val widePt =
if migrateTo3 || refinementIsInvariant(patternTp) then scrutineeTp
else if widenParams then widenVariantParams(scrutineeTp)
else scrutineeTp
val narrowTp = SkolemType(patternTp)
trace(i"constraining simple pattern type $narrowTp <:< $widePt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
isSubType(narrowTp, widePt)
val patternCls = patternTp.classSymbol
val scrutineeCls = scrutineeTp.classSymbol

// NOTE: we already know that there is a derives-from relationship in either direction
val upcastPattern =
patternCls.derivesFrom(scrutineeCls)

val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
val tp = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp

val assumeInvariantRefinement =
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)

trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
(tp, pt) match {
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
val saved = state.constraint
val savedGadt = ctx.gadt.fresh
val result =
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
val variance = param.paramVarianceSign
if variance != 0 && !assumeInvariantRefinement then true
else if argS.isInstanceOf[TypeBounds] || argP.isInstanceOf[TypeBounds] then
// Passing TypeBounds to isSubType on LHS or RHS does the
// incorrect thing and infers unsound constraints, while simply
// returning true is sound. However, I believe that it should
// still be possible to extract useful constraints here.
// TODO extract GADT information out of wildcard type arguments
true
else {
var res = true
if variance < 1 then res &&= isSubType(argS, argP)
if variance > -1 then res &&= isSubType(argP, argS)
res
}
}
if !result then
constraint = saved
ctx.gadt.restore(savedGadt)
result
case _ =>
// Give up if we don't get AppliedType, e.g. if we upcasted to Any.
// Note that this doesn't mean that patternTp, scrutineeTp cannot possibly
// be co-inhabited, just that we cannot extract information out of them directly
// and should upcast.
false
}
}
}
}
13 changes: 8 additions & 5 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1275,14 +1275,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
else if tp1 eq tp2 then true
else
val saved = constraint
val savedGadt = ctx.gadt.fresh
inline def restore() =
state.constraint = saved
ctx.gadt.restore(savedGadt)
val savedSuccessCount = successCount
try
recCount += 1
if recCount >= Config.LogPendingSubTypesThreshold then monitored = true
val result = if monitored then monitoredIsSubType else firstTry
recCount -= 1
if !result then
state.constraint = saved
if !result then restore()
else if recCount == 0 && needsGc then
state.gc()
needsGc = false
Expand All @@ -1291,7 +1294,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
catch case NonFatal(ex) =>
if ex.isInstanceOf[AssertionError] then showGoal(tp1, tp2)
recCount -= 1
state.constraint = saved
restore()
successCount = savedSuccessCount
throw ex
}
Expand Down Expand Up @@ -2763,8 +2766,8 @@ object TypeComparer {
def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
comparing(_.dropTransparentTraits(tp, bound))

def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true)(using Context): Boolean =
comparing(_.constrainPatternType(pat, scrut, widenParams))
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean =
comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))

def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:")(using Context): String =
comparing(_.explained(op, header))
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ object TypeTestsCasts {
//
// If we perform widening, we will get X = Nothing, and we don't have
// Ident[X] <:< Ident[Int] any more.
TypeComparer.constrainPatternType(P1, X, widenParams = false)
debug.println(TypeComparer.explained(_.constrainPatternType(P1, X, widenParams = false)))
TypeComparer.constrainPatternType(P1, X, forceInvariantRefinement = true)
debug.println(
TypeComparer.explained(_.constrainPatternType(P1, X, forceInvariantRefinement = true))
)
}

// Maximization of the type means we try to cover all possible values
Expand Down
8 changes: 7 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3840,9 +3840,15 @@ class Typer extends Namer

// approximate type params with bounds
def approx = new ApproximatingTypeMap {
var alreadyExpanding: List[TypeRef] = Nil
def apply(tp: Type) = tp.dealias match
case tp: TypeRef if !tp.symbol.isClass =>
expandBounds(tp.info.bounds)
if alreadyExpanding contains tp then tp else
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How is this connected to the rest of the changes? Did the rest trigger a SO here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the source of the SO that happened after restoring GADT constraints on failed subtype checks.

val saved = alreadyExpanding
alreadyExpanding ::= tp
val res = expandBounds(tp.info.bounds)
alreadyExpanding = saved
res
case _ =>
mapOver(tp)
}
Expand Down
13 changes: 13 additions & 0 deletions tests/neg/gadt-contradictory-pattern.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
object Test {
sealed abstract class Foo[T]
case object Bar1 extends Foo[Int]
case object Bar2 extends Foo[String]
case object Bar3 extends Foo[AnyRef]

def fail4[T <: AnyRef](xx: (Foo[T], Foo[T])) = xx match {
case (Bar1, Bar1) => () // error // error
case (Bar2, Bar3) => ()
case (Bar3, _) => ()
}

}
16 changes: 16 additions & 0 deletions tests/neg/i11103.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
@main def test: Unit = {
class Foo
class Bar

trait UpBnd[+A]
trait P extends UpBnd[Foo]

def pmatch[A, T <: UpBnd[A]](s: T): A = s match {
case p: P =>
new Foo // error
}

class UpBndAndB extends UpBnd[Bar] with P
// ClassCastException: Foo cannot be cast to Bar
val x = pmatch(new UpBndAndB)
}
2 changes: 1 addition & 1 deletion tests/pos/i9740c.scala → tests/neg/i9740c.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ class Foo {
def bar[A <: Txn[A]](x: Exp[A]): Unit = x match
case IntExp(x) =>
case StrExp(x) =>
case UnitExp =>
case UnitExp => // error
case Obj(o) =>
}
4 changes: 2 additions & 2 deletions tests/pos/i9740b.scala → tests/neg/i9740d.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ class Foo[U <: Int, T <: U] {
def bar[A <: T](x: Exp[A]): Unit = x match
case IntExp(x) =>
case StrExp(x) =>
case UnitExp =>
}
case UnitExp => // error
}
2 changes: 1 addition & 1 deletion tests/patmat/exhausting.check
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
32: Pattern Match Exhaustivity: List(_, _*)
39: Pattern Match Exhaustivity: Bar3
44: Pattern Match Exhaustivity: (Bar2, Bar2)
50: Pattern Match Exhaustivity: (Bar2, Bar2)
49: Pattern Match Exhaustivity: (Bar2, Bar2)
1 change: 0 additions & 1 deletion tests/patmat/exhausting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ object Test {
}
// fails for: (Bar2, Bar2)
def fail4[T <: AnyRef](xx: (Foo[T], Foo[T])) = xx match {
case (Bar1, Bar1) => ()
case (Bar2, Bar3) => ()
case (Bar3, _) => ()
}
Expand Down
10 changes: 10 additions & 0 deletions tests/pos/i12476.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
object test {
def foo[A, B](m: B) = {
m match {
case _: A =>
m match {
case _: B => // crash with -Yno-deep-subtypes
}
}
}
}