diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index fcde37fabe8f..d27497d27f66 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -130,6 +130,9 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { def Return(expr: Tree, from: Tree)(implicit ctx: Context): Return = ta.assignType(untpd.Return(expr, from)) + def Return(expr: Tree, from: Symbol)(implicit ctx: Context): Return = + Return(expr, Ident(from.termRef)) + def WhileDo(cond: Tree, body: Tree)(implicit ctx: Context): WhileDo = ta.assignType(untpd.WhileDo(cond, body)) diff --git a/compiler/src/dotty/tools/dotc/core/NameKinds.scala b/compiler/src/dotty/tools/dotc/core/NameKinds.scala index 508479c8dbff..42955de15a50 100644 --- a/compiler/src/dotty/tools/dotc/core/NameKinds.scala +++ b/compiler/src/dotty/tools/dotc/core/NameKinds.scala @@ -286,6 +286,8 @@ object NameKinds { val NonLocalReturnKeyName: UniqueNameKind = new UniqueNameKind("nonLocalReturnKey") val WildcardParamName: UniqueNameKind = new UniqueNameKind("_$") val TailLabelName: UniqueNameKind = new UniqueNameKind("tailLabel") + val TailLocalName: UniqueNameKind = new UniqueNameKind("$tailLocal") + val TailTempName: UniqueNameKind = new UniqueNameKind("$tmp") val ExceptionBinderName: UniqueNameKind = new UniqueNameKind("ex") val SkolemName: UniqueNameKind = new UniqueNameKind("?") val LiftedTreeName: UniqueNameKind = new UniqueNameKind("liftedTree") diff --git a/compiler/src/dotty/tools/dotc/transform/TailRec.scala b/compiler/src/dotty/tools/dotc/transform/TailRec.scala index bde3d80b98ce..f0a5dbaa1fc7 100644 --- a/compiler/src/dotty/tools/dotc/transform/TailRec.scala +++ b/compiler/src/dotty/tools/dotc/transform/TailRec.scala @@ -4,64 +4,102 @@ package transform import ast.Trees._ import ast.{TreeTypeMap, tpd} import core._ +import Constants.Constant import Contexts.Context import Decorators._ import Symbols._ import StdNames.nme import Types._ -import NameKinds.TailLabelName +import NameKinds.{TailLabelName, TailLocalName, TailTempName} import MegaPhase.MiniPhase import reporting.diagnostic.messages.TailrecNotApplicable -/** - * A Tail Rec Transformer - * @author Erik Stenman, Iulian Dragos, - * ported and heavily modified for dotty by Dmitry Petrashko - * moved after erasure by Sébastien Doeraene - * @version 1.1 +/** A Tail Rec Transformer. * - * What it does: - *

- * Finds method calls in tail-position and replaces them with jumps. - * A call is in a tail-position if it is the last instruction to be - * executed in the body of a method. This includes being in - * tail-position of a `return` from a `Labeled` block which is itself - * in tail-position (which is critical for tail-recursive calls in the - * cases of a `match`). To identify tail positions, we recurse over - * the trees that may contain calls in tail-position (trees that can't - * contain such calls are not transformed). - *

- *

- * Self-recursive calls in tail-position are replaced by jumps to a - * label at the beginning of the method. As the JVM provides no way to - * jump from a method to another one, non-recursive calls in - * tail-position are not optimized. - *

- *

- * A method call is self-recursive if it calls the current method and - * the method is final (otherwise, it could - * be a call to an overridden method in a subclass). - * Recursive calls on a different instance are optimized. Since 'this' - * is not a local variable it is added as a label parameter. - *

- *

- * This phase has been moved after erasure to allow the use of vars - * for the parameters combined with a `WhileDo` (upcoming change). - * This is also beneficial to support polymorphic tail-recursive - * calls. - *

- *

- * If a method contains self-recursive calls, a label is added to at - * the beginning of its body and the calls are replaced by jumps to - * that label. - *

- *

- * In scalac, if the method had type parameters, the call must contain - * the same parameters as type arguments. This is no longer the case in - * dotc thanks to being located after erasure. - * In scalac, this is named tailCall but it does only provide optimization for - * self recursive functions, that's why it's renamed to tailrec - *

+ * What it does: + * + * Finds method calls in tail-position and replaces them with jumps. + * A call is in a tail-position if it is the last instruction to be + * executed in the body of a method. This includes being in + * tail-position of a `return` from a `Labeled` block which is itself + * in tail-position (which is critical for tail-recursive calls in the + * cases of a `match`). To identify tail positions, we recurse over + * the trees that may contain calls in tail-position (trees that can't + * contain such calls are not transformed). + * + * When a method contains at least one tail-recursive call, its rhs + * is wrapped in the following structure: + * {{{ + * var localForParam1: T1 = param1 + * ... + * while () { + * tailResult[ResultType]: { + * return { + * // original rhs with tail recursive calls transformed (see below) + * } + * } + * } + * }}} + * + * Self-recursive calls in tail-position are then replaced by (a) + * reassigning the local `var`s substituting formal parameters and + * (b) a `return` from the `tailResult` labeled block, which has the + * net effect of looping back to the beginning of the method. + * If the receiver is modifed in a recursive call, an additional `var` + * is used to replace `this`. + * + * As a complete example of the transformation, the classical `fact` + * function, defined as: + * {{{ + * def fact(n: Int, acc: Int): Int = + * if (n == 0) acc + * else fact(n - 1, acc * n) + * }}} + * is rewritten as: + * {{{ + * def fact(n: Int, acc: Int): Int = { + * var acc$tailLocal1: Int = acc + * var n$tailLocal1: Int = n + * while () { + * tailLabel1[Unit]: { + * return { + * if (n$tailLocal1 == 0) + * acc$tailLocal1 + * else { + * val n$tailLocal1$tmp1: Int = n$tailLocal1 - 1 + * val acc$tailLocal1$tmp1: Int = acc$tailLocal1 * n$tailLocal1 + * n$tailLocal1 = n$tailLocal1$tmp1 + * acc$tailLocal1 = acc$tailLocal1$tmp1 + * (return[tailLabel1] ()): Int + * } + * } + * } + * } + * } + * }}} + * + * As the JVM provides no way to jump from a method to another one, + * non-recursive calls in tail-position are not optimized. + * + * A method call is self-recursive if it calls the current method and + * the method is final (otherwise, it could be a call to an overridden + * method in a subclass). Recursive calls on a different instance are + * optimized. + * + * This phase has been moved after erasure to allow the use of vars + * for the parameters combined with a `WhileDo`. This is also + * beneficial to support polymorphic tail-recursive calls. + * + * In scalac, if the method had type parameters, the call must contain + * the same parameters as type arguments. This is no longer the case in + * dotc thanks to being located after erasure. + * In scalac, this is named tailCall but it does only provide optimization for + * self recursive functions, that's why it's renamed to tailrec + * + * @author + * Erik Stenman, Iulian Dragos, + * ported and heavily modified for dotty by Dmitry Petrashko + * moved after erasure and adapted to emit `Labeled` blocks by Sébastien Doeraene */ class TailRec extends MiniPhase { import TailRec._ @@ -72,25 +110,6 @@ class TailRec extends MiniPhase { override def runsAfter: Set[String] = Set(Erasure.name) // tailrec assumes erased types - final val labelFlags: Flags.FlagSet = Flags.Synthetic | Flags.Label | Flags.Method - - private def mkLabel(method: Symbol)(implicit ctx: Context): TermSymbol = { - val name = TailLabelName.fresh() - - if (method.owner.isClass) { - val MethodTpe(paramNames, paramInfos, resultType) = method.info - - val enclosingClass = method.enclosingClass.asClass - val thisParamType = - if (enclosingClass.is(Flags.Module)) enclosingClass.thisType - else enclosingClass.classInfo.selfType - - ctx.newSymbol(method, name.toTermName, labelFlags, - MethodType(nme.SELF :: paramNames, thisParamType :: paramInfos, resultType)) - } - else ctx.newSymbol(method, name.toTermName, labelFlags, method.info) - } - override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree = { val sym = tree.symbol tree match { @@ -100,62 +119,57 @@ class TailRec extends MiniPhase { cpy.DefDef(dd)(rhs = { val defIsTopLevel = sym.owner.isClass val origMeth = sym - val label = mkLabel(sym) val owner = ctx.owner.enclosingClass.asClass - var rewrote = false - // Note: this can be split in two separate transforms(in different groups), // than first one will collect info about which transformations and rewritings should be applied // and second one will actually apply, // now this speculatively transforms tree and throws away result in many cases - val rhsSemiTransformed = { - val transformer = new TailRecElimination(origMeth, owner, mandatory, label) - val rhs = transformer.transform(dd.rhs) - rewrote = transformer.rewrote - rhs - } - - if (rewrote) { - if (tree.symbol.owner.isClass) { - val classSym = tree.symbol.owner.asClass - - val labelDef = DefDef(label, vrefss => { - assert(vrefss.size == 1, vrefss) - val vrefs = vrefss.head - val thisRef = vrefs.head - val origMeth = tree.symbol - val origVParams = vparams.map(_.symbol) + val transformer = new TailRecElimination(origMeth, owner, vparams.map(_.symbol), mandatory) + val rhsSemiTransformed = transformer.transform(dd.rhs) + + if (transformer.rewrote) { + val varForRewrittenThis = transformer.varForRewrittenThis + val rewrittenParamSyms = transformer.rewrittenParamSyms + val varsForRewrittenParamSyms = transformer.varsForRewrittenParamSyms + + val initialVarDefs = { + val initialParamVarDefs = (rewrittenParamSyms, varsForRewrittenParamSyms).zipped.map { + (param, local) => ValDef(local.asTerm, ref(param)) + } + varForRewrittenThis match { + case Some(local) => ValDef(local.asTerm, This(tree.symbol.owner.asClass)) :: initialParamVarDefs + case none => initialParamVarDefs + } + } + + val rhsFullyTransformed = varForRewrittenThis match { + case Some(localThisSym) => + val thisRef = localThisSym.termRef new TreeTypeMap( - typeMap = identity(_) - .substThisUnlessStatic(classSym, thisRef.tpe) - .subst(origVParams, vrefs.tail.map(_.tpe)), + typeMap = _.substThisUnlessStatic(owner, thisRef) + .subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef)), treeMap = { - case tree: This if tree.symbol == classSym => thisRef + case tree: This if tree.symbol == owner => Ident(thisRef) case tree => tree - }, - oldOwners = origMeth :: Nil, - newOwners = label :: Nil + } ).transform(rhsSemiTransformed) - }) - val callIntoLabel = ref(label).appliedToArgs(This(classSym) :: vparams.map(x => ref(x.symbol))) - Block(List(labelDef), callIntoLabel) - } else { // inner method. Tail recursion does not change `this` - val labelDef = DefDef(label, vrefss => { - assert(vrefss.size == 1, vrefss) - val vrefs = vrefss.head - val origMeth = tree.symbol - val origVParams = vparams.map(_.symbol) + + case none => new TreeTypeMap( - typeMap = identity(_) - .subst(origVParams, vrefs.map(_.tpe)), - oldOwners = origMeth :: Nil, - newOwners = label :: Nil + typeMap = _.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef)) ).transform(rhsSemiTransformed) + } + + Block( + initialVarDefs, + WhileDo(EmptyTree, { + Labeled(transformer.continueLabel.asTerm, { + Return(rhsFullyTransformed, origMeth) + }) }) - val callIntoLabel = ref(label).appliedToArgs(vparams.map(x => ref(x.symbol))) - Block(List(labelDef), callIntoLabel) - }} else { + ) + } else { if (mandatory) ctx.error( "TailRec optimisation not applicable, method not tail recursive", // FIXME: want to report this error on `dd.namePos`, but @@ -174,12 +188,51 @@ class TailRec extends MiniPhase { } - class TailRecElimination(method: Symbol, enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap { + class TailRecElimination(method: Symbol, enclosingClass: Symbol, paramSyms: List[Symbol], isMandatory: Boolean) extends tpd.TreeMap { import dotty.tools.dotc.ast.tpd._ var rewrote: Boolean = false + /** The `tailLabelN` label symbol, used to encode a `continue` from the infinite `while` loop. */ + private[this] var myContinueLabel: Symbol = _ + def continueLabel(implicit c: Context): Symbol = { + if (myContinueLabel == null) + myContinueLabel = c.newSymbol(method, TailLabelName.fresh(), Flags.Label, defn.UnitType) + myContinueLabel + } + + /** The local `var` that replaces `this`, if it is modified in at least one recursive call. */ + var varForRewrittenThis: Option[Symbol] = None + /** The subset of `paramSyms` that are modified in at least one recursive call, and which therefore need a replacement `var`. */ + var rewrittenParamSyms: List[Symbol] = Nil + /** The replacement `var`s for the params in `rewrittenParamSyms`. */ + var varsForRewrittenParamSyms: List[Symbol] = Nil + + private def getVarForRewrittenThis()(implicit c: Context): Symbol = { + varForRewrittenThis match { + case Some(sym) => sym + case none => + val tpe = + if (enclosingClass.is(Flags.Module)) enclosingClass.thisType + else enclosingClass.asClass.classInfo.selfType + val sym = c.newSymbol(method, nme.SELF, Flags.Synthetic | Flags.Mutable, tpe) + varForRewrittenThis = Some(sym) + sym + } + } + + private def getVarForRewrittenParam(param: Symbol)(implicit c: Context): Symbol = { + rewrittenParamSyms.indexOf(param) match { + case -1 => + val sym = c.newSymbol(method, TailLocalName.fresh(param.name.toTermName), Flags.Synthetic | Flags.Mutable, param.info) + rewrittenParamSyms ::= param + varsForRewrittenParamSyms ::= sym + sym + case index => varsForRewrittenParamSyms(index) + } + } + /** Symbols of Labeled blocks that are in tail position. */ private val tailPositionLabeledSyms = new collection.mutable.HashSet[Symbol]() @@ -233,15 +286,44 @@ class TailRec extends MiniPhase { if (ctx.tailPos) { c.debuglog("Rewriting tail recursive call: " + tree.pos) rewrote = true - def receiver = - if (prefix eq EmptyTree) This(enclosingClass.asClass) - else noTailTransform(prefix) - val argumentsWithReceiver = - if (this.method.owner.isClass) receiver :: arguments - else arguments - - tpd.cpy.Apply(tree)(ref(label), argumentsWithReceiver) + val assignParamPairs = for { + (param, arg) <- paramSyms.zip(arguments) + if (arg match { + case arg: Ident => arg.symbol != param + case _ => true + }) + } yield { + (getVarForRewrittenParam(param), arg) + } + + val assignThisAndParamPairs = { + if (prefix eq EmptyTree) assignParamPairs + else { + // TODO Opt: also avoid assigning `this` if the prefix is `this.` + (getVarForRewrittenThis(), noTailTransform(prefix)) :: assignParamPairs + } + } + + val assignments = assignThisAndParamPairs match { + case (lhs, rhs) :: Nil => + Assign(ref(lhs), rhs) :: Nil + case _ :: _ => + val (tempValDefs, assigns) = (for ((lhs, rhs) <- assignThisAndParamPairs) yield { + val temp = c.newSymbol(method, TailTempName.fresh(lhs.name.toTermName), Flags.Synthetic, lhs.info) + (ValDef(temp, rhs), Assign(ref(lhs), ref(temp)).withPos(tree.pos)) + }).unzip + tempValDefs ::: assigns + case nil => + Nil + } + + /* The `Typed` node is necessary to perfectly preserve the type of the node. + * Without it, lubbing in enclosing if/else or match can infer a different type, + * which can cause Ycheck errors. + */ + val tpt = TypeTree(method.info.resultType) + seq(assignments, Typed(Return(Literal(Constant(())).withPos(tree.pos), continueLabel), tpt)) } else fail("it is not in tail position") } else { diff --git a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala index a6d8d5b8f83e..72fb88bb249f 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeChecker.scala @@ -452,6 +452,11 @@ class TreeChecker extends Phase with SymTransformer { tree1 } + override def typedWhileDo(tree: untpd.WhileDo)(implicit ctx: Context): Tree = { + assert((tree.cond ne EmptyTree) || ctx.phase.refChecked, i"invalid empty condition in while at $tree") + super.typedWhileDo(tree) + } + override def ensureNoLocalRefs(tree: Tree, pt: Type, localSyms: => List[Symbol])(implicit ctx: Context): Tree = tree diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index c02796a33106..6f56eafd7636 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -486,7 +486,7 @@ trait TypeAssigner { tree.withType(defn.NothingType) def assignType(tree: untpd.WhileDo)(implicit ctx: Context): WhileDo = - tree.withType(defn.UnitType) + tree.withType(if (tree.cond eq EmptyTree) defn.NothingType else defn.UnitType) def assignType(tree: untpd.Try, expr: Tree, cases: List[CaseDef])(implicit ctx: Context): Try = if (cases.isEmpty) tree.withType(expr.tpe) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 9c3901128aa2..c2fe93d32475 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -1127,7 +1127,9 @@ class Typer extends Namer } def typedWhileDo(tree: untpd.WhileDo)(implicit ctx: Context): Tree = track("typedWhileDo") { - val cond1 = typed(tree.cond, defn.BooleanType) + val cond1 = + if (tree.cond eq EmptyTree) EmptyTree + else typed(tree.cond, defn.BooleanType) val body1 = typed(tree.body, defn.UnitType) assignType(cpy.WhileDo(tree)(cond1, body1)) } diff --git a/scala-backend b/scala-backend index 47eaa7fbbc65..18b3f30e8c30 160000 --- a/scala-backend +++ b/scala-backend @@ -1 +1 @@ -Subproject commit 47eaa7fbbc6549e4ad8a07c1d644e947eef9b9bd +Subproject commit 18b3f30e8c302ddd0e404f1b3dacf877eb6257c5