Skip to content

Commit dd3f2fb

Browse files
authored
Merge pull request #12506 from dotty-staging/deskolemize-gadts
Deskolemize PatternTypeConstrainer
2 parents f1252d8 + cb10bb7 commit dd3f2fb

11 files changed

+153
-43
lines changed

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 91 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ trait PatternTypeConstrainer { self: TypeComparer =>
7373
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
7474
* in which case the subtyping relationship "heals" the type.
7575
*/
76-
def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
76+
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
7777

7878
def classesMayBeCompatible: Boolean = {
7979
import Flags._
80-
val patClassSym = pat.classSymbol
81-
val scrutClassSym = scrut.classSymbol
82-
!patClassSym.exists || !scrutClassSym.exists || {
83-
if (patClassSym.is(Final)) patClassSym.derivesFrom(scrutClassSym)
84-
else if (scrutClassSym.is(Final)) scrutClassSym.derivesFrom(patClassSym)
85-
else if (!patClassSym.is(Flags.Trait) && !scrutClassSym.is(Flags.Trait))
86-
patClassSym.derivesFrom(scrutClassSym) || scrutClassSym.derivesFrom(patClassSym)
80+
val patCls = pat.classSymbol
81+
val scrCls = scrut.classSymbol
82+
!patCls.exists || !scrCls.exists || {
83+
if (patCls.is(Final)) patCls.derivesFrom(scrCls)
84+
else if (scrCls.is(Final)) scrCls.derivesFrom(patCls)
85+
else if (!patCls.is(Flags.Trait) && !scrCls.is(Flags.Trait))
86+
patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls)
8787
else true
8888
}
8989
}
@@ -93,6 +93,14 @@ trait PatternTypeConstrainer { self: TypeComparer =>
9393
case tp => tp
9494
}
9595

96+
def tryConstrainSimplePatternType(pat: Type, scrut: Type) = {
97+
val patCls = pat.classSymbol
98+
val scrCls = scrut.classSymbol
99+
patCls.exists && scrCls.exists
100+
&& (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls))
101+
&& constrainSimplePatternType(pat, scrut, forceInvariantRefinement)
102+
}
103+
96104
def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) {
97105
// Fold a list of types into an AndType
98106
def buildAndType(xs: List[Type]): Type = {
@@ -113,15 +121,15 @@ trait PatternTypeConstrainer { self: TypeComparer =>
113121
val andType = buildAndType(parents)
114122
!andType.exists || constrainPatternType(pat, andType)
115123
case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass =>
116-
val patClassSym = pat.classSymbol
124+
val patCls = pat.classSymbol
117125
// find all shared parents in the inheritance hierarchy between pat and scrut
118126
def allParentsSharedWithPat(tp: Type, tpClassSym: ClassSymbol): List[Symbol] = {
119127
var parents = tpClassSym.info.parents
120128
if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then
121129
parents = parents.tail
122130
parents flatMap { tp =>
123131
val sym = tp.classSymbol.asClass
124-
if patClassSym.derivesFrom(sym) then List(sym)
132+
if patCls.derivesFrom(sym) then List(sym)
125133
else allParentsSharedWithPat(tp, sym)
126134
}
127135
}
@@ -135,42 +143,55 @@ trait PatternTypeConstrainer { self: TypeComparer =>
135143
case _ => NoType
136144
}
137145
if (upcasted.exists)
138-
constrainSimplePatternType(pat, upcasted, widenParams) || constrainUpcasted(upcasted)
146+
tryConstrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted)
139147
else true
140148
}
141149
}
142150

143-
scrut.dealias match {
151+
def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match {
152+
case tp: TermRef =>
153+
// we drop TermRefs that don't have a class symbol, as they can't
154+
// meaningfully participate in GADT reasoning and just get in the way.
155+
// Their info could, for an example, be an AndType. One example where
156+
// this is important is an enum case that extends its parent and an
157+
// additional trait - argument-less enum cases desugar to vals.
158+
// See run/enum-Tree.scala.
159+
if tp.classSymbol.exists then tp else tp.info
160+
case tp => tp
161+
}
162+
163+
dealiasDropNonmoduleRefs(scrut) match {
144164
case OrType(scrut1, scrut2) =>
145165
either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
146166
case AndType(scrut1, scrut2) =>
147167
constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
148168
case scrut: RefinedOrRecType =>
149169
constrainPatternType(pat, stripRefinement(scrut))
150-
case scrut => pat.dealias match {
170+
case scrut => dealiasDropNonmoduleRefs(pat) match {
151171
case OrType(pat1, pat2) =>
152172
either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
153173
case AndType(pat1, pat2) =>
154174
constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
155175
case pat: RefinedOrRecType =>
156176
constrainPatternType(stripRefinement(pat), scrut)
157177
case pat =>
158-
constrainSimplePatternType(pat, scrut, widenParams) || classesMayBeCompatible && constrainUpcasted(scrut)
178+
tryConstrainSimplePatternType(pat, scrut)
179+
|| classesMayBeCompatible && constrainUpcasted(scrut)
159180
}
160181
}
161182
}
162183

163184
/** Constrain "simple" patterns (see `constrainPatternType`).
164185
*
165-
* This function attempts to modify pattern and scrutinee type s.t. the pattern must be a subtype of the scrutinee,
166-
* or otherwise it cannot possibly match. In order to do that, we:
167-
*
168-
* 1. Rely on `constrainPatternType` to break the actual scrutinee/pattern types into subcomponents
169-
* 2. Widen type parameters of scrutinee type that are not invariantly refined (see below) by the pattern type.
170-
* 3. Wrap the pattern type in a skolem to avoid overconstraining top-level abstract types in scrutinee type
171-
* 4. Check that `WidenedScrutineeType <: NarrowedPatternType`
186+
* This function expects to receive two types (scrutinee and pattern), both
187+
* of which have class symbols, one of which is derived from another. If the
188+
* type "being derived from" is an applied type, it will 1) "upcast" the
189+
* deriving type to an applied type with the same constructor and 2) infer
190+
* constraints for the applied types' arguments that follow from both
191+
* types being inhabited by one value (the scrutinee).
172192
*
173-
* Importantly, note that the pattern type may contain type variables.
193+
* Importantly, note that the pattern type may contain type variables, which
194+
* are used to infer type arguments to Unapply trees.
174195
*
175196
* ## Invariant refinement
176197
* Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
@@ -194,7 +215,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
194215
* case classes without also appropriately extending the relevant case class
195216
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
196217
*/
197-
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, widenParams: Boolean): Boolean = {
218+
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
198219
def refinementIsInvariant(tp: Type): Boolean = tp match {
199220
case tp: SingletonType => true
200221
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
@@ -212,13 +233,53 @@ trait PatternTypeConstrainer { self: TypeComparer =>
212233
tp
213234
}
214235

215-
val widePt =
216-
if migrateTo3 || refinementIsInvariant(patternTp) then scrutineeTp
217-
else if widenParams then widenVariantParams(scrutineeTp)
218-
else scrutineeTp
219-
val narrowTp = SkolemType(patternTp)
220-
trace(i"constraining simple pattern type $narrowTp <:< $widePt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
221-
isSubType(narrowTp, widePt)
236+
val patternCls = patternTp.classSymbol
237+
val scrutineeCls = scrutineeTp.classSymbol
238+
239+
// NOTE: we already know that there is a derives-from relationship in either direction
240+
val upcastPattern =
241+
patternCls.derivesFrom(scrutineeCls)
242+
243+
val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
244+
val tp = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
245+
246+
val assumeInvariantRefinement =
247+
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
248+
249+
trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
250+
(tp, pt) match {
251+
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
252+
val saved = state.constraint
253+
val savedGadt = ctx.gadt.fresh
254+
val result =
255+
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
256+
val variance = param.paramVarianceSign
257+
if variance != 0 && !assumeInvariantRefinement then true
258+
else if argS.isInstanceOf[TypeBounds] || argP.isInstanceOf[TypeBounds] then
259+
// Passing TypeBounds to isSubType on LHS or RHS does the
260+
// incorrect thing and infers unsound constraints, while simply
261+
// returning true is sound. However, I believe that it should
262+
// still be possible to extract useful constraints here.
263+
// TODO extract GADT information out of wildcard type arguments
264+
true
265+
else {
266+
var res = true
267+
if variance < 1 then res &&= isSubType(argS, argP)
268+
if variance > -1 then res &&= isSubType(argP, argS)
269+
res
270+
}
271+
}
272+
if !result then
273+
constraint = saved
274+
ctx.gadt.restore(savedGadt)
275+
result
276+
case _ =>
277+
// Give up if we don't get AppliedType, e.g. if we upcasted to Any.
278+
// Note that this doesn't mean that patternTp, scrutineeTp cannot possibly
279+
// be co-inhabited, just that we cannot extract information out of them directly
280+
// and should upcast.
281+
false
282+
}
222283
}
223284
}
224285
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,14 +1275,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
12751275
else if tp1 eq tp2 then true
12761276
else
12771277
val saved = constraint
1278+
val savedGadt = ctx.gadt.fresh
1279+
inline def restore() =
1280+
state.constraint = saved
1281+
ctx.gadt.restore(savedGadt)
12781282
val savedSuccessCount = successCount
12791283
try
12801284
recCount += 1
12811285
if recCount >= Config.LogPendingSubTypesThreshold then monitored = true
12821286
val result = if monitored then monitoredIsSubType else firstTry
12831287
recCount -= 1
1284-
if !result then
1285-
state.constraint = saved
1288+
if !result then restore()
12861289
else if recCount == 0 && needsGc then
12871290
state.gc()
12881291
needsGc = false
@@ -1291,7 +1294,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
12911294
catch case NonFatal(ex) =>
12921295
if ex.isInstanceOf[AssertionError] then showGoal(tp1, tp2)
12931296
recCount -= 1
1294-
state.constraint = saved
1297+
restore()
12951298
successCount = savedSuccessCount
12961299
throw ex
12971300
}
@@ -2772,8 +2775,8 @@ object TypeComparer {
27722775
def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
27732776
comparing(_.dropTransparentTraits(tp, bound))
27742777

2775-
def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true)(using Context): Boolean =
2776-
comparing(_.constrainPatternType(pat, scrut, widenParams))
2778+
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean =
2779+
comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))
27772780

27782781
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:")(using Context): String =
27792782
comparing(_.explained(op, header))

compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,10 @@ object TypeTestsCasts {
9898
//
9999
// If we perform widening, we will get X = Nothing, and we don't have
100100
// Ident[X] <:< Ident[Int] any more.
101-
TypeComparer.constrainPatternType(P1, X, widenParams = false)
102-
debug.println(TypeComparer.explained(_.constrainPatternType(P1, X, widenParams = false)))
101+
TypeComparer.constrainPatternType(P1, X, forceInvariantRefinement = true)
102+
debug.println(
103+
TypeComparer.explained(_.constrainPatternType(P1, X, forceInvariantRefinement = true))
104+
)
103105
}
104106

105107
// Maximization of the type means we try to cover all possible values

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3851,9 +3851,15 @@ class Typer extends Namer
38513851

38523852
// approximate type params with bounds
38533853
def approx = new ApproximatingTypeMap {
3854+
var alreadyExpanding: List[TypeRef] = Nil
38543855
def apply(tp: Type) = tp.dealias match
38553856
case tp: TypeRef if !tp.symbol.isClass =>
3856-
expandBounds(tp.info.bounds)
3857+
if alreadyExpanding contains tp then tp else
3858+
val saved = alreadyExpanding
3859+
alreadyExpanding ::= tp
3860+
val res = expandBounds(tp.info.bounds)
3861+
alreadyExpanding = saved
3862+
res
38573863
case _ =>
38583864
mapOver(tp)
38593865
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
object Test {
2+
sealed abstract class Foo[T]
3+
case object Bar1 extends Foo[Int]
4+
case object Bar2 extends Foo[String]
5+
case object Bar3 extends Foo[AnyRef]
6+
7+
def fail4[T <: AnyRef](xx: (Foo[T], Foo[T])) = xx match {
8+
case (Bar1, Bar1) => () // error // error
9+
case (Bar2, Bar3) => ()
10+
case (Bar3, _) => ()
11+
}
12+
13+
}

tests/neg/i11103.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
@main def test: Unit = {
2+
class Foo
3+
class Bar
4+
5+
trait UpBnd[+A]
6+
trait P extends UpBnd[Foo]
7+
8+
def pmatch[A, T <: UpBnd[A]](s: T): A = s match {
9+
case p: P =>
10+
new Foo // error
11+
}
12+
13+
class UpBndAndB extends UpBnd[Bar] with P
14+
// ClassCastException: Foo cannot be cast to Bar
15+
val x = pmatch(new UpBndAndB)
16+
}

tests/pos/i9740c.scala renamed to tests/neg/i9740c.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ class Foo {
1111
def bar[A <: Txn[A]](x: Exp[A]): Unit = x match
1212
case IntExp(x) =>
1313
case StrExp(x) =>
14-
case UnitExp =>
14+
case UnitExp => // error
1515
case Obj(o) =>
1616
}

tests/pos/i9740b.scala renamed to tests/neg/i9740d.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ class Foo[U <: Int, T <: U] {
77
def bar[A <: T](x: Exp[A]): Unit = x match
88
case IntExp(x) =>
99
case StrExp(x) =>
10-
case UnitExp =>
11-
}
10+
case UnitExp => // error
11+
}

tests/patmat/exhausting.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
32: Pattern Match Exhaustivity: List(_, _*)
44
39: Pattern Match Exhaustivity: Bar3
55
44: Pattern Match Exhaustivity: (Bar2, Bar2)
6-
50: Pattern Match Exhaustivity: (Bar2, Bar2)
6+
49: Pattern Match Exhaustivity: (Bar2, Bar2)

tests/patmat/exhausting.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ object Test {
4242
}
4343
// fails for: (Bar2, Bar2)
4444
def fail4[T <: AnyRef](xx: (Foo[T], Foo[T])) = xx match {
45-
case (Bar1, Bar1) => ()
4645
case (Bar2, Bar3) => ()
4746
case (Bar3, _) => ()
4847
}

tests/pos/i12476.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object test {
2+
def foo[A, B](m: B) = {
3+
m match {
4+
case _: A =>
5+
m match {
6+
case _: B => // crash with -Yno-deep-subtypes
7+
}
8+
}
9+
}
10+
}

0 commit comments

Comments
 (0)