Skip to content

Add eta reduction miniphase #14210

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class Compiler {
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
new PureStats, // Remove pure stats from blocks
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
new EtaReduce, // Reduce eta expansions of pure paths to the underlying function reference
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any fundamental reason this is after erasure but the existing BetaReduce phase is before erasure?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment of betaReduce mentions blackbox macros, which would be eliminated after erasure. That might have something to do with it, but I did not dig deeper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

EtaReduce is after erasure since a lot of context function wizardry happens before or at erasure so we want to make sure this is all done before we apply the optimization.

new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
new sjs.AddLocalJSFakeNews, // Adds fake new invocations to local JS classes in calls to `createLocalJSClass`
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
Expand Down
44 changes: 44 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/EtaReduce.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package dotty.tools
package dotc
package transform

import MegaPhase.MiniPhase
import core.*
import Symbols.*, Contexts.*, Types.*, Decorators.*
import StdNames.nme
import ast.Trees.*

/** Rewrite `(x1, ... xN) => f(x1, ... xN)` for N >= 0 to `f`,
* provided `f` is a pure path of function type.
*
* This optimization is crucial for context functions. The compiler
* produces a contextual closure around values passed as arguments
* where a context function is expected, unless that value has the
* syntactic form of a context function literal.
*
* Without this phase, when a contextual function is passed as an argument to a
* recursive function, that would have the unfortunate effect of a linear growth
* in transient thunks of identical type wrapped around each other, leading
* to performance degradation, and in some cases, stack overflows.
*/
class EtaReduce extends MiniPhase:
import ast.tpd._

override def phaseName: String = "etaReduce"

override def transformBlock(tree: Block)(using Context): Tree = tree match
case Block((meth : DefDef) :: Nil, closure: Closure)
if meth.symbol == closure.meth.symbol =>
meth.rhs match
case Apply(Select(fn, nme.apply), args)
if meth.paramss.head.corresponds(args)((param, arg) =>
arg.isInstanceOf[Ident] && arg.symbol == param.symbol)
&& isPurePath(fn)
&& fn.tpe <:< tree.tpe
&& defn.isFunctionClass(fn.tpe.widen.typeSymbol) =>
report.log(i"eta reducing $tree --> $fn")
fn
case _ => tree
case _ => tree

end EtaReduce
25 changes: 25 additions & 0 deletions tests/run/i10889.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import scala.annotation.tailrec
import scala.util.chaining.given

object Test {
class Ctx
type Op[A] = Ctx ?=> A

var min = Int.MaxValue
var max = 0
def stk = new Throwable().getStackTrace.length

@tailrec def f[A](n: Int)(op: Op[A]): A =
val depth = stk
min = min.min(depth)
max = max.max(depth)
given Ctx = Ctx()
if (n > 0) f(n-1)(op)
else op

def g(ctx: Ctx) = stk

def main(args: Array[String]): Unit =
val extra = 3
f(10)(Ctx ?=> g(summon[Ctx])).tap(res => assert(res <= max + extra, s"min $min, max $max, ran g at $res"))
}