diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 736e1b08d1e7..44ca4015ac97 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -73,6 +73,7 @@ class Compiler { new ProtectedAccessors, // Add accessors for protected members new ExtensionMethods, // Expand methods of value classes with extension methods new UncacheGivenAliases, // Avoid caching RHS of simple parameterless given aliases + new ElimContextClosures, // Unwrap context closures that contain only a context function of compatible type new ByNameClosures, // Expand arguments to by-name parameters to closures new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope new SpecializeApplyMethods, // Adds specialized methods to FunctionN diff --git a/compiler/src/dotty/tools/dotc/transform/ElimContextClosures.scala b/compiler/src/dotty/tools/dotc/transform/ElimContextClosures.scala new file mode 100644 index 000000000000..e6d5ca527a31 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/ElimContextClosures.scala @@ -0,0 +1,76 @@ +package dotty.tools +package dotc +package transform + +import MegaPhase._ +import core._ +import Symbols._ +import SymDenotations._ +import Contexts._ +import Types._ +import Flags._ +import Decorators._ +import DenotTransformers._ +import core.StdNames.nme +import core.StdNames +import ast.Trees._ +import reporting.trace + +/** Transforms function arguments which are context functions to + * avoid a build-up of redundant thunks when passed repeatedly, + * e.g. due to recursion. + * + * This is necessary because the compiler produces a contextual + * closure around values passed as arguments where a context function + * is expected, unless that value has the syntactic form of a context + * function literal. + * + * This makes for very ergonomic client code, but the implementation + * requires the wrapper to be generated before type information is available. + * Thus, it can't be determined if the passed value is already a context function + * of the expected type, and the closure must be generated either way. + * + * Without this phase, when a contextual function is passed as an argument to a + * recursive function, that would have the unfortunate effect of a linear growth + * in transient thunks of identical type wrapped around each other, leading + * to performance degradation, and in some cases, stack overflows. + * + * For additional reading material, please refer to the Simplicitly paper and/or + * the discussion at https://github.com/lampepfl/dotty/issues/10889 + */ +class ElimContextClosures extends MiniPhase with IdentityDenotTransformer { thisPhase: DenotTransformer => + import ast.tpd._ + import ast.untpd + + override def phaseName:String = ElimContextClosures.name + + override def transformApply(tree: Apply)(using Context): Tree = + trace(s"transforming ${tree.show} at phase ${ctx.phase}", show = true) { + + def transformArg(arg: Tree, formal: Type): Tree = { + val formal1 = formal.widen + if defn.isContextFunctionType(formal1) && untpd.isContextualClosure(arg) then + unsplice(closureBody(arg)) match { + case Apply(Select(body, nme.apply), _) => + val underlyingBodyType = body.tpe.widen + val bodyIsContextual = defn.isContextFunctionType(underlyingBodyType) + val bodyTypeMatches = TypeComparer.isSubType(underlyingBodyType, formal1) + if bodyIsContextual && bodyTypeMatches then + body + else + arg + case other => other + } // no-op if not a nested closure of some kind + else + arg + } + + val mt @ MethodType(_) = tree.fun.tpe.widen + val args1 = tree.args.zipWithConserve(mt.paramInfos)(transformArg) + cpy.Apply(tree)(tree.fun, args1) + } +} + +object ElimContextClosures { + val name: String = "elimContextClosures" +} \ No newline at end of file diff --git a/tests/run/contextual-closure-unwrapping.check b/tests/run/contextual-closure-unwrapping.check new file mode 100644 index 000000000000..27ba77ddaf61 --- /dev/null +++ b/tests/run/contextual-closure-unwrapping.check @@ -0,0 +1 @@ +true diff --git a/tests/run/contextual-closure-unwrapping.scala b/tests/run/contextual-closure-unwrapping.scala new file mode 100644 index 000000000000..448131816b28 --- /dev/null +++ b/tests/run/contextual-closure-unwrapping.scala @@ -0,0 +1,84 @@ +import scala.annotation.tailrec + +import java.lang.management.ManagementFactory +import java.lang.management.RuntimeMXBean + +import scala.jdk.CollectionConverters._ + +object Test { + trait Txn {} + type AtomicOp[Z] = Txn ?=> Z + type VanillaAtomicOp[Z] = Txn => Z + + object AbortAndRetry extends scala.util.control.ControlThrowable("abort and retry") {} + + def beginTxn:Txn = new Txn {} + + @tailrec def retryN[Z](n:Int)(txn:AtomicOp[Z], i:Int = 0):Z = { + try { + given Txn = beginTxn + val ret:Z = txn + if(i < n) { throw AbortAndRetry } + ret + } catch { + case AbortAndRetry => retryN(n)(txn, i + 1) + } + } + + + @tailrec def safeRetryN[Z](n:Int)(txn:VanillaAtomicOp[Z], i:Int = 0):Z = { + try { + given Txn = beginTxn + val ret:Z = txn.asInstanceOf[AtomicOp[Z]] + if(i < n) { throw AbortAndRetry } + ret + } catch { + case AbortAndRetry => safeRetryN(n)(txn, i + 1) + } + } + + object StackSize { + def unapply(arg:String):Option[Int] = { + Option(arg match { + case s"-Xss$rest" => rest + case s"-XX:ThreadStackSize=$rest" => rest + case _ => null + }).map{ rest => + val shift = (rest.toLowerCase.last match { + case 'k' => 10 + case 'm' => 20 + case 'g' => 30 + case _ => 0 + }) + (if(shift > 0) { + rest.dropRight(1) + } else { + rest + }).toInt << shift + } + } + } + def main(args:Array[String]) = { + val arguments = ManagementFactory.getRuntimeMXBean.getInputArguments + // 64bit VM defaults to 1024K / 1M stack size, 32bit defaults to 320k + // Use 1024 as upper bound. + val maxStackSize:Int = arguments.asScala.reverseIterator.collectFirst{case StackSize(stackSize) => stackSize}.getOrElse(1 << 20) + + val testTxn:VanillaAtomicOp[Boolean] = { + (txn:Txn) => + given Txn = txn + true + } + + Console.println(try { + // maxStackSize is a good upper bound on linear stack growth + // without assuming too much about frame size + // (1 byte / frame is conservative even for a z80) + retryN(maxStackSize)(testTxn.asInstanceOf[AtomicOp[Boolean]]) == safeRetryN(maxStackSize)(testTxn) + } catch { + case e:StackOverflowError => + Console.println(s"Exploded after ${e.getStackTrace.length} frames") + false + }) + } +}