Skip to content

Commit b920577

Browse files
committed
Use Labeled blocks in TailRec, instead of label-defs.
It's easier to first explain on an example. Consider the following tail-recursive method: def fact(n: Int, acc: Int): Int = if (n == 0) acc else fact(n - 1, n * acc) It is now translated as follows by the `tailrec` transform: def fact(n: Int, acc: Int): Int = { var n$tailLocal1: Int = n var acc$tailLocal1: Int = acc while (true) { tailLabel1[Unit]: { return { if (n$tailLocal1 == 0) { acc } else { val n$tailLocal1$tmp1: Int = n$tailLocal1 - 1 val acc$tailLocal1$tmp1: Int = n$tailLocal1 * acc$tailLocal1 n$tailLocal1 = n$tailLocal1$tmp1 acc$tailLocal1 = acc$tailLocal1$tmp1 (return[tailLabel1] ()): Int } } } } throw null // unreachable code } First, we allocate local `var`s for every parameter, as well as `this` if necessary. When we find a tail-recursive call, we evaluate the arguments into temporaries, then assign them to the `var`s. It is necessary to use temporaries in order not to use the new contents of a param local when computing the new value of another param local. We avoid reassigning param locals if their rhs (i.e., the actual argument to the recursive call) is itself, which does happen quite often in practice. In particular, we thus avoid reassigning the local var for `this` if the prefix is empty. We could further optimize this by avoiding the reassignment if the prefix is non-empty but equivalent to `this`. If only one parameter ends up changing value in any particular tail-recursive call, we can avoid the temporaries and directly assign it. This is also a fairly common situation, especially after discarding useless assignments to the local for `this`. After all that, we `return` from a labeled block, which is right inside an infinite `while` loop. The net result is to loop back to the beginning, implementing the jump. The `return` node is explicitly ascribed with the previous result type, so that lubs upstream are not affected (not doing so can cause Ycheck errors). For control flows that do *not* end up in a tail-recursive call, the result value is given to an explicit `return` out of the enclosing method, which prevents the looping. There is one pretty ugly artifact: after the `while` loop, we must insert a `throw null` for the body to still typecheck as an `Int` (the result type of the `def`). This could be avoided if we dared type a `WhileDo(Literal(Constant(true)), body)` as having type `Nothing` rather than `Unit`. This is probably dangerous, though, as we have no guarantee that further transformations will leave the `true` alone, especially in the presence of compiler plugins. If the `true` gets wrapped in any way, the type of the `WhileDo` will be altered, and chaos will ensue. In the future, we could enhance the codegen to avoid emitting that dead code. This should not be too difficult: * emitting a `WhileDo` whose argument is `true` would set the generated `BType` to `Nothing`. * then, when emitting a `Block`, we would drop any statements and expr following a statement whose generated `BType` was `Nothing`. This commit does not go to such lengths, however. This change removes the last source of label-defs in the compiler. After this commit, we will be able to entirely remove label-defs.
1 parent 7b1e8d9 commit b920577

File tree

2 files changed

+119
-73
lines changed

2 files changed

+119
-73
lines changed

compiler/src/dotty/tools/dotc/core/NameKinds.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,8 @@ object NameKinds {
288288
val NonLocalReturnKeyName = new UniqueNameKind("nonLocalReturnKey")
289289
val WildcardParamName = new UniqueNameKind("_$")
290290
val TailLabelName = new UniqueNameKind("tailLabel")
291+
val TailLocalName = new UniqueNameKind("$tailLocal")
292+
val TailTempName = new UniqueNameKind("$tmp")
291293
val ExceptionBinderName = new UniqueNameKind("ex")
292294
val SkolemName = new UniqueNameKind("?")
293295
val LiftedTreeName = new UniqueNameKind("liftedTree")

compiler/src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 117 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ package transform
44
import ast.Trees._
55
import ast.{TreeTypeMap, tpd}
66
import core._
7+
import Constants.Constant
78
import Contexts.Context
89
import Decorators._
910
import Symbols._
1011
import StdNames.nme
1112
import Types._
12-
import NameKinds.TailLabelName
13+
import NameKinds.{TailLabelName, TailLocalName, TailTempName}
1314
import MegaPhase.MiniPhase
1415
import reporting.diagnostic.messages.TailrecNotApplicable
1516
import util.Property
@@ -72,25 +73,6 @@ class TailRec extends MiniPhase {
7273

7374
override def runsAfter = Set(ShortcutImplicits.name) // Replaces non-tail calls by tail calls
7475

75-
final val labelFlags = Flags.Synthetic | Flags.Label | Flags.Method
76-
77-
private def mkLabel(method: Symbol)(implicit ctx: Context): TermSymbol = {
78-
val name = TailLabelName.fresh()
79-
80-
if (method.owner.isClass) {
81-
val MethodTpe(paramNames, paramInfos, resultType) = method.info
82-
83-
val enclosingClass = method.enclosingClass.asClass
84-
val thisParamType =
85-
if (enclosingClass.is(Flags.Module)) enclosingClass.thisType
86-
else enclosingClass.classInfo.selfType
87-
88-
ctx.newSymbol(method, name.toTermName, labelFlags,
89-
MethodType(nme.SELF :: paramNames, thisParamType :: paramInfos, resultType))
90-
}
91-
else ctx.newSymbol(method, name.toTermName, labelFlags, method.info)
92-
}
93-
9476
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree = {
9577
val sym = tree.symbol
9678
tree match {
@@ -100,63 +82,62 @@ class TailRec extends MiniPhase {
10082
cpy.DefDef(dd)(rhs = {
10183
val defIsTopLevel = sym.owner.isClass
10284
val origMeth = sym
103-
val label = mkLabel(sym)
10485
val owner = ctx.owner.enclosingClass.asClass
10586

106-
var rewrote = false
107-
10887
// Note: this can be split in two separate transforms(in different groups),
10988
// than first one will collect info about which transformations and rewritings should be applied
11089
// and second one will actually apply,
11190
// now this speculatively transforms tree and throws away result in many cases
112-
val rhsSemiTransformed = {
113-
val transformer = new TailRecElimination(origMeth, owner, mandatory, label)
114-
val rhs = transformer.transform(dd.rhs)
115-
rewrote = transformer.rewrote
116-
rhs
117-
}
91+
val transformer = new TailRecElimination(origMeth, owner, vparams0.map(_.symbol), mandatory)
92+
val rhsSemiTransformed = transformer.transform(dd.rhs)
11893

119-
if (rewrote) {
94+
if (transformer.rewrote) {
12095
assert(dd.tparams.isEmpty, dd)
121-
if (tree.symbol.owner.isClass) {
122-
val classSym = tree.symbol.owner.asClass
123-
124-
val labelDef = DefDef(label, vrefss => {
125-
assert(vrefss.size == 1, vrefss)
126-
val vrefs = vrefss.head
127-
val thisRef = vrefs.head
128-
val origMeth = tree.symbol
129-
val origVParams = tree.vparamss.flatten map (_.symbol)
96+
97+
val varForRewrittenThis = transformer.varForRewrittenThis
98+
val rewrittenParamSyms = transformer.rewrittenParamSyms
99+
val varsForRewrittenParamSyms = transformer.varsForRewrittenParamSyms
100+
101+
val initialValDefs = {
102+
val initialParamValDefs = for ((param, local) <- rewrittenParamSyms.zip(varsForRewrittenParamSyms)) yield {
103+
ValDef(local.asTerm, Ident(param.termRef))
104+
}
105+
varForRewrittenThis match {
106+
case Some(local) => ValDef(local.asTerm, This(tree.symbol.owner.asClass)) :: initialParamValDefs
107+
case none => initialParamValDefs
108+
}
109+
}
110+
111+
val rhsFullyTransformed = varForRewrittenThis match {
112+
case Some(localThisSym) =>
113+
val classSym = tree.symbol.owner.asClass
114+
val thisRef = localThisSym.termRef
130115
new TreeTypeMap(
131-
typeMap = identity(_)
132-
.substThisUnlessStatic(classSym, thisRef.tpe)
133-
.subst(origVParams, vrefs.tail.map(_.tpe)),
116+
typeMap = _.substThisUnlessStatic(classSym, thisRef)
117+
.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef)),
134118
treeMap = {
135-
case tree: This if tree.symbol == classSym => thisRef
119+
case tree: This if tree.symbol == classSym => Ident(thisRef)
136120
case tree => tree
137-
},
138-
oldOwners = origMeth :: Nil,
139-
newOwners = label :: Nil
121+
}
140122
).transform(rhsSemiTransformed)
141-
})
142-
val callIntoLabel = ref(label).appliedToArgs(This(classSym) :: vparams0.map(x => ref(x.symbol)))
143-
Block(List(labelDef), callIntoLabel)
144-
} else { // inner method. Tail recursion does not change `this`
145-
val labelDef = DefDef(label, vrefss => {
146-
assert(vrefss.size == 1, vrefss)
147-
val vrefs = vrefss.head
148-
val origMeth = tree.symbol
149-
val origVParams = tree.vparamss.flatten map (_.symbol)
123+
124+
case none =>
150125
new TreeTypeMap(
151-
typeMap = identity(_)
152-
.subst(origVParams, vrefs.map(_.tpe)),
153-
oldOwners = origMeth :: Nil,
154-
newOwners = label :: Nil
126+
typeMap = _.subst(rewrittenParamSyms, varsForRewrittenParamSyms.map(_.termRef))
155127
).transform(rhsSemiTransformed)
156-
})
157-
val callIntoLabel = ref(label).appliedToArgs(vparams0.map(x => ref(x.symbol)))
158-
Block(List(labelDef), callIntoLabel)
159-
}} else {
128+
}
129+
130+
Block(
131+
initialValDefs :::
132+
WhileDo(Literal(Constant(true)), {
133+
Labeled(transformer.continueLabel.get.asTerm, {
134+
Return(rhsFullyTransformed, ref(origMeth))
135+
})
136+
}) ::
137+
Nil,
138+
Throw(Literal(Constant(null))) // unreachable code
139+
)
140+
} else {
160141
if (mandatory) ctx.error(
161142
"TailRec optimisation not applicable, method not tail recursive",
162143
// FIXME: want to report this error on `dd.namePos`, but
@@ -175,12 +156,51 @@ class TailRec extends MiniPhase {
175156

176157
}
177158

178-
class TailRecElimination(method: Symbol, enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap {
159+
class TailRecElimination(method: Symbol, enclosingClass: Symbol, paramSyms: List[Symbol], isMandatory: Boolean) extends tpd.TreeMap {
179160

180161
import dotty.tools.dotc.ast.tpd._
181162

182163
var rewrote = false
183164

165+
var continueLabel: Option[Symbol] = None
166+
var varForRewrittenThis: Option[Symbol] = None
167+
var rewrittenParamSyms: List[Symbol] = Nil
168+
var varsForRewrittenParamSyms: List[Symbol] = Nil
169+
170+
private def getContinueLabel()(implicit c: Context): Symbol = {
171+
continueLabel match {
172+
case Some(sym) => sym
173+
case none =>
174+
val sym = c.newSymbol(method, TailLabelName.fresh(), Flags.Label, defn.UnitType)
175+
continueLabel = Some(sym)
176+
sym
177+
}
178+
}
179+
180+
private def getVarForRewrittenThis()(implicit c: Context): Symbol = {
181+
varForRewrittenThis match {
182+
case Some(sym) => sym
183+
case none =>
184+
val tpe =
185+
if (enclosingClass.is(Flags.Module)) enclosingClass.thisType
186+
else enclosingClass.asClass.classInfo.selfType
187+
val sym = c.newSymbol(method, nme.SELF, Flags.Synthetic | Flags.Mutable, tpe)
188+
varForRewrittenThis = Some(sym)
189+
sym
190+
}
191+
}
192+
193+
private def getVarForRewrittenParam(param: Symbol)(implicit c: Context): Symbol = {
194+
rewrittenParamSyms.indexOf(param) match {
195+
case -1 =>
196+
val sym = c.newSymbol(method, TailLocalName.fresh(param.name.toTermName), Flags.Synthetic | Flags.Mutable, param.info)
197+
rewrittenParamSyms ::= param
198+
varsForRewrittenParamSyms ::= sym
199+
sym
200+
case index => varsForRewrittenParamSyms(index)
201+
}
202+
}
203+
184204
/** Symbols of Labeled blocks that are in tail position. */
185205
private val tailPositionLabeledSyms = new collection.mutable.HashSet[Symbol]()
186206

@@ -233,16 +253,40 @@ class TailRec extends MiniPhase {
233253
if (ctx.tailPos) {
234254
c.debuglog("Rewriting tail recursive call: " + tree.pos)
235255
rewrote = true
236-
def receiver =
237-
if (prefix eq EmptyTree) This(enclosingClass.asClass)
238-
else noTailTransform(prefix)
239-
240-
val method = Ident(label.termRef)
241-
val argumentsWithReceiver =
242-
if (this.method.owner.isClass) receiver :: arguments
243-
else arguments
244256

245-
Apply(method, argumentsWithReceiver)
257+
val assignParamPairs = for {
258+
(param, arg) <- paramSyms.zip(arguments)
259+
if (arg match {
260+
case arg: Ident => arg.symbol != param
261+
case _ => true
262+
})
263+
} yield {
264+
(getVarForRewrittenParam(param), arg)
265+
}
266+
267+
val assignThisAndParamPairs = {
268+
if (prefix eq EmptyTree) assignParamPairs
269+
else {
270+
// TODO Opt: also avoid assigning `this` if the prefix is `this.`
271+
(getVarForRewrittenThis(), noTailTransform(prefix)) :: assignParamPairs
272+
}
273+
}
274+
275+
val assignments = assignThisAndParamPairs match {
276+
case (lhs, rhs) :: Nil =>
277+
Assign(Ident(lhs.termRef), rhs) :: Nil
278+
case _ :: _ =>
279+
val (tempValDefs, assigns) = (for ((lhs, rhs) <- assignThisAndParamPairs) yield {
280+
val temp = c.newSymbol(method, TailTempName.fresh(lhs.name.toTermName), Flags.Synthetic, lhs.info)
281+
(ValDef(temp, rhs), Assign(Ident(lhs.termRef), Ident(temp.termRef)))
282+
}).unzip
283+
tempValDefs ::: assigns
284+
case nil =>
285+
Nil
286+
}
287+
288+
val tpt = TypeTree(method.info.resultType)
289+
Block(assignments, Typed(Return(Literal(Constant(())), Ident(getContinueLabel().termRef)), tpt))
246290
}
247291
else fail("it is not in tail position")
248292
} else {

0 commit comments

Comments
 (0)