Skip to content

Commit 13ec91b

Browse files
committed
Merge pull request #271 from dotty-staging/fix/i268-gadts
Fixed #264 - failure to typecheck GADTs
2 parents f3c0eaa + 009a1e6 commit 13ec91b

File tree

6 files changed

+67
-45
lines changed

6 files changed

+67
-45
lines changed

src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ object Contexts {
146146
protected def diagnostics_=(diagnostics: Option[StringBuilder]) = _diagnostics = diagnostics
147147
def diagnostics: Option[StringBuilder] = _diagnostics
148148

149+
/** The current bounds in force for type parameters appearing in a GADT */
150+
private var _gadt: GADTMap = _
151+
protected def gadt_=(gadt: GADTMap) = _gadt = gadt
152+
def gadt: GADTMap = _gadt
153+
149154
/** A map in which more contextual properties can be stored */
150155
private var _moreProperties: Map[String, Any] = _
151156
protected def moreProperties_=(moreProperties: Map[String, Any]) = _moreProperties = moreProperties
@@ -418,6 +423,8 @@ object Contexts {
418423
def setSetting[T](setting: Setting[T], value: T): this.type =
419424
setSettings(setting.updateIn(sstate, value))
420425

426+
def setFreshGADTBounds: this.type = { this.gadt = new GADTMap(gadt.bounds); this }
427+
421428
def setDebug = setSetting(base.settings.debug, true)
422429
}
423430

@@ -439,6 +446,7 @@ object Contexts {
439446
moreProperties = Map.empty
440447
typeComparer = new TypeComparer(this)
441448
searchHistory = new SearchHistory(0, Map())
449+
gadt = new GADTMap(SimpleMap.Empty)
442450
}
443451

444452
object NoContext extends Context {
@@ -593,6 +601,13 @@ object Contexts {
593601
implicit val ctx: Context = initctx
594602
}
595603

604+
class GADTMap(initBounds: SimpleMap[Symbol, TypeBounds]) {
605+
private var myBounds = initBounds
606+
def setBounds(sym: Symbol, b: TypeBounds): Unit =
607+
myBounds = myBounds.updated(sym, b)
608+
def bounds = myBounds
609+
}
610+
596611
/** Initial size of superId table */
597612
private final val InitialSuperIdsSize = 4096
598613

src/dotty/tools/dotc/core/Flags.scala

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,9 +301,6 @@ object Flags {
301301
/** Method is assumed to be stable */
302302
final val Stable = termFlag(24, "<stable>")
303303

304-
/** Info can be refined during GADT pattern match */
305-
final val GADTFlexType = typeFlag(24, "<gadt-flex-type>")
306-
307304
/** A case parameter accessor */
308305
final val CaseAccessor = termFlag(25, "<caseaccessor>")
309306

@@ -553,7 +550,7 @@ object Flags {
553550

554551
/** A Java interface, potentially with default methods */
555552
final val JavaTrait = allOf(JavaDefined, Trait, NoInits)
556-
553+
557554
/** A Java interface */
558555
final val JavaInterface = allOf(JavaDefined, Trait)
559556

src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -450,31 +450,6 @@ object Symbols {
450450
*/
451451
def pos: Position = if (coord.isPosition) coord.toPosition else NoPosition
452452

453-
// -------- GADT handling -----------------------------------------------
454-
455-
/** Perform given operation `op` where this symbol allows tightening of
456-
* its type bounds.
457-
*/
458-
private[dotc] def withGADTFlexType[T](op: () => T)(implicit ctx: Context): () => T = { () =>
459-
assert((denot is TypeParam) && denot.owner.isTerm)
460-
val saved = denot
461-
denot = denot.copySymDenotation(initFlags = denot.flags | GADTFlexType)
462-
try op()
463-
finally denot = saved
464-
}
465-
466-
/** Disallow tightening of type bounds for this symbol from now on */
467-
private[dotc] def resetGADTFlexType()(implicit ctx: Context): Unit = {
468-
assert(denot is GADTFlexType)
469-
denot = denot.copySymDenotation(initFlags = denot.flags &~ GADTFlexType)
470-
}
471-
472-
/** Change info of this symbol to new, tightened type bounds */
473-
private[core] def changeGADTInfo(bounds: TypeBounds)(implicit ctx: Context): Unit = {
474-
assert(denot is GADTFlexType)
475-
denot = denot.copySymDenotation(info = bounds)
476-
}
477-
478453
// -------- Printing --------------------------------------------------------
479454

480455
/** The prefix string to be used when displaying this symbol without denotation */

src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -618,9 +618,12 @@ class TypeComparer(initctx: Context) extends DotClass {
618618
secondTry(OrType.make(derivedRef(tp11), derivedRef(tp12)), tp2)
619619
*/
620620
case TypeBounds(lo1, hi1) =>
621-
if ((ctx.mode is Mode.GADTflexible) && (tp1.symbol is GADTFlexType) &&
622-
!isSubTypeWhenFrozen(hi1, tp2))
623-
trySetType(tp1, TypeBounds(lo1, hi1 & tp2))
621+
val gbounds1 = ctx.gadt.bounds(tp1.symbol)
622+
if (gbounds1 != null)
623+
isSubTypeWhenFrozen(gbounds1.hi, tp2) ||
624+
(ctx.mode is Mode.GADTflexible) &&
625+
narrowGADTBounds(tp1, TypeBounds(gbounds1.lo, gbounds1.hi & tp2)) ||
626+
tryRebase2nd
624627
else if (lo1 eq hi1) isSubType(hi1, tp2)
625628
else tryRebase2nd
626629
case _ =>
@@ -637,9 +640,12 @@ class TypeComparer(initctx: Context) extends DotClass {
637640
}
638641
def compareNamed: Boolean = tp2.info match {
639642
case TypeBounds(lo2, hi2) =>
640-
if ((ctx.mode is Mode.GADTflexible) && (tp2.symbol is GADTFlexType) &&
641-
!isSubTypeWhenFrozen(tp1, lo2))
642-
trySetType(tp2, TypeBounds(lo2 | tp1, hi2))
643+
val gbounds2 = ctx.gadt.bounds(tp2.symbol)
644+
if (gbounds2 != null)
645+
isSubTypeWhenFrozen(tp1, gbounds2.lo) ||
646+
(ctx.mode is Mode.GADTflexible) &&
647+
narrowGADTBounds(tp2, TypeBounds(gbounds2.lo | tp1, gbounds2.hi)) ||
648+
tryRebase3rd
643649
else
644650
((frozenConstraint || !isCappable(tp1)) && isSubType(tp1, lo2)
645651
|| tryRebase3rd)
@@ -912,9 +918,9 @@ class TypeComparer(initctx: Context) extends DotClass {
912918
tp.exists && !tp.isLambda
913919
}
914920

915-
def trySetType(tr: NamedType, bounds: TypeBounds): Boolean =
921+
def narrowGADTBounds(tr: NamedType, bounds: TypeBounds): Boolean =
916922
isSubType(bounds.lo, bounds.hi) &&
917-
{ tr.symbol.changeGADTInfo(bounds); true }
923+
{ ctx.gadt.setBounds(tr.symbol, bounds); true }
918924

919925
// Tests around `matches`
920926

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -628,10 +628,10 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
628628
def typedCases(cases: List[untpd.CaseDef], selType: Type, pt: Type)(implicit ctx: Context) = {
629629

630630
/** gadtSyms = "all type parameters of enclosing methods that appear
631-
* non-variantly in the selector type" todo: should typevars
632-
* which appear with variances +1 and -1 (in different
633-
* places) be considered as well?
634-
*/
631+
* non-variantly in the selector type" todo: should typevars
632+
* which appear with variances +1 and -1 (in different
633+
* places) be considered as well?
634+
*/
635635
val gadtSyms: Set[Symbol] = ctx.traceIndented(i"GADT syms of $selType", gadts) {
636636
val accu = new TypeAccumulator[Set[Symbol]] {
637637
def apply(tsyms: Set[Symbol], t: Type): Set[Symbol] = {
@@ -650,9 +650,13 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
650650
cases mapconserve (typedCase(_, pt, selType, gadtSyms))
651651
}
652652

653+
/** Type a case. Overridden in ReTyper, that's why it's separate from
654+
* typedCases.
655+
*/
653656
def typedCase(tree: untpd.CaseDef, pt: Type, selType: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = track("typedCase") {
657+
val originalCtx = ctx
658+
654659
def caseRest(pat: Tree)(implicit ctx: Context) = {
655-
gadtSyms foreach (_.resetGADTFlexType)
656660
pat foreachSubTree {
657661
case b: Bind =>
658662
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(b.symbol)
@@ -661,11 +665,21 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
661665
}
662666
val guard1 = typedExpr(tree.guard, defn.BooleanType)
663667
val body1 = typedExpr(tree.body, pt)
668+
.ensureConforms(pt)(originalCtx) // insert a cast if body does not conform to expected type if we disregard gadt bounds
664669
assignType(cpy.CaseDef(tree)(pat, guard1, body1), body1)
665670
}
666-
val doCase: () => CaseDef =
667-
() => caseRest(typedPattern(tree.pat, selType))(ctx.fresh.setNewScope)
668-
(doCase /: gadtSyms)((op, tsym) => tsym.withGADTFlexType(op))()
671+
672+
val gadtCtx =
673+
if (gadtSyms.isEmpty) ctx
674+
else {
675+
val c = ctx.fresh.setFreshGADTBounds
676+
for (sym <- gadtSyms)
677+
if (!c.gadt.bounds.contains(sym))
678+
c.gadt.setBounds(sym, TypeBounds.empty)
679+
c
680+
}
681+
val pat1 = typedPattern(tree.pat, selType)(gadtCtx)
682+
caseRest(pat1)(gadtCtx.fresh.setNewScope)
669683
}
670684

671685
def typedReturn(tree: untpd.Return)(implicit ctx: Context): Return = track("typedReturn") {

tests/pos/i0268.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package typespatmat
2+
3+
sealed trait Box2[T]
4+
final case class Int2(x: Int) extends Box2[Int]
5+
final case class Str2(x: String)
6+
extends Box2[String]
7+
final case class Gen[T](x: T) extends Box2[T]
8+
9+
object Box2 {
10+
def double2[T](x: Box2[T]): T = x match {
11+
case Int2(i) => i * 2
12+
case Str2(s) => s + s
13+
case Gen(x) => x
14+
}
15+
}

0 commit comments

Comments
 (0)