@@ -73,17 +73,17 @@ trait PatternTypeConstrainer { self: TypeComparer =>
73
73
* scrutinee and pattern types. This does not apply if the pattern type is only applied to type variables,
74
74
* in which case the subtyping relationship "heals" the type.
75
75
*/
76
- def constrainPatternType (pat : Type , scrut : Type , widenParams : Boolean = true ): Boolean = trace(i " constrainPatternType( $scrut, $pat) " , gadts) {
76
+ def constrainPatternType (pat : Type , scrut : Type , forceInvariantRefinement : Boolean = false ): Boolean = trace(i " constrainPatternType( $scrut, $pat) " , gadts) {
77
77
78
78
def classesMayBeCompatible : Boolean = {
79
79
import Flags ._
80
- val patClassSym = pat.classSymbol
81
- val scrutClassSym = scrut.classSymbol
82
- ! patClassSym .exists || ! scrutClassSym .exists || {
83
- if (patClassSym .is(Final )) patClassSym .derivesFrom(scrutClassSym )
84
- else if (scrutClassSym .is(Final )) scrutClassSym .derivesFrom(patClassSym )
85
- else if (! patClassSym .is(Flags .Trait ) && ! scrutClassSym .is(Flags .Trait ))
86
- patClassSym .derivesFrom(scrutClassSym ) || scrutClassSym .derivesFrom(patClassSym )
80
+ val patCls = pat.classSymbol
81
+ val scrCls = scrut.classSymbol
82
+ ! patCls .exists || ! scrCls .exists || {
83
+ if (patCls .is(Final )) patCls .derivesFrom(scrCls )
84
+ else if (scrCls .is(Final )) scrCls .derivesFrom(patCls )
85
+ else if (! patCls .is(Flags .Trait ) && ! scrCls .is(Flags .Trait ))
86
+ patCls .derivesFrom(scrCls ) || scrCls .derivesFrom(patCls )
87
87
else true
88
88
}
89
89
}
@@ -93,6 +93,14 @@ trait PatternTypeConstrainer { self: TypeComparer =>
93
93
case tp => tp
94
94
}
95
95
96
+ def tryConstrainSimplePatternType (pat : Type , scrut : Type ) = {
97
+ val patCls = pat.classSymbol
98
+ val scrCls = scrut.classSymbol
99
+ patCls.exists && scrCls.exists
100
+ && (patCls.derivesFrom(scrCls) || scrCls.derivesFrom(patCls))
101
+ && constrainSimplePatternType(pat, scrut, forceInvariantRefinement)
102
+ }
103
+
96
104
def constrainUpcasted (scrut : Type ): Boolean = trace(i " constrainUpcasted( $scrut) " , gadts) {
97
105
// Fold a list of types into an AndType
98
106
def buildAndType (xs : List [Type ]): Type = {
@@ -113,15 +121,15 @@ trait PatternTypeConstrainer { self: TypeComparer =>
113
121
val andType = buildAndType(parents)
114
122
! andType.exists || constrainPatternType(pat, andType)
115
123
case scrut @ AppliedType (tycon : TypeRef , _) if tycon.symbol.isClass =>
116
- val patClassSym = pat.classSymbol
124
+ val patCls = pat.classSymbol
117
125
// find all shared parents in the inheritance hierarchy between pat and scrut
118
126
def allParentsSharedWithPat (tp : Type , tpClassSym : ClassSymbol ): List [Symbol ] = {
119
127
var parents = tpClassSym.info.parents
120
128
if parents.nonEmpty && parents.head.classSymbol == defn.ObjectClass then
121
129
parents = parents.tail
122
130
parents flatMap { tp =>
123
131
val sym = tp.classSymbol.asClass
124
- if patClassSym .derivesFrom(sym) then List (sym)
132
+ if patCls .derivesFrom(sym) then List (sym)
125
133
else allParentsSharedWithPat(tp, sym)
126
134
}
127
135
}
@@ -135,42 +143,55 @@ trait PatternTypeConstrainer { self: TypeComparer =>
135
143
case _ => NoType
136
144
}
137
145
if (upcasted.exists)
138
- constrainSimplePatternType (pat, upcasted, widenParams ) || constrainUpcasted(upcasted)
146
+ tryConstrainSimplePatternType (pat, upcasted) || constrainUpcasted(upcasted)
139
147
else true
140
148
}
141
149
}
142
150
143
- scrut.dealias match {
151
+ def dealiasDropNonmoduleRefs (tp : Type ) = tp.dealias match {
152
+ case tp : TermRef =>
153
+ // we drop TermRefs that don't have a class symbol, as they can't
154
+ // meaningfully participate in GADT reasoning and just get in the way.
155
+ // Their info could, for an example, be an AndType. One example where
156
+ // this is important is an enum case that extends its parent and an
157
+ // additional trait - argument-less enum cases desugar to vals.
158
+ // See run/enum-Tree.scala.
159
+ if tp.classSymbol.exists then tp else tp.info
160
+ case tp => tp
161
+ }
162
+
163
+ dealiasDropNonmoduleRefs(scrut) match {
144
164
case OrType (scrut1, scrut2) =>
145
165
either(constrainPatternType(pat, scrut1), constrainPatternType(pat, scrut2))
146
166
case AndType (scrut1, scrut2) =>
147
167
constrainPatternType(pat, scrut1) && constrainPatternType(pat, scrut2)
148
168
case scrut : RefinedOrRecType =>
149
169
constrainPatternType(pat, stripRefinement(scrut))
150
- case scrut => pat.dealias match {
170
+ case scrut => dealiasDropNonmoduleRefs( pat) match {
151
171
case OrType (pat1, pat2) =>
152
172
either(constrainPatternType(pat1, scrut), constrainPatternType(pat2, scrut))
153
173
case AndType (pat1, pat2) =>
154
174
constrainPatternType(pat1, scrut) && constrainPatternType(pat2, scrut)
155
175
case pat : RefinedOrRecType =>
156
176
constrainPatternType(stripRefinement(pat), scrut)
157
177
case pat =>
158
- constrainSimplePatternType(pat, scrut, widenParams) || classesMayBeCompatible && constrainUpcasted(scrut)
178
+ tryConstrainSimplePatternType(pat, scrut)
179
+ || classesMayBeCompatible && constrainUpcasted(scrut)
159
180
}
160
181
}
161
182
}
162
183
163
184
/** Constrain "simple" patterns (see `constrainPatternType`).
164
185
*
165
- * This function attempts to modify pattern and scrutinee type s.t. the pattern must be a subtype of the scrutinee,
166
- * or otherwise it cannot possibly match. In order to do that, we:
167
- *
168
- * 1. Rely on `constrainPatternType` to break the actual scrutinee/pattern types into subcomponents
169
- * 2. Widen type parameters of scrutinee type that are not invariantly refined (see below) by the pattern type.
170
- * 3. Wrap the pattern type in a skolem to avoid overconstraining top-level abstract types in scrutinee type
171
- * 4. Check that `WidenedScrutineeType <: NarrowedPatternType`
186
+ * This function expects to receive two types (scrutinee and pattern), both
187
+ * of which have class symbols, one of which is derived from another. If the
188
+ * type "being derived from" is an applied type, it will 1) "upcast" the
189
+ * deriving type to an applied type with the same constructor and 2) infer
190
+ * constraints for the applied types' arguments that follow from both
191
+ * types being inhabited by one value (the scrutinee).
172
192
*
173
- * Importantly, note that the pattern type may contain type variables.
193
+ * Importantly, note that the pattern type may contain type variables, which
194
+ * are used to infer type arguments to Unapply trees.
174
195
*
175
196
* ## Invariant refinement
176
197
* Essentially, we say that `D[B] extends C[B]` s.t. refines parameter `A` of `trait C[A]` invariantly if
@@ -194,7 +215,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
194
215
* case classes without also appropriately extending the relevant case class
195
216
* (see `RefChecks#checkCaseClassInheritanceInvariant`).
196
217
*/
197
- def constrainSimplePatternType (patternTp : Type , scrutineeTp : Type , widenParams : Boolean ): Boolean = {
218
+ def constrainSimplePatternType (patternTp : Type , scrutineeTp : Type , forceInvariantRefinement : Boolean ): Boolean = {
198
219
def refinementIsInvariant (tp : Type ): Boolean = tp match {
199
220
case tp : SingletonType => true
200
221
case tp : ClassInfo => tp.cls.is(Final ) || tp.cls.is(Case )
@@ -212,13 +233,53 @@ trait PatternTypeConstrainer { self: TypeComparer =>
212
233
tp
213
234
}
214
235
215
- val widePt =
216
- if migrateTo3 || refinementIsInvariant(patternTp) then scrutineeTp
217
- else if widenParams then widenVariantParams(scrutineeTp)
218
- else scrutineeTp
219
- val narrowTp = SkolemType (patternTp)
220
- trace(i " constraining simple pattern type $narrowTp <:< $widePt" , gadts, res => s " $res\n gadt = ${ctx.gadt.debugBoundsDescription}" ) {
221
- isSubType(narrowTp, widePt)
236
+ val patternCls = patternTp.classSymbol
237
+ val scrutineeCls = scrutineeTp.classSymbol
238
+
239
+ // NOTE: we already know that there is a derives-from relationship in either direction
240
+ val upcastPattern =
241
+ patternCls.derivesFrom(scrutineeCls)
242
+
243
+ val pt = if upcastPattern then patternTp.baseType(scrutineeCls) else patternTp
244
+ val tp = if ! upcastPattern then scrutineeTp.baseType(patternCls) else scrutineeTp
245
+
246
+ val assumeInvariantRefinement =
247
+ migrateTo3 || forceInvariantRefinement || refinementIsInvariant(patternTp)
248
+
249
+ trace(i " constraining simple pattern type $tp >:< $pt" , gadts, res => s " $res\n gadt = ${ctx.gadt.debugBoundsDescription}" ) {
250
+ (tp, pt) match {
251
+ case (AppliedType (tyconS, argsS), AppliedType (tyconP, argsP)) =>
252
+ val saved = state.constraint
253
+ val savedGadt = ctx.gadt.fresh
254
+ val result =
255
+ tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
256
+ val variance = param.paramVarianceSign
257
+ if variance != 0 && ! assumeInvariantRefinement then true
258
+ else if argS.isInstanceOf [TypeBounds ] || argP.isInstanceOf [TypeBounds ] then
259
+ // Passing TypeBounds to isSubType on LHS or RHS does the
260
+ // incorrect thing and infers unsound constraints, while simply
261
+ // returning true is sound. However, I believe that it should
262
+ // still be possible to extract useful constraints here.
263
+ // TODO extract GADT information out of wildcard type arguments
264
+ true
265
+ else {
266
+ var res = true
267
+ if variance < 1 then res &&= isSubType(argS, argP)
268
+ if variance > - 1 then res &&= isSubType(argP, argS)
269
+ res
270
+ }
271
+ }
272
+ if ! result then
273
+ constraint = saved
274
+ ctx.gadt.restore(savedGadt)
275
+ result
276
+ case _ =>
277
+ // Give up if we don't get AppliedType, e.g. if we upcasted to Any.
278
+ // Note that this doesn't mean that patternTp, scrutineeTp cannot possibly
279
+ // be co-inhabited, just that we cannot extract information out of them directly
280
+ // and should upcast.
281
+ false
282
+ }
222
283
}
223
284
}
224
285
}
0 commit comments