Skip to content

Commit 234c281

Browse files
committed
Fix #4323: enter class type parameters in GADT bounds
Previously, GADTMap only accepts type bounds, it means we need to force type parameters in order to get their bounds. However forcing type parameters creates a cycle: constr --------> bodyIndex /\ | | | | | | | | | | | type params ----------- We change GADTMap to accept `TermRef` to avoid forcing type params, which breaks the cycle.
1 parent 78ab82d commit 234c281

File tree

9 files changed

+52
-11
lines changed

9 files changed

+52
-11
lines changed

compiler/src/dotty/tools/dotc/config/Settings.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ object Settings {
137137
else update((argRest split ",").toList, args)
138138
case (StringTag, _) if choices.nonEmpty =>
139139
if (argRest.isEmpty) missingArg
140-
else if (!choices.contains(argRest))
140+
else if (!choices.contains(argRest.asInstanceOf[T]))
141141
fail(s"$arg is not a valid choice for $name", args)
142142
else update(argRest, args)
143143
case (StringTag, arg :: args) if name == "-d" =>

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -696,14 +696,19 @@ object Contexts {
696696
// @sharable val theBase = new ContextBase // !!! DEBUG, so that we can use a minimal context for reporting even in code that normally cannot access a context
697697
}
698698

699-
class GADTMap(initBounds: SimpleIdentityMap[Symbol, TypeBounds]) extends util.DotClass {
699+
class GADTMap(initBounds: SimpleIdentityMap[Symbol, Type]) extends util.DotClass {
700700
private[this] var myBounds = initBounds
701-
def setBounds(sym: Symbol, b: TypeBounds): Unit =
701+
def setBounds(sym: Symbol, b: Type): Unit =
702702
myBounds = myBounds.updated(sym, b)
703703
def bounds = myBounds
704+
def get(sym: Symbol)(implicit ctx: Context): TypeBounds = myBounds(sym) match {
705+
case tref: TypeRef => tref.info.bounds
706+
case tb: TypeBounds => tb
707+
case null => null
708+
}
704709
}
705710

706711
@sharable object EmptyGADTMap extends GADTMap(SimpleIdentityMap.Empty) {
707-
override def setBounds(sym: Symbol, b: TypeBounds) = unsupported("EmptyGADTMap.setBounds")
712+
override def setBounds(sym: Symbol, b: Type) = unsupported("EmptyGADTMap.setBounds")
708713
}
709714
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
366366
def thirdTryNamed(tp2: NamedType): Boolean = tp2.info match {
367367
case TypeBounds(lo2, _) =>
368368
def compareGADT: Boolean = {
369-
val gbounds2 = ctx.gadt.bounds(tp2.symbol)
369+
val gbounds2 = ctx.gadt.get(tp2.symbol)
370370
(gbounds2 != null) &&
371371
(isSubTypeWhenFrozen(tp1, gbounds2.lo) ||
372372
narrowGADTBounds(tp2, tp1, approx, isUpper = false)) &&
@@ -578,7 +578,7 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
578578
tp1.info match {
579579
case TypeBounds(_, hi1) =>
580580
def compareGADT = {
581-
val gbounds1 = ctx.gadt.bounds(tp1.symbol)
581+
val gbounds1 = ctx.gadt.get(tp1.symbol)
582582
(gbounds1 != null) &&
583583
(isSubTypeWhenFrozen(gbounds1.hi, tp2) ||
584584
narrowGADTBounds(tp1, tp2, approx, isUpper = true)) &&
@@ -1112,7 +1112,7 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
11121112
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
11131113
if (bound.isRef(tparam)) false
11141114
else {
1115-
val oldBounds = ctx.gadt.bounds(tparam)
1115+
val oldBounds = ctx.gadt.get(tparam)
11161116
val newBounds =
11171117
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
11181118
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)

compiler/src/dotty/tools/dotc/printing/Formatting.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ object Formatting {
170170
case sym: Symbol =>
171171
val info =
172172
if (ctx.gadt.bounds.contains(sym))
173-
sym.info & ctx.gadt.bounds(sym)
173+
sym.info & ctx.gadt.get(sym)
174174
else
175175
sym.info
176176
s"is a ${ctx.printer.kindString(sym)}${sym.showExtendedLocation}${addendum("bounds", info)}"
@@ -189,7 +189,7 @@ object Formatting {
189189
case param: TypeParamRef => ctx.typerState.constraint.contains(param)
190190
case skolem: SkolemType => true
191191
case sym: Symbol =>
192-
ctx.gadt.bounds.contains(sym) && ctx.gadt.bounds(sym) != TypeBounds.empty
192+
ctx.gadt.bounds.contains(sym) && ctx.gadt.get(sym) != TypeBounds.empty
193193
case _ =>
194194
assert(false, "unreachable")
195195
false

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,10 @@ trait NamerContextOps { this: Context =>
103103

104104
/** A new context for the interior of a class */
105105
def inClassContext(selfInfo: DotClass /* Should be Type | Symbol*/): Context = {
106-
val localCtx: Context = ctx.fresh.setNewScope
106+
val localCtx: Context = ctx.fresh.setNewScope.setFreshGADTBounds
107+
108+
localCtx.owner.typeParams.foreach (tparam => localCtx.gadt.setBounds(tparam, tparam.typeRef))
109+
107110
selfInfo match {
108111
case sym: Symbol if sym.exists && sym.name != nme.WILDCARD => localCtx.scope.openForMutations.enter(sym)
109112
case _ =>

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ class Typer extends Namer
10231023
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(sym)
10241024
else ctx.error(new DuplicateBind(b, tree), b.pos)
10251025
if (!ctx.isAfterTyper) {
1026-
val bounds = ctx.gadt.bounds(sym)
1026+
val bounds = ctx.gadt.get(sym)
10271027
if (bounds != null) sym.info = bounds
10281028
}
10291029
b

tests/pos/i4323.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
enum Expr[T] {
2+
case IExpr(value: Int) extends Expr[Int]
3+
case BExpr(value: Boolean) extends Expr[Boolean]
4+
5+
def join(other: Expr[T]): Expr[T] = (this, other) match {
6+
case (IExpr(i1), IExpr(i2)) => IExpr(i1 + i2)
7+
case (BExpr(b1), BExpr(b2)) => BExpr(b1 & b2)
8+
}
9+
}

tests/pos/i4323b.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
sealed trait Expr[T] {
2+
import Expr._
3+
4+
def join(other: Expr[T]): Expr[T] = (this, other) match {
5+
case (IExpr(i1), IExpr(i2)) => IExpr(i1 + i2)
6+
case (BExpr(b1), BExpr(b2)) => BExpr(b1 & b2)
7+
}
8+
}
9+
10+
object Expr {
11+
case class IExpr(value: Int) extends Expr[Int]
12+
case class BExpr(value: Boolean) extends Expr[Boolean]
13+
}

tests/pos/i4323c.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
sealed trait Expr[T] { outer =>
2+
class Inner {
3+
def join(other: Expr[T]): Expr[T] = (outer, other) match {
4+
case (IExpr(i1), IExpr(i2)) => IExpr(i1 + i2)
5+
case (BExpr(b1), BExpr(b2)) => BExpr(b1 & b2)
6+
}
7+
}
8+
}
9+
10+
case class IExpr(value: Int) extends Expr[Int]
11+
case class BExpr(value: Boolean) extends Expr[Boolean]

0 commit comments

Comments
 (0)