Skip to content

Commit dfa3280

Browse files
authored
Merge pull request #1377 from dotty-staging/#1365
Fix 1365: Fix bindings in patterns
2 parents f37e45a + 9bde23a commit dfa3280

File tree

3 files changed

+80
-22
lines changed

3 files changed

+80
-22
lines changed

src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -936,14 +936,24 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
936936
private def narrowGADTBounds(tr: NamedType, bound: Type, isUpper: Boolean): Boolean =
937937
ctx.mode.is(Mode.GADTflexible) && {
938938
val tparam = tr.symbol
939-
typr.println(s"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.isRef(tparam)}")
940-
!bound.isRef(tparam) && {
941-
val oldBounds = ctx.gadt.bounds(tparam)
942-
val newBounds =
943-
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
944-
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
945-
isSubType(newBounds.lo, newBounds.hi) &&
946-
{ ctx.gadt.setBounds(tparam, newBounds); true }
939+
typr.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.isRef(tparam)}")
940+
if (bound.isRef(tparam)) false
941+
else bound match {
942+
case bound: TypeRef
943+
if bound.symbol.is(BindDefinedType) && ctx.gadt.bounds.contains(bound.symbol) &&
944+
!tr.symbol.is(BindDefinedType) =>
945+
// Avoid having pattern-bound types in gadt bounds,
946+
// as these might be eliminated once the pattern is typechecked.
947+
// Pattern-bound type symbols should be narrowed first, only if that fails
948+
// should symbols in the environment be constrained.
949+
narrowGADTBounds(bound, tr, !isUpper)
950+
case _ =>
951+
val oldBounds = ctx.gadt.bounds(tparam)
952+
val newBounds =
953+
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
954+
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
955+
isSubType(newBounds.lo, newBounds.hi) &&
956+
{ ctx.gadt.setBounds(tparam, newBounds); true }
947957
}
948958
}
949959

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -454,11 +454,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
454454
return typed(untpd.Apply(untpd.TypedSplice(arg), tree.expr), pt)
455455
case _ =>
456456
}
457-
case tref: TypeRef if tref.symbol.isClass && !ctx.isAfterTyper =>
458-
val setBefore = ctx.mode is Mode.GADTflexible
459-
tpt1.tpe.<:<(pt)(ctx.addMode(Mode.GADTflexible))
460-
if (!setBefore) ctx.retractMode(Mode.GADTflexible)
461457
case _ =>
458+
if (!ctx.isAfterTyper) tpt1.tpe.<:<(pt)(ctx.addMode(Mode.GADTflexible))
462459
}
463460
ascription(tpt1, isWildcard = true)
464461
}
@@ -774,17 +771,37 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
774771
def typedCase(tree: untpd.CaseDef, pt: Type, selType: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = track("typedCase") {
775772
val originalCtx = ctx
776773

777-
def caseRest(pat: Tree)(implicit ctx: Context) = {
778-
pat foreachSubTree {
779-
case b: Bind =>
780-
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(b.symbol)
781-
else ctx.error(d"duplicate pattern variable: ${b.name}", b.pos)
782-
case _ =>
774+
/** - replace all references to symbols associated with wildcards by their GADT bounds
775+
* - enter all symbols introduced by a Bind in current scope
776+
*/
777+
val indexPattern = new TreeMap {
778+
val elimWildcardSym = new TypeMap {
779+
def apply(t: Type) = t match {
780+
case ref @ TypeRef(_, tpnme.WILDCARD) if ctx.gadt.bounds.contains(ref.symbol) =>
781+
ctx.gadt.bounds(ref.symbol)
782+
case TypeAlias(ref @ TypeRef(_, tpnme.WILDCARD)) if ctx.gadt.bounds.contains(ref.symbol) =>
783+
ctx.gadt.bounds(ref.symbol)
784+
case _ =>
785+
mapOver(t)
786+
}
783787
}
788+
override def transform(tree: Tree)(implicit ctx: Context) =
789+
super.transform(tree.withType(elimWildcardSym(tree.tpe))) match {
790+
case b: Bind =>
791+
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(b.symbol)
792+
else ctx.error(d"duplicate pattern variable: ${b.name}", b.pos)
793+
b.symbol.info = elimWildcardSym(b.symbol.info)
794+
b
795+
case t => t
796+
}
797+
}
798+
799+
def caseRest(pat: Tree)(implicit ctx: Context) = {
800+
val pat1 = indexPattern.transform(pat)
784801
val guard1 = typedExpr(tree.guard, defn.BooleanType)
785802
val body1 = ensureNoLocalRefs(typedExpr(tree.body, pt), pt, ctx.scope.toList)
786803
.ensureConforms(pt)(originalCtx) // insert a cast if body does not conform to expected type if we disregard gadt bounds
787-
assignType(cpy.CaseDef(tree)(pat, guard1, body1), body1)
804+
assignType(cpy.CaseDef(tree)(pat1, guard1, body1), body1)
788805
}
789806

790807
val gadtCtx =
@@ -983,11 +1000,30 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
9831000
assignType(cpy.ByNameTypeTree(tree)(result1), result1)
9841001
}
9851002

1003+
/** Define a new symbol associated with a Bind or pattern wildcard and
1004+
* make it gadt narrowable.
1005+
*/
1006+
private def newPatternBoundSym(name: Name, info: Type, pos: Position)(implicit ctx: Context) = {
1007+
val flags = if (name.isTypeName) BindDefinedType else EmptyFlags
1008+
val sym = ctx.newSymbol(ctx.owner, name, flags | Case, info, coord = pos)
1009+
if (name.isTypeName) ctx.gadt.setBounds(sym, info.bounds)
1010+
sym
1011+
}
1012+
9861013
def typedTypeBoundsTree(tree: untpd.TypeBoundsTree)(implicit ctx: Context): TypeBoundsTree = track("typedTypeBoundsTree") {
9871014
val TypeBoundsTree(lo, hi) = desugar.typeBoundsTree(tree)
9881015
val lo1 = typed(lo)
9891016
val hi1 = typed(hi)
990-
assignType(cpy.TypeBoundsTree(tree)(lo1, hi1), lo1, hi1)
1017+
val tree1 = assignType(cpy.TypeBoundsTree(tree)(lo1, hi1), lo1, hi1)
1018+
if (ctx.mode.is(Mode.Pattern)) {
1019+
// Associate a pattern-bound type symbol with the wildcard.
1020+
// The bounds of the type symbol can be constrained when comparing a pattern type
1021+
// with an expected type in typedTyped. The type symbol is eliminated once
1022+
// the enclosing pattern has been typechecked; see `indexPattern` in `typedCase`.
1023+
val wildcardSym = newPatternBoundSym(tpnme.WILDCARD, tree1.tpe, tree.pos)
1024+
tree1.withType(wildcardSym.typeRef)
1025+
}
1026+
else tree1
9911027
}
9921028

9931029
def typedBind(tree: untpd.Bind, pt: Type)(implicit ctx: Context): Tree = track("typedBind") {
@@ -1003,8 +1039,7 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
10031039
tpd.cpy.UnApply(body1)(fn, Nil,
10041040
typed(untpd.Bind(tree.name, arg).withPos(tree.pos), arg.tpe) :: Nil)
10051041
case _ =>
1006-
val flags = if (tree.isType) BindDefinedType else EmptyFlags
1007-
val sym = ctx.newSymbol(ctx.owner, tree.name, flags | Case, body1.tpe, coord = tree.pos)
1042+
val sym = newPatternBoundSym(tree.name, body1.tpe, tree.pos)
10081043
assignType(cpy.Bind(tree)(tree.name, body1), sym)
10091044
}
10101045
}

tests/pos/i1365.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.collection.mutable.ArrayBuffer
2+
3+
trait Message[M]
4+
class Script[S] extends ArrayBuffer[Message[S]] with Message[S]
5+
6+
class Test[A] {
7+
def f(cmd: Message[A]): Unit = cmd match {
8+
case s: Script[_] => s.iterator.foreach(x => f(x))
9+
}
10+
def g(cmd: Message[A]): Unit = cmd match {
11+
case s: Script[z] => s.iterator.foreach(x => g(x))
12+
}
13+
}

0 commit comments

Comments
 (0)