diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 5326361ada98..44bfc20c5418 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -1525,38 +1525,44 @@ object desugar { * * 1. * - * for (P <- G) E ==> G.foreach (P => E) + * for (P <- G) E ==> G.foreach (P => E) * - * Here and in the following (P => E) is interpreted as the function (P => E) - * if P is a variable pattern and as the partial function { case P => E } otherwise. + * Here and in the following (P => E) is interpreted as the function (P => E) + * if P is a variable pattern and as the partial function { case P => E } otherwise. * * 2. * - * for (P <- G) yield E ==> G.map (P => E) + * for (P <- G) yield P ==> G + * + * if P is a variable or a tuple of variables and G is not a withFilter. + * + * for (P <- G) yield E ==> G.map (P => E) + * + * otherwise * * 3. * - * for (P_1 <- G_1; P_2 <- G_2; ...) ... - * ==> - * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) + * for (P_1 <- G_1; P_2 <- G_2; ...) ... + * ==> + * G_1.flatMap (P_1 => for (P_2 <- G_2; ...) ...) * * 4. * - * for (P <- G; E; ...) ... - * => - * for (P <- G.filter (P => E); ...) ... + * for (P <- G; E; ...) ... + * => + * for (P <- G.filter (P => E); ...) ... * * 5. For any N: * - * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) - * ==> - * for (TupleN(P_1, P_2, ... P_N) <- - * for (x_1 @ P_1 <- G) yield { - * val x_2 @ P_2 = E_2 - * ... - * val x_N & P_N = E_N - * TupleN(x_1, ..., x_N) - * } ...) + * for (P_1 <- G; P_2 = E_2; val P_N = E_N; ...) + * ==> + * for (TupleN(P_1, P_2, ... P_N) <- + * for (x_1 @ P_1 <- G) yield { + * val x_2 @ P_2 = E_2 + * ... + * val x_N & P_N = E_N + * TupleN(x_1, ..., x_N) + * } ...) * * If any of the P_i are variable patterns, the corresponding `x_i @ P_i` is not generated * and the variable constituting P_i is used instead of x_i @@ -1669,7 +1675,7 @@ object desugar { case GenCheckMode.FilterAlways => false // pattern was prefixed by `case` case GenCheckMode.FilterNow | GenCheckMode.CheckAndFilter => isVarBinding(gen.pat) || isIrrefutable(gen.pat, gen.expr) case GenCheckMode.Check => true - case GenCheckMode.Ignore => true + case GenCheckMode.Ignore | GenCheckMode.Filtered => true /** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when * matched against `rhs`. @@ -1679,9 +1685,18 @@ object desugar { Select(rhs, name) } + def deepEquals(t1: Tree, t2: Tree): Boolean = + (unsplice(t1), unsplice(t2)) match + case (Ident(n1), Ident(n2)) => n1 == n2 + case (Tuple(ts1), Tuple(ts2)) => ts1.corresponds(ts2)(deepEquals) + case _ => false + enums match { case (gen: GenFrom) :: Nil => - Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) + if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type + && deepEquals(gen.pat, body) + then gen.expr // avoid a redundant map with identity + else Apply(rhsSelect(gen, mapName), makeLambda(gen, body)) case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) => val cont = makeFor(mapName, flatMapName, rest, body) Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont)) @@ -1703,7 +1718,7 @@ object desugar { makeFor(mapName, flatMapName, vfrom1 :: rest1, body) case (gen: GenFrom) :: test :: rest => val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test)) - val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore) + val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Filtered) makeFor(mapName, flatMapName, genFrom :: rest, body) case _ => EmptyTree //may happen for erroneous input diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 8a6ba48d22c5..517fc17f36c4 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -169,7 +169,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { /** An enum to control checking or filtering of patterns in GenFrom trees */ enum GenCheckMode { - case Ignore // neither filter nor check since filtering was done before + case Ignore // neither filter since pattern is trivially irrefutable + case Filtered // neither filter nor check since filtering was done before case Check // check that pattern is irrefutable case CheckAndFilter // both check and filter (transitional period starting with 3.2) case FilterNow // filter out non-matching elements if we are not in 3.2 or later diff --git a/tests/run/fors.check b/tests/run/fors.check index 50f6385e5845..7b7e8d076108 100644 --- a/tests/run/fors.check +++ b/tests/run/fors.check @@ -45,6 +45,9 @@ hello world hello/1~2 hello/3~4 /1~2 /3~4 world/1~2 world/3~4 (2,1) (4,3) +testTailrec +List((4,Symbol(a)), (5,Symbol(b)), (6,Symbol(c))) + testGivens 123 456 diff --git a/tests/run/fors.scala b/tests/run/fors.scala index 682978b5b3d8..bd7de7d32263 100644 --- a/tests/run/fors.scala +++ b/tests/run/fors.scala @@ -4,6 +4,8 @@ //############################################################################ +import annotation.tailrec + object Test extends App { val xs = List(1, 2, 3) val ys = List(Symbol("a"), Symbol("b"), Symbol("c")) @@ -108,6 +110,17 @@ object Test extends App { for case (x, y) <- xs do print(s"${(y, x)} "); println() } + /////////////////// elimination of map /////////////////// + + @tailrec + def pair[B](xs: List[Int], ys: List[B], n: Int): List[(Int, B)] = + if n == 0 then xs.zip(ys) + else for (x, y) <- pair(xs.map(_ + 1), ys, n - 1) yield (x, y) + + def testTailrec() = + println("\ntestTailrec") + println(pair(xs, ys, 3)) + def testGivens(): Unit = { println("\ntestGivens") @@ -141,5 +154,6 @@ object Test extends App { testOld() testNew() testFiltering() + testTailrec() testGivens() }