Skip to content

Commit 16554c0

Browse files
committed
Fixed #264 - failure to typecheck GADTs
The previous scheme derived the right bounds, but then failed to use them because a TypeRef already has a set info (its bounds). Changing the bounds in the symbol by a side effect does not affect that. This is good! But it showed that the previous scheme was too fragile because it used a sneaky side effect when updating the symbol info which failed to propgate into the cached info in TypeRef. We now keep GADT computed bounds separate form the symbol info in a map `gadt` in the current context.
1 parent 21fa5dd commit 16554c0

File tree

6 files changed

+62
-45
lines changed

6 files changed

+62
-45
lines changed

src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ object Contexts {
146146
protected def diagnostics_=(diagnostics: Option[StringBuilder]) = _diagnostics = diagnostics
147147
def diagnostics: Option[StringBuilder] = _diagnostics
148148

149+
/** The current bounds in force for type parameters appearing in a GADT */
150+
private var _gadt: GADTMap = _
151+
protected def gadt_=(gadt: GADTMap) = _gadt = gadt
152+
def gadt: GADTMap = _gadt
153+
149154
/** A map in which more contextual properties can be stored */
150155
private var _moreProperties: Map[String, Any] = _
151156
protected def moreProperties_=(moreProperties: Map[String, Any]) = _moreProperties = moreProperties
@@ -418,6 +423,8 @@ object Contexts {
418423
def setSetting[T](setting: Setting[T], value: T): this.type =
419424
setSettings(setting.updateIn(sstate, value))
420425

426+
def setFreshGADTBounds: this.type = { this.gadt = new GADTMap(gadt.bounds); this }
427+
421428
def setDebug = setSetting(base.settings.debug, true)
422429
}
423430

@@ -439,6 +446,7 @@ object Contexts {
439446
moreProperties = Map.empty
440447
typeComparer = new TypeComparer(this)
441448
searchHistory = new SearchHistory(0, Map())
449+
gadt = new GADTMap(SimpleMap.Empty)
442450
}
443451

444452
object NoContext extends Context {
@@ -593,6 +601,8 @@ object Contexts {
593601
implicit val ctx: Context = initctx
594602
}
595603

604+
class GADTMap(var bounds: SimpleMap[Symbol, TypeBounds])
605+
596606
/** Initial size of superId table */
597607
private final val InitialSuperIdsSize = 4096
598608

src/dotty/tools/dotc/core/Flags.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,6 @@ object Flags {
301301
/** Method is assumed to be stable */
302302
final val Stable = termFlag(24, "<stable>")
303303

304-
/** Info can be refined during GADT pattern match */
305-
final val GADTFlexType = typeFlag(24, "<gadt-flex-type>")
306-
307304
/** A case parameter accessor */
308305
final val CaseAccessor = termFlag(25, "<caseaccessor>")
309306

@@ -553,7 +550,7 @@ object Flags {
553550

554551
/** A Java interface, potentially with default methods */
555552
final val JavaTrait = allOf(JavaDefined, Trait, NoInits)
556-
553+
557554
/** A Java interface */
558555
final val JavaInterface = allOf(JavaDefined, Trait)
559556

src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -450,31 +450,6 @@ object Symbols {
450450
*/
451451
def pos: Position = if (coord.isPosition) coord.toPosition else NoPosition
452452

453-
// -------- GADT handling -----------------------------------------------
454-
455-
/** Perform given operation `op` where this symbol allows tightening of
456-
* its type bounds.
457-
*/
458-
private[dotc] def withGADTFlexType[T](op: () => T)(implicit ctx: Context): () => T = { () =>
459-
assert((denot is TypeParam) && denot.owner.isTerm)
460-
val saved = denot
461-
denot = denot.copySymDenotation(initFlags = denot.flags | GADTFlexType)
462-
try op()
463-
finally denot = saved
464-
}
465-
466-
/** Disallow tightening of type bounds for this symbol from now on */
467-
private[dotc] def resetGADTFlexType()(implicit ctx: Context): Unit = {
468-
assert(denot is GADTFlexType)
469-
denot = denot.copySymDenotation(initFlags = denot.flags &~ GADTFlexType)
470-
}
471-
472-
/** Change info of this symbol to new, tightened type bounds */
473-
private[core] def changeGADTInfo(bounds: TypeBounds)(implicit ctx: Context): Unit = {
474-
assert(denot is GADTFlexType)
475-
denot = denot.copySymDenotation(info = bounds)
476-
}
477-
478453
// -------- Printing --------------------------------------------------------
479454

480455
/** The prefix string to be used when displaying this symbol without denotation */

src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -617,9 +617,12 @@ class TypeComparer(initctx: Context) extends DotClass {
617617
secondTry(OrType.make(derivedRef(tp11), derivedRef(tp12)), tp2)
618618
*/
619619
case TypeBounds(lo1, hi1) =>
620-
if ((ctx.mode is Mode.GADTflexible) && (tp1.symbol is GADTFlexType) &&
621-
!isSubTypeWhenFrozen(hi1, tp2))
622-
trySetType(tp1, TypeBounds(lo1, hi1 & tp2))
620+
val gbounds1 = ctx.gadt.bounds(tp1.symbol)
621+
if (gbounds1 != null)
622+
isSubTypeWhenFrozen(gbounds1.hi, tp2) ||
623+
(ctx.mode is Mode.GADTflexible) &&
624+
narrowGADTBounds(tp1, TypeBounds(gbounds1.lo, gbounds1.hi & tp2)) ||
625+
tryRebase2nd
623626
else if (lo1 eq hi1) isSubType(hi1, tp2)
624627
else tryRebase2nd
625628
case _ =>
@@ -636,9 +639,12 @@ class TypeComparer(initctx: Context) extends DotClass {
636639
}
637640
def compareNamed: Boolean = tp2.info match {
638641
case TypeBounds(lo2, hi2) =>
639-
if ((ctx.mode is Mode.GADTflexible) && (tp2.symbol is GADTFlexType) &&
640-
!isSubTypeWhenFrozen(tp1, lo2))
641-
trySetType(tp2, TypeBounds(lo2 | tp1, hi2))
642+
val gbounds2 = ctx.gadt.bounds(tp2.symbol)
643+
if (gbounds2 != null)
644+
isSubTypeWhenFrozen(tp1, gbounds2.lo) ||
645+
(ctx.mode is Mode.GADTflexible) &&
646+
narrowGADTBounds(tp2, TypeBounds(gbounds2.lo | tp1, gbounds2.hi)) ||
647+
tryRebase3rd
642648
else
643649
((frozenConstraint || !isCappable(tp1)) && isSubType(tp1, lo2)
644650
|| tryRebase3rd)
@@ -911,9 +917,9 @@ class TypeComparer(initctx: Context) extends DotClass {
911917
tp.exists && !tp.isLambda
912918
}
913919

914-
def trySetType(tr: NamedType, bounds: TypeBounds): Boolean =
920+
def narrowGADTBounds(tr: NamedType, bounds: TypeBounds): Boolean =
915921
isSubType(bounds.lo, bounds.hi) &&
916-
{ tr.symbol.changeGADTInfo(bounds); true }
922+
{ ctx.gadt.bounds = ctx.gadt.bounds.updated(tr.symbol, bounds); true }
917923

918924
// Tests around `matches`
919925

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -628,10 +628,10 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
628628
def typedCases(cases: List[untpd.CaseDef], selType: Type, pt: Type)(implicit ctx: Context) = {
629629

630630
/** gadtSyms = "all type parameters of enclosing methods that appear
631-
* non-variantly in the selector type" todo: should typevars
632-
* which appear with variances +1 and -1 (in different
633-
* places) be considered as well?
634-
*/
631+
* non-variantly in the selector type" todo: should typevars
632+
* which appear with variances +1 and -1 (in different
633+
* places) be considered as well?
634+
*/
635635
val gadtSyms: Set[Symbol] = ctx.traceIndented(i"GADT syms of $selType", gadts) {
636636
val accu = new TypeAccumulator[Set[Symbol]] {
637637
def apply(tsyms: Set[Symbol], t: Type): Set[Symbol] = {
@@ -650,9 +650,13 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
650650
cases mapconserve (typedCase(_, pt, selType, gadtSyms))
651651
}
652652

653+
/** Type a case. Overridden in ReTyper, that's why it's separate from
654+
* typedCases.
655+
*/
653656
def typedCase(tree: untpd.CaseDef, pt: Type, selType: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = track("typedCase") {
657+
val originalCtx = ctx
658+
654659
def caseRest(pat: Tree)(implicit ctx: Context) = {
655-
gadtSyms foreach (_.resetGADTFlexType)
656660
pat foreachSubTree {
657661
case b: Bind =>
658662
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(b.symbol)
@@ -661,11 +665,21 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
661665
}
662666
val guard1 = typedExpr(tree.guard, defn.BooleanType)
663667
val body1 = typedExpr(tree.body, pt)
668+
.ensureConforms(pt)(originalCtx) // insert a cast if body does not conform to expected type if we disregard gadt bounds
664669
assignType(cpy.CaseDef(tree)(pat, guard1, body1), body1)
665670
}
666-
val doCase: () => CaseDef =
667-
() => caseRest(typedPattern(tree.pat, selType))(ctx.fresh.setNewScope)
668-
(doCase /: gadtSyms)((op, tsym) => tsym.withGADTFlexType(op))()
671+
672+
val gadtCtx =
673+
if (gadtSyms.isEmpty) ctx
674+
else {
675+
val c = ctx.fresh.setFreshGADTBounds
676+
for (sym <- gadtSyms)
677+
if (!c.gadt.bounds.contains(sym))
678+
c.gadt.bounds = c.gadt.bounds.updated(sym, TypeBounds.empty)
679+
c
680+
}
681+
val pat1 = typedPattern(tree.pat, selType)(gadtCtx)
682+
caseRest(pat1)(gadtCtx.fresh.setNewScope)
669683
}
670684

671685
def typedReturn(tree: untpd.Return)(implicit ctx: Context): Return = track("typedReturn") {

tests/pos/i0268.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package typespatmat
2+
3+
sealed trait Box2[T]
4+
final case class Int2(x: Int) extends Box2[Int]
5+
final case class Str2(x: String)
6+
extends Box2[String]
7+
final case class Gen[T](x: T) extends Box2[T]
8+
9+
object Box2 {
10+
def double2[T](x: Box2[T]): T = x match {
11+
case Int2(i) => i * 2
12+
case Str2(s) => s + s
13+
case Gen(x) => x
14+
}
15+
}

0 commit comments

Comments
 (0)