Skip to content

Commit c2a41b6

Browse files
committed
Refine checking mode for for generators
Add a third mode, Ignore, for generators where the pattern is refutable, but we know it has already been used as a filter in an earlier step.
1 parent b9ecfc0 commit c2a41b6

File tree

5 files changed

+56
-17
lines changed

5 files changed

+56
-17
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1361,22 +1361,22 @@ object desugar {
13611361
}
13621362
}
13631363

1364-
def isIrrefutableGenFrom(gen: GenFrom): Boolean =
1365-
!gen.filtering ||
1364+
def needsFilter(gen: GenFrom): Boolean =
1365+
gen.checkMode != GenCheckMode.Filter ||
13661366
IdPattern.unapply(gen.pat).isDefined ||
13671367
isIrrefutable(gen.pat, gen.expr)
13681368

13691369
/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
13701370
* matched against `rhs`.
13711371
*/
13721372
def rhsSelect(gen: GenFrom, name: TermName) = {
1373-
val rhs = if (isIrrefutableGenFrom(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
1373+
val rhs = if (needsFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
13741374
Select(rhs, name)
13751375
}
13761376

13771377
def checkMode(gen: GenFrom) =
1378-
if (gen.filtering) MatchCheck.None // refutable paterns were already eliminated in filter step
1379-
else MatchCheck.IrrefutableGenFrom
1378+
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
1379+
else MatchCheck.None // refutable paterns were already eliminated in filter step
13801380

13811381
enums match {
13821382
case (gen: GenFrom) :: Nil =>
@@ -1391,13 +1391,13 @@ object desugar {
13911391
val (defpat0, id0) = makeIdPat(gen.pat)
13921392
val (defpats, ids) = (pats map makeIdPat).unzip
13931393
val pdefs = (valeqs, defpats, rhss).zipped.map(makePatDef(_, Modifiers(), _, _))
1394-
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.filtering) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
1394+
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.checkMode) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
13951395
val allpats = gen.pat :: pats
1396-
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, filtering = false)
1396+
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
13971397
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
13981398
case (gen: GenFrom) :: test :: rest =>
13991399
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test, MatchCheck.None))
1400-
val genFrom = GenFrom(gen.pat, filtered, filtering = false)
1400+
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
14011401
makeFor(mapName, flatMapName, genFrom :: rest, body)
14021402
case _ =>
14031403
EmptyTree //may happen for erroneous input

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
9999
case class DoWhile(body: Tree, cond: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
100100
case class ForYield(enums: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
101101
case class ForDo(enums: List[Tree], body: Tree)(implicit @constructorOnly src: SourceFile) extends TermTree
102-
case class GenFrom(pat: Tree, expr: Tree, filtering: Boolean)(implicit @constructorOnly src: SourceFile) extends Tree
102+
case class GenFrom(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit @constructorOnly src: SourceFile) extends Tree
103103
case class GenAlias(pat: Tree, expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
104104
case class ContextBounds(bounds: TypeBoundsTree, cxBounds: List[Tree])(implicit @constructorOnly src: SourceFile) extends TypTree
105105
case class PatDef(mods: Modifiers, pats: List[Tree], tpt: Tree, rhs: Tree)(implicit @constructorOnly src: SourceFile) extends DefTree
@@ -116,6 +116,14 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
116116
* `Positioned#checkPos` */
117117
class XMLBlock(stats: List[Tree], expr: Tree)(implicit @constructorOnly src: SourceFile) extends Block(stats, expr)
118118

119+
/** An enum to control checking or filtering of patterns in GenFrom trees */
120+
class GenCheckMode(val x: Int) extends AnyVal
121+
object GenCheckMode {
122+
val Ignore = new GenCheckMode(0) // neither filter nor check since filtering was done before
123+
val Check = new GenCheckMode(1) // check that pattern is irrefutable
124+
val Filter = new GenCheckMode(2) // filter out non-matching elements
125+
}
126+
119127
// ----- Modifiers -----------------------------------------------------
120128
/** Mod is intended to record syntactic information about modifiers, it's
121129
* NOT a replacement of FlagSet.
@@ -525,9 +533,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
525533
case tree: ForDo if (enums eq tree.enums) && (body eq tree.body) => tree
526534
case _ => finalize(tree, untpd.ForDo(enums, body)(tree.source))
527535
}
528-
def GenFrom(tree: Tree)(pat: Tree, expr: Tree, filtering: Boolean)(implicit ctx: Context): Tree = tree match {
529-
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) && (filtering == tree.filtering) => tree
530-
case _ => finalize(tree, untpd.GenFrom(pat, expr, filtering)(tree.source))
536+
def GenFrom(tree: Tree)(pat: Tree, expr: Tree, checkMode: GenCheckMode)(implicit ctx: Context): Tree = tree match {
537+
case tree: GenFrom if (pat eq tree.pat) && (expr eq tree.expr) && (checkMode == tree.checkMode) => tree
538+
case _ => finalize(tree, untpd.GenFrom(pat, expr, checkMode)(tree.source))
531539
}
532540
def GenAlias(tree: Tree)(pat: Tree, expr: Tree)(implicit ctx: Context): Tree = tree match {
533541
case tree: GenAlias if (pat eq tree.pat) && (expr eq tree.expr) => tree
@@ -589,8 +597,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
589597
cpy.ForYield(tree)(transform(enums), transform(expr))
590598
case ForDo(enums, body) =>
591599
cpy.ForDo(tree)(transform(enums), transform(body))
592-
case GenFrom(pat, expr, filtering) =>
593-
cpy.GenFrom(tree)(transform(pat), transform(expr), filtering)
600+
case GenFrom(pat, expr, checkMode) =>
601+
cpy.GenFrom(tree)(transform(pat), transform(expr), checkMode)
594602
case GenAlias(pat, expr) =>
595603
cpy.GenAlias(tree)(transform(pat), transform(expr))
596604
case ContextBounds(bounds, cxBounds) =>

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1741,7 +1741,10 @@ object Parsers {
17411741

17421742
def generatorRest(pat: Tree, casePat: Boolean): GenFrom =
17431743
atSpan(startOffset(pat), accept(LARROW)) {
1744-
GenFrom(pat, expr(), filtering = casePat || !ctx.settings.strict.value) // don't filter under -strict
1744+
val checkMode =
1745+
if (casePat || !ctx.settings.strict.value) GenCheckMode.Filter // don't filter under -strict
1746+
else GenCheckMode.Check
1747+
GenFrom(pat, expr(), checkMode)
17451748
}
17461749

17471750
/** ForExpr ::= `for' (`(' Enumerators `)' | `{' Enumerators `}')

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -570,8 +570,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
570570
forText(enums, expr, keywordStr(" yield "))
571571
case ForDo(enums, expr) =>
572572
forText(enums, expr, keywordStr(" do "))
573-
case GenFrom(pat, expr, filtering) =>
574-
(Str("case ") provided filtering) ~ toText(pat) ~ " <- " ~ toText(expr)
573+
case GenFrom(pat, expr, checkMode) =>
574+
(Str("case ") provided checkMode == untpd.GenCheckMode.Filter) ~
575+
toText(pat) ~ " <- " ~ toText(expr)
575576
case GenAlias(pat, expr) =>
576577
toText(pat) ~ " = " ~ toText(expr)
577578
case ContextBounds(bounds, cxBounds) =>

tests/neg-strict/filtering-fors.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
object Test {
2+
3+
val xs: List[Any] = ???
4+
5+
for (x <- xs) do () // OK
6+
for (x: Any <- xs) do () // OK
7+
8+
for (x: String <- xs) do () // error
9+
for ((x: String) <- xs) do () // error
10+
for (y@ (x: String) <- xs) do () // error
11+
for ((x, y) <- xs) do () // error
12+
13+
for ((x: String) <- xs if x.isEmpty) do () // error
14+
for ((x: String) <- xs; y = x) do () // error
15+
for ((x: String) <- xs; (y, z) <- xs) do () // error // error
16+
for (case (x: String) <- xs; (y, z) <- xs) do () // error
17+
for ((x: String) <- xs; case (y, z) <- xs) do () // error
18+
19+
for (case x: String <- xs) do () // OK
20+
for (case (x: String) <- xs) do () // OK
21+
for (case y@ (x: String) <- xs) do () // OK
22+
for (case (x, y) <- xs) do () // OK
23+
24+
for (case (x: String) <- xs if x.isEmpty) do () // OK
25+
for (case (x: String) <- xs; y = x) do () // OK
26+
for (case (x: String) <- xs; case (y, z) <- xs) do () // OK
27+
}

0 commit comments

Comments
 (0)