Skip to content

Commit a597fbd

Browse files
committed
Deskolemize PatternTypeConstrainer
Also, add restoring the GADT constraint to TypeComparer.
1 parent 83e17f1 commit a597fbd

File tree

8 files changed

+121
-33
lines changed

8 files changed

+121
-33
lines changed

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -73,17 +73,17 @@ trait PatternTypeConstrainer { self: TypeComparer =>
7373
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
7474
* in which case the subtyping relationship "heals" the type.
7575
*/
76-
def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
76+
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false): Boolean = trace(i"constrainPatternType($scrut, $pat)", gadts) {
7777

7878
def classesMayBeCompatible: Boolean = {
7979
import Flags._
80-
val patClassSym = pat.classSymbol
81-
val scrutClassSym = scrut.classSymbol
82-
!patClassSym.exists || !scrutClassSym.exists || {
83-
if (patClassSym.is(Final)) patClassSym.derivesFrom(scrutClassSym)
84-
else if (scrutClassSym.is(Final)) scrutClassSym.derivesFrom(patClassSym)
85-
else if (!patClassSym.is(Flags.Trait) && !scrutClassSym.is(Flags.Trait))
86-
patClassSym.derivesFrom(scrutClassSym) || scrutClassSym.derivesFrom(patClassSym)
80+
val patCls = pat.classSymbol
81+
val scrCls = scrut.classSymbol
82+
!patCls.exists || !scrCls.exists || {
83+
if (patCls.is(Final)) patCls.derivesFrom(scrCls)
84+
else if (scrCls.is(Final)) scrCls.derivesFrom(patCls)
85+
else if (!patCls.is(Flags.Trait) && !scrCls.is(Flags.Trait))
86+
patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls)
8787
else true
8888
}
8989
}
@@ -93,6 +93,14 @@ trait PatternTypeConstrainer { self: TypeComparer =>
9393
case tp => tp
9494
}
9595

96+
def tryConstrainSimplePatternType(pat: Type, scrut: Type) = {
97+
val patCls = pat.classSymbol
98+
val scrCls = scrut.classSymbol
99+
patCls.exists && scrCls.exists
100+
&& (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls))
101+
&& constrainSimplePatternType(pat, scrut, forceInvariantRefinement)
102+
}
103+
96104
def constrainUpcasted(scrut: Type): Boolean = trace(i"constrainUpcasted($scrut)", gadts) {
97105
// Fold a list of types into an AndType
98106
def buildAndType(xs: List[Type]): Type = {
@@ -113,15 +121,15 @@ trait PatternTypeConstrainer { self: TypeComparer =>
113121
val andType = buildAndType(parents)
114122
!andType.exists || constrainPatternType(pat, andType)
115123
case scrut @ AppliedType(tycon: TypeRef, _) if tycon.symbol.isClass =>
116-
val patClassSym = pat.classSymbol
124+
val patCls = pat.classSymbol
117125
// find all shared parents in the inheritance hierarchy between pat and scrut
118126
def allParentsSharedWithPat(tp: Type, tpClassSym: ClassSymbol): List[Symbol] = {
119127
var parents = tpClassSym.info.parents
120128
if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then
121129
parents = parents.tail
122130
parents flatMap { tp =>
123131
val sym = tp.classSymbol.asClass
124-
if patClassSym.derivesFrom(sym) then List(sym)
132+
if patCls.derivesFrom(sym) then List(sym)
125133
else allParentsSharedWithPat(tp, sym)
126134
}
127135
}
@@ -135,27 +143,39 @@ trait PatternTypeConstrainer { self: TypeComparer =>
135143
case _ => NoType
136144
}
137145
if (upcasted.exists)
138-
constrainSimplePatternType(pat, upcasted, widenParams) || constrainUpcasted(upcasted)
146+
tryConstrainSimplePatternType(pat, upcasted) || constrainUpcasted(upcasted)
139147
else true
140148
}
141149
}
142150

143-
scrut.dealias match {
151+
def dealiasDropNonmoduleRefs(tp: Type) = tp.dealias match {
152+
case tp: TermRef =>
153+
// we drop TermRefs that don't have a class symbol, as they can't
154+
// meaningfully participate in GADT reasoning and just get in the way.
155+
// Their info could, for an example, be an AndType. One example where
156+
// this is important is an enum case that extends its parent and an
157+
// additional trait - argument-less enum cases desugar to vals.
158+
if tp.classSymbol.exists then tp else tp.info
159+
case tp => tp
160+
}
161+
162+
dealiasDropNonmoduleRefs(scrut) match {
144163
case OrType(scrut1, scrut2) =>
145164
either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
146165
case AndType(scrut1, scrut2) =>
147166
constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
148167
case scrut: RefinedOrRecType =>
149168
constrainPatternType(pat, stripRefinement(scrut))
150-
case scrut => pat.dealias match {
169+
case scrut => dealiasDropNonmoduleRefs(pat) match {
151170
case OrType(pat1, pat2) =>
152171
either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
153172
case AndType(pat1, pat2) =>
154173
constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
155174
case pat: RefinedOrRecType =>
156175
constrainPatternType(stripRefinement(pat), scrut)
157176
case pat =>
158-
constrainSimplePatternType(pat, scrut, widenParams) || classesMayBeCompatible && constrainUpcasted(scrut)
177+
tryConstrainSimplePatternType(pat, scrut)
178+
|| classesMayBeCompatible && constrainUpcasted(scrut)
159179
}
160180
}
161181
}
@@ -194,7 +214,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
194214
* case classes without also appropriately extending the relevant case class
195215
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
196216
*/
197-
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, widenParams: Boolean): Boolean = {
217+
def constrainSimplePatternType(patternTp: Type, scrutineeTp: Type, forceInvariantRefinement: Boolean): Boolean = {
198218
def refinementIsInvariant(tp: Type): Boolean = tp match {
199219
case tp: SingletonType => true
200220
case tp: ClassInfo => tp.cls.is(Final) || tp.cls.is(Case)
@@ -212,13 +232,44 @@ trait PatternTypeConstrainer { self: TypeComparer =>
212232
tp
213233
}
214234

215-
val widePt =
216-
if migrateTo3 || refinementIsInvariant(patternTp) then scrutineeTp
217-
else if widenParams then widenVariantParams(scrutineeTp)
218-
else scrutineeTp
219-
val narrowTp = SkolemType(patternTp)
220-
trace(i"constraining simple pattern type $narrowTp <:< $widePt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
221-
isSubType(narrowTp, widePt)
235+
val patternCls = patternTp.classSymbol
236+
val scrutineeCls = scrutineeTp.classSymbol
237+
238+
// NOTE: we already know that there is a derives-from relationship in either direction
239+
val upcastPattern =
240+
patternCls.derivesFrom(scrutineeCls)
241+
242+
val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
243+
val tp = if !upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
244+
245+
val assumeInvariantRefinement =
246+
migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
247+
248+
trace(i"constraining simple pattern type $tp >:< $pt", gadts, res => s"$res\ngadt = ${ctx.gadt.debugBoundsDescription}") {
249+
(tp, pt) match {
250+
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
251+
val saved = state.constraint
252+
val savedGadt = ctx.gadt.fresh
253+
val result =
254+
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
255+
val variance = param.paramVarianceSign
256+
if variance != 0 && !assumeInvariantRefinement then true
257+
else if argS.isInstanceOf[TypeBounds] || argP.isInstanceOf[TypeBounds] then true
258+
else {
259+
var res = true
260+
if variance < 1 then res &&= isSubType(argS, argP)
261+
if variance > -1 then res &&= isSubType(argP, argS)
262+
res
263+
}
264+
}
265+
if !result then
266+
constraint = saved
267+
ctx.gadt.restore(savedGadt)
268+
result
269+
case _ =>
270+
// give up if we don't get AppliedType, e.g. if we upcasted to Any.
271+
false
272+
}
222273
}
223274
}
224275
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,14 +1275,17 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
12751275
else if tp1 eq tp2 then true
12761276
else
12771277
val saved = constraint
1278+
val savedGadt = ctx.gadt.fresh
1279+
inline def restore() =
1280+
state.constraint = saved
1281+
ctx.gadt.restore(savedGadt)
12781282
val savedSuccessCount = successCount
12791283
try
12801284
recCount += 1
12811285
if recCount >= Config.LogPendingSubTypesThreshold then monitored = true
12821286
val result = if monitored then monitoredIsSubType else firstTry
12831287
recCount -= 1
1284-
if !result then
1285-
state.constraint = saved
1288+
if !result then restore()
12861289
else if recCount == 0 && needsGc then
12871290
state.gc()
12881291
needsGc = false
@@ -1291,7 +1294,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
12911294
catch case NonFatal(ex) =>
12921295
if ex.isInstanceOf[AssertionError] then showGoal(tp1, tp2)
12931296
recCount -= 1
1294-
state.constraint = saved
1297+
restore()
12951298
successCount = savedSuccessCount
12961299
throw ex
12971300
}
@@ -2763,8 +2766,8 @@ object TypeComparer {
27632766
def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
27642767
comparing(_.dropTransparentTraits(tp, bound))
27652768

2766-
def constrainPatternType(pat: Type, scrut: Type, widenParams: Boolean = true)(using Context): Boolean =
2767-
comparing(_.constrainPatternType(pat, scrut, widenParams))
2769+
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean =
2770+
comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))
27682771

27692772
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:")(using Context): String =
27702773
comparing(_.explained(op, header))

compiler/src/dotty/tools/dotc/transform/TypeTestsCasts.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,10 @@ object TypeTestsCasts {
9898
//
9999
// If we perform widening, we will get X = Nothing, and we don't have
100100
// Ident[X] <:< Ident[Int] any more.
101-
TypeComparer.constrainPatternType(P1, X, widenParams = false)
102-
debug.println(TypeComparer.explained(_.constrainPatternType(P1, X, widenParams = false)))
101+
TypeComparer.constrainPatternType(P1, X, forceInvariantRefinement = true)
102+
debug.println(
103+
TypeComparer.explained(_.constrainPatternType(P1, X, forceInvariantRefinement = true))
104+
)
103105
}
104106

105107
// Maximization of the type means we try to cover all possible values

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3840,9 +3840,15 @@ class Typer extends Namer
38403840

38413841
// approximate type params with bounds
38423842
def approx = new ApproximatingTypeMap {
3843+
var alreadyExpanding: List[TypeRef] = Nil
38433844
def apply(tp: Type) = tp.dealias match
38443845
case tp: TypeRef if !tp.symbol.isClass =>
3845-
expandBounds(tp.info.bounds)
3846+
if alreadyExpanding contains tp then tp else
3847+
val saved = alreadyExpanding
3848+
alreadyExpanding ::= tp
3849+
val res = expandBounds(tp.info.bounds)
3850+
alreadyExpanding = saved
3851+
res
38463852
case _ =>
38473853
mapOver(tp)
38483854
}

tests/neg/i11103.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
@main def test: Unit = {
2+
class Foo
3+
class Bar
4+
5+
trait UpBnd[+A]
6+
trait P extends UpBnd[Foo]
7+
8+
def pmatch[A, T <: UpBnd[A]](s: T): A = s match {
9+
case p: P =>
10+
new Foo // error
11+
}
12+
13+
class UpBndAndB extends UpBnd[Bar] with P
14+
// ClassCastException: Foo cannot be cast to Bar
15+
val x = pmatch(new UpBndAndB)
16+
}

tests/pos/i9740c.scala renamed to tests/neg/i9740c.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ class Foo {
1111
def bar[A <: Txn[A]](x: Exp[A]): Unit = x match
1212
case IntExp(x) =>
1313
case StrExp(x) =>
14-
case UnitExp =>
14+
case UnitExp => // error
1515
case Obj(o) =>
1616
}

tests/pos/i9740b.scala renamed to tests/neg/i9740d.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ class Foo[U <: Int, T <: U] {
77
def bar[A <: T](x: Exp[A]): Unit = x match
88
case IntExp(x) =>
99
case StrExp(x) =>
10-
case UnitExp =>
11-
}
10+
case UnitExp => // error
11+
}

tests/pos/i12476.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object test {
2+
def foo[A, B](m: B) = {
3+
m match {
4+
case _: A =>
5+
m match {
6+
case _: B => // crash with -Yno-deep-subtypes
7+
}
8+
}
9+
}
10+
}

0 commit comments

Comments
 (0)