Skip to content

Commit 6b0ae0b

Browse files
committed
Merge branch 'master' into fix-equality
2 parents d5dfd16 + 02f1ec9 commit 6b0ae0b

File tree

4 files changed

+50
-5
lines changed

4 files changed

+50
-5
lines changed

src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,20 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
7373
final val labelPrefix = "tailLabel"
7474
final val labelFlags = Flags.Synthetic | Flags.Label
7575

76+
/** Symbols of methods that have @tailrec annotatios inside */
77+
private val methodsWithInnerAnnots = new collection.mutable.HashSet[Symbol]()
78+
79+
override def transformUnit(tree: Tree)(implicit ctx: Context, info: TransformerInfo): Tree = {
80+
methodsWithInnerAnnots.clear()
81+
tree
82+
}
83+
84+
override def transformTyped(tree: Typed)(implicit ctx: Context, info: TransformerInfo): Tree = {
85+
if (tree.tpt.tpe.hasAnnotation(defn.TailrecAnnot))
86+
methodsWithInnerAnnots += ctx.owner.enclosingMethod
87+
tree
88+
}
89+
7690
private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit c: Context): TermSymbol = {
7791
val name = c.freshName(labelPrefix)
7892

@@ -137,10 +151,10 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
137151
}
138152
})
139153
}
140-
case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot) =>
154+
case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) =>
141155
ctx.error("TailRec optimisation not applicable, method is neither private nor final so can be overridden", d.pos)
142156
d
143-
case d if d.symbol.hasAnnotation(defn.TailrecAnnot) =>
157+
case d if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) =>
144158
ctx.error("TailRec optimisation not applicable, not a method", d.pos)
145159
d
146160
case _ => tree
@@ -180,7 +194,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
180194

181195
override def transform(tree: Tree)(implicit c: Context): Tree = {
182196
/* A possibly polymorphic apply to be considered for tail call transformation. */
183-
def rewriteApply(tree: Tree, sym: Symbol): Tree = {
197+
def rewriteApply(tree: Tree, sym: Symbol, required: Boolean = false): Tree = {
184198
def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil, accT: List[Tree] = Nil):
185199
(Tree, Tree, List[List[Tree]], List[Tree], Symbol) = t match {
186200
case TypeApply(fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs)
@@ -216,7 +230,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
216230
}
217231
}
218232
def fail(reason: String) = {
219-
if (isMandatory) c.error(s"Cannot rewrite recursive call: $reason", tree.pos)
233+
if (isMandatory || required) c.error(s"Cannot rewrite recursive call: $reason", tree.pos)
220234
else c.debuglog("Cannot rewrite recursive call at: " + tree.pos + " because: " + reason)
221235
continue
222236
}
@@ -299,7 +313,8 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
299313
noTailTransforms(stats),
300314
transform(expr)
301315
)
302-
316+
case tree @ Typed(t: Apply, tpt) if tpt.tpe.hasAnnotation(defn.TailrecAnnot) =>
317+
tpd.Typed(rewriteApply(t, t.fun.symbol, required = true), tpt)
303318
case tree@If(cond, thenp, elsep) =>
304319
tpd.cpy.If(tree)(
305320
noTailTransform(cond),

tests/neg/tailcall/i1221.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import annotation.tailrec
2+
3+
object I1221{
4+
final def foo(a: Int): Int = {
5+
if ((foo(a - 1): @tailrec) > 0) // error: not in tail position
6+
foo(a - 1): @tailrec
7+
else
8+
foo(a - 2): @tailrec
9+
}
10+
}

tests/neg/tailcall/i1221b.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import annotation.tailrec
2+
3+
class Test {
4+
def foo(a: Int): Int = { // error: method is not final
5+
if ((foo(a - 1): @tailrec) > 0)
6+
foo(a - 1): @tailrec
7+
else
8+
foo(a - 2): @tailrec
9+
}
10+
}

tests/pos/tailcall/i1221.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import annotation.tailrec
2+
3+
object i1221{
4+
final def foo(a: Int): Int = {
5+
if (foo(a - 1) > 0)
6+
foo(a - 1): @tailrec
7+
else
8+
foo(a - 2): @tailrec
9+
}
10+
}

0 commit comments

Comments
 (0)