diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala
index fcde37fabe8f..d27497d27f66 100644
--- a/compiler/src/dotty/tools/dotc/ast/tpd.scala
+++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala
@@ -130,6 +130,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def Return(expr: Tree, from: Tree)(implicit ctx: Context): Return =
ta.assignType(untpd.Return(expr, from))
+ def Return(expr: Tree, from: Symbol)(implicit ctx: Context): Return =
+ Return(expr, Ident(from.termRef))
+
def WhileDo(cond: Tree, body: Tree)(implicit ctx: Context): WhileDo =
ta.assignType(untpd.WhileDo(cond, body))
diff --git a/compiler/src/dotty/tools/dotc/core/NameKinds.scala b/compiler/src/dotty/tools/dotc/core/NameKinds.scala
index 508479c8dbff..42955de15a50 100644
--- a/compiler/src/dotty/tools/dotc/core/NameKinds.scala
+++ b/compiler/src/dotty/tools/dotc/core/NameKinds.scala
@@ -286,6 +286,8 @@ object NameKinds {
val NonLocalReturnKeyName: UniqueNameKind = new UniqueNameKind("nonLocalReturnKey")
val WildcardParamName: UniqueNameKind = new UniqueNameKind("_$")
val TailLabelName: UniqueNameKind = new UniqueNameKind("tailLabel")
+ val TailLocalName: UniqueNameKind = new UniqueNameKind("$tailLocal")
+ val TailTempName: UniqueNameKind = new UniqueNameKind("$tmp")
val ExceptionBinderName: UniqueNameKind = new UniqueNameKind("ex")
val SkolemName: UniqueNameKind = new UniqueNameKind("?")
val LiftedTreeName: UniqueNameKind = new UniqueNameKind("liftedTree")
diff --git a/compiler/src/dotty/tools/dotc/transform/TailRec.scala b/compiler/src/dotty/tools/dotc/transform/TailRec.scala
index bde3d80b98ce..f0a5dbaa1fc7 100644
--- a/compiler/src/dotty/tools/dotc/transform/TailRec.scala
+++ b/compiler/src/dotty/tools/dotc/transform/TailRec.scala
@@ -4,64 +4,102 @@ package transform
import ast.Trees._
import ast.{TreeTypeMap, tpd}
import core._
+import Constants.Constant
import Contexts.Context
import Decorators._
import Symbols._
import StdNames.nme
import Types._
-import NameKinds.TailLabelName
+import NameKinds.{TailLabelName, TailLocalName, TailTempName}
import MegaPhase.MiniPhase
import reporting.diagnostic.messages.TailrecNotApplicable
-/**
- * A Tail Rec Transformer
- * @author Erik Stenman, Iulian Dragos,
- * ported and heavily modified for dotty by Dmitry Petrashko
- * moved after erasure by Sébastien Doeraene
- * @version 1.1
+/** A Tail Rec Transformer.
*
- * What it does:
- *
- * Finds method calls in tail-position and replaces them with jumps.
- * A call is in a tail-position if it is the last instruction to be
- * executed in the body of a method. This includes being in
- * tail-position of a `return` from a `Labeled` block which is itself
- * in tail-position (which is critical for tail-recursive calls in the
- * cases of a `match`). To identify tail positions, we recurse over
- * the trees that may contain calls in tail-position (trees that can't
- * contain such calls are not transformed).
- *
- *
- * Self-recursive calls in tail-position are replaced by jumps to a
- * label at the beginning of the method. As the JVM provides no way to
- * jump from a method to another one, non-recursive calls in
- * tail-position are not optimized.
- *
- *
- * A method call is self-recursive if it calls the current method and
- * the method is final (otherwise, it could
- * be a call to an overridden method in a subclass).
- * Recursive calls on a different instance are optimized. Since 'this'
- * is not a local variable it is added as a label parameter.
- *
- *
- * This phase has been moved after erasure to allow the use of vars
- * for the parameters combined with a `WhileDo` (upcoming change).
- * This is also beneficial to support polymorphic tail-recursive
- * calls.
- *
- *
- * If a method contains self-recursive calls, a label is added to at
- * the beginning of its body and the calls are replaced by jumps to
- * that label.
- *
- *
- * In scalac, if the method had type parameters, the call must contain
- * the same parameters as type arguments. This is no longer the case in
- * dotc thanks to being located after erasure.
- * In scalac, this is named tailCall but it does only provide optimization for
- * self recursive functions, that's why it's renamed to tailrec
- *
+ * What it does:
+ *
+ * Finds method calls in tail-position and replaces them with jumps.
+ * A call is in a tail-position if it is the last instruction to be
+ * executed in the body of a method. This includes being in
+ * tail-position of a `return` from a `Labeled` block which is itself
+ * in tail-position (which is critical for tail-recursive calls in the
+ * cases of a `match`). To identify tail positions, we recurse over
+ * the trees that may contain calls in tail-position (trees that can't
+ * contain such calls are not transformed).
+ *
+ * When a method contains at least one tail-recursive call, its rhs
+ * is wrapped in the following structure:
+ * {{{
+ * var localForParam1: T1 = param1
+ * ...
+ * while () {
+ * tailResult[ResultType]: {
+ * return {
+ * // original rhs with tail recursive calls transformed (see below)
+ * }
+ * }
+ * }
+ * }}}
+ *
+ * Self-recursive calls in tail-position are then replaced by (a)
+ * reassigning the local `var`s substituting formal parameters and
+ * (b) a `return` from the `tailResult` labeled block, which has the
+ * net effect of looping back to the beginning of the method.
+ * If the receiver is modifed in a recursive call, an additional `var`
+ * is used to replace `this`.
+ *
+ * As a complete example of the transformation, the classical `fact`
+ * function, defined as:
+ * {{{
+ * def fact(n: Int, acc: Int): Int =
+ * if (n == 0) acc
+ * else fact(n - 1, acc * n)
+ * }}}
+ * is rewritten as:
+ * {{{
+ * def fact(n: Int, acc: Int): Int = {
+ * var acc$tailLocal1: Int = acc
+ * var n$tailLocal1: Int = n
+ * while () {
+ * tailLabel1[Unit]: {
+ * return {
+ * if (n$tailLocal1 == 0)
+ * acc$tailLocal1
+ * else {
+ * val n$tailLocal1$tmp1: Int = n$tailLocal1 - 1
+ * val acc$tailLocal1$tmp1: Int = acc$tailLocal1 * n$tailLocal1
+ * n$tailLocal1 = n$tailLocal1$tmp1
+ * acc$tailLocal1 = acc$tailLocal1$tmp1
+ * (return[tailLabel1] ()): Int
+ * }
+ * }
+ * }
+ * }
+ * }
+ * }}}
+ *
+ * As the JVM provides no way to jump from a method to another one,
+ * non-recursive calls in tail-position are not optimized.
+ *
+ * A method call is self-recursive if it calls the current method and
+ * the method is final (otherwise, it could be a call to an overridden
+ * method in a subclass). Recursive calls on a different instance are
+ * optimized.
+ *
+ * This phase has been moved after erasure to allow the use of vars
+ * for the parameters combined with a `WhileDo`. This is also
+ * beneficial to support polymorphic tail-recursive calls.
+ *
+ * In scalac, if the method had type parameters, the call must contain
+ * the same parameters as type arguments. This is no longer the case in
+ * dotc thanks to being located after erasure.
+ * In scalac, this is named tailCall but it does only provide optimization for
+ * self recursive functions, that's why it's renamed to tailrec
+ *
+ * @author
+ * Erik Stenman, Iulian Dragos,
+ * ported and heavily modified for dotty by Dmitry Petrashko
+ * moved after erasure and adapted to emit `Labeled` blocks by Sébastien Doeraene
*/
class TailRec extends MiniPhase {
import TailRec._
@@ -72,25 +110,6 @@ class TailRec extends MiniPhase {
override def runsAfter: Set[String] = Set(Erasure.name) // tailrec assumes erased types
- final val labelFlags: Flags.FlagSet = Flags.Synthetic | Flags.Label | Flags.Method
-
- private def mkLabel(method: Symbol)(implicit ctx: Context): TermSymbol = {
- val name = TailLabelName.fresh()
-
- if (method.owner.isClass) {
- val MethodTpe(paramNames, paramInfos, resultType) = method.info
-
- val enclosingClass = method.enclosingClass.asClass
- val thisParamType =
- if (enclosingClass.is(Flags.Module)) enclosingClass.thisType
- else enclosingClass.classInfo.selfType
-
- ctx.newSymbol(method, name.toTermName, labelFlags,
- MethodType(nme.SELF :: paramNames, thisParamType :: paramInfos, resultType))
- }
- else ctx.newSymbol(method, name.toTermName, labelFlags, method.info)
- }
-
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree = {
val sym = tree.symbol
tree match {
@@ -100,62 +119,57 @@ class TailRec extends MiniPhase {
cpy.DefDef(dd)(rhs = {
val defIsTopLevel = sym.owner.isClass
val origMeth = sym
- val label = mkLabel(sym)
val owner = ctx.owner.enclosingClass.asClass
- var rewrote = false
-
// Note: this can be split in two separate transforms(in different groups),
// than first one will collect info about which transformations and rewritings should be applied
// and second one will actually apply,
// now this speculatively transforms tree and throws away result in many cases
- val rhsSemiTransformed = {
- val transformer = new TailRecElimination(origMeth, owner, mandatory, label)
- val rhs = transformer.transform(dd.rhs)
- rewrote = transformer.rewrote
- rhs
- }
-
- if (rewrote) {
- if (tree.symbol.owner.isClass) {
- val classSym = tree.symbol.owner.asClass
-
- val labelDef = DefDef(label, vrefss => {
- assert(vrefss.size == 1, vrefss)
- val vrefs = vrefss.head
- val thisRef = vrefs.head
- val origMeth = tree.symbol
- val origVParams = vparams.map(_.symbol)
+ val transformer = new TailRecElimination(origMeth, owner, vparams.map(_.symbol), mandatory)
+ val rhsSemiTransformed = transformer.transform(dd.rhs)
+
+ if (transformer.rewrote) {
+ val varForRewrittenThis = transformer.varForRewrittenThis
+ val rewrittenParamSyms = transformer.rewrittenParamSyms
+ val varsForRewrittenParamSyms = transformer.varsForRewrittenParamSyms
+
+ val initialVarDefs = {
+ val initialParamVarDefs = (rewrittenParamSyms, varsForRewrittenParamSyms).zipped.map {
+ (param, local) => ValDef(local.asTerm, ref(param))
+ }
+ varForRewrittenThis match {
+ case Some(local) => ValDef(local.asTerm, This(tree.symbol.owner.asClass)) :: initialParamVarDefs
+ case none => initialParamVarDefs
+ }
+ }
+
+ val rhsFullyTransformed = varForRewrittenThis match {
+ case Some(localThisSym) =>
+ val thisRef = localThisSym.termRef
new TreeTypeMap(
- typeMap = identity(_)
- .substThisUnlessStatic(classSym, thisRef.tpe)
- .subst(origVParams, vrefs.tail.map(_.tpe)),
+ typeMap = _.substThisUnlessStatic(owner, thisRef)
+ .subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef)),
treeMap = {
- case tree: This if tree.symbol == classSym => thisRef
+ case tree: This if tree.symbol == owner => Ident(thisRef)
case tree => tree
- },
- oldOwners = origMeth :: Nil,
- newOwners = label :: Nil
+ }
).transform(rhsSemiTransformed)
- })
- val callIntoLabel = ref(label).appliedToArgs(This(classSym) :: vparams.map(x => ref(x.symbol)))
- Block(List(labelDef), callIntoLabel)
- } else { // inner method. Tail recursion does not change `this`
- val labelDef = DefDef(label, vrefss => {
- assert(vrefss.size == 1, vrefss)
- val vrefs = vrefss.head
- val origMeth = tree.symbol
- val origVParams = vparams.map(_.symbol)
+
+ case none =>
new TreeTypeMap(
- typeMap = identity(_)
- .subst(origVParams, vrefs.map(_.tpe)),
- oldOwners = origMeth :: Nil,
- newOwners = label :: Nil
+ typeMap = _.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef))
).transform(rhsSemiTransformed)
+ }
+
+ Block(
+ initialVarDefs,
+ WhileDo(EmptyTree, {
+ Labeled(transformer.continueLabel.asTerm, {
+ Return(rhsFullyTransformed, origMeth)
+ })
})
- val callIntoLabel = ref(label).appliedToArgs(vparams.map(x => ref(x.symbol)))
- Block(List(labelDef), callIntoLabel)
- }} else {
+ )
+ } else {
if (mandatory) ctx.error(
"TailRec optimisation not applicable, method not tail recursive",
// FIXME: want to report this error on `dd.namePos`, but
@@ -174,12 +188,51 @@ class TailRec extends MiniPhase {
}
- class TailRecElimination(method: Symbol, enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap {
+ class TailRecElimination(method: Symbol, enclosingClass: Symbol, paramSyms: List[Symbol], isMandatory: Boolean) extends tpd.TreeMap {
import dotty.tools.dotc.ast.tpd._
var rewrote: Boolean = false
+ /** The `tailLabelN` label symbol, used to encode a `continue` from the infinite `while` loop. */
+ private[this] var myContinueLabel: Symbol = _
+ def continueLabel(implicit c: Context): Symbol = {
+ if (myContinueLabel == null)
+ myContinueLabel = c.newSymbol(method, TailLabelName.fresh(), Flags.Label, defn.UnitType)
+ myContinueLabel
+ }
+
+ /** The local `var` that replaces `this`, if it is modified in at least one recursive call. */
+ var varForRewrittenThis: Option[Symbol] = None
+ /** The subset of `paramSyms` that are modified in at least one recursive call, and which therefore need a replacement `var`. */
+ var rewrittenParamSyms: List[Symbol] = Nil
+ /** The replacement `var`s for the params in `rewrittenParamSyms`. */
+ var varsForRewrittenParamSyms: List[Symbol] = Nil
+
+ private def getVarForRewrittenThis()(implicit c: Context): Symbol = {
+ varForRewrittenThis match {
+ case Some(sym) => sym
+ case none =>
+ val tpe =
+ if (enclosingClass.is(Flags.Module)) enclosingClass.thisType
+ else enclosingClass.asClass.classInfo.selfType
+ val sym = c.newSymbol(method, nme.SELF, Flags.Synthetic | Flags.Mutable, tpe)
+ varForRewrittenThis = Some(sym)
+ sym
+ }
+ }
+
+ private def getVarForRewrittenParam(param: Symbol)(implicit c: Context): Symbol = {
+ rewrittenParamSyms.indexOf(param) match {
+ case -1 =>
+ val sym = c.newSymbol(method, TailLocalName.fresh(param.name.toTermName), Flags.Synthetic | Flags.Mutable, param.info)
+ rewrittenParamSyms ::= param
+ varsForRewrittenParamSyms ::= sym
+ sym
+ case index => varsForRewrittenParamSyms(index)
+ }
+ }
+
/** Symbols of Labeled blocks that are in tail position. */
private val tailPositionLabeledSyms = new collection.mutable.HashSet[Symbol]()
@@ -233,15 +286,44 @@ class TailRec extends MiniPhase {
if (ctx.tailPos) {
c.debuglog("Rewriting tail recursive call: " + tree.pos)
rewrote = true
- def receiver =
- if (prefix eq EmptyTree) This(enclosingClass.asClass)
- else noTailTransform(prefix)
- val argumentsWithReceiver =
- if (this.method.owner.isClass) receiver :: arguments
- else arguments
-
- tpd.cpy.Apply(tree)(ref(label), argumentsWithReceiver)
+ val assignParamPairs = for {
+ (param, arg) <- paramSyms.zip(arguments)
+ if (arg match {
+ case arg: Ident => arg.symbol != param
+ case _ => true
+ })
+ } yield {
+ (getVarForRewrittenParam(param), arg)
+ }
+
+ val assignThisAndParamPairs = {
+ if (prefix eq EmptyTree) assignParamPairs
+ else {
+ // TODO Opt: also avoid assigning `this` if the prefix is `this.`
+ (getVarForRewrittenThis(), noTailTransform(prefix)) :: assignParamPairs
+ }
+ }
+
+ val assignments = assignThisAndParamPairs match {
+ case (lhs, rhs) :: Nil =>
+ Assign(ref(lhs), rhs) :: Nil
+ case _ :: _ =>
+ val (tempValDefs, assigns) = (for ((lhs, rhs) <- assignThisAndParamPairs) yield {
+ val temp = c.newSymbol(method, TailTempName.fresh(lhs.name.toTermName), Flags.Synthetic, lhs.info)
+ (ValDef(temp, rhs), Assign(ref(lhs), ref(temp)).withPos(tree.pos))
+ }).unzip
+ tempValDefs ::: assigns
+ case nil =>
+ Nil
+ }
+
+ /* The `Typed` node is necessary to perfectly preserve the type of the node.
+ * Without it, lubbing in enclosing if/else or match can infer a different type,
+ * which can cause Ycheck errors.
+ */
+ val tpt = TypeTree(method.info.resultType)
+ seq(assignments, Typed(Return(Literal(Constant(())).withPos(tree.pos), continueLabel), tpt))
}
else fail("it is not in tail position")
} else {
diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
index a6d8d5b8f83e..72fb88bb249f 100644
--- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
+++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
@@ -452,6 +452,11 @@ class TreeChecker extends Phase with SymTransformer {
tree1
}
+ override def typedWhileDo(tree: untpd.WhileDo)(implicit ctx: Context): Tree = {
+ assert((tree.cond ne EmptyTree) || ctx.phase.refChecked, i"invalid empty condition in while at $tree")
+ super.typedWhileDo(tree)
+ }
+
override def ensureNoLocalRefs(tree: Tree, pt: Type, localSyms: => List[Symbol])(implicit ctx: Context): Tree =
tree
diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
index c02796a33106..6f56eafd7636 100644
--- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
+++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala
@@ -486,7 +486,7 @@ trait TypeAssigner {
tree.withType(defn.NothingType)
def assignType(tree: untpd.WhileDo)(implicit ctx: Context): WhileDo =
- tree.withType(defn.UnitType)
+ tree.withType(if (tree.cond eq EmptyTree) defn.NothingType else defn.UnitType)
def assignType(tree: untpd.Try, expr: Tree, cases: List[CaseDef])(implicit ctx: Context): Try =
if (cases.isEmpty) tree.withType(expr.tpe)
diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala
index 9c3901128aa2..c2fe93d32475 100644
--- a/compiler/src/dotty/tools/dotc/typer/Typer.scala
+++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala
@@ -1127,7 +1127,9 @@ class Typer extends Namer
}
def typedWhileDo(tree: untpd.WhileDo)(implicit ctx: Context): Tree = track("typedWhileDo") {
- val cond1 = typed(tree.cond, defn.BooleanType)
+ val cond1 =
+ if (tree.cond eq EmptyTree) EmptyTree
+ else typed(tree.cond, defn.BooleanType)
val body1 = typed(tree.body, defn.UnitType)
assignType(cpy.WhileDo(tree)(cond1, body1))
}
diff --git a/scala-backend b/scala-backend
index 47eaa7fbbc65..18b3f30e8c30 160000
--- a/scala-backend
+++ b/scala-backend
@@ -1 +1 @@
-Subproject commit 47eaa7fbbc6549e4ad8a07c1d644e947eef9b9bd
+Subproject commit 18b3f30e8c302ddd0e404f1b3dacf877eb6257c5