Skip to content

Commit 490cc5f

Browse files
committed
Fix 1365: Fix bindings in patterns
We need to compare pattern types with expected types in order to derive knowledge about pattern-bound variables. This is done use the mechanism of gadt bounds.
1 parent 1719178 commit 490cc5f

File tree

3 files changed

+84
-22
lines changed

3 files changed

+84
-22
lines changed

src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -813,14 +813,24 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
813813
private def narrowGADTBounds(tr: NamedType, bound: Type, isUpper: Boolean): Boolean =
814814
ctx.mode.is(Mode.GADTflexible) && {
815815
val tparam = tr.symbol
816-
typr.println(s"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.isRef(tparam)}")
817-
!bound.isRef(tparam) && {
818-
val oldBounds = ctx.gadt.bounds(tparam)
819-
val newBounds =
820-
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
821-
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
822-
isSubType(newBounds.lo, newBounds.hi) &&
823-
{ ctx.gadt.setBounds(tparam, newBounds); true }
816+
typr.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.isRef(tparam)}")
817+
if (bound.isRef(tparam)) false
818+
else bound match {
819+
case bound: TypeRef
820+
if bound.symbol.is(BindDefinedType) && ctx.gadt.bounds.contains(bound.symbol) &&
821+
!tr.symbol.is(BindDefinedType) =>
822+
// Avoid having pattern-bound types in gadt bounds,
823+
// as these might be eliminated once the pattern is typechecked.
824+
// Pattern-bound type symbols should be narrowed first, only if that fails
825+
// should symbols in the environment be constrained.
826+
narrowGADTBounds(bound, tr, !isUpper)
827+
case _ =>
828+
val oldBounds = ctx.gadt.bounds(tparam)
829+
val newBounds =
830+
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
831+
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
832+
isSubType(newBounds.lo, newBounds.hi) &&
833+
{ ctx.gadt.setBounds(tparam, newBounds); true }
824834
}
825835
}
826836

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -448,11 +448,12 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
448448
return typed(untpd.Apply(untpd.TypedSplice(arg), tree.expr), pt)
449449
case _ =>
450450
}
451-
case tref: TypeRef if tref.symbol.isClass && !ctx.isAfterTyper =>
452-
val setBefore = ctx.mode is Mode.GADTflexible
453-
tpt1.tpe.<:<(pt)(ctx.addMode(Mode.GADTflexible))
454-
if (!setBefore) ctx.retractMode(Mode.GADTflexible)
455451
case _ =>
452+
if (!ctx.isAfterTyper) {
453+
val setBefore = ctx.mode is Mode.GADTflexible
454+
tpt1.tpe.<:<(pt)(ctx.addMode(Mode.GADTflexible))
455+
if (!setBefore) ctx.retractMode(Mode.GADTflexible)
456+
}
456457
}
457458
ascription(tpt1, isWildcard = true)
458459
}
@@ -762,17 +763,37 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
762763
def typedCase(tree: untpd.CaseDef, pt: Type, selType: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = track("typedCase") {
763764
val originalCtx = ctx
764765

765-
def caseRest(pat: Tree)(implicit ctx: Context) = {
766-
pat foreachSubTree {
767-
case b: Bind =>
768-
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(b.symbol)
769-
else ctx.error(d"duplicate pattern variable: ${b.name}", b.pos)
770-
case _ =>
766+
/** - replace all references to symbols associated with wildcards by their GADT bounds
767+
* - enter all symbols introduced by a Bind in current scope
768+
*/
769+
val indexPattern = new TreeMap {
770+
val elimWildcardSym = new TypeMap {
771+
def apply(t: Type) = t match {
772+
case ref @ TypeRef(_, tpnme.WILDCARD) if ctx.gadt.bounds.contains(ref.symbol) =>
773+
ctx.gadt.bounds(ref.symbol)
774+
case TypeAlias(ref @ TypeRef(_, tpnme.WILDCARD)) if ctx.gadt.bounds.contains(ref.symbol) =>
775+
ctx.gadt.bounds(ref.symbol)
776+
case _ =>
777+
mapOver(t)
778+
}
771779
}
780+
override def transform(tree: Tree)(implicit ctx: Context) =
781+
super.transform(tree.withType(elimWildcardSym(tree.tpe))) match {
782+
case b: Bind =>
783+
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(b.symbol)
784+
else ctx.error(d"duplicate pattern variable: ${b.name}", b.pos)
785+
b.symbol.info = elimWildcardSym(b.symbol.info)
786+
b
787+
case t => t
788+
}
789+
}
790+
791+
def caseRest(pat: Tree)(implicit ctx: Context) = {
792+
val pat1 = indexPattern.transform(pat)
772793
val guard1 = typedExpr(tree.guard, defn.BooleanType)
773794
val body1 = ensureNoLocalRefs(typedExpr(tree.body, pt), pt, ctx.scope.toList)
774795
.ensureConforms(pt)(originalCtx) // insert a cast if body does not conform to expected type if we disregard gadt bounds
775-
assignType(cpy.CaseDef(tree)(pat, guard1, body1), body1)
796+
assignType(cpy.CaseDef(tree)(pat1, guard1, body1), body1)
776797
}
777798

778799
val gadtCtx =
@@ -963,11 +984,30 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
963984
assignType(cpy.ByNameTypeTree(tree)(result1), result1)
964985
}
965986

987+
/** Define a new symbol associated with a Bind or pattern wildcard and
988+
* make it gadt narrowable.
989+
*/
990+
private def newPatternBoundSym(name: Name, info: Type, pos: Position)(implicit ctx: Context) = {
991+
val flags = if (name.isTypeName) BindDefinedType else EmptyFlags
992+
val sym = ctx.newSymbol(ctx.owner, name, flags | Case, info, coord = pos)
993+
if (name.isTypeName) ctx.gadt.setBounds(sym, info.bounds)
994+
sym
995+
}
996+
966997
def typedTypeBoundsTree(tree: untpd.TypeBoundsTree)(implicit ctx: Context): TypeBoundsTree = track("typedTypeBoundsTree") {
967998
val TypeBoundsTree(lo, hi) = desugar.typeBoundsTree(tree)
968999
val lo1 = typed(lo)
9691000
val hi1 = typed(hi)
970-
assignType(cpy.TypeBoundsTree(tree)(lo1, hi1), lo1, hi1)
1001+
val tree1 = assignType(cpy.TypeBoundsTree(tree)(lo1, hi1), lo1, hi1)
1002+
if (ctx.mode.is(Mode.Pattern)) {
1003+
// Associate a pattern-bound type symbol with the wildcard.
1004+
// The bounds of the type symbol can be constrained when comparing a pattern type
1005+
// with an expected type in typedTyped. The type symbol is eliminated once
1006+
// the enclosing pattern has been typechecked; see `indexPattern` in `typedCase`.
1007+
val wildcardSym = newPatternBoundSym(tpnme.WILDCARD, tree1.tpe, tree.pos)
1008+
tree1.withType(wildcardSym.typeRef)
1009+
}
1010+
else tree1
9711011
}
9721012

9731013
def typedBind(tree: untpd.Bind, pt: Type)(implicit ctx: Context): Tree = track("typedBind") {
@@ -983,8 +1023,7 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
9831023
tpd.cpy.UnApply(body1)(fn, Nil,
9841024
typed(untpd.Bind(tree.name, arg).withPos(tree.pos), arg.tpe) :: Nil)
9851025
case _ =>
986-
val flags = if (tree.isType) BindDefinedType else EmptyFlags
987-
val sym = ctx.newSymbol(ctx.owner, tree.name, flags | Case, body1.tpe, coord = tree.pos)
1026+
val sym = newPatternBoundSym(tree.name, body1.tpe, tree.pos)
9881027
assignType(cpy.Bind(tree)(tree.name, body1), sym)
9891028
}
9901029
}

tests/pos/i1365.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.collection.mutable.ArrayBuffer
2+
3+
trait Message[M]
4+
class Script[S] extends ArrayBuffer[Message[S]] with Message[S]
5+
6+
class Test[A] {
7+
def f(cmd: Message[A]): Unit = cmd match {
8+
case s: Script[_] => s.iterator.foreach(x => f(x))
9+
}
10+
def g(cmd: Message[A]): Unit = cmd match {
11+
case s: Script[z] => s.iterator.foreach(x => g(x))
12+
}
13+
}

0 commit comments

Comments
 (0)