@@ -7,10 +7,12 @@ import core._
7
7
import Contexts .Context
8
8
import Decorators ._
9
9
import Symbols ._
10
+ import StdNames .nme
10
11
import Types ._
11
12
import NameKinds .TailLabelName
12
13
import MegaPhase .MiniPhase
13
14
import reporting .diagnostic .messages .TailrecNotApplicable
15
+ import util .Property
14
16
15
17
/**
16
18
* A Tail Rec Transformer
@@ -80,28 +82,31 @@ class TailRec extends MiniPhase with FullParameterization {
80
82
tree
81
83
}
82
84
83
- override def transformTyped (tree : Typed )(implicit ctx : Context ): Tree = {
84
- if (tree.tpt.tpe.hasAnnotation(defn. TailrecAnnot ) )
85
+ override def transformApply (tree : Apply )(implicit ctx : Context ): Tree = {
86
+ if (tree.getAttachment( TailRecCallSiteKey ).isDefined )
85
87
methodsWithInnerAnnots += ctx.owner.enclosingMethod
86
88
tree
87
89
}
88
90
89
91
private def mkLabel (method : Symbol , abstractOverClass : Boolean )(implicit ctx : Context ): TermSymbol = {
90
92
val name = TailLabelName .fresh()
91
93
92
- if (method.owner.isClass)
93
- ctx.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass, liftThisType = false ))
94
+ if (method.owner.isClass) {
95
+ val MethodTpe (paramNames, paramInfos, resultType) = method.info
96
+
97
+ ctx.newSymbol(method, name.toTermName, labelFlags,
98
+ MethodType (nme.SELF :: paramNames, method.enclosingClass.asClass.classInfo.selfType :: paramInfos, resultType))
99
+ }
94
100
else ctx.newSymbol(method, name.toTermName, labelFlags, method.info)
95
101
}
96
102
97
103
override def transformDefDef (tree : tpd.DefDef )(implicit ctx : Context ): tpd.Tree = {
98
104
val sym = tree.symbol
99
105
tree match {
100
- case dd@ DefDef (name, tparams, vparamss0 , tpt, _)
106
+ case dd@ DefDef (name, Nil , vparams0 :: Nil , tpt, _)
101
107
if (sym.isEffectivelyFinal) && ! ((sym is Flags .Accessor ) || (dd.rhs eq EmptyTree ) || (sym is Flags .Label )) =>
102
108
val mandatory = sym.hasAnnotation(defn.TailrecAnnot )
103
109
cpy.DefDef (dd)(rhs = {
104
-
105
110
val defIsTopLevel = sym.owner.isClass
106
111
val origMeth = sym
107
112
val label = mkLabel(sym, abstractOverClass = defIsTopLevel)
@@ -115,34 +120,51 @@ class TailRec extends MiniPhase with FullParameterization {
115
120
// and second one will actually apply,
116
121
// now this speculatively transforms tree and throws away result in many cases
117
122
val rhsSemiTransformed = {
118
- val transformer = new TailRecElimination (origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
123
+ val transformer = new TailRecElimination (origMeth, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
119
124
val rhs = transformer.transform(dd.rhs)
120
125
rewrote = transformer.rewrote
121
126
rhs
122
127
}
123
128
124
129
if (rewrote) {
125
- val dummyDefDef = cpy. DefDef (tree)(rhs = rhsSemiTransformed )
130
+ assert(dd.tparams.isEmpty, dd )
126
131
if (tree.symbol.owner.isClass) {
127
- val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
128
- val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true )
129
- Block (List (labelDef), call)
132
+ val classSym = tree.symbol.owner.asClass
133
+
134
+ val labelDef = DefDef (label, vrefss => {
135
+ assert(vrefss.size == 1 , vrefss)
136
+ val vrefs = vrefss.head
137
+ val thisRef = vrefs.head
138
+ val origMeth = tree.symbol
139
+ val origVParams = tree.vparamss.flatten map (_.symbol)
140
+ new TreeTypeMap (
141
+ typeMap = identity(_)
142
+ .substThisUnlessStatic(classSym, thisRef.tpe)
143
+ .subst(origVParams, vrefs.tail.map(_.tpe)),
144
+ treeMap = {
145
+ case tree : This if tree.symbol == classSym => thisRef
146
+ case tree => tree
147
+ },
148
+ oldOwners = origMeth :: Nil ,
149
+ newOwners = label :: Nil
150
+ ).transform(rhsSemiTransformed)
151
+ })
152
+ val callIntoLabel = ref(label).appliedToArgs(This (classSym) :: vparams0.map(x => ref(x.symbol)))
153
+ Block (List (labelDef), callIntoLabel)
130
154
} else { // inner method. Tail recursion does not change `this`
131
- val labelDef = polyDefDef(label, trefs => vrefss => {
155
+ val labelDef = DefDef (label, vrefss => {
156
+ assert(vrefss.size == 1 , vrefss)
157
+ val vrefs = vrefss.head
132
158
val origMeth = tree.symbol
133
- val origTParams = tree.tparams.map(_.symbol)
134
159
val origVParams = tree.vparamss.flatten map (_.symbol)
135
160
new TreeTypeMap (
136
161
typeMap = identity(_)
137
- .subst(origTParams ++ origVParams, trefs ++ vrefss.flatten .map(_.tpe)),
162
+ .subst(origVParams, vrefs .map(_.tpe)),
138
163
oldOwners = origMeth :: Nil ,
139
164
newOwners = label :: Nil
140
165
).transform(rhsSemiTransformed)
141
166
})
142
- val callIntoLabel = (
143
- if (dd.tparams.isEmpty) ref(label)
144
- else ref(label).appliedToTypes(dd.tparams.map(_.tpe))
145
- ).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol))))
167
+ val callIntoLabel = ref(label).appliedToArgs(vparams0.map(x => ref(x.symbol)))
146
168
Block (List (labelDef), callIntoLabel)
147
169
}} else {
148
170
if (mandatory) ctx.error(
@@ -163,13 +185,14 @@ class TailRec extends MiniPhase with FullParameterization {
163
185
164
186
}
165
187
166
- class TailRecElimination (method : Symbol , methTparams : List [ Tree ], enclosingClass : Symbol , thisType : Type , isMandatory : Boolean , label : Symbol , abstractOverClass : Boolean ) extends tpd.TreeMap {
188
+ class TailRecElimination (method : Symbol , enclosingClass : Symbol , thisType : Type , isMandatory : Boolean , label : Symbol , abstractOverClass : Boolean ) extends tpd.TreeMap {
167
189
168
190
import dotty .tools .dotc .ast .tpd ._
169
191
170
192
var rewrote = false
171
193
172
- private val defaultReason = " it contains a recursive call not in tail position"
194
+ /** Symbols of Labeled blocks that are in tail position. */
195
+ private val tailPositionLabeledSyms = new collection.mutable.HashSet [Symbol ]()
173
196
174
197
private [this ] var ctx : TailContext = yesTailContext
175
198
@@ -195,82 +218,60 @@ class TailRec extends MiniPhase with FullParameterization {
195
218
196
219
override def transform (tree : Tree )(implicit c : Context ): Tree = {
197
220
/* A possibly polymorphic apply to be considered for tail call transformation. */
198
- def rewriteApply (tree : Tree , sym : Symbol , required : Boolean = false ): Tree = {
199
- def receiverArgumentsAndSymbol (t : Tree , accArgs : List [List [Tree ]] = Nil , accT : List [Tree ] = Nil ):
200
- (Tree , Tree , List [List [Tree ]], List [Tree ], Symbol ) = t match {
201
- case TypeApply (fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs)
202
- case Apply (fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs, accT)
203
- case Select (qual, _) => (qual, t, accArgs, accT, t.symbol)
204
- case x : This => (x, x, accArgs, accT, x.symbol)
205
- case x : Ident if x.symbol eq method => (EmptyTree , x, accArgs, accT, x.symbol)
206
- case x => (x, x, accArgs, accT, x.symbol)
221
+ def rewriteApply (tree : Tree , sym : Symbol ): Tree = {
222
+ def receiverArgumentsAndSymbol (t : Tree , accArgs : List [List [Tree ]] = Nil ):
223
+ (Tree , Tree , List [List [Tree ]], Symbol ) = t match {
224
+ case Apply (fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs)
225
+ case Select (qual, _) => (qual, t, accArgs, t.symbol)
226
+ case x : This => (x, x, accArgs, x.symbol)
227
+ case x : Ident if x.symbol eq method => (EmptyTree , x, accArgs, x.symbol)
228
+ case x => (x, x, accArgs, x.symbol)
207
229
}
208
230
209
- val (prefix, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree)
210
- val hasConformingTargs = (typeArguments zip methTparams).forall{x => x._1.tpe <:< x._2.tpe}
231
+ val (prefix, call, arguments, symbol) = receiverArgumentsAndSymbol(tree)
211
232
212
- val targs = typeArguments.map(noTailTransform)
213
233
val argumentss = arguments.map(noTailTransforms)
234
+ assert(argumentss.size == 1 , tree)
214
235
215
236
val isRecursiveCall = (method eq sym)
216
237
val recvWiden = prefix.tpe.widenDealias
217
238
218
239
def continue = {
219
240
val method = noTailTransform(call)
220
- val methodWithTargs = if (targs.nonEmpty) TypeApply (method, targs) else method
221
- if (methodWithTargs.tpe.widen.isParameterless) methodWithTargs
222
- else argumentss.foldLeft(methodWithTargs) {
241
+ argumentss.foldLeft(method) {
223
242
// case (method, args) => Apply(method, args) // Dotty deviation no auto-detupling yet. Interesting that one can do it in Scala2!
224
243
(method, args) => Apply (method, args)
225
244
}
226
245
}
227
246
def fail (reason : String ) = {
247
+ def required = tree.getAttachment(TailRecCallSiteKey ).isDefined
228
248
if (isMandatory || required) c.error(s " Cannot rewrite recursive call: $reason" , tree.pos)
229
249
else c.debuglog(" Cannot rewrite recursive call at: " + tree.pos + " because: " + reason)
230
250
continue
231
251
}
232
252
233
253
if (isRecursiveCall) {
234
254
if (ctx.tailPos) {
235
- val receiverIsSame =
236
- recvWiden <:< enclosingClass.appliedRef &&
237
- (sym.isEffectivelyFinal || enclosingClass.appliedRef <:< recvWiden)
238
- val receiverIsThis = prefix.tpe =:= thisType || prefix.tpe.widen =:= thisType
239
-
240
255
def rewriteTailCall (recv : Tree ): Tree = {
241
256
c.debuglog(" Rewriting tail recursive call: " + tree.pos)
242
257
rewrote = true
243
258
val receiver = noTailTransform(recv)
244
259
245
- val callTargs : List [tpd.Tree ] =
246
- if (abstractOverClass) {
247
- val classTypeArgs = recv.tpe.baseType(enclosingClass).argInfos
248
- targs ::: classTypeArgs.map(x => ref(x.typeSymbol))
249
- } else targs
250
-
251
- val method = if (callTargs.nonEmpty) TypeApply (Ident (label.termRef), callTargs) else Ident (label.termRef)
252
- val thisPassed =
260
+ val method = Ident (label.termRef)
261
+ val argumentsWithReceiver =
253
262
if (this .method.owner.isClass)
254
- method.appliedTo(receiver.ensureConforms(method.tpe.widen.firstParamTypes.head))
255
- else method
256
-
257
- val res =
258
- if (thisPassed.tpe.widen.isParameterless) thisPassed
259
- else argumentss.foldLeft(thisPassed) {
260
- (met, ar) => Apply (met, ar) // Dotty deviation no auto-detupling yet.
261
- }
262
- res
263
- }
263
+ receiver :: argumentss.head
264
+ else
265
+ argumentss.head
264
266
265
- if (! hasConformingTargs) fail(" it changes type arguments on a polymorphic recursive call" )
266
- else {
267
- val recv = noTailTransform(prefix)
268
- if (recv eq EmptyTree ) rewriteTailCall(This (enclosingClass.asClass))
269
- else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv)
270
- else fail(" it changes type of 'this' on a polymorphic recursive call" )
267
+ Apply (method, argumentsWithReceiver)
271
268
}
269
+
270
+ val recv = noTailTransform(prefix)
271
+ if (recv eq EmptyTree ) rewriteTailCall(This (enclosingClass.asClass))
272
+ else rewriteTailCall(recv)
272
273
}
273
- else fail(defaultReason )
274
+ else fail(" it is not in tail position " )
274
275
} else {
275
276
val receiverIsSuper = (method.name eq sym) && enclosingClass.appliedRef.widen <:< recvWiden
276
277
@@ -305,28 +306,21 @@ class TailRec extends MiniPhase with FullParameterization {
305
306
else tree
306
307
307
308
case tree : Select =>
308
- val sym = tree.symbol
309
- if (sym == method && ctx.tailPos) rewriteApply(tree, sym)
310
- else tpd.cpy.Select (tree)(noTailTransform(tree.qualifier), tree.name)
309
+ tpd.cpy.Select (tree)(noTailTransform(tree.qualifier), tree.name)
311
310
312
- case Apply (fun, args) =>
311
+ case Apply (fun, args) if ! fun. isInstanceOf [ TypeApply ] =>
313
312
val meth = fun.symbol
314
313
if (meth == defn.Boolean_|| || meth == defn.Boolean_&& )
315
314
tpd.cpy.Apply (tree)(fun, transform(args))
316
315
else
317
316
rewriteApply(tree, meth)
318
317
319
- case TypeApply (fun, targs) =>
320
- val meth = fun.symbol
321
- rewriteApply(tree, meth)
322
-
323
318
case tree@ Block (stats, expr) =>
324
319
tpd.cpy.Block (tree)(
325
320
noTailTransforms(stats),
326
321
transform(expr)
327
322
)
328
- case tree @ Typed (t : Apply , tpt) if tpt.tpe.hasAnnotation(defn.TailrecAnnot ) =>
329
- tpd.Typed (rewriteApply(t, t.fun.symbol, required = true ), tpt)
323
+
330
324
case tree@ If (cond, thenp, elsep) =>
331
325
tpd.cpy.If (tree)(
332
326
noTailTransform(cond),
@@ -357,8 +351,15 @@ class TailRec extends MiniPhase with FullParameterization {
357
351
Literal (_) | TypeTree () | TypeDef (_, _) =>
358
352
tree
359
353
354
+ case Labeled (bind, expr) =>
355
+ if (ctx.tailPos)
356
+ tailPositionLabeledSyms += bind.symbol
357
+ tpd.cpy.Labeled (tree)(bind, transform(expr))
358
+
360
359
case Return (expr, from) =>
361
- tpd.cpy.Return (tree)(noTailTransform(expr), from)
360
+ val fromSym = from.symbol
361
+ val tailPos = fromSym.is(Flags .Label ) && tailPositionLabeledSyms.contains(fromSym)
362
+ tpd.cpy.Return (tree)(transform(expr, new TailContext (tailPos)), from)
362
363
363
364
case _ =>
364
365
super .transform(tree)
@@ -382,4 +383,6 @@ object TailRec {
382
383
383
384
final val noTailContext = new TailContext (false )
384
385
final val yesTailContext = new TailContext (true )
386
+
387
+ object TailRecCallSiteKey extends Property .StickyKey [Unit ]
385
388
}
0 commit comments