Skip to content

Commit 02cbb4b

Browse files
committed
Refine checking mode for for generators
Add a third mode, Ignore, for generators where the pattern is refutable, but we know it has already been used as a filter in an earlier step.
1 parent 122f6ef commit 02cbb4b

File tree

5 files changed

+56
-17
lines changed

5 files changed

+56
-17
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,22 +1359,22 @@ object desugar {
13591359
}
13601360
}
13611361

1362-
def isIrrefutableGenFrom(gen: GenFrom): Boolean =
1363-
!gen.filtering ||
1362+
def needsFilter(gen: GenFrom): Boolean =
1363+
gen.checkMode != GenCheckMode.Filter ||
13641364
IdPattern.unapply(gen.pat).isDefined ||
13651365
isIrrefutable(gen.pat, gen.expr)
13661366

13671367
/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
13681368
* matched against `rhs`.
13691369
*/
13701370
def rhsSelect(gen: GenFrom, name: TermName) = {
1371-
val rhs = if (isIrrefutableGenFrom(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
1371+
val rhs = if (needsFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
13721372
Select(rhs, name)
13731373
}
13741374

13751375
def checkMode(gen: GenFrom) =
1376-
if (gen.filtering) MatchCheck.None // refutable paterns were already eliminated in filter step
1377-
else MatchCheck.IrrefutableGenFrom
1376+
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
1377+
else MatchCheck.None // refutable paterns were already eliminated in filter step
13781378

13791379
enums match {
13801380
case (gen: GenFrom) :: Nil =>
@@ -1389,13 +1389,13 @@ object desugar {
13891389
val (defpat0, id0) = makeIdPat(gen.pat)
13901390
val (defpats, ids) = (pats map makeIdPat).unzip
13911391
val pdefs = (valeqs, defpats, rhss).zipped.map(makePatDef(_, Modifiers(), _, _))
1392-
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.filtering) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
1392+
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
13931393
val allpats = gen.pat :: pats
1394-
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, filtering = false)
1394+
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
13951395
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
13961396
case (gen: GenFrom) :: test :: rest =>
13971397
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test, MatchCheck.None))
1398-
val genFrom = GenFrom(gen.pat, filtered, filtering = false)
1398+
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
13991399
makeFor(mapName, flatMapName, genFrom :: rest, body)
14001400
case _ =>
14011401
EmptyTree //may happen for erroneous input

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
9999
case class DoWhile(body: Tree, cond: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
100100
case class ForYield(enums: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
101101
case class ForDo(enums: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
102-
case class GenFrom(pat: Tree, expr: Tree, filtering: Boolean)(implicit @constructorOnly src: SourceFile) extends Tree
102+
case class GenFrom(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit @constructorOnly src: SourceFile) extends Tree
103103
case class GenAlias(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
104104
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree])(implicit @constructorOnly src: SourceFile) extends TypTree
105105
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree)(implicit @constructorOnly src: SourceFile) extends DefTree
@@ -116,6 +116,14 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
116116
* `Positioned#checkPos` */
117117
class XMLBlock(stats: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends Block(stats, expr)
118118

119+
/** An enum to control checking or filtering of patterns in GenFrom trees */
120+
class GenCheckMode(val x: Int) extends AnyVal
121+
object GenCheckMode {
122+
val Ignore = new GenCheckMode(0) // neither filter nor check since filtering was done before
123+
val Check = new GenCheckMode(1) // check that pattern is irrefutable
124+
val Filter = new GenCheckMode(2) // filter out non-matching elements
125+
}
126+
119127
// ----- Modifiers -----------------------------------------------------
120128
/** Mod is intended to record syntactic information about modifiers, it's
121129
* NOT a replacement of FlagSet.
@@ -525,9 +533,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
525533
case tree: ForDo if (enums eq tree.enums) && (body eq tree.body) => tree
526534
case _ => finalize(tree, untpd.ForDo(enums, body)(tree.source))
527535
}
528-
def GenFrom(tree: Tree)(pat: Tree, expr: Tree, filtering: Boolean)(implicit ctx: Context): Tree = tree match {
529-
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) && (filtering == tree.filtering) => tree
530-
case _ => finalize(tree, untpd.GenFrom(pat, expr, filtering)(tree.source))
536+
def GenFrom(tree: Tree)(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit ctx: Context): Tree = tree match {
537+
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) && (checkMode == tree.checkMode) => tree
538+
case _ => finalize(tree, untpd.GenFrom(pat, expr, checkMode)(tree.source))
531539
}
532540
def GenAlias(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match {
533541
case tree: GenAlias if (pat eq tree.pat) && (expr eq tree.expr) => tree
@@ -589,8 +597,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
589597
cpy.ForYield(tree)(transform(enums), transform(expr))
590598
case ForDo(enums, body) =>
591599
cpy.ForDo(tree)(transform(enums), transform(body))
592-
case GenFrom(pat, expr, filtering) =>
593-
cpy.GenFrom(tree)(transform(pat), transform(expr), filtering)
600+
case GenFrom(pat, expr, checkMode) =>
601+
cpy.GenFrom(tree)(transform(pat), transform(expr), checkMode)
594602
case GenAlias(pat, expr) =>
595603
cpy.GenAlias(tree)(transform(pat), transform(expr))
596604
case ContextBounds(bounds, cxBounds) =>

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1717,7 +1717,10 @@ object Parsers {
17171717

17181718
def generatorRest(pat: Tree, casePat: Boolean): GenFrom =
17191719
atSpan(startOffset(pat), accept(LARROW)) {
1720-
GenFrom(pat, expr(), filtering = casePat || !ctx.settings.strict.value) // don't filter under -strict
1720+
val checkMode =
1721+
if (casePat || !ctx.settings.strict.value) GenCheckMode.Filter // don't filter under -strict
1722+
else GenCheckMode.Check
1723+
GenFrom(pat, expr(), checkMode)
17211724
}
17221725

17231726
/** ForExpr ::= `for' (`(' Enumerators `)' | `{' Enumerators `}')

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -564,8 +564,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
564564
forText(enums, expr, keywordStr(" yield "))
565565
case ForDo(enums, expr) =>
566566
forText(enums, expr, keywordStr(" do "))
567-
case GenFrom(pat, expr, filtering) =>
568-
(Str("case ") provided filtering) ~ toText(pat) ~ " <- " ~ toText(expr)
567+
case GenFrom(pat, expr, checkMode) =>
568+
(Str("case ") provided checkMode == untpd.GenCheckMode.Filter) ~
569+
toText(pat) ~ " <- " ~ toText(expr)
569570
case GenAlias(pat, expr) =>
570571
toText(pat) ~ " = " ~ toText(expr)
571572
case ContextBounds(bounds, cxBounds) =>

tests/neg-strict/filtering-fors.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
object Test {
2+
3+
val xs: List[Any] = ???
4+
5+
for (x <- xs) do () // OK
6+
for (x: Any <- xs) do () // OK
7+
8+
for (x: String <- xs) do () // error
9+
for ((x: String) <- xs) do () // error
10+
for (y@ (x: String) <- xs) do () // error
11+
for ((x, y) <- xs) do () // error
12+
13+
for ((x: String) <- xs if x.isEmpty) do () // error
14+
for ((x: String) <- xs; y = x) do () // error
15+
for ((x: String) <- xs; (y, z) <- xs) do () // error // error
16+
for (case (x: String) <- xs; (y, z) <- xs) do () // error
17+
for ((x: String) <- xs; case (y, z) <- xs) do () // error
18+
19+
for (case x: String <- xs) do () // OK
20+
for (case (x: String) <- xs) do () // OK
21+
for (case y@ (x: String) <- xs) do () // OK
22+
for (case (x, y) <- xs) do () // OK
23+
24+
for (case (x: String) <- xs if x.isEmpty) do () // OK
25+
for (case (x: String) <- xs; y = x) do () // OK
26+
for (case (x: String) <- xs; case (y, z) <- xs) do () // OK
27+
}

0 commit comments

Comments
 (0)