Skip to content

Commit 23300d0

Browse files
committed
Crude implementation of removing trailing unit-literal maps from for-comprehensions
1 parent 2366bd9 commit 23300d0

File tree

5 files changed

+109
-7
lines changed

5 files changed

+109
-7
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import parsing.Parser
88
import Phases.Phase
99
import transform.*
1010
import backend.jvm.{CollectSuperCalls, GenBCode}
11-
import localopt.StringInterpolatorOpt
11+
import localopt.{StringInterpolatorOpt, DropForMap}
1212

1313
/** The central class of the dotc compiler. The job of a compiler is to create
1414
* runs, which process given `phases` in a given `rootContext`.
@@ -90,7 +90,8 @@ class Compiler {
9090
new ExplicitOuter, // Add accessors to outer classes from nested ones.
9191
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
9292
new StringInterpolatorOpt, // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats
93-
new DropBreaks) :: // Optimize local Break throws by rewriting them
93+
new DropBreaks, // Optimize local Break throws by rewriting them
94+
new DropForMap) :: // Drop unused trailing map calls in for comprehensions
9495
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
9596
new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_`
9697
new InlinePatterns, // Remove placeholders of inlined patterns

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,11 @@ object desugar {
6464
*/
6565
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()
6666

67+
/** An attachment key to indicate that an Apply is created as a last `map`
68+
* scall in a for-comprehension.
69+
*/
70+
val TrailingForMap: Property.Key[Unit] = Property.StickyKey()
71+
6772
/** What static check should be applied to a Match? */
6873
enum MatchCheck {
6974
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
@@ -2149,11 +2154,13 @@ object desugar {
21492154
enums match {
21502155
case Nil if betterForsEnabled => body
21512156
case (gen: GenFrom) :: Nil =>
2157+
val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
21522158
if betterForsEnabled
2153-
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
2154-
&& deepEquals(gen.pat, body)
2155-
then gen.expr // avoid a redundant map with identity
2156-
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2159+
&& gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
2160+
// && deepEquals(gen.pat, body)
2161+
then
2162+
aply.putAttachment(TrailingForMap, ())
2163+
aply
21572164
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
21582165
val cont = makeFor(mapName, flatMapName, rest, body)
21592166
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
@@ -2164,7 +2171,10 @@ object desugar {
21642171
val selectName =
21652172
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
21662173
else mapName
2167-
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
2174+
val aply = Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
2175+
if selectName == mapName then
2176+
aply.pushAttachment(TrailingForMap, ())
2177+
aply
21682178
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
21692179
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
21702180
val pats = valeqs map { case GenAlias(pat, _) => pat }
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package dotty.tools.dotc
2+
package transform.localopt
3+
4+
import scala.language.unsafeNulls
5+
6+
import dotty.tools.dotc.ast.tpd
7+
import dotty.tools.dotc.core.Decorators.*
8+
import dotty.tools.dotc.core.Constants.Constant
9+
import dotty.tools.dotc.core.Contexts.*
10+
import dotty.tools.dotc.core.StdNames.*
11+
import dotty.tools.dotc.core.Symbols.*
12+
import dotty.tools.dotc.core.Types.*
13+
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
14+
import dotty.tools.dotc.typer.ConstFold
15+
import dotty.tools.dotc.ast.desugar
16+
import scala.util.chaining.*
17+
18+
class DropForMap extends MiniPhase:
19+
import tpd.*
20+
21+
override def phaseName: String = DropForMap.name
22+
23+
override def description: String = DropForMap.description
24+
25+
override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
26+
if !tree.hasAttachment(desugar.TrailingForMap) then tree.tap(_.removeAttachment(desugar.TrailingForMap))
27+
else tree match
28+
case Apply(MapCall(f), List(Lambda(List(param), body)))
29+
if isEssentiallyUnitLiteral(param, body) && param.tpt.tpe.isRef(defn.UnitClass) =>
30+
f
31+
case _ =>
32+
tree.tap(_.removeAttachment(desugar.TrailingForMap))
33+
34+
private object Lambda:
35+
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] =
36+
tree match
37+
case Block(List(defdef: DefDef), Closure(Nil, ref, _)) if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
38+
Some((defdef.termParamss.flatten, defdef.rhs))
39+
case _ => None
40+
41+
private object MapCall:
42+
def unapply(tree: Tree)(using Context): Option[Tree] = tree match
43+
case Select(f, nme.map) => Some(f)
44+
case Apply(fn, _) => unapply(fn)
45+
case TypeApply(fn, _) => unapply(fn)
46+
case _ => None
47+
48+
def isEssentiallyUnitLiteral(param: ValDef, tree: Tree)(using Context): Boolean = tree match
49+
case Literal(Constant(())) => true
50+
case Match(scrutinee, List(CaseDef(_, EmptyTree, body))) => isEssentiallyUnitLiteral(param, body)
51+
case Block(Nil, expr) => isEssentiallyUnitLiteral(param, expr)
52+
case _ => false
53+
54+
object DropForMap:
55+
val name: String = "dropForMap"
56+
val description: String = "Drop unused trailing map calls in for comprehensions"

tests/run/map-unit-elim.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
MySome(())

tests/run/map-unit-elim.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import scala.language.experimental.betterFors
2+
3+
class myOptionPackage(doOnMap: => Unit) {
4+
sealed trait MyOption[+A] {
5+
def map[B](f: A => B): MyOption[B] = this match {
6+
case MySome(x) => {
7+
doOnMap
8+
MySome(f(x))
9+
}
10+
case MyNone => MyNone
11+
}
12+
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match {
13+
case MySome(x) => f(x)
14+
case MyNone => MyNone
15+
}
16+
}
17+
case class MySome[A](x: A) extends MyOption[A]
18+
case object MyNone extends MyOption[Nothing]
19+
}
20+
21+
object Test extends App {
22+
23+
val myOption = new myOptionPackage(println("map called"))
24+
25+
import myOption.*
26+
27+
val z = for {
28+
a <- MySome(1)
29+
b <- MySome(())
30+
} yield ()
31+
32+
println(z)
33+
34+
}

0 commit comments

Comments
 (0)