Skip to content

Commit 5c9b7ec

Browse files
committed
Fixes to support case id: T <- ...
This was not classified as a pattern binding before, so no filtering was applied.
1 parent 18c05d0 commit 5c9b7ec

File tree

5 files changed

+30
-21
lines changed

5 files changed

+30
-21
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,16 +1278,19 @@ object desugar {
12781278
*/
12791279
def makeFor(mapName: TermName, flatMapName: TermName, enums: List[Tree], body: Tree): Tree = trace(i"make for ${ForYield(enums, body)}", show = true) {
12801280

1281-
/** Make a function value pat => body.
1282-
* If pat is a var pattern id: T then this gives (id: T) => body
1283-
* Otherwise this gives { case pat => body }, where `pat` is allowed to be
1284-
* refutable only if `checkMode` is MatchCheck.None.
1281+
/** Let `pat` be `gen`'s pattern. Make a function value `pat => body`.
1282+
* If `pat` is a var pattern `id: T` then this gives `(id: T) => body`.
1283+
* Otherwise this gives `{ case pat => body }`, where `pat` is checked to be
1284+
* irrefutable if `gen`'s checkMode is GenCheckMode.Check.
12851285
*/
1286-
def makeLambda(pat: Tree, body: Tree, checkMode: MatchCheck): Tree = pat match {
1287-
case IdPattern(named, tpt) =>
1288-
Function(derivedValDef(pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
1286+
def makeLambda(gen: GenFrom, body: Tree): Tree = gen.pat match {
1287+
case IdPattern(named, tpt) if gen.checkMode != GenCheckMode.FilterAlways =>
1288+
Function(derivedValDef(gen.pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
12891289
case _ =>
1290-
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil, checkMode)
1290+
val matchCheckMode =
1291+
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
1292+
else MatchCheck.None
1293+
makeCaseLambda(CaseDef(gen.pat, EmptyTree, body) :: Nil, matchCheckMode)
12911294
}
12921295

12931296
/** If `pat` is not an Identifier, a Typed(Ident, _), or a Bind, wrap
@@ -1359,16 +1362,20 @@ object desugar {
13591362
}
13601363
}
13611364

1362-
def needsFilter(gen: GenFrom): Boolean =
1363-
gen.checkMode != GenCheckMode.Filter ||
1364-
IdPattern.unapply(gen.pat).isDefined ||
1365-
isIrrefutable(gen.pat, gen.expr)
1365+
def needsNoFilter(gen: GenFrom): Boolean =
1366+
if (gen.checkMode == GenCheckMode.FilterAlways) // pattern was prefixed by `case`
1367+
isIrrefutable(gen.pat, gen.expr)
1368+
else (
1369+
gen.checkMode != GenCheckMode.FilterNow ||
1370+
IdPattern.unapply(gen.pat).isDefined ||
1371+
isIrrefutable(gen.pat, gen.expr)
1372+
)
13661373

13671374
/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
13681375
* matched against `rhs`.
13691376
*/
13701377
def rhsSelect(gen: GenFrom, name: TermName) = {
1371-
val rhs = if (needsFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
1378+
val rhs = if (needsNoFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
13721379
Select(rhs, name)
13731380
}
13741381

@@ -1378,10 +1385,10 @@ object desugar {
13781385

13791386
enums match {
13801387
case (gen: GenFrom) :: Nil =>
1381-
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body, checkMode(gen)))
1388+
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
13821389
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
13831390
val cont = makeFor(mapName, flatMapName, rest, body)
1384-
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont, checkMode(gen)))
1391+
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
13851392
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
13861393
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
13871394
val pats = valeqs map { case GenAlias(pat, _) => pat }
@@ -1394,7 +1401,7 @@ object desugar {
13941401
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
13951402
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
13961403
case (gen: GenFrom) :: test :: rest =>
1397-
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test, MatchCheck.None))
1404+
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
13981405
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
13991406
makeFor(mapName, flatMapName, genFrom :: rest, body)
14001407
case _ =>

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
121121
object GenCheckMode {
122122
val Ignore = new GenCheckMode(0) // neither filter nor check since filtering was done before
123123
val Check = new GenCheckMode(1) // check that pattern is irrefutable
124-
val Filter = new GenCheckMode(2) // filter out non-matching elements
124+
val FilterNow = new GenCheckMode(2) // filter out non-matching elements since we are not in -strict
125+
val FilterAlways = new GenCheckMode(3) // filter out non-matching elements since pattern is prefixed by `case`
125126
}
126127

127128
// ----- Modifiers -----------------------------------------------------

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,8 +1718,9 @@ object Parsers {
17181718
def generatorRest(pat: Tree, casePat: Boolean): GenFrom =
17191719
atSpan(startOffset(pat), accept(LARROW)) {
17201720
val checkMode =
1721-
if (casePat || !ctx.settings.strict.value) GenCheckMode.Filter // don't filter under -strict
1722-
else GenCheckMode.Check
1721+
if (casePat) GenCheckMode.FilterAlways
1722+
else if (ctx.settings.strict.value) GenCheckMode.Check
1723+
else GenCheckMode.FilterNow // filter for now, to keep backwards compat
17231724
GenFrom(pat, expr(), checkMode)
17241725
}
17251726

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
565565
case ForDo(enums, expr) =>
566566
forText(enums, expr, keywordStr(" do "))
567567
case GenFrom(pat, expr, checkMode) =>
568-
(Str("case ") provided checkMode == untpd.GenCheckMode.Filter) ~
568+
(Str("case ") provided checkMode == untpd.GenCheckMode.FilterAlways) ~
569569
toText(pat) ~ " <- " ~ toText(expr)
570570
case GenAlias(pat, expr) =>
571571
toText(pat) ~ " = " ~ toText(expr)

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ trait Checking {
609609
else {
610610
var reportedPt = pt.dropAnnot(defn.UncheckedAnnot)
611611
if (!pat.tpe.isSingleton) reportedPt = reportedPt.widen
612-
val problem = if (pat.tpe <:< pt) "is more specialized than" else "does not match"
612+
val problem = if (pat.tpe <:< reportedPt) "is more specialized than" else "does not match"
613613
val fix = if (isPatDef) "`: @unchecked` after" else "`case ` before"
614614
ctx.errorOrMigrationWarning(
615615
ex"""pattern's type ${pat.tpe} $problem the right hand side expression's type $reportedPt

0 commit comments

Comments
 (0)