Skip to content

Commit 804b786

Browse files
committed
Add eta reduction miniphase
Fixes #10889. Alternative to #13750, from which the test is taken.
1 parent 355d2f6 commit 804b786

File tree

4 files changed

+131
-0
lines changed

4 files changed

+131
-0
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class Compiler {
106106
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
107107
new PureStats, // Remove pure stats from blocks
108108
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
109+
new EtaReduce,
109110
new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
110111
new sjs.AddLocalJSFakeNews, // Adds fake new invocations to local JS classes in calls to `createLocalJSClass`
111112
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
package dotty.tools
2+
package dotc
3+
package transform
4+
5+
import MegaPhase.MiniPhase
6+
import core.*
7+
import Symbols.*, Contexts.*, Types.*, Decorators.*
8+
import StdNames.nme
9+
import ast.Trees.*
10+
11+
/** Rewrite `(x1, ... xN) => f(x1, ... xN)` for N >= 0 to `f`,
12+
* provided `f` is a pure path of function type.
13+
*
14+
* This optimization is crucial for context functions. The compiler
15+
* produces a contextual closure around values passed as arguments
16+
* where a context function is expected, unless that value has the
17+
* syntactic form of a context function literal.
18+
*
19+
* Without this phase, when a contextual function is passed as an argument to a
20+
* recursive function, that would have the unfortunate effect of a linear growth
21+
* in transient thunks of identical type wrapped around each other, leading
22+
* to performance degradation, and in some cases, stack overflows.
23+
*/
24+
class EtaReduce extends MiniPhase:
25+
import ast.tpd._
26+
27+
override def phaseName: String = "etaReduce"
28+
29+
override def transformBlock(tree: Block)(using Context): Tree = tree match
30+
case Block((meth : DefDef) :: Nil, closure: Closure)
31+
if meth.symbol == closure.meth.symbol =>
32+
meth.rhs match
33+
case Apply(Select(fn, nme.apply), args)
34+
if meth.paramss.head.corresponds(args)((param, arg) =>
35+
arg.isInstanceOf[Ident] && arg.symbol == param.symbol)
36+
&& isPurePath(fn) =>
37+
val treeSym = tree.tpe.widen.typeSymbol
38+
val fnSym = fn.tpe.widen.typeSymbol
39+
if treeSym == fnSym && defn.isFunctionClass(fnSym) then
40+
report.log(i"eta reducing $tree --> $fn")
41+
fn
42+
else tree
43+
case _ => tree
44+
case _ => tree
45+
46+
end EtaReduce
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
true
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import scala.annotation.tailrec
2+
3+
import java.lang.management.ManagementFactory
4+
import java.lang.management.RuntimeMXBean
5+
6+
import scala.jdk.CollectionConverters._
7+
8+
object Test {
9+
trait Txn {}
10+
type AtomicOp[Z] = Txn ?=> Z
11+
type VanillaAtomicOp[Z] = Txn => Z
12+
13+
object AbortAndRetry extends scala.util.control.ControlThrowable("abort and retry") {}
14+
15+
def beginTxn:Txn = new Txn {}
16+
17+
@tailrec def retryN[Z](n:Int)(txn:AtomicOp[Z], i:Int = 0):Z = {
18+
try {
19+
given Txn = beginTxn
20+
val ret:Z = txn
21+
if(i < n) { throw AbortAndRetry }
22+
ret
23+
} catch {
24+
case AbortAndRetry => retryN(n)(txn, i + 1)
25+
}
26+
}
27+
28+
@tailrec def safeRetryN[Z](n:Int)(txn:VanillaAtomicOp[Z], i:Int = 0):Z = {
29+
try {
30+
given Txn = beginTxn
31+
val ret:Z = txn.asInstanceOf[AtomicOp[Z]]
32+
if(i < n) { throw AbortAndRetry }
33+
ret
34+
} catch {
35+
case AbortAndRetry => safeRetryN(n)(txn, i + 1)
36+
}
37+
}
38+
39+
object StackSize {
40+
def unapply(arg:String):Option[Int] = {
41+
Option(arg match {
42+
case s"-Xss$rest" => rest
43+
case s"-XX:ThreadStackSize=$rest" => rest
44+
case _ => null
45+
}).map{ rest =>
46+
val shift = (rest.toLowerCase.last match {
47+
case 'k' => 10
48+
case 'm' => 20
49+
case 'g' => 30
50+
case _ => 0
51+
})
52+
(if(shift > 0) {
53+
rest.dropRight(1)
54+
} else {
55+
rest
56+
}).toInt << shift
57+
}
58+
}
59+
}
60+
def main(args:Array[String]) = {
61+
val arguments = ManagementFactory.getRuntimeMXBean.getInputArguments
62+
// 64bit VM defaults to 1024K / 1M stack size, 32bit defaults to 320k
63+
// Use 1024 as upper bound.
64+
val maxStackSize:Int = arguments.asScala.reverseIterator.collectFirst{case StackSize(stackSize) => stackSize}.getOrElse(1 << 20)
65+
66+
val testTxn:VanillaAtomicOp[Boolean] = {
67+
(txn:Txn) =>
68+
given Txn = txn
69+
true
70+
}
71+
72+
Console.println(try {
73+
// maxStackSize is a good upper bound on linear stack growth
74+
// without assuming too much about frame size
75+
// (1 byte / frame is conservative even for a z80)
76+
retryN(maxStackSize)(testTxn.asInstanceOf[AtomicOp[Boolean]]) == safeRetryN(maxStackSize)(testTxn)
77+
} catch {
78+
case e:StackOverflowError =>
79+
Console.println(s"Exploded after ${e.getStackTrace.length} frames")
80+
false
81+
})
82+
}
83+
}

0 commit comments

Comments
 (0)