Skip to content

Commit 6eac6be

Browse files
Special case typing matches w/ match type protos
1 parent 7dc8efc commit 6eac6be

File tree

1 file changed

+83
-4
lines changed

1 file changed

+83
-4
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1161,7 +1161,55 @@ class Typer extends Namer
11611161
if (tree.isInline) checkInInlineContext("inline match", tree.posd)
11621162
val sel1 = typedExpr(tree.selector)
11631163
val selType = fullyDefinedType(sel1.tpe, "pattern selector", tree.span).widen
1164-
val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1164+
1165+
/** Extractor for match types hidden behind an AppliedType/MatchAlias */
1166+
object MatchTypeInDisguise {
1167+
def unapply(tp: AppliedType): Option[MatchType] = tp match {
1168+
case AppliedType(tycon: TypeRef, args) =>
1169+
tycon.info match {
1170+
case MatchAlias(alias) =>
1171+
alias.applyIfParameterized(args) match {
1172+
case mt: MatchType => Some(mt)
1173+
case _ => None
1174+
}
1175+
case _ => None
1176+
}
1177+
case _ => None
1178+
}
1179+
}
1180+
1181+
/** Does `tree` has the same shape as the given match type?
1182+
* We only support typed patterns with empty guards, but
1183+
* that could potentially be extended in the future.
1184+
*/
1185+
def isMatchTypeShaped(mt: MatchType): Boolean =
1186+
mt.cases.size == tree.cases.size &&
1187+
sel1.tpe.frozen_<:<(mt.scrutinee) &&
1188+
tree.cases.forall(_.guard.isEmpty) &&
1189+
tree.cases
1190+
.map(cas => untpd.unbind(untpd.unsplice(cas.pat)))
1191+
.zip(mt.cases)
1192+
.forall {
1193+
case (pat: Typed, pt) =>
1194+
// To check that pattern types correspond we need to type
1195+
// check `pat` here and throw away the result.
1196+
val gadtCtx: Context = ctx.fresh.setFreshGADTBounds
1197+
val pat1 = typedPattern(pat, selType)(gadtCtx)
1198+
val Typed(_, tpt) = tpd.unbind(tpd.unsplice(pat1))
1199+
instantiateMatchTypeProto(pat1, pt) match {
1200+
case defn.MatchCase(patternTp, _) => tpt.tpe frozen_=:= patternTp
1201+
case _ => false
1202+
}
1203+
case _ => false
1204+
}
1205+
1206+
val result = pt match {
1207+
case MatchTypeInDisguise(mt) if isMatchTypeShaped(mt) =>
1208+
typedDependentMatchFinish(tree, sel1, selType, tree.cases, mt)
1209+
case _ =>
1210+
typedMatchFinish(tree, sel1, selType, tree.cases, pt)
1211+
}
1212+
11651213
result match {
11661214
case Match(sel, CaseDef(pat, _, _) :: _) =>
11671215
tree.selector.removeAttachment(desugar.CheckIrrefutable) match {
@@ -1177,6 +1225,21 @@ class Typer extends Namer
11771225
result
11781226
}
11791227

1228+
/** Special typing of Match tree when the expected type is a MatchType,
1229+
* and the patterns of the Match tree and the MatchType correspond.
1230+
*/
1231+
def typedDependentMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: MatchType)(implicit ctx: Context): Tree = {
1232+
var caseCtx = ctx
1233+
val cases1 = tree.cases.zip(pt.cases)
1234+
.map { case (cas, tpe) =>
1235+
val case1 = typedCase(cas, sel, wideSelType, tpe)(given caseCtx)
1236+
caseCtx = Nullables.afterPatternContext(sel, case1.pat)
1237+
case1
1238+
}
1239+
.asInstanceOf[List[CaseDef]]
1240+
assignType(cpy.Match(tree)(sel, cases1), sel, cases1).cast(pt)
1241+
}
1242+
11801243
// Overridden in InlineTyper for inline matches
11811244
def typedMatchFinish(tree: untpd.Match, sel: Tree, wideSelType: Type, cases: List[untpd.CaseDef], pt: Type)(implicit ctx: Context): Tree = {
11821245
val cases1 = harmonic(harmonize, pt)(typedCases(cases, sel, wideSelType, pt.dropIfProto))
@@ -1216,17 +1279,33 @@ class Typer extends Namer
12161279
}
12171280
}
12181281

1282+
/** If the prototype `pt` is the type lambda (when doing a dependent
1283+
* typing of a match), instantiate that type lambda with the pattern
1284+
* variables found in the pattern `pat`.
1285+
*/
1286+
def instantiateMatchTypeProto(pat: Tree, pt: Type)(implicit ctx: Context) = pt match {
1287+
case caseTp: HKTypeLambda =>
1288+
val bindingsSyms = tpd.patVars(pat).reverse
1289+
val bindingsTps = bindingsSyms.collect { case sym if sym.isType => sym.typeRef }
1290+
caseTp.appliedTo(bindingsTps)
1291+
case pt => pt
1292+
}
1293+
12191294
/** Type a case. */
12201295
def typedCase(tree: untpd.CaseDef, sel: Tree, wideSelType: Type, pt: Type)(implicit ctx: Context): CaseDef = {
12211296
val originalCtx = ctx
12221297
val gadtCtx: Context = ctx.fresh.setFreshGADTBounds
12231298

12241299
def caseRest(pat: Tree)(implicit ctx: Context) = {
1300+
val pt1 = instantiateMatchTypeProto(pat, pt) match {
1301+
case defn.MatchCase(_, bodyPt) => bodyPt
1302+
case pt => pt
1303+
}
12251304
val pat1 = indexPattern(tree).transform(pat)
12261305
val guard1 = typedExpr(tree.guard, defn.BooleanType)
1227-
var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt), pt, ctx.scope.toList)
1228-
if (pt.isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1229-
body1 = body1.ensureConforms(pt)(originalCtx)
1306+
var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1), pt1, ctx.scope.toList)
1307+
if (pt1.isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
1308+
body1 = body1.ensureConforms(pt1)(originalCtx)
12301309
assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1)
12311310
}
12321311

0 commit comments

Comments
 (0)