Skip to content

Make implicit function variables safe as arguments (fixes #10889) #13750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class Compiler {
new ProtectedAccessors, // Add accessors for protected members
new ExtensionMethods, // Expand methods of value classes with extension methods
new UncacheGivenAliases, // Avoid caching RHS of simple parameterless given aliases
new ElimContextClosures, // Unwrap context closures that contain only a context function of compatible type
new ByNameClosures, // Expand arguments to by-name parameters to closures
new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope
new SpecializeApplyMethods, // Adds specialized methods to FunctionN
Expand Down
77 changes: 77 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/ElimContextClosures.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package dotty.tools
package dotc
package transform

import MegaPhase._
import core._
import Symbols._
import SymDenotations._
import Contexts._
import Types._
import Flags._
import Decorators._
import DenotTransformers._
import core.StdNames.nme
import core.StdNames
import ast.Trees._
import reporting.trace

/** Transforms function arguments which are context functions to
* avoid a build-up of redundant thunks when passed repeatedly,
* e.g. due to recursion.
*
* This is necessary because the compiler produces a contextual
* closure around values passed as arguments where a context function
* is expected, unless that value has the syntactic form of a context
* function literal.
*
* This makes for very ergonomic client code, but the implementation
* requires the wrapper to be genereated before type information is available.
* Thus, it can't be determine if the passed value is already a context function
* of the expected type, and the closure must be generated either way.
*
* Without this phase, when a contextual function is passed as an argument to a
* recursive function, that would have the unfortunate effect of a linear growth
* in transient thunks of identical type wrapped around each other, leading
* to performance degradation, and in some cases, stack overflows.
*
* For additional reading material, please refer to the Simplicitly paper and/or
* the discussion at https://github.com/lampepfl/dotty/issues/10889
*/
class ElimContextClosures extends MiniPhase with IdentityDenotTransformer { thisPhase: DenotTransformer =>
import ast.tpd._
import ast.untpd

override def phaseName:String = ElimContextClosures.name

override def transformApply(tree: Apply)(using Context): Tree =
trace(s"transforming ${tree.show} at phase ${ctx.phase}", show = true) {

def transformArg(arg: Tree, formal: Type): Tree = {
val formal1 = formal.widenDealias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, widen is enough here. isContextFunction and isSubType already do dealiasing. (It's better to be consistent here since otherwise the next people looking at the compiler see multiple styles and get confused).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Presumably this applies to underlyingBodyType as well as formal1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorted in 009f18b.

if defn.isContextFunctionType(formal1) && untpd.isContextualClosure(arg) then
val body = unsplice(closureBody(arg)) match {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be better to start like this

unsplice(closureBody(arg)) match
  case Apply(Select(body, nme.apply, _) => 
    ... // transformation code
  case _ =>
    arg

If the argument is not a nested closure there's no need to go through the following tests, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

closureBody returns arg if arg is not a closure, so we would have to essentially reimplement the closureBody logic inline here to make sure we aren't destructuring a random Apply for no reason. I assumed checking the outer structure first with this if would be more efficient / cleaner.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I meant: after the if defn.isContextFunctionType ..., instead of val body = ... continue with what I suggested.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, I see now, yes. Should be addressed in 009f18b.

case Apply(Select(fn, nme.apply), _) => fn
case other => other
} // no-op if not a nested closure of some kind
val underlyingBodyType = body.tpe.widenDealias
val bodyIsContextual = defn.isContextFunctionType(underlyingBodyType)
val bodyTypeMatches = TypeComparer.isSubType(underlyingBodyType, formal1)
if bodyIsContextual && bodyTypeMatches then
body
else
arg

else
arg
}

val mt @ MethodType(_) = tree.fun.tpe.widen
val args1 = tree.args.zipWithConserve(mt.paramInfos)(transformArg)
cpy.Apply(tree)(tree.fun, args1)
}
}

object ElimContextClosures {
val name: String = "elimContextClosures"
}
1 change: 1 addition & 0 deletions tests/run/contextual-closure-unwrapping.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
true
84 changes: 84 additions & 0 deletions tests/run/contextual-closure-unwrapping.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import scala.annotation.tailrec

import java.lang.management.ManagementFactory
import java.lang.management.RuntimeMXBean

import scala.jdk.CollectionConverters._

object Test {
trait Txn {}
type AtomicOp[Z] = Txn ?=> Z
type VanillaAtomicOp[Z] = Txn => Z

object AbortAndRetry extends scala.util.control.ControlThrowable("abort and retry") {}

def beginTxn:Txn = new Txn {}

@tailrec def retryN[Z](n:Int)(txn:AtomicOp[Z], i:Int = 0):Z = {
try {
given Txn = beginTxn
val ret:Z = txn
if(i < n) { throw AbortAndRetry }
ret
} catch {
case AbortAndRetry => retryN(n)(txn, i + 1)
}
}


@tailrec def safeRetryN[Z](n:Int)(txn:VanillaAtomicOp[Z], i:Int = 0):Z = {
try {
given Txn = beginTxn
val ret:Z = txn.asInstanceOf[AtomicOp[Z]]
if(i < n) { throw AbortAndRetry }
ret
} catch {
case AbortAndRetry => safeRetryN(n)(txn, i + 1)
}
}

object StackSize {
def unapply(arg:String):Option[Int] = {
Option(arg match {
case s"-Xss$rest" => rest
case s"-XX:ThreadStackSize=$rest" => rest
case _ => null
}).map{ rest =>
val shift = (rest.toLowerCase.last match {
case 'k' => 10
case 'm' => 20
case 'g' => 30
case _ => 0
})
(if(shift > 0) {
rest.dropRight(1)
} else {
rest
}).toInt << shift
}
}
}
def main(args:Array[String]) = {
val arguments = ManagementFactory.getRuntimeMXBean.getInputArguments
// 64bit VM defaults to 1024K / 1M stack size, 32bit defaults to 320k
// Use 1024 as upper bound.
val maxStackSize:Int = arguments.asScala.reverseIterator.collectFirst{case StackSize(stackSize) => stackSize}.getOrElse(1 << 20)

val testTxn:VanillaAtomicOp[Boolean] = {
(txn:Txn) =>
given Txn = txn
true
}

Console.println(try {
// maxStackSize is a good upper bound on linear stack growth
// without assuming too much about frame size
// (1 byte / frame is conservative even for a z80)
retryN(maxStackSize)(testTxn.asInstanceOf[AtomicOp[Boolean]]) == safeRetryN(maxStackSize)(testTxn)
} catch {
case e:StackOverflowError =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it easier to just probe stack depth with new Throwable().getStackTrace? for n and n+1 recursions.

(Generally wary of tests for SOE and OOM.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I see what you have in mind, can you be a little bit more explicit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apologies in advance if I'm missing the whole point, I'm still reading the issue. I was thinking of https://github.com/scala/scala/blob/2.13.x/test/junit/scala/collection/IteratorTest.scala#L36 although counting frames is subject to optimizations.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sure, we could have the return value from testTxn be the current stack depth and then check retryN with different iteration counts. Happy to do that if it's preferable.

Console.println(s"Exploded after ${e.getStackTrace.length} frames")
false
})
}
}