Skip to content

Commit 8196a71

Browse files
committed
Fixes to support case id: T <- ...
This was not classified as a pattern binding before, so no filtering was applied.
1 parent c63519c commit 8196a71

File tree

5 files changed

+30
-21
lines changed

5 files changed

+30
-21
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,16 +1280,19 @@ object desugar {
12801280
*/
12811281
def makeFor(mapName: TermName, flatMapName: TermName, enums: List[Tree], body: Tree): Tree = trace(i"make for ${ForYield(enums, body)}", show = true) {
12821282

1283-
/** Make a function value pat => body.
1284-
* If pat is a var pattern id: T then this gives (id: T) => body
1285-
* Otherwise this gives { case pat => body }, where `pat` is allowed to be
1286-
* refutable only if `checkMode` is MatchCheck.None.
1283+
/** Let `pat` be `gen`'s pattern. Make a function value `pat => body`.
1284+
* If `pat` is a var pattern `id: T` then this gives `(id: T) => body`.
1285+
* Otherwise this gives `{ case pat => body }`, where `pat` is checked to be
1286+
* irrefutable if `gen`'s checkMode is GenCheckMode.Check.
12871287
*/
1288-
def makeLambda(pat: Tree, body: Tree, checkMode: MatchCheck): Tree = pat match {
1289-
case IdPattern(named, tpt) =>
1290-
Function(derivedValDef(pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
1288+
def makeLambda(gen: GenFrom, body: Tree): Tree = gen.pat match {
1289+
case IdPattern(named, tpt) if gen.checkMode != GenCheckMode.FilterAlways =>
1290+
Function(derivedValDef(gen.pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
12911291
case _ =>
1292-
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil, checkMode)
1292+
val matchCheckMode =
1293+
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
1294+
else MatchCheck.None
1295+
makeCaseLambda(CaseDef(gen.pat, EmptyTree, body) :: Nil, matchCheckMode)
12931296
}
12941297

12951298
/** If `pat` is not an Identifier, a Typed(Ident, _), or a Bind, wrap
@@ -1361,16 +1364,20 @@ object desugar {
13611364
}
13621365
}
13631366

1364-
def needsFilter(gen: GenFrom): Boolean =
1365-
gen.checkMode != GenCheckMode.Filter ||
1366-
IdPattern.unapply(gen.pat).isDefined ||
1367-
isIrrefutable(gen.pat, gen.expr)
1367+
def needsNoFilter(gen: GenFrom): Boolean =
1368+
if (gen.checkMode == GenCheckMode.FilterAlways) // pattern was prefixed by `case`
1369+
isIrrefutable(gen.pat, gen.expr)
1370+
else (
1371+
gen.checkMode != GenCheckMode.FilterNow ||
1372+
IdPattern.unapply(gen.pat).isDefined ||
1373+
isIrrefutable(gen.pat, gen.expr)
1374+
)
13681375

13691376
/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
13701377
* matched against `rhs`.
13711378
*/
13721379
def rhsSelect(gen: GenFrom, name: TermName) = {
1373-
val rhs = if (needsFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
1380+
val rhs = if (needsNoFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
13741381
Select(rhs, name)
13751382
}
13761383

@@ -1380,10 +1387,10 @@ object desugar {
13801387

13811388
enums match {
13821389
case (gen: GenFrom) :: Nil =>
1383-
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body, checkMode(gen)))
1390+
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
13841391
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
13851392
val cont = makeFor(mapName, flatMapName, rest, body)
1386-
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont, checkMode(gen)))
1393+
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
13871394
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
13881395
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
13891396
val pats = valeqs map { case GenAlias(pat, _) => pat }
@@ -1396,7 +1403,7 @@ object desugar {
13961403
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
13971404
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
13981405
case (gen: GenFrom) :: test :: rest =>
1399-
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test, MatchCheck.None))
1406+
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
14001407
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
14011408
makeFor(mapName, flatMapName, genFrom :: rest, body)
14021409
case _ =>

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
121121
object GenCheckMode {
122122
val Ignore = new GenCheckMode(0) // neither filter nor check since filtering was done before
123123
val Check = new GenCheckMode(1) // check that pattern is irrefutable
124-
val Filter = new GenCheckMode(2) // filter out non-matching elements
124+
val FilterNow = new GenCheckMode(2) // filter out non-matching elements since we are not in -strict
125+
val FilterAlways = new GenCheckMode(3) // filter out non-matching elements since pattern is prefixed by `case`
125126
}
126127

127128
// ----- Modifiers -----------------------------------------------------

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1742,8 +1742,9 @@ object Parsers {
17421742
def generatorRest(pat: Tree, casePat: Boolean): GenFrom =
17431743
atSpan(startOffset(pat), accept(LARROW)) {
17441744
val checkMode =
1745-
if (casePat || !ctx.settings.strict.value) GenCheckMode.Filter // don't filter under -strict
1746-
else GenCheckMode.Check
1745+
if (casePat) GenCheckMode.FilterAlways
1746+
else if (ctx.settings.strict.value) GenCheckMode.Check
1747+
else GenCheckMode.FilterNow // filter for now, to keep backwards compat
17471748
GenFrom(pat, expr(), checkMode)
17481749
}
17491750

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,7 +571,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
571571
case ForDo(enums, expr) =>
572572
forText(enums, expr, keywordStr(" do "))
573573
case GenFrom(pat, expr, checkMode) =>
574-
(Str("case ") provided checkMode == untpd.GenCheckMode.Filter) ~
574+
(Str("case ") provided checkMode == untpd.GenCheckMode.FilterAlways) ~
575575
toText(pat) ~ " <- " ~ toText(expr)
576576
case GenAlias(pat, expr) =>
577577
toText(pat) ~ " = " ~ toText(expr)

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ trait Checking {
607607
def fail(pat: Tree, pt: Type): Boolean = {
608608
var reportedPt = pt.dropAnnot(defn.UncheckedAnnot)
609609
if (!pat.tpe.isSingleton) reportedPt = reportedPt.widen
610-
val problem = if (pat.tpe <:< pt) "is more specialized than" else "does not match"
610+
val problem = if (pat.tpe <:< reportedPt) "is more specialized than" else "does not match"
611611
val fix = if (isPatDef) "`: @unchecked` after" else "`case ` before"
612612
ctx.errorOrMigrationWarning(
613613
ex"""pattern's type ${pat.tpe} $problem the right hand side expression's type $reportedPt

0 commit comments

Comments
 (0)