Skip to content

Commit 31cbd47

Browse files
committed
Crude implementation of removing trailing unit-literal maps from for-comprehensions
1 parent 08a8457 commit 31cbd47

File tree

5 files changed

+111
-6
lines changed

5 files changed

+111
-6
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Phases.Phase
99
import transform.*
1010
import dotty.tools.backend
1111
import backend.jvm.{CollectSuperCalls, GenBCode}
12-
import localopt.StringInterpolatorOpt
12+
import localopt.{StringInterpolatorOpt, DropForMap}
1313

1414
/** The central class of the dotc compiler. The job of a compiler is to create
1515
* runs, which process given `phases` in a given `rootContext`.
@@ -91,7 +91,8 @@ class Compiler {
9191
new ExplicitOuter, // Add accessors to outer classes from nested ones.
9292
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
9393
new StringInterpolatorOpt, // Optimizes raw and s and f string interpolators by rewriting them to string concatenations or formats
94-
new DropBreaks) :: // Optimize local Break throws by rewriting them
94+
new DropBreaks, // Optimize local Break throws by rewriting them
95+
new DropForMap) :: // Drop unused trailing map calls in for comprehensions
9596
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
9697
new UninitializedDefs, // Replaces `compiletime.uninitialized` by `_`
9798
new InlinePatterns, // Remove placeholders of inlined patterns

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ object desugar {
5252
*/
5353
val ContextBoundParam: Property.Key[Unit] = Property.StickyKey()
5454

55+
/** An attachment key to indicate that an Apply is created as a last `map`
56+
* scall in a for-comprehension.
57+
*/
58+
val TrailingForMap: Property.Key[Unit] = Property.StickyKey()
59+
5560
/** What static check should be applied to a Match? */
5661
enum MatchCheck {
5762
case None, Exhaustive, IrrefutablePatDef, IrrefutableGenFrom
@@ -1872,9 +1877,9 @@ object desugar {
18721877
* (Where empty for-comprehensions are excluded by the parser)
18731878
*
18741879
* If the aliases are not followed by a guard, otherwise an error.
1875-
*
1880+
*
18761881
* With betterFors disabled, the translation is as follows:
1877-
*
1882+
*
18781883
* 1.
18791884
*
18801885
* for (P <- G) E ==> G.foreach (P => E)
@@ -2044,14 +2049,22 @@ object desugar {
20442049
if gen.checkMode != GenCheckMode.Filtered // results of withFilter have the wrong type
20452050
&& deepEquals(gen.pat, body)
20462051
then gen.expr // avoid a redundant map with identity
2047-
else Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2052+
else
2053+
val aply = Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
2054+
aply.putAttachment(TrailingForMap, ())
2055+
aply
20482056
case (gen: GenFrom) :: rest
20492057
if rest.dropWhile(_.isInstanceOf[GenAlias]).headOption.forall(e => e.isInstanceOf[GenFrom]) =>
20502058
val cont = makeFor(mapName, flatMapName, rest, body)
20512059
val selectName =
20522060
if rest.exists(_.isInstanceOf[GenFrom]) then flatMapName
20532061
else mapName
2054-
Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
2062+
val aply = Apply(rhsSelect(gen, selectName), makeLambda(gen, cont))
2063+
if selectName == mapName then
2064+
aply.pushAttachment(TrailingForMap, ())
2065+
else
2066+
aply
2067+
aply
20552068
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
20562069
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
20572070
val pats = valeqs map { case GenAlias(pat, _) => pat }
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
package dotty.tools.dotc
2+
package transform.localopt
3+
4+
import scala.language.unsafeNulls
5+
6+
import dotty.tools.dotc.ast.tpd
7+
import dotty.tools.dotc.core.Decorators.*
8+
import dotty.tools.dotc.core.Constants.Constant
9+
import dotty.tools.dotc.core.Contexts.*
10+
import dotty.tools.dotc.core.StdNames.*
11+
import dotty.tools.dotc.core.Symbols.*
12+
import dotty.tools.dotc.core.Types.*
13+
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
14+
import dotty.tools.dotc.typer.ConstFold
15+
import dotty.tools.dotc.ast.desugar
16+
import scala.util.chaining.*
17+
18+
class DropForMap extends MiniPhase:
19+
import tpd.*
20+
21+
override def phaseName: String = DropForMap.name
22+
23+
override def description: String = DropForMap.description
24+
25+
override def transformApply(tree: tpd.Apply)(using Context): tpd.Tree =
26+
if !tree.hasAttachment(desugar.TrailingForMap) then tree.tap(_.removeAttachment(desugar.TrailingForMap))
27+
else tree match
28+
case Apply(MapCall(f), List(Lambda(List(param), body)))
29+
if isEssentiallyUnitLiteral(param, body) && param.tpt.tpe.isRef(defn.UnitClass) =>
30+
f
31+
case _ =>
32+
tree.tap(_.removeAttachment(desugar.TrailingForMap))
33+
34+
private object Lambda:
35+
def unapply(tree: Tree)(using Context): Option[(List[ValDef], Tree)] =
36+
tree match
37+
case Block(List(defdef: DefDef), Closure(Nil, ref, _)) if ref.symbol == defdef.symbol && !defdef.paramss.exists(_.forall(_.isType)) =>
38+
Some((defdef.termParamss.flatten, defdef.rhs))
39+
case _ => None
40+
41+
private object MapCall:
42+
def unapply(tree: Tree)(using Context): Option[Tree] = tree match
43+
case Select(f, nme.map) => Some(f)
44+
case Apply(fn, _) => unapply(fn)
45+
case TypeApply(fn, _) => unapply(fn)
46+
case _ => None
47+
48+
def isEssentiallyUnitLiteral(param: ValDef, tree: Tree)(using Context): Boolean = tree match
49+
case Literal(Constant(())) => true
50+
case Match(scrutinee, List(CaseDef(_, EmptyTree, body))) => isEssentiallyUnitLiteral(param, body)
51+
case Block(Nil, expr) => isEssentiallyUnitLiteral(param, expr)
52+
case _ => false
53+
54+
object DropForMap:
55+
val name: String = "dropForMap"
56+
val description: String = "Drop unused trailing map calls in for comprehensions"

tests/run/map-unit-elim.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
MySome(())

tests/run/map-unit-elim.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import scala.language.experimental.betterFors
2+
3+
class myOptionPackage(doOnMap: => Unit) {
4+
sealed trait MyOption[+A] {
5+
def map[B](f: A => B): MyOption[B] = this match {
6+
case MySome(x) => {
7+
doOnMap
8+
MySome(f(x))
9+
}
10+
case MyNone => MyNone
11+
}
12+
def flatMap[B](f: A => MyOption[B]): MyOption[B] = this match {
13+
case MySome(x) => f(x)
14+
case MyNone => MyNone
15+
}
16+
}
17+
case class MySome[A](x: A) extends MyOption[A]
18+
case object MyNone extends MyOption[Nothing]
19+
}
20+
21+
object Test extends App {
22+
23+
val myOption = new myOptionPackage(println("map called"))
24+
25+
import myOption.*
26+
27+
val z = for {
28+
a <- MySome(1)
29+
b <- MySome(())
30+
} yield ()
31+
32+
println(z)
33+
34+
}

0 commit comments

Comments
 (0)