Skip to content

Commit 53ab457

Browse files
Merge pull request #14210 from dotty-staging/fix-10889
Add eta reduction miniphase
2 parents c9c6de6 + 4634481 commit 53ab457

File tree

3 files changed

+70
-0
lines changed

3 files changed

+70
-0
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class Compiler {
106106
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
107107
new PureStats, // Remove pure stats from blocks
108108
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
109+
new EtaReduce, // Reduce eta expansions of pure paths to the underlying function reference
109110
new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
110111
new sjs.AddLocalJSFakeNews, // Adds fake new invocations to local JS classes in calls to `createLocalJSClass`
111112
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package dotty.tools
2+
package dotc
3+
package transform
4+
5+
import MegaPhase.MiniPhase
6+
import core.*
7+
import Symbols.*, Contexts.*, Types.*, Decorators.*
8+
import StdNames.nme
9+
import ast.Trees.*
10+
11+
/** Rewrite `(x1, ... xN) => f(x1, ... xN)` for N >= 0 to `f`,
12+
* provided `f` is a pure path of function type.
13+
*
14+
* This optimization is crucial for context functions. The compiler
15+
* produces a contextual closure around values passed as arguments
16+
* where a context function is expected, unless that value has the
17+
* syntactic form of a context function literal.
18+
*
19+
* Without this phase, when a contextual function is passed as an argument to a
20+
* recursive function, that would have the unfortunate effect of a linear growth
21+
* in transient thunks of identical type wrapped around each other, leading
22+
* to performance degradation, and in some cases, stack overflows.
23+
*/
24+
class EtaReduce extends MiniPhase:
25+
import ast.tpd._
26+
27+
override def phaseName: String = "etaReduce"
28+
29+
override def transformBlock(tree: Block)(using Context): Tree = tree match
30+
case Block((meth : DefDef) :: Nil, closure: Closure)
31+
if meth.symbol == closure.meth.symbol =>
32+
meth.rhs match
33+
case Apply(Select(fn, nme.apply), args)
34+
if meth.paramss.head.corresponds(args)((param, arg) =>
35+
arg.isInstanceOf[Ident] && arg.symbol == param.symbol)
36+
&& isPurePath(fn)
37+
&& fn.tpe <:< tree.tpe
38+
&& defn.isFunctionClass(fn.tpe.widen.typeSymbol) =>
39+
report.log(i"eta reducing $tree --> $fn")
40+
fn
41+
case _ => tree
42+
case _ => tree
43+
44+
end EtaReduce

tests/run/i10889.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import scala.annotation.tailrec
2+
import scala.util.chaining.given
3+
4+
object Test {
5+
class Ctx
6+
type Op[A] = Ctx ?=> A
7+
8+
var min = Int.MaxValue
9+
var max = 0
10+
def stk = new Throwable().getStackTrace.length
11+
12+
@tailrec def f[A](n: Int)(op: Op[A]): A =
13+
val depth = stk
14+
min = min.min(depth)
15+
max = max.max(depth)
16+
given Ctx = Ctx()
17+
if (n > 0) f(n-1)(op)
18+
else op
19+
20+
def g(ctx: Ctx) = stk
21+
22+
def main(args: Array[String]): Unit =
23+
val extra = 3
24+
f(10)(Ctx ?=> g(summon[Ctx])).tap(res => assert(res <= max + extra, s"min $min, max $max, ran g at $res"))
25+
}

0 commit comments

Comments
 (0)