@@ -1161,7 +1161,55 @@ class Typer extends Namer
1161
1161
if (tree.isInline) checkInInlineContext(" inline match" , tree.posd)
1162
1162
val sel1 = typedExpr(tree.selector)
1163
1163
val selType = fullyDefinedType(sel1.tpe, " pattern selector" , tree.span).widen
1164
- val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1164
+
1165
+ /** Extractor for match types hidden behind an AppliedType/MatchAlias */
1166
+ object MatchTypeInDisguise {
1167
+ def unapply (tp : AppliedType ): Option [MatchType ] = tp match {
1168
+ case AppliedType (tycon : TypeRef , args) =>
1169
+ tycon.info match {
1170
+ case MatchAlias (alias) =>
1171
+ alias.applyIfParameterized(args) match {
1172
+ case mt : MatchType => Some (mt)
1173
+ case _ => None
1174
+ }
1175
+ case _ => None
1176
+ }
1177
+ case _ => None
1178
+ }
1179
+ }
1180
+
1181
+ /** Does `tree` has the same shape as the given match type?
1182
+ * We only support typed patterns with empty guards, but
1183
+ * that could potentially be extended in the future.
1184
+ */
1185
+ def isMatchTypeShaped (mt : MatchType ): Boolean =
1186
+ mt.cases.size == tree.cases.size &&
1187
+ sel1.tpe.frozen_<:< (mt.scrutinee) &&
1188
+ tree.cases.forall(_.guard.isEmpty) &&
1189
+ tree.cases
1190
+ .map(cas => untpd.unbind(untpd.unsplice(cas.pat)))
1191
+ .zip(mt.cases)
1192
+ .forall {
1193
+ case (pat : Typed , pt) =>
1194
+ // To check that pattern types correspond we need to type
1195
+ // check `pat` here and throw away the result.
1196
+ val gadtCtx : Context = ctx.fresh.setFreshGADTBounds
1197
+ val pat1 = typedPattern(pat, selType)(gadtCtx)
1198
+ val Typed (_, tpt) = tpd.unbind(tpd.unsplice(pat1))
1199
+ instantiateMatchTypeProto(pat1, pt) match {
1200
+ case defn.MatchCase (patternTp, _) => tpt.tpe frozen_=:= patternTp
1201
+ case _ => false
1202
+ }
1203
+ case _ => false
1204
+ }
1205
+
1206
+ val result = pt match {
1207
+ case MatchTypeInDisguise (mt) if isMatchTypeShaped(mt) =>
1208
+ typedDependentMatchFinish(tree, sel1, selType, tree.cases, mt)
1209
+ case _ =>
1210
+ typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1211
+ }
1212
+
1165
1213
result match {
1166
1214
case Match (sel, CaseDef (pat, _, _) :: _) =>
1167
1215
tree.selector.removeAttachment(desugar.CheckIrrefutable ) match {
@@ -1177,6 +1225,21 @@ class Typer extends Namer
1177
1225
result
1178
1226
}
1179
1227
1228
+ /** Special typing of Match tree when the expected type is a MatchType,
1229
+ * and the patterns of the Match tree and the MatchType correspond.
1230
+ */
1231
+ def typedDependentMatchFinish (tree : untpd.Match , sel : Tree , wideSelType : Type , cases : List [untpd.CaseDef ], pt : MatchType )(implicit ctx : Context ): Tree = {
1232
+ var caseCtx = ctx
1233
+ val cases1 = tree.cases.zip(pt.cases)
1234
+ .map { case (cas, tpe) =>
1235
+ val case1 = typedCase(cas, sel, wideSelType, tpe)(given caseCtx )
1236
+ caseCtx = Nullables .afterPatternContext(sel, case1.pat)
1237
+ case1
1238
+ }
1239
+ .asInstanceOf [List [CaseDef ]]
1240
+ assignType(cpy.Match (tree)(sel, cases1), sel, cases1).cast(pt)
1241
+ }
1242
+
1180
1243
// Overridden in InlineTyper for inline matches
1181
1244
def typedMatchFinish (tree : untpd.Match , sel : Tree , wideSelType : Type , cases : List [untpd.CaseDef ], pt : Type )(implicit ctx : Context ): Tree = {
1182
1245
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
@@ -1216,17 +1279,33 @@ class Typer extends Namer
1216
1279
}
1217
1280
}
1218
1281
1282
+ /** If the prototype `pt` is the type lambda (when doing a dependent
1283
+ * typing of a match), instantiate that type lambda with the pattern
1284
+ * variables found in the pattern `pat`.
1285
+ */
1286
+ def instantiateMatchTypeProto (pat : Tree , pt : Type )(implicit ctx : Context ) = pt match {
1287
+ case caseTp : HKTypeLambda =>
1288
+ val bindingsSyms = tpd.patVars(pat).reverse
1289
+ val bindingsTps = bindingsSyms.collect { case sym if sym.isType => sym.typeRef }
1290
+ caseTp.appliedTo(bindingsTps)
1291
+ case pt => pt
1292
+ }
1293
+
1219
1294
/** Type a case. */
1220
1295
def typedCase (tree : untpd.CaseDef , sel : Tree , wideSelType : Type , pt : Type )(implicit ctx : Context ): CaseDef = {
1221
1296
val originalCtx = ctx
1222
1297
val gadtCtx : Context = ctx.fresh.setFreshGADTBounds
1223
1298
1224
1299
def caseRest (pat : Tree )(implicit ctx : Context ) = {
1300
+ val pt1 = instantiateMatchTypeProto(pat, pt) match {
1301
+ case defn.MatchCase (_, bodyPt) => bodyPt
1302
+ case pt => pt
1303
+ }
1225
1304
val pat1 = indexPattern(tree).transform(pat)
1226
1305
val guard1 = typedExpr(tree.guard, defn.BooleanType )
1227
- var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt ), pt , ctx.scope.toList)
1228
- if (pt .isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1229
- body1 = body1.ensureConforms(pt )(originalCtx)
1306
+ var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1 ), pt1 , ctx.scope.toList)
1307
+ if (pt1 .isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1308
+ body1 = body1.ensureConforms(pt1 )(originalCtx)
1230
1309
assignType(cpy.CaseDef (tree)(pat1, guard1, body1), pat1, body1)
1231
1310
}
1232
1311
0 commit comments