@@ -73,6 +73,20 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
73
73
final val labelPrefix = " tailLabel"
74
74
final val labelFlags = Flags .Synthetic | Flags .Label
75
75
76
+ /** Symbols of methods that have @tailrec annotatios inside */
77
+ private val methodsWithInnerAnnots = new collection.mutable.HashSet [Symbol ]()
78
+
79
+ override def transformUnit (tree : Tree )(implicit ctx : Context , info : TransformerInfo ): Tree = {
80
+ methodsWithInnerAnnots.clear()
81
+ tree
82
+ }
83
+
84
+ override def transformTyped (tree : Typed )(implicit ctx : Context , info : TransformerInfo ): Tree = {
85
+ if (tree.tpt.tpe.hasAnnotation(defn.TailrecAnnot ))
86
+ methodsWithInnerAnnots += ctx.owner.enclosingMethod
87
+ tree
88
+ }
89
+
76
90
private def mkLabel (method : Symbol , abstractOverClass : Boolean )(implicit c : Context ): TermSymbol = {
77
91
val name = c.freshName(labelPrefix)
78
92
@@ -137,10 +151,10 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
137
151
}
138
152
})
139
153
}
140
- case d : DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot ) =>
154
+ case d : DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot ) || methodsWithInnerAnnots.contains(d.symbol) =>
141
155
ctx.error(" TailRec optimisation not applicable, method is neither private nor final so can be overridden" , d.pos)
142
156
d
143
- case d if d.symbol.hasAnnotation(defn.TailrecAnnot ) =>
157
+ case d if d.symbol.hasAnnotation(defn.TailrecAnnot ) || methodsWithInnerAnnots.contains(d.symbol) =>
144
158
ctx.error(" TailRec optimisation not applicable, not a method" , d.pos)
145
159
d
146
160
case _ => tree
@@ -180,7 +194,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
180
194
181
195
override def transform (tree : Tree )(implicit c : Context ): Tree = {
182
196
/* A possibly polymorphic apply to be considered for tail call transformation. */
183
- def rewriteApply (tree : Tree , sym : Symbol ): Tree = {
197
+ def rewriteApply (tree : Tree , sym : Symbol , required : Boolean = false ): Tree = {
184
198
def receiverArgumentsAndSymbol (t : Tree , accArgs : List [List [Tree ]] = Nil , accT : List [Tree ] = Nil ):
185
199
(Tree , Tree , List [List [Tree ]], List [Tree ], Symbol ) = t match {
186
200
case TypeApply (fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs)
@@ -216,7 +230,7 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
216
230
}
217
231
}
218
232
def fail (reason : String ) = {
219
- if (isMandatory) c.error(s " Cannot rewrite recursive call: $reason" , tree.pos)
233
+ if (isMandatory || required ) c.error(s " Cannot rewrite recursive call: $reason" , tree.pos)
220
234
else c.debuglog(" Cannot rewrite recursive call at: " + tree.pos + " because: " + reason)
221
235
continue
222
236
}
@@ -299,7 +313,8 @@ class TailRec extends MiniPhaseTransform with DenotTransformer with FullParamete
299
313
noTailTransforms(stats),
300
314
transform(expr)
301
315
)
302
-
316
+ case tree @ Typed (t : Apply , tpt) if tpt.tpe.hasAnnotation(defn.TailrecAnnot ) =>
317
+ tpd.Typed (rewriteApply(t, t.fun.symbol, required = true ), tpt)
303
318
case tree@ If (cond, thenp, elsep) =>
304
319
tpd.cpy.If (tree)(
305
320
noTailTransform(cond),
0 commit comments