diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 736e1b08d1e7..14664af30c9e 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -106,6 +106,7 @@ class Compiler { List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new PureStats, // Remove pure stats from blocks new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations + new EtaReduce, // Reduce eta expansions of pure paths to the underlying function reference new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]` new sjs.AddLocalJSFakeNews, // Adds fake new invocations to local JS classes in calls to `createLocalJSClass` new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses diff --git a/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala b/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala new file mode 100644 index 000000000000..8afdb5121e19 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala @@ -0,0 +1,44 @@ +package dotty.tools +package dotc +package transform + +import MegaPhase.MiniPhase +import core.* +import Symbols.*, Contexts.*, Types.*, Decorators.* +import StdNames.nme +import ast.Trees.* + +/** Rewrite `(x1, ... xN) => f(x1, ... xN)` for N >= 0 to `f`, + * provided `f` is a pure path of function type. + * + * This optimization is crucial for context functions. The compiler + * produces a contextual closure around values passed as arguments + * where a context function is expected, unless that value has the + * syntactic form of a context function literal. + * + * Without this phase, when a contextual function is passed as an argument to a + * recursive function, that would have the unfortunate effect of a linear growth + * in transient thunks of identical type wrapped around each other, leading + * to performance degradation, and in some cases, stack overflows. + */ +class EtaReduce extends MiniPhase: + import ast.tpd._ + + override def phaseName: String = "etaReduce" + + override def transformBlock(tree: Block)(using Context): Tree = tree match + case Block((meth : DefDef) :: Nil, closure: Closure) + if meth.symbol == closure.meth.symbol => + meth.rhs match + case Apply(Select(fn, nme.apply), args) + if meth.paramss.head.corresponds(args)((param, arg) => + arg.isInstanceOf[Ident] && arg.symbol == param.symbol) + && isPurePath(fn) + && fn.tpe <:< tree.tpe + && defn.isFunctionClass(fn.tpe.widen.typeSymbol) => + report.log(i"eta reducing $tree --> $fn") + fn + case _ => tree + case _ => tree + +end EtaReduce \ No newline at end of file diff --git a/tests/run/i10889.scala b/tests/run/i10889.scala new file mode 100644 index 000000000000..130b6d2d9636 --- /dev/null +++ b/tests/run/i10889.scala @@ -0,0 +1,25 @@ +import scala.annotation.tailrec +import scala.util.chaining.given + +object Test { + class Ctx + type Op[A] = Ctx ?=> A + + var min = Int.MaxValue + var max = 0 + def stk = new Throwable().getStackTrace.length + + @tailrec def f[A](n: Int)(op: Op[A]): A = + val depth = stk + min = min.min(depth) + max = max.max(depth) + given Ctx = Ctx() + if (n > 0) f(n-1)(op) + else op + + def g(ctx: Ctx) = stk + + def main(args: Array[String]): Unit = + val extra = 3 + f(10)(Ctx ?=> g(summon[Ctx])).tap(res => assert(res <= max + extra, s"min $min, max $max, ran g at $res")) +} \ No newline at end of file