From 804b7866fc12fd94be3b8faefa64284893f384db Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 4 Jan 2022 19:06:16 +0100 Subject: [PATCH 1/4] Add eta reduction miniphase Fixes #10889. Alternative to #13750, from which the test is taken. --- compiler/src/dotty/tools/dotc/Compiler.scala | 1 + .../tools/dotc/transform/EtaReduce.scala | 46 ++++++++++ tests/run/contextual-closure-unwrapping.check | 1 + tests/run/contextual-closure-unwrapping.scala | 83 +++++++++++++++++++ 4 files changed, 131 insertions(+) create mode 100644 compiler/src/dotty/tools/dotc/transform/EtaReduce.scala create mode 100644 tests/run/contextual-closure-unwrapping.check create mode 100644 tests/run/contextual-closure-unwrapping.scala diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 736e1b08d1e7..2bdf33b259fc 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -106,6 +106,7 @@ class Compiler { List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new PureStats, // Remove pure stats from blocks new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations + new EtaReduce, new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]` new sjs.AddLocalJSFakeNews, // Adds fake new invocations to local JS classes in calls to `createLocalJSClass` new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses diff --git a/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala b/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala new file mode 100644 index 000000000000..1689eb62caa4 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala @@ -0,0 +1,46 @@ +package dotty.tools +package dotc +package transform + +import MegaPhase.MiniPhase +import core.* +import Symbols.*, Contexts.*, Types.*, Decorators.* +import StdNames.nme +import ast.Trees.* + +/** Rewrite `(x1, ... xN) => f(x1, ... xN)` for N >= 0 to `f`, + * provided `f` is a pure path of function type. + * + * This optimization is crucial for context functions. The compiler + * produces a contextual closure around values passed as arguments + * where a context function is expected, unless that value has the + * syntactic form of a context function literal. + * + * Without this phase, when a contextual function is passed as an argument to a + * recursive function, that would have the unfortunate effect of a linear growth + * in transient thunks of identical type wrapped around each other, leading + * to performance degradation, and in some cases, stack overflows. + */ +class EtaReduce extends MiniPhase: + import ast.tpd._ + + override def phaseName: String = "etaReduce" + + override def transformBlock(tree: Block)(using Context): Tree = tree match + case Block((meth : DefDef) :: Nil, closure: Closure) + if meth.symbol == closure.meth.symbol => + meth.rhs match + case Apply(Select(fn, nme.apply), args) + if meth.paramss.head.corresponds(args)((param, arg) => + arg.isInstanceOf[Ident] && arg.symbol == param.symbol) + && isPurePath(fn) => + val treeSym = tree.tpe.widen.typeSymbol + val fnSym = fn.tpe.widen.typeSymbol + if treeSym == fnSym && defn.isFunctionClass(fnSym) then + report.log(i"eta reducing $tree --> $fn") + fn + else tree + case _ => tree + case _ => tree + +end EtaReduce \ No newline at end of file diff --git a/tests/run/contextual-closure-unwrapping.check b/tests/run/contextual-closure-unwrapping.check new file mode 100644 index 000000000000..27ba77ddaf61 --- /dev/null +++ b/tests/run/contextual-closure-unwrapping.check @@ -0,0 +1 @@ +true diff --git a/tests/run/contextual-closure-unwrapping.scala b/tests/run/contextual-closure-unwrapping.scala new file mode 100644 index 000000000000..18a1e5f9e5e1 --- /dev/null +++ b/tests/run/contextual-closure-unwrapping.scala @@ -0,0 +1,83 @@ +import scala.annotation.tailrec + +import java.lang.management.ManagementFactory +import java.lang.management.RuntimeMXBean + +import scala.jdk.CollectionConverters._ + +object Test { + trait Txn {} + type AtomicOp[Z] = Txn ?=> Z + type VanillaAtomicOp[Z] = Txn => Z + + object AbortAndRetry extends scala.util.control.ControlThrowable("abort and retry") {} + + def beginTxn:Txn = new Txn {} + + @tailrec def retryN[Z](n:Int)(txn:AtomicOp[Z], i:Int = 0):Z = { + try { + given Txn = beginTxn + val ret:Z = txn + if(i < n) { throw AbortAndRetry } + ret + } catch { + case AbortAndRetry => retryN(n)(txn, i + 1) + } + } + + @tailrec def safeRetryN[Z](n:Int)(txn:VanillaAtomicOp[Z], i:Int = 0):Z = { + try { + given Txn = beginTxn + val ret:Z = txn.asInstanceOf[AtomicOp[Z]] + if(i < n) { throw AbortAndRetry } + ret + } catch { + case AbortAndRetry => safeRetryN(n)(txn, i + 1) + } + } + + object StackSize { + def unapply(arg:String):Option[Int] = { + Option(arg match { + case s"-Xss$rest" => rest + case s"-XX:ThreadStackSize=$rest" => rest + case _ => null + }).map{ rest => + val shift = (rest.toLowerCase.last match { + case 'k' => 10 + case 'm' => 20 + case 'g' => 30 + case _ => 0 + }) + (if(shift > 0) { + rest.dropRight(1) + } else { + rest + }).toInt << shift + } + } + } + def main(args:Array[String]) = { + val arguments = ManagementFactory.getRuntimeMXBean.getInputArguments + // 64bit VM defaults to 1024K / 1M stack size, 32bit defaults to 320k + // Use 1024 as upper bound. + val maxStackSize:Int = arguments.asScala.reverseIterator.collectFirst{case StackSize(stackSize) => stackSize}.getOrElse(1 << 20) + + val testTxn:VanillaAtomicOp[Boolean] = { + (txn:Txn) => + given Txn = txn + true + } + + Console.println(try { + // maxStackSize is a good upper bound on linear stack growth + // without assuming too much about frame size + // (1 byte / frame is conservative even for a z80) + retryN(maxStackSize)(testTxn.asInstanceOf[AtomicOp[Boolean]]) == safeRetryN(maxStackSize)(testTxn) + } catch { + case e:StackOverflowError => + Console.println(s"Exploded after ${e.getStackTrace.length} frames") + false + }) + } +} \ No newline at end of file From a9d2706dfcb776183b37132a428e2c6391b81bc8 Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 4 Jan 2022 19:57:47 +0100 Subject: [PATCH 2/4] Simplify type match test --- .../src/dotty/tools/dotc/transform/EtaReduce.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala b/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala index 1689eb62caa4..8afdb5121e19 100644 --- a/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala +++ b/compiler/src/dotty/tools/dotc/transform/EtaReduce.scala @@ -33,13 +33,11 @@ class EtaReduce extends MiniPhase: case Apply(Select(fn, nme.apply), args) if meth.paramss.head.corresponds(args)((param, arg) => arg.isInstanceOf[Ident] && arg.symbol == param.symbol) - && isPurePath(fn) => - val treeSym = tree.tpe.widen.typeSymbol - val fnSym = fn.tpe.widen.typeSymbol - if treeSym == fnSym && defn.isFunctionClass(fnSym) then - report.log(i"eta reducing $tree --> $fn") - fn - else tree + && isPurePath(fn) + && fn.tpe <:< tree.tpe + && defn.isFunctionClass(fn.tpe.widen.typeSymbol) => + report.log(i"eta reducing $tree --> $fn") + fn case _ => tree case _ => tree From 18a522749aee5bf3bf59848351db3cd10c1fc56e Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Tue, 4 Jan 2022 22:02:58 +0100 Subject: [PATCH 3/4] Simpler test --- tests/run/contextual-closure-unwrapping.check | 1 - tests/run/contextual-closure-unwrapping.scala | 83 ------------------- tests/run/i10889.scala | 25 ++++++ 3 files changed, 25 insertions(+), 84 deletions(-) delete mode 100644 tests/run/contextual-closure-unwrapping.check delete mode 100644 tests/run/contextual-closure-unwrapping.scala create mode 100644 tests/run/i10889.scala diff --git a/tests/run/contextual-closure-unwrapping.check b/tests/run/contextual-closure-unwrapping.check deleted file mode 100644 index 27ba77ddaf61..000000000000 --- a/tests/run/contextual-closure-unwrapping.check +++ /dev/null @@ -1 +0,0 @@ -true diff --git a/tests/run/contextual-closure-unwrapping.scala b/tests/run/contextual-closure-unwrapping.scala deleted file mode 100644 index 18a1e5f9e5e1..000000000000 --- a/tests/run/contextual-closure-unwrapping.scala +++ /dev/null @@ -1,83 +0,0 @@ -import scala.annotation.tailrec - -import java.lang.management.ManagementFactory -import java.lang.management.RuntimeMXBean - -import scala.jdk.CollectionConverters._ - -object Test { - trait Txn {} - type AtomicOp[Z] = Txn ?=> Z - type VanillaAtomicOp[Z] = Txn => Z - - object AbortAndRetry extends scala.util.control.ControlThrowable("abort and retry") {} - - def beginTxn:Txn = new Txn {} - - @tailrec def retryN[Z](n:Int)(txn:AtomicOp[Z], i:Int = 0):Z = { - try { - given Txn = beginTxn - val ret:Z = txn - if(i < n) { throw AbortAndRetry } - ret - } catch { - case AbortAndRetry => retryN(n)(txn, i + 1) - } - } - - @tailrec def safeRetryN[Z](n:Int)(txn:VanillaAtomicOp[Z], i:Int = 0):Z = { - try { - given Txn = beginTxn - val ret:Z = txn.asInstanceOf[AtomicOp[Z]] - if(i < n) { throw AbortAndRetry } - ret - } catch { - case AbortAndRetry => safeRetryN(n)(txn, i + 1) - } - } - - object StackSize { - def unapply(arg:String):Option[Int] = { - Option(arg match { - case s"-Xss$rest" => rest - case s"-XX:ThreadStackSize=$rest" => rest - case _ => null - }).map{ rest => - val shift = (rest.toLowerCase.last match { - case 'k' => 10 - case 'm' => 20 - case 'g' => 30 - case _ => 0 - }) - (if(shift > 0) { - rest.dropRight(1) - } else { - rest - }).toInt << shift - } - } - } - def main(args:Array[String]) = { - val arguments = ManagementFactory.getRuntimeMXBean.getInputArguments - // 64bit VM defaults to 1024K / 1M stack size, 32bit defaults to 320k - // Use 1024 as upper bound. - val maxStackSize:Int = arguments.asScala.reverseIterator.collectFirst{case StackSize(stackSize) => stackSize}.getOrElse(1 << 20) - - val testTxn:VanillaAtomicOp[Boolean] = { - (txn:Txn) => - given Txn = txn - true - } - - Console.println(try { - // maxStackSize is a good upper bound on linear stack growth - // without assuming too much about frame size - // (1 byte / frame is conservative even for a z80) - retryN(maxStackSize)(testTxn.asInstanceOf[AtomicOp[Boolean]]) == safeRetryN(maxStackSize)(testTxn) - } catch { - case e:StackOverflowError => - Console.println(s"Exploded after ${e.getStackTrace.length} frames") - false - }) - } -} \ No newline at end of file diff --git a/tests/run/i10889.scala b/tests/run/i10889.scala new file mode 100644 index 000000000000..130b6d2d9636 --- /dev/null +++ b/tests/run/i10889.scala @@ -0,0 +1,25 @@ +import scala.annotation.tailrec +import scala.util.chaining.given + +object Test { + class Ctx + type Op[A] = Ctx ?=> A + + var min = Int.MaxValue + var max = 0 + def stk = new Throwable().getStackTrace.length + + @tailrec def f[A](n: Int)(op: Op[A]): A = + val depth = stk + min = min.min(depth) + max = max.max(depth) + given Ctx = Ctx() + if (n > 0) f(n-1)(op) + else op + + def g(ctx: Ctx) = stk + + def main(args: Array[String]): Unit = + val extra = 3 + f(10)(Ctx ?=> g(summon[Ctx])).tap(res => assert(res <= max + extra, s"min $min, max $max, ran g at $res")) +} \ No newline at end of file From 463448175e91b0bcf65b6c71509539b6c45a5c7c Mon Sep 17 00:00:00 2001 From: Martin Odersky Date: Wed, 5 Jan 2022 11:37:20 +0100 Subject: [PATCH 4/4] Add comment to phase --- compiler/src/dotty/tools/dotc/Compiler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 2bdf33b259fc..14664af30c9e 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -106,7 +106,7 @@ class Compiler { List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new PureStats, // Remove pure stats from blocks new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations - new EtaReduce, + new EtaReduce, // Reduce eta expansions of pure paths to the underlying function reference new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]` new sjs.AddLocalJSFakeNews, // Adds fake new invocations to local JS classes in calls to `createLocalJSClass` new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses