Skip to content

Commit d47ead5

Browse files
committed
Optimise the Tailrec phase
We can stop traversing a tree in Tailrec as soon as we are not in tail position anymore and we are not within a labeled block in tail position. We keep traversing the tree however if the method is @tailrec annotated to report errors on eventual recursive calls.
1 parent 5e67cd0 commit d47ead5

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

compiler/src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,23 @@ class TailRec extends MiniPhase {
249249
def yesTailTransform(tree: Tree)(implicit ctx: Context): Tree =
250250
transform(tree, tailPosition = true)
251251

252+
/** If not in tail position a tree traversal may not be needed.
253+
*
254+
* A recursive call may still be in tail position if within the return
255+
* expression of a labeled block.
256+
* A tree traversal may also be needed to report a failure to transform
257+
* a recursive call of a @tailrec annotated method (i.e. `isMandatory`).
258+
*/
259+
private def isTraversalNeeded =
260+
isMandatory || tailPositionLabeledSyms.nonEmpty
261+
252262
def noTailTransform(tree: Tree)(implicit ctx: Context): Tree =
253-
transform(tree, tailPosition = false)
263+
if (isTraversalNeeded) transform(tree, tailPosition = false)
264+
else tree
254265

255266
def noTailTransforms[Tr <: Tree](trees: List[Tr])(implicit ctx: Context): List[Tr] =
256-
trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]
267+
if (isTraversalNeeded) trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]
268+
else trees
257269

258270
override def transform(tree: Tree)(implicit ctx: Context): Tree = {
259271
/* Rewrite an Apply to be considered for tail call transformation. */
@@ -394,7 +406,11 @@ class TailRec extends MiniPhase {
394406
case Labeled(bind, expr) =>
395407
if (inTailPosition)
396408
tailPositionLabeledSyms += bind.symbol
397-
cpy.Labeled(tree)(bind, transform(expr))
409+
try cpy.Labeled(tree)(bind, transform(expr))
410+
finally {
411+
if (inTailPosition)
412+
tailPositionLabeledSyms -= bind.symbol
413+
}
398414

399415
case Return(expr, from) =>
400416
val fromSym = from.symbol

0 commit comments

Comments
 (0)