@@ -73,17 +73,17 @@ trait PatternTypeConstrainer { self: TypeComparer =>
73
73
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
74
74
* in which case the subtyping relationship "heals" the type.
75
75
*/
76
- def constrainPatternType (pat : Type , scrut : Type , widenParams : Boolean = true ): Boolean = trace(i " constrainPatternType( $scrut, $pat) " , gadts) {
76
+ def constrainPatternType (pat : Type , scrut : Type , forceInvariantRefinement : Boolean = false ): Boolean = trace(i " constrainPatternType( $scrut, $pat) " , gadts) {
77
77
78
78
def classesMayBeCompatible : Boolean = {
79
79
import Flags ._
80
- val patClassSym = pat.classSymbol
81
- val scrutClassSym = scrut.classSymbol
82
- ! patClassSym .exists || ! scrutClassSym .exists || {
83
- if (patClassSym .is(Final )) patClassSym .derivesFrom(scrutClassSym )
84
- else if (scrutClassSym .is(Final )) scrutClassSym .derivesFrom(patClassSym )
85
- else if (! patClassSym .is(Flags .Trait ) && ! scrutClassSym .is(Flags .Trait ))
86
- patClassSym .derivesFrom(scrutClassSym ) || scrutClassSym .derivesFrom(patClassSym )
80
+ val patCls = pat.classSymbol
81
+ val scrCls = scrut.classSymbol
82
+ ! patCls .exists || ! scrCls .exists || {
83
+ if (patCls .is(Final )) patCls .derivesFrom(scrCls )
84
+ else if (scrCls .is(Final )) scrCls .derivesFrom(patCls )
85
+ else if (! patCls .is(Flags .Trait ) && ! scrCls .is(Flags .Trait ))
86
+ patCls .derivesFrom(scrCls ) || scrCls .derivesFrom(patCls )
87
87
else true
88
88
}
89
89
}
@@ -93,6 +93,14 @@ trait PatternTypeConstrainer { self: TypeComparer =>
93
93
case tp => tp
94
94
}
95
95
96
+ def tryConstrainSimplePatternType (pat : Type , scrut : Type ) = {
97
+ val patCls = pat.classSymbol
98
+ val scrCls = scrut.classSymbol
99
+ patCls.exists && scrCls.exists
100
+ && (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls))
101
+ && constrainSimplePatternType(pat, scrut, forceInvariantRefinement)
102
+ }
103
+
96
104
def constrainUpcasted (scrut : Type ): Boolean = trace(i " constrainUpcasted( $scrut) " , gadts) {
97
105
// Fold a list of types into an AndType
98
106
def buildAndType (xs : List [Type ]): Type = {
@@ -113,15 +121,15 @@ trait PatternTypeConstrainer { self: TypeComparer =>
113
121
val andType = buildAndType(parents)
114
122
! andType.exists || constrainPatternType(pat, andType)
115
123
case scrut @ AppliedType (tycon : TypeRef , _) if tycon.symbol.isClass =>
116
- val patClassSym = pat.classSymbol
124
+ val patCls = pat.classSymbol
117
125
// find all shared parents in the inheritance hierarchy between pat and scrut
118
126
def allParentsSharedWithPat (tp : Type , tpClassSym : ClassSymbol ): List [Symbol ] = {
119
127
var parents = tpClassSym.info.parents
120
128
if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then
121
129
parents = parents.tail
122
130
parents flatMap { tp =>
123
131
val sym = tp.classSymbol.asClass
124
- if patClassSym .derivesFrom(sym) then List (sym)
132
+ if patCls .derivesFrom(sym) then List (sym)
125
133
else allParentsSharedWithPat(tp, sym)
126
134
}
127
135
}
@@ -135,27 +143,39 @@ trait PatternTypeConstrainer { self: TypeComparer =>
135
143
case _ => NoType
136
144
}
137
145
if (upcasted.exists)
138
- constrainSimplePatternType (pat, upcasted, widenParams ) || constrainUpcasted(upcasted)
146
+ tryConstrainSimplePatternType (pat, upcasted) || constrainUpcasted(upcasted)
139
147
else true
140
148
}
141
149
}
142
150
143
- scrut.dealias match {
151
+ def dealiasDropNonmoduleRefs (tp : Type ) = tp.dealias match {
152
+ case tp : TermRef =>
153
+ // we drop TermRefs that don't have a class symbol, as they can't
154
+ // meaningfully participate in GADT reasoning and just get in the way.
155
+ // Their info could, for an example, be an AndType. One example where
156
+ // this is important is an enum case that extends its parent and an
157
+ // additional trait - argument-less enum cases desugar to vals.
158
+ if tp.classSymbol.exists then tp else tp.info
159
+ case tp => tp
160
+ }
161
+
162
+ dealiasDropNonmoduleRefs(scrut) match {
144
163
case OrType (scrut1, scrut2) =>
145
164
either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
146
165
case AndType (scrut1, scrut2) =>
147
166
constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
148
167
case scrut : RefinedOrRecType =>
149
168
constrainPatternType(pat, stripRefinement(scrut))
150
- case scrut => pat.dealias match {
169
+ case scrut => dealiasDropNonmoduleRefs( pat) match {
151
170
case OrType (pat1, pat2) =>
152
171
either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
153
172
case AndType (pat1, pat2) =>
154
173
constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
155
174
case pat : RefinedOrRecType =>
156
175
constrainPatternType(stripRefinement(pat), scrut)
157
176
case pat =>
158
- constrainSimplePatternType(pat, scrut, widenParams) || classesMayBeCompatible && constrainUpcasted(scrut)
177
+ tryConstrainSimplePatternType(pat, scrut)
178
+ || classesMayBeCompatible && constrainUpcasted(scrut)
159
179
}
160
180
}
161
181
}
@@ -194,7 +214,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
194
214
* case classes without also appropriately extending the relevant case class
195
215
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
196
216
*/
197
- def constrainSimplePatternType (patternTp : Type , scrutineeTp : Type , widenParams : Boolean ): Boolean = {
217
+ def constrainSimplePatternType (patternTp : Type , scrutineeTp : Type , forceInvariantRefinement : Boolean ): Boolean = {
198
218
def refinementIsInvariant (tp : Type ): Boolean = tp match {
199
219
case tp : SingletonType => true
200
220
case tp : ClassInfo => tp.cls.is(Final ) || tp.cls.is(Case )
@@ -212,13 +232,44 @@ trait PatternTypeConstrainer { self: TypeComparer =>
212
232
tp
213
233
}
214
234
215
- val widePt =
216
- if migrateTo3 || refinementIsInvariant(patternTp) then scrutineeTp
217
- else if widenParams then widenVariantParams(scrutineeTp)
218
- else scrutineeTp
219
- val narrowTp = SkolemType (patternTp)
220
- trace(i " constraining simple pattern type $narrowTp <:< $widePt" , gadts, res => s " $res\n gadt = ${ctx.gadt.debugBoundsDescription}" ) {
221
- isSubType(narrowTp, widePt)
235
+ val patternCls = patternTp.classSymbol
236
+ val scrutineeCls = scrutineeTp.classSymbol
237
+
238
+ // NOTE: we already know that there is a derives-from relationship in either direction
239
+ val upcastPattern =
240
+ patternCls.derivesFrom(scrutineeCls)
241
+
242
+ val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
243
+ val tp = if ! upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
244
+
245
+ val assumeInvariantRefinement =
246
+ migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
247
+
248
+ trace(i " constraining simple pattern type $tp >:< $pt" , gadts, res => s " $res\n gadt = ${ctx.gadt.debugBoundsDescription}" ) {
249
+ (tp, pt) match {
250
+ case (AppliedType (tyconS, argsS), AppliedType (tyconP, argsP)) =>
251
+ val saved = state.constraint
252
+ val savedGadt = ctx.gadt.fresh
253
+ val result =
254
+ tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
255
+ val variance = param.paramVarianceSign
256
+ if variance != 0 && ! assumeInvariantRefinement then true
257
+ else if argS.isInstanceOf [TypeBounds ] || argP.isInstanceOf [TypeBounds ] then true
258
+ else {
259
+ var res = true
260
+ if variance < 1 then res &&= isSubType(argS, argP)
261
+ if variance > - 1 then res &&= isSubType(argP, argS)
262
+ res
263
+ }
264
+ }
265
+ if ! result then
266
+ constraint = saved
267
+ ctx.gadt.restore(savedGadt)
268
+ result
269
+ case _ =>
270
+ // give up if we don't get AppliedType, e.g. if we upcasted to Any.
271
+ false
272
+ }
222
273
}
223
274
}
224
275
}
0 commit comments