Skip to content

Commit e823e8b

Browse files
committed
Optimise the Tailrec phase
We can stop traversing a tree in Tailrec as soon as we are not in tail position anymore and we are not within a labeled block in tail position. We keep traversing the tree however if the method is @tailrec annotated to report errors on eventual recursive calls.
1 parent c63cc43 commit e823e8b

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

compiler/src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,23 @@ class TailRec extends MiniPhase {
245245
def yesTailTransform(tree: Tree)(implicit ctx: Context): Tree =
246246
transform(tree, tailPosition = true)
247247

248+
/** If not in tail position a tree traversal may not be needed.
249+
*
250+
* A recursive call may still be in tail position if within the return
251+
* expression of a labelled block.
252+
* A tree traversal may also be needed to report a failure to transform
253+
* a recursive call of a @tailrec annotated method (i.e. `isMandatory`).
254+
*/
255+
private def isTraversalNeeded =
256+
isMandatory || tailPositionLabeledSyms.nonEmpty
257+
248258
def noTailTransform(tree: Tree)(implicit ctx: Context): Tree =
249-
transform(tree, tailPosition = false)
259+
if (isTraversalNeeded) transform(tree, tailPosition = false)
260+
else tree
250261

251262
def noTailTransforms[Tr <: Tree](trees: List[Tr])(implicit ctx: Context): List[Tr] =
252-
trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]
263+
if (isTraversalNeeded) trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]
264+
else trees
253265

254266
override def transform(tree: Tree)(implicit ctx: Context): Tree = {
255267
/* Rewrite an Apply to be considered for tail call transformation. */
@@ -390,7 +402,9 @@ class TailRec extends MiniPhase {
390402
case Labeled(bind, expr) =>
391403
if (inTailPosition)
392404
tailPositionLabeledSyms += bind.symbol
393-
cpy.Labeled(tree)(bind, transform(expr))
405+
try cpy.Labeled(tree)(bind, transform(expr))
406+
finally if (inTailPosition)
407+
tailPositionLabeledSyms -= bind.symbol
394408

395409
case Return(expr, from) =>
396410
val fromSym = from.symbol

0 commit comments

Comments
 (0)