Skip to content

Commit 961a3cd

Browse files
committed
Working better-fors fix
1 parent 23300d0 commit 961a3cd

File tree

7 files changed

+106
-25
lines changed

7 files changed

+106
-25
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ class Compiler {
6868
new InlineVals, // Check right hand-sides of an `inline val`s
6969
new ExpandSAMs, // Expand single abstract method closures to anonymous classes
7070
new ElimRepeated, // Rewrite vararg parameters and arguments
71-
new RefChecks) :: // Various checks mostly related to abstract members and overriding
71+
new RefChecks, // Various checks mostly related to abstract members and overriding
72+
new DropForMap) :: // Drop unused trailing map calls in for comprehensions
7273
List(new semanticdb.ExtractSemanticDB.AppendDiagnostics) :: // Attach warnings to extracted SemanticDB and write to .semanticdb file
7374
List(new init.Checker) :: // Check initialization of objects
7475
List(new ProtectedAccessors, // Add accessors for protected members
@@ -90,8 +91,7 @@ class Compiler {
9091
new ExplicitOuter, // Add accessors to outer classes from nested ones.
9192
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
9293
new StringInterpolatorOpt, // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats
93-
new DropBreaks, // Optimize local Break throws by rewriting them
94-
new DropForMap) :: // Drop unused trailing map calls in for comprehensions
94+
new DropBreaks) :: // Optimize local Break throws by rewriting them
9595
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
9696
new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_`
9797
new InlinePatterns, // Remove placeholders of inlined patterns

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1971,14 +1971,8 @@ object desugar {
19711971
*
19721972
* 3.
19731973
*
1974-
* for (P <- G) yield P ==> G
1975-
*
1976-
* If betterFors is enabled, P is a variable or a tuple of variables and G is not a withFilter.
1977-
*
19781974
* for (P <- G) yield E ==> G.map (P => E)
19791975
*
1980-
* Otherwise
1981-
*
19821976
* 4.
19831977
*
19841978
* for (P_1 <- G_1; P_2 <- G_2; ...) ...
@@ -2157,7 +2151,7 @@ object desugar {
21572151
val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
21582152
if betterForsEnabled
21592153
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
2160-
// && deepEquals(gen.pat, body)
2154+
&& deepEquals(gen.pat, body)
21612155
then
21622156
aply.putAttachment(TrailingForMap, ())
21632157
aply

compiler/src/dotty/tools/dotc/transform/localopt/DropForMap.scala

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,22 @@ import dotty.tools.dotc.transform.MegaPhase.MiniPhase
1414
import dotty.tools.dotc.typer.ConstFold
1515
import dotty.tools.dotc.ast.desugar
1616
import scala.util.chaining.*
17+
import tpd.*
1718

1819
class DropForMap extends MiniPhase:
19-
import tpd.*
20+
import DropForMap.*
21+
import Binder.*
2022

2123
override def phaseName: String = DropForMap.name
2224

2325
override def description: String = DropForMap.description
2426

2527
override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
26-
if !tree.hasAttachment(desugar.TrailingForMap) then tree.tap(_.removeAttachment(desugar.TrailingForMap))
28+
if !tree.hasAttachment(desugar.TrailingForMap) then tree
2729
else tree match
28-
case Apply(MapCall(f), List(Lambda(List(param), body)))
29-
if isEssentiallyUnitLiteral(param, body) && param.tpt.tpe.isRef(defn.UnitClass) =>
30-
f
30+
case aply @ Apply(MapCall(f), List(Lambda(List(param), body)))
31+
if canDropMap(Single(param), body) && f.tpe =:= aply.tpe => // make sure that the type of the expression won't change
32+
f // drop the map call
3133
case _ =>
3234
tree.tap(_.removeAttachment(desugar.TrailingForMap))
3335

@@ -45,12 +47,57 @@ class DropForMap extends MiniPhase:
4547
case TypeApply(fn, _) => unapply(fn)
4648
case _ => None
4749

48-
def isEssentiallyUnitLiteral(param: ValDef, tree: Tree)(using Context): Boolean = tree match
49-
case Literal(Constant(())) => true
50-
case Match(scrutinee, List(CaseDef(_, EmptyTree, body))) => isEssentiallyUnitLiteral(param, body)
51-
case Block(Nil, expr) => isEssentiallyUnitLiteral(param, expr)
52-
case _ => false
50+
/** We can drop the map call if:
51+
* - it is a Unit literal
52+
* - is an identity function (i.e. the last pattern is the same as the result)
53+
*/
54+
private def canDropMap(params: Binder, tree: Tree)(using Context): Boolean = tree match
55+
case Literal(Constant(())) => params match
56+
case Single(bind) => bind.symbol.info.isRef(defn.UnitClass)
57+
case _ => false
58+
case ident: Ident => params match
59+
case Single(bind) => bind.symbol == ident.symbol
60+
case _ => false
61+
case tree: Apply if tree.tpe.typeSymbol.derivesFrom(defn.TupleClass) => params match
62+
case Tuple(binds) => tree.args.zip(binds).forall((arg, param) => canDropMap(param, arg))
63+
case _ => false
64+
case Match(scrutinee, List(CaseDef(pat, EmptyTree, body))) =>
65+
val newParams = newParamsFromMatch(params, scrutinee, pat)
66+
canDropMap(newParams, body)
67+
case Block(Nil, expr) => canDropMap(params, expr)
68+
case _ =>
69+
false
70+
71+
/** Extract potentially new parameters from a match expression
72+
*/
73+
private def newParamsFromMatch(params: Binder, scrutinee: Tree, pat: Tree)(using Context): Binder =
74+
def extractTraverse(pats: List[Tree]): Option[List[Binder]] = pats match
75+
case Nil => Some(List.empty)
76+
case pat :: pats =>
77+
extractBinders(pat).map(_ +: extractTraverse(pats).get)
78+
def extractBinders(pat: Tree): Option[Binder] = pat match
79+
case bind: Bind => Some(Single(bind))
80+
case tree @ UnApply(fun, implicits, pats)
81+
if implicits.isEmpty && tree.tpe.finalResultType.dealias.typeSymbol.derivesFrom(defn.TupleClass) =>
82+
extractTraverse(pats).map(Tuple.apply)
83+
case _ => None
84+
85+
params match
86+
case Single(bind) if scrutinee.symbol == bind.symbol =>
87+
pat match
88+
case bind: Bind => Single(bind)
89+
case tree @ UnApply(fun, implicits, pats) if implicits.isEmpty =>
90+
val unapplied = tree.tpe.finalResultType.dealias.typeSymbol
91+
if unapplied.derivesFrom(defn.TupleClass) then
92+
extractTraverse(pats).map(Tuple.apply).getOrElse(params)
93+
else params
94+
case _ => params
95+
case _ => params
5396

5497
object DropForMap:
5598
val name: String = "dropForMap"
5699
val description: String = "Drop unused trailing map calls in for comprehensions"
100+
101+
private enum Binder:
102+
case Single(bind: NamedDefTree)
103+
case Tuple(binds: List[Binder])

tests/pos/better-fors-i21804.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.language.experimental.betterFors
2+
3+
case class Container[A](val value: A) {
4+
def map[B](f: A => B): Container[B] = Container(f(value))
5+
}
6+
7+
sealed trait Animal
8+
case class Dog() extends Animal
9+
10+
def opOnDog(dog: Container[Dog]): Container[Animal] =
11+
for
12+
v <- dog
13+
yield v

tests/run/better-fors-map-elim.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
MySome(())
2+
MySome(2)
3+
MySome((2,3))
4+
MySome((2,(3,4)))

tests/run/map-unit-elim.scala renamed to tests/run/better-fors-map-elim.scala

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import scala.language.experimental.betterFors
22

3-
class myOptionPackage(doOnMap: => Unit) {
3+
class myOptionModule(doOnMap: => Unit) {
44
sealed trait MyOption[+A] {
55
def map[B](f: A => B): MyOption[B] = this match {
66
case MySome(x) => {
@@ -16,19 +16,43 @@ class myOptionPackage(doOnMap: => Unit) {
1616
}
1717
case class MySome[A](x: A) extends MyOption[A]
1818
case object MyNone extends MyOption[Nothing]
19+
object MyOption {
20+
def apply[A](x: A): MyOption[A] = MySome(x)
21+
}
1922
}
2023

2124
object Test extends App {
2225

23-
val myOption = new myOptionPackage(println("map called"))
26+
val myOption = new myOptionModule(println("map called"))
2427

2528
import myOption.*
2629

2730
val z = for {
28-
a <- MySome(1)
29-
b <- MySome(())
31+
a <- MyOption(1)
32+
b <- MyOption(())
3033
} yield ()
3134

3235
println(z)
3336

37+
val z2 = for {
38+
a <- MyOption(1)
39+
b <- MyOption(2)
40+
} yield b
41+
42+
println(z2)
43+
44+
val z3 = for {
45+
a <- MyOption(1)
46+
(b, c) <- MyOption((2, 3))
47+
} yield (b, c)
48+
49+
println(z3)
50+
51+
val z4 = for {
52+
a <- MyOption(1)
53+
(b, (c, d)) <- MyOption((2, (3, 4)))
54+
} yield (b, (c, d))
55+
56+
println(z4)
57+
3458
}

tests/run/map-unit-elim.check

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)