Skip to content

Commit ec5a86c

Browse files
committed
Back to nested/peeped GadtExpr (wip)
1 parent ad67b7a commit ec5a86c

File tree

15 files changed

+101
-124
lines changed

15 files changed

+101
-124
lines changed

compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -125,24 +125,22 @@ class TreeTypeMap(
125125
transformInlined(inlined)
126126
case GadtExpr(gadt, expr) =>
127127
cpy.GadtExpr(expr)(gadt, transform(expr))
128-
case cdef @ CaseDef(pat, guard, rhs) =>
128+
case cdef @ CaseDef(pat, guard, expr @ GadtExpr(gadt, rhs)) =>
129129
val patVars1 = patVars(pat)
130-
cdef.gadt match
131-
case EmptyGadtConstraint =>
132-
val tmap = withMappedSyms(patVars1)
133-
val pat1 = tmap.transform(pat)
134-
val guard1 = tmap.transform(guard)
135-
val rhs1 = tmap.transform(rhs)
136-
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
137-
case _ =>
138-
val tmap = withMappedSyms(patVars1 ::: cdef.gadt.symbols.diff(substFrom).diff(patVars1))
139-
val gadt1 = tmap.rebuild(cdef.gadt)
140-
inContext(ctx.withGadt(gadt1)) {
141-
val pat1 = tmap.transform(pat)
142-
val guard1 = tmap.transform(guard)
143-
val rhs1 = tmap.transform(rhs)
144-
cpy.CaseDef(cdef)(pat1, guard1, rhs1, gadt1)
145-
}
130+
val tmap = withMappedSyms(patVars1 ::: gadt.symbols.diff(substFrom).diff(patVars1))
131+
val gadt1 = tmap.rebuild(gadt)
132+
inContext(ctx.withGadt(gadt1)) {
133+
val pat1 = tmap.transform(pat)
134+
val guard1 = tmap.transform(guard)
135+
val rhs1 = cpy.GadtExpr(expr)(gadt1, tmap.transform(rhs))
136+
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
137+
}
138+
case cdef @ CaseDef(pat, guard, rhs) =>
139+
val tmap = withMappedSyms(patVars(pat))
140+
val pat1 = tmap.transform(pat)
141+
val guard1 = tmap.transform(guard)
142+
val rhs1 = tmap.transform(rhs)
143+
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
146144
case labeled @ Labeled(bind, expr) =>
147145
val tmap = withMappedSyms(bind.symbol :: Nil)
148146
val bind1 = tmap.transformSub(bind)

compiler/src/dotty/tools/dotc/ast/Trees.scala

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ object Trees {
614614
}
615615

616616
/** case pat if guard => body */
617-
case class CaseDef[-T >: Untyped] private[ast] (pat: Tree[T], guard: Tree[T], body: Tree[T])(val gadt: GadtConstraint)(implicit @constructorOnly src: SourceFile)
617+
case class CaseDef[-T >: Untyped] private[ast] (pat: Tree[T], guard: Tree[T], body: Tree[T])(implicit @constructorOnly src: SourceFile)
618618
extends Tree[T] {
619619
type ThisTree[-T >: Untyped] = CaseDef[T]
620620
}
@@ -1233,9 +1233,9 @@ object Trees {
12331233
case tree: InlineMatch => finalize(tree, untpd.InlineMatch(selector, cases)(sourceFile(tree)))
12341234
case _ => finalize(tree, untpd.Match(selector, cases)(sourceFile(tree)))
12351235
}
1236-
def CaseDef(tree: Tree)(pat: Tree, guard: Tree, body: Tree, gadt: GadtConstraint)(using Context): CaseDef = tree match {
1236+
def CaseDef(tree: Tree)(pat: Tree, guard: Tree, body: Tree)(using Context): CaseDef = tree match {
12371237
case tree: CaseDef if (pat eq tree.pat) && (guard eq tree.guard) && (body eq tree.body) => tree
1238-
case _ => finalize(tree, untpd.CaseDef(pat, guard, body, gadt)(sourceFile(tree)))
1238+
case _ => finalize(tree, untpd.CaseDef(pat, guard, body)(sourceFile(tree)))
12391239
}
12401240
def Labeled(tree: Tree)(bind: Bind, expr: Tree)(using Context): Labeled = tree match {
12411241
case tree: Labeled if (bind eq tree.bind) && (expr eq tree.expr) => tree
@@ -1355,8 +1355,8 @@ object Trees {
13551355
If(tree: Tree)(cond, thenp, elsep)
13561356
def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =
13571357
Closure(tree: Tree)(env, meth, tpt)
1358-
def CaseDef(tree: CaseDef)(pat: Tree = tree.pat, guard: Tree = tree.guard, body: Tree = tree.body, gadt: GadtConstraint = tree.gadt)(using Context): CaseDef =
1359-
CaseDef(tree: Tree)(pat, guard, body, gadt)
1358+
def CaseDef(tree: CaseDef)(pat: Tree = tree.pat, guard: Tree = tree.guard, body: Tree = tree.body)(using Context): CaseDef =
1359+
CaseDef(tree: Tree)(pat, guard, body)
13601360
def Try(tree: Try)(expr: Tree = tree.expr, cases: List[CaseDef] = tree.cases, finalizer: Tree = tree.finalizer)(using Context): Try =
13611361
Try(tree: Tree)(expr, cases, finalizer)
13621362
def UnApply(tree: UnApply)(fun: Tree = tree.fun, implicits: List[Tree] = tree.implicits, patterns: List[Tree] = tree.patterns)(using Context): UnApply =
@@ -1442,8 +1442,8 @@ object Trees {
14421442
cpy.Match(tree)(transform(selector), transformSub(cases))
14431443
case GadtExpr(gadt, expr) =>
14441444
inContext(ctx.withGadt(gadt))(cpy.GadtExpr(tree)(gadt, transform(expr)))
1445-
case cdef @ CaseDef(pat, guard, body) =>
1446-
inContext(ctx.withGadt(cdef.gadt))(cpy.CaseDef(cdef)(transform(pat), transform(guard), transform(body)))
1445+
case CaseDef(pat, guard, body) =>
1446+
cpy.CaseDef(tree)(transform(pat), transform(guard), transform(body))
14471447
case Labeled(bind, expr) =>
14481448
cpy.Labeled(tree)(transformSub(bind), transform(expr))
14491449
case Return(expr, from) =>
@@ -1580,8 +1580,8 @@ object Trees {
15801580
this(this(x, selector), cases)
15811581
case GadtExpr(gadt, expr) =>
15821582
inContext(ctx.withGadt(gadt))(this(x, expr))
1583-
case cdef @ CaseDef(pat, guard, body) =>
1584-
inContext(ctx.withGadt(cdef.gadt))(this(this(this(x, pat), guard), body))
1583+
case CaseDef(pat, guard, body) =>
1584+
this(this(this(x, pat), guard), body)
15851585
case Labeled(bind, expr) =>
15861586
this(this(x, bind), expr)
15871587
case Return(expr, from) =>

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
131131
Closure(meth, tss => rhsFn(tss.head).changeOwner(ctx.owner, meth))
132132
}
133133

134-
def CaseDef(pat: Tree, guard: Tree, body: Tree, gadt: GadtConstraint = EmptyGadtConstraint)(using Context): CaseDef =
135-
ta.assignType(untpd.CaseDef(pat, guard, body, gadt), pat, body)
134+
def CaseDef(pat: Tree, guard: Tree, body: Tree)(using Context): CaseDef =
135+
ta.assignType(untpd.CaseDef(pat, guard, body), pat, body)
136136

137137
def Match(selector: Tree, cases: List[CaseDef])(using Context): Match =
138138
ta.assignType(untpd.Match(selector, cases), selector, cases)
@@ -713,8 +713,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
713713
}
714714
}
715715

716-
override def CaseDef(tree: Tree)(pat: Tree, guard: Tree, body: Tree, gadt: GadtConstraint)(using Context): CaseDef = {
717-
val tree1 = untpdCpy.CaseDef(tree)(pat, guard, body, gadt)
716+
override def CaseDef(tree: Tree)(pat: Tree, guard: Tree, body: Tree)(using Context): CaseDef = {
717+
val tree1 = untpdCpy.CaseDef(tree)(pat, guard, body)
718718
tree match {
719719
case tree: CaseDef if body.tpe eq tree.body.tpe => tree1.withTypeUnchecked(tree.tpe)
720720
case _ => ta.assignType(tree1, pat, body)
@@ -770,8 +770,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
770770
If(tree: Tree)(cond, thenp, elsep)
771771
override def Closure(tree: Closure)(env: List[Tree] = tree.env, meth: Tree = tree.meth, tpt: Tree = tree.tpt)(using Context): Closure =
772772
Closure(tree: Tree)(env, meth, tpt)
773-
override def CaseDef(tree: CaseDef)(pat: Tree = tree.pat, guard: Tree = tree.guard, body: Tree = tree.body, gadt: GadtConstraint = tree.gadt)(using Context): CaseDef =
774-
CaseDef(tree: Tree)(pat, guard, body, gadt)
773+
override def CaseDef(tree: CaseDef)(pat: Tree = tree.pat, guard: Tree = tree.guard, body: Tree = tree.body)(using Context): CaseDef =
774+
CaseDef(tree: Tree)(pat, guard, body)
775775
override def Try(tree: Try)(expr: Tree = tree.expr, cases: List[CaseDef] = tree.cases, finalizer: Tree = tree.finalizer)(using Context): Try =
776776
Try(tree: Tree)(expr, cases, finalizer)
777777
}

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
388388
def Closure(env: List[Tree], meth: Tree, tpt: Tree)(implicit src: SourceFile): Closure = new Closure(env, meth, tpt)
389389
def Match(selector: Tree, cases: List[CaseDef])(implicit src: SourceFile): Match = new Match(selector, cases)
390390
def InlineMatch(selector: Tree, cases: List[CaseDef])(implicit src: SourceFile): Match = new InlineMatch(selector, cases)
391-
def CaseDef(pat: Tree, guard: Tree, body: Tree, gadt: GadtConstraint = EmptyGadtConstraint)(implicit src: SourceFile): CaseDef = new CaseDef(pat, guard, body)(gadt)
391+
def CaseDef(pat: Tree, guard: Tree, body: Tree)(implicit src: SourceFile): CaseDef = new CaseDef(pat, guard, body)
392392
def Labeled(bind: Bind, expr: Tree)(implicit src: SourceFile): Labeled = new Labeled(bind, expr)
393393
def Return(expr: Tree, from: Tree)(implicit src: SourceFile): Return = new Return(expr, from)
394394
def WhileDo(cond: Tree, body: Tree)(implicit src: SourceFile): WhileDo = new WhileDo(cond, body)

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,21 @@ class TreePickler(pickler: TastyPickler) {
472472
writeByte(BLOCK)
473473
stats.foreach(preRegister)
474474
withLength { pickleTree(expr); stats.foreach(pickleTree) }
475+
case GadtExpr(gadt, expr) =>
476+
writeByte(GADTEXPR)
477+
withLength {
478+
for (symbols, nestingLevel) <- gadt.inputs do
479+
writeByte(CONSTRAINT)
480+
withLength {
481+
writeInt(nestingLevel)
482+
for sym <- symbols do
483+
val TypeBounds(lo, hi) = gadt.fullBounds(sym).nn
484+
pickleSymRef(sym)
485+
pickleType(lo)
486+
pickleType(hi)
487+
}
488+
pickleTree(expr)
489+
}
475490
case tree @ If(cond, thenp, elsep) =>
476491
writeByte(IF)
477492
withLength {
@@ -496,23 +511,9 @@ class TreePickler(pickler: TastyPickler) {
496511
else pickleTree(selector)
497512
tree.cases.foreach(pickleTree)
498513
}
499-
case tree @ CaseDef(pat, guard, rhs) =>
514+
case CaseDef(pat, guard, rhs) =>
500515
writeByte(CASEDEF)
501-
withLength {
502-
for (symbols, nestingLevel) <- tree.gadt.inputs do
503-
writeByte(CONSTRAINT)
504-
withLength {
505-
writeInt(nestingLevel)
506-
for sym <- symbols do
507-
val TypeBounds(lo, hi) = tree.gadt.fullBounds(sym).nn
508-
pickleSymRef(sym)
509-
pickleType(lo)
510-
pickleType(hi)
511-
}
512-
pickleTree(pat)
513-
pickleTree(rhs)
514-
pickleTreeUnlessEmpty(guard)
515-
}
516+
withLength { pickleTree(pat); pickleTree(rhs); pickleTreeUnlessEmpty(guard) }
516517
case Return(expr, from) =>
517518
writeByte(RETURN)
518519
withLength { pickleSymRef(from.symbol); pickleTreeUnlessEmpty(expr) }

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,27 @@ class TreeUnpickler(reader: TastyReader,
12451245
skipTree()
12461246
readStats(ctx.owner, end,
12471247
(stats, ctx) => Block(stats, exprReader.readTerm()(using ctx)))
1248+
case GADTEXPR =>
1249+
val gadt = EmptyGadtConstraint.fresh
1250+
while nextByte == CONSTRAINT do
1251+
readByte()
1252+
val end = readEnd()
1253+
val nestingLevel = readInt()
1254+
val constraints = until(end)((readSymRef(), readType(), readType()))
1255+
gadt.addToConstraint(constraints.map(_._1), nestingLevel)
1256+
for (sym, lo, hi) <- constraints do
1257+
if (sym.typeRef <:< lo)(using ctx.withGadt(gadt)) then
1258+
// add in reverse order so that unification runs in the right direction (keep sym)
1259+
// for a counter-example: say the symbol is c: b and the bound is b
1260+
// if we add c >: b it will unify to b: c not c: b
1261+
gadt.addBound(sym, hi, isUpper = true)
1262+
gadt.addBound(sym, lo, isUpper = false)
1263+
else
1264+
gadt.addBound(sym, lo, isUpper = false)
1265+
gadt.addBound(sym, hi, isUpper = true)
1266+
end while
1267+
val expr = inContext(ctx.withGadt(gadt))(readTerm())
1268+
GadtExpr(gadt, expr)
12481269
case INLINED =>
12491270
val exprReader = fork
12501271
skipTree()
@@ -1442,32 +1463,10 @@ class TreeUnpickler(reader: TastyReader,
14421463
val start = currentAddr
14431464
assert(readByte() == CASEDEF)
14441465
val end = readEnd()
1445-
val originalCtx = ctx
1446-
val gadt = if nextByte == CONSTRAINT then EmptyGadtConstraint.fresh else originalCtx.gadt
1447-
while nextByte == CONSTRAINT do
1448-
readByte()
1449-
val end = readEnd()
1450-
val nestingLevel = readInt()
1451-
val constraints = until(end)((readSymRef(), readType(), readType()))
1452-
gadt.addToConstraint(constraints.map(_._1), nestingLevel)
1453-
for (sym, lo, hi) <- constraints do
1454-
if (sym.typeRef <:< lo)(using ctx.withGadt(gadt)) then
1455-
// add in reverse order so that unification runs in the right direction (keep sym)
1456-
// for a counter-example: say the symbol is c: b and the bound is b
1457-
// if we add c >: b it will unify to b: c not c: b
1458-
gadt.addBound(sym, hi, isUpper = true)
1459-
gadt.addBound(sym, lo, isUpper = false)
1460-
else
1461-
gadt.addBound(sym, lo, isUpper = false)
1462-
gadt.addBound(sym, hi, isUpper = true)
1463-
end while
1464-
inContext(ctx.withGadt(gadt)) {
1465-
val pat = readTerm()
1466-
val rhs = readTerm()
1467-
val guard = ifBefore(end)(readTerm(), EmptyTree)
1468-
val gadt1 = if gadt eq originalCtx.gadt then EmptyGadtConstraint else gadt
1469-
setSpan(start, CaseDef(pat, guard, rhs, gadt1))
1470-
}
1466+
val pat = readTerm()
1467+
val rhs = readTerm()
1468+
val guard = ifBefore(end)(readTerm(), EmptyTree)
1469+
setSpan(start, CaseDef(pat, guard, rhs))
14711470
}
14721471

14731472
def readLater[T <: AnyRef](end: Addr, op: TreeReader => Context ?=> T)(using Context): Trees.Lazy[T] =

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -494,11 +494,8 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
494494
else toText(sel)
495495
selTxt ~ keywordStr(" match ") ~ blockText(cases)
496496
}
497-
case cdef @ CaseDef(pat, guard, body) =>
498-
keywordStr("case ") ~ inPattern(toText(pat))
499-
~ (" ~ " ~ toText(cdef.gadt)).provided(cdef.gadt != EmptyGadtConstraint)
500-
~ optText(guard)(keywordStr(" if ") ~ _)
501-
~ " => " ~ caseBlockText(body)
497+
case CaseDef(pat, guard, body) =>
498+
keywordStr("case ") ~ inPattern(toText(pat)) ~ optText(guard)(keywordStr(" if ") ~ _) ~ " => " ~ caseBlockText(body)
502499
case Labeled(bind, expr) =>
503500
changePrec(GlobalPrec) { toText(bind.name) ~ keywordStr("[") ~ toText(bind.symbol.info) ~ keywordStr("]: ") ~ toText(expr) }
504501
case Return(expr, from) =>

compiler/src/dotty/tools/dotc/transform/Erasure.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -816,17 +816,6 @@ object Erasure {
816816
}
817817
}
818818

819-
override def typedGadtExpr(tree: untpd.GadtExpr, pt: Type)(using Context): GadtExpr = tree match {
820-
case GadtExpr(gadt, expr @ If(cond, _, _)) =>
821-
// type the condition without installing the gadt constraints
822-
// so that TypeTestsCasts can correctly check type tests
823-
val cond1 = typed(cond, defn.BooleanType)
824-
val expr1 = cpy.If(expr.withType(expr.tpe))(cond = cond1)
825-
val gadt1 = cpy.GadtExpr(tree.withType(tree.tpe))(gadt, expr1)
826-
super.typedGadtExpr(gadt1, pt)
827-
case _ => super.typedGadtExpr(tree, pt)
828-
}
829-
830819
/** Besides normal typing, this method does uncurrying and collects parameters
831820
* to anonymous functions of arity > 22.
832821
*/

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ object PatternMatcher {
147147
sealed abstract class Plan { val id: Int = nxId; nxId += 1 }
148148

149149
case class TestPlan(test: Test, var scrutinee: Tree, span: Span,
150-
var onSuccess: Plan)(var gadt: GadtConstraint) extends Plan {
150+
var onSuccess: Plan) extends Plan {
151151
override def equals(that: Any): Boolean = that match {
152152
case that: TestPlan => this.scrutinee === that.scrutinee && this.test == that.test
153153
case _ => false
@@ -164,8 +164,6 @@ object PatternMatcher {
164164
object TestPlan {
165165
def apply(test: Test, sym: Symbol, span: Span, ons: Plan): TestPlan =
166166
TestPlan(test, ref(sym), span, ons)
167-
def apply(test: Test, scr: Tree, span: Span, ons: Plan): TestPlan =
168-
TestPlan(test, scr, span, ons)(EmptyGadtConstraint)
169167
}
170168

171169
/** The different kinds of tests */
@@ -455,11 +453,7 @@ object PatternMatcher {
455453
var onSuccess: Plan = ResultPlan(cdef.body)
456454
if (!cdef.guard.isEmpty)
457455
onSuccess = TestPlan(GuardTest, cdef.guard, cdef.guard.span, onSuccess)
458-
patternPlan(scrutinee, cdef.pat, onSuccess) match
459-
case plan: TestPlan =>
460-
plan.gadt = cdef.gadt
461-
plan
462-
case plan => plan
456+
patternPlan(scrutinee, cdef.pat, onSuccess)
463457
}
464458

465459
private def matchPlan(tree: Match): Plan =
@@ -936,8 +930,7 @@ object PatternMatcher {
936930
If(conditions, emit(plan.onSuccess), unitLiteral)
937931
}
938932
}
939-
val tree = emitWithMashedConditions(plan :: Nil)
940-
if plan.gadt == EmptyGadtConstraint then tree else GadtExpr(plan.gadt, tree)
933+
emitWithMashedConditions(plan :: Nil)
941934
case LetPlan(sym, body) =>
942935
val valDef = ValDef(sym, initializer(sym).ensureConforms(sym.info), inferred = true).withSpan(sym.span)
943936
seq(valDef :: Nil, emit(body))

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,13 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
407407
case tree: New if isCheckable(tree) =>
408408
Checking.checkInstantiable(tree.tpe, tree.srcPos)
409409
super.transform(tree)
410+
case cdef @ CaseDef(pat, guard, GadtExpr(gadt, body)) =>
411+
inContext(ctx.withGadt(gadt)) {
412+
val pat1 = transform(pat)
413+
val guard1 = transform(guard)
414+
val body1 = transform(body)
415+
cpy.CaseDef(cdef)(pat1, guard1, GadtExpr(gadt, body1))
416+
}
410417
case tree: Closure if !tree.tpt.isEmpty =>
411418
Checking.checkRealizable(tree.tpt.tpe, tree.srcPos, "SAM type")
412419
super.transform(tree)

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,9 @@ abstract class Recheck extends Phase, SymTransformer:
266266
TypeComparer.lub(casesTypes)
267267

268268
def recheckCase(tree: CaseDef, selType: Type, pt: Type)(using Context): Type =
269-
inContext(ctx.withGadt(if tree.gadt == EmptyGadtConstraint then ctx.gadt else tree.gadt)) {
270-
recheck(tree.pat, selType)
271-
recheck(tree.guard, defn.BooleanType)
272-
recheck(tree.body, pt)
273-
}
269+
recheck(tree.pat, selType)
270+
recheck(tree.guard, defn.BooleanType)
271+
recheck(tree.body, pt)
274272

275273
def recheckReturn(tree: Return)(using Context): Type =
276274
// Avoid local pattern defined symbols in returns from matchResult blocks

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1708,9 +1708,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
17081708
/** Type a case. */
17091709
def typedCase(tree: untpd.CaseDef, sel: Tree, wideSelType: Type, pt: Type)(using Context): CaseDef = {
17101710
val originalCtx = ctx
1711-
val previousGadt = tree.gadt match
1712-
case EmptyGadtConstraint => originalCtx.gadt
1713-
case _ => tree.gadt
1711+
val previousGadt = tree.body match
1712+
case GadtExpr(gadt, _) => gadt
1713+
case _ => originalCtx.gadt
17141714
val gadtCtx: Context = ctx.fresh.setGadt(previousGadt.fresh).setNewScope
17151715

17161716
def caseRest(pat: Tree)(using Context) = {
@@ -1723,10 +1723,9 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
17231723
var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1), pt1, ctx.scope.toList)
17241724
if (pt1.isValueType) // insert a cast if body does not conform to expected type if we disregard gadt bounds
17251725
body1 = body1.ensureConforms(pt1)(using originalCtx)
1726-
val gadt1 =
1727-
if (previousGadt eq originalCtx.gadt) && ctx.gadt.eql(previousGadt) then EmptyGadtConstraint
1728-
else ctx.gadt
1729-
assignType(cpy.CaseDef(tree)(pat1, guard1, body1, gadt1), pat1, body1)
1726+
if (previousGadt ne originalCtx.gadt) || !ctx.gadt.eql(previousGadt) then
1727+
body1 = GadtExpr(ctx.gadt, body1)
1728+
assignType(cpy.CaseDef(tree)(pat1, guard1, body1), pat1, body1)
17301729
}
17311730

17321731
val pat1 = typedPattern(tree.pat, wideSelType)(using gadtCtx)

0 commit comments

Comments
 (0)