Skip to content

Commit 4eaa795

Browse files
committed
Move TailRec after Erasure.
This used not to be possible because it was too complicated for TailRec to understand trees produced by the pattern matcher. Now that patmat uses `Labeled` blocks instead of label-defs, it is trivial to do so: the argument of a `return` from a labeled block `lab` is in tail position if and only if the labeled block `lab` is itself in tail position. Running TailRec after erasure has two major benefits: * it is much simpler, as it does not have to deal with type parameters, `TypeApply`s, and a bunch of other stuff. * it supports polymorphic tail-recursive calls by construction (recursive calls whose receiver or method has different type parameters). It also has one difficulty: it cannot see the *call-site* `@tailrec` annotations anymore. This is why we add a mini-phase to record such annotations as tree attachments. Having TailRec after erasure will also be necessary to later make it use `Labeled` blocks and loops instead of label-defs itself. Indeed, in that case it will need to declare local `var`s for its term parameters, which would break path-dependent types within the method if it were done before erasure.
1 parent 0a2734b commit 4eaa795

File tree

10 files changed

+143
-95
lines changed

10 files changed

+143
-95
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ class Compiler {
6767
new ProtectedAccessors, // Add accessors for protected members
6868
new ExtensionMethods, // Expand methods of value classes with extension methods
6969
new ShortcutImplicits, // Allow implicit functions without creating closures
70-
new TailRec, // Rewrite tail recursion to loops
7170
new ByNameClosures, // Expand arguments to by-name parameters to closures
7271
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
7372
new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope
@@ -92,10 +91,12 @@ class Compiler {
9291
new ResolveSuper, // Implement super accessors and add forwarders to trait methods
9392
new PrimitiveForwarders, // Add forwarders to trait methods that have a mismatch between generic and primitives
9493
new FunctionXXLForwarders, // Add forwarders for FunctionXXL apply method
94+
new RecordTailRecCallSites, // Records call-site @tailrec annotations as attachments
9595
new ArrayConstructors) :: // Intercept creation of (non-generic) arrays and intrinsify.
9696
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
9797
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
9898
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
99+
new TailRec, // Rewrite tail recursion to loops
99100
new Mixin, // Expand trait fields and trait initializers
100101
new LazyVals, // Expand lazy vals
101102
new Memoize, // Add private fields to getters and setters

compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ class TreeTypeMap(
116116
val guard1 = tmap.transform(guard)
117117
val rhs1 = tmap.transform(rhs)
118118
cpy.CaseDef(cdef)(pat1, guard1, rhs1)
119+
case labeled @ Labeled(bind, expr) =>
120+
val tmap = withMappedSyms(bind.symbol :: Nil)
121+
val bind1 = tmap.transformSub(bind)
122+
val expr1 = tmap.transform(expr)
123+
cpy.Labeled(labeled)(bind1, expr1)
119124
case Hole(n, args) =>
120125
Hole(n, args.mapConserve(transform)).withPos(tree.pos).withType(mapType(tree.tpe))
121126
case tree1 =>

compiler/src/dotty/tools/dotc/transform/MixinOps.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(implicit ctx: Cont
2424
name = member.name.stripScala2LocalSuffix,
2525
flags = member.flags &~ Deferred,
2626
info = cls.thisType.memberInfo(member)).enteredAfter(thisPhase).asTerm
27-
res.addAnnotations(member.annotations)
27+
res.addAnnotations(member.annotations.filter(_.symbol != defn.TailrecAnnot))
2828
res
2929
}
3030

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class PatternMatcher extends MiniPhase {
2727

2828
override def phaseName = PatternMatcher.name
2929
override def runsAfter = Set(ElimRepeated.name)
30-
override def runsAfterGroupsOf = Set(TailRec.name) // tailrec is not capable of reversing the patmat tranformation made for tree
3130

3231
override def transformMatch(tree: Match)(implicit ctx: Context): Tree = {
3332
val translated = new Translator(tree.tpe, this).translateMatch(tree)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import core._
5+
import MegaPhase._
6+
import Contexts.Context
7+
import Flags._
8+
import SymUtils._
9+
import Symbols._
10+
import SymDenotations._
11+
import Types._
12+
import Decorators._
13+
import DenotTransformers._
14+
import StdNames._
15+
import NameOps._
16+
import ast.Trees._
17+
import dotty.tools.dotc.ast.tpd
18+
import util.Positions._
19+
import Names._
20+
21+
import collection.mutable
22+
import ResolveSuper._
23+
24+
import scala.collection.immutable.::
25+
26+
27+
/** This phase saves call-site `@tailrec` annotations as attachments.
28+
*
29+
* Since erasure will come before the `tailrec` phase, it will erase the `@tailrec` annotations
30+
* in the `Typed` nodes.
31+
*/
32+
class RecordTailRecCallSites extends MiniPhase {
33+
import ast.tpd._
34+
35+
override def phaseName: String = "recordTailrecCallSites"
36+
37+
override def transformTyped(tree: Typed)(implicit ctx: Context): Tree = {
38+
if (tree.tpt.tpe.hasAnnotation(defn.TailrecAnnot) && tree.expr.isInstanceOf[Apply])
39+
tree.expr.pushAttachment(TailRec.TailRecCallSiteKey, ())
40+
super.transformTyped(tree)
41+
}
42+
}

compiler/src/dotty/tools/dotc/transform/TailRec.scala

Lines changed: 79 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ import core._
77
import Contexts.Context
88
import Decorators._
99
import Symbols._
10+
import StdNames.nme
1011
import Types._
1112
import NameKinds.TailLabelName
1213
import MegaPhase.MiniPhase
1314
import reporting.diagnostic.messages.TailrecNotApplicable
15+
import util.Property
1416

1517
/**
1618
* A Tail Rec Transformer
@@ -80,28 +82,31 @@ class TailRec extends MiniPhase with FullParameterization {
8082
tree
8183
}
8284

83-
override def transformTyped(tree: Typed)(implicit ctx: Context): Tree = {
84-
if (tree.tpt.tpe.hasAnnotation(defn.TailrecAnnot))
85+
override def transformApply(tree: Apply)(implicit ctx: Context): Tree = {
86+
if (tree.getAttachment(TailRecCallSiteKey).isDefined)
8587
methodsWithInnerAnnots += ctx.owner.enclosingMethod
8688
tree
8789
}
8890

8991
private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit ctx: Context): TermSymbol = {
9092
val name = TailLabelName.fresh()
9193

92-
if (method.owner.isClass)
93-
ctx.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass, liftThisType = false))
94+
if (method.owner.isClass) {
95+
val MethodTpe(paramNames, paramInfos, resultType) = method.info
96+
97+
ctx.newSymbol(method, name.toTermName, labelFlags,
98+
MethodType(nme.SELF :: paramNames, method.enclosingClass.asClass.classInfo.selfType :: paramInfos, resultType))
99+
}
94100
else ctx.newSymbol(method, name.toTermName, labelFlags, method.info)
95101
}
96102

97103
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree = {
98104
val sym = tree.symbol
99105
tree match {
100-
case dd@DefDef(name, tparams, vparamss0, tpt, _)
106+
case dd@DefDef(name, Nil, vparams0 :: Nil, tpt, _)
101107
if (sym.isEffectivelyFinal) && !((sym is Flags.Accessor) || (dd.rhs eq EmptyTree) || (sym is Flags.Label)) =>
102108
val mandatory = sym.hasAnnotation(defn.TailrecAnnot)
103109
cpy.DefDef(dd)(rhs = {
104-
105110
val defIsTopLevel = sym.owner.isClass
106111
val origMeth = sym
107112
val label = mkLabel(sym, abstractOverClass = defIsTopLevel)
@@ -115,34 +120,51 @@ class TailRec extends MiniPhase with FullParameterization {
115120
// and second one will actually apply,
116121
// now this speculatively transforms tree and throws away result in many cases
117122
val rhsSemiTransformed = {
118-
val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
123+
val transformer = new TailRecElimination(origMeth, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel)
119124
val rhs = transformer.transform(dd.rhs)
120125
rewrote = transformer.rewrote
121126
rhs
122127
}
123128

124129
if (rewrote) {
125-
val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed)
130+
assert(dd.tparams.isEmpty, dd)
126131
if (tree.symbol.owner.isClass) {
127-
val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel)
128-
val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true)
129-
Block(List(labelDef), call)
132+
val classSym = tree.symbol.owner.asClass
133+
134+
val labelDef = DefDef(label, vrefss => {
135+
assert(vrefss.size == 1, vrefss)
136+
val vrefs = vrefss.head
137+
val thisRef = vrefs.head
138+
val origMeth = tree.symbol
139+
val origVParams = tree.vparamss.flatten map (_.symbol)
140+
new TreeTypeMap(
141+
typeMap = identity(_)
142+
.substThisUnlessStatic(classSym, thisRef.tpe)
143+
.subst(origVParams, vrefs.tail.map(_.tpe)),
144+
treeMap = {
145+
case tree: This if tree.symbol == classSym => thisRef
146+
case tree => tree
147+
},
148+
oldOwners = origMeth :: Nil,
149+
newOwners = label :: Nil
150+
).transform(rhsSemiTransformed)
151+
})
152+
val callIntoLabel = ref(label).appliedToArgs(This(classSym) :: vparams0.map(x => ref(x.symbol)))
153+
Block(List(labelDef), callIntoLabel)
130154
} else { // inner method. Tail recursion does not change `this`
131-
val labelDef = polyDefDef(label, trefs => vrefss => {
155+
val labelDef = DefDef(label, vrefss => {
156+
assert(vrefss.size == 1, vrefss)
157+
val vrefs = vrefss.head
132158
val origMeth = tree.symbol
133-
val origTParams = tree.tparams.map(_.symbol)
134159
val origVParams = tree.vparamss.flatten map (_.symbol)
135160
new TreeTypeMap(
136161
typeMap = identity(_)
137-
.subst(origTParams ++ origVParams, trefs ++ vrefss.flatten.map(_.tpe)),
162+
.subst(origVParams, vrefs.map(_.tpe)),
138163
oldOwners = origMeth :: Nil,
139164
newOwners = label :: Nil
140165
).transform(rhsSemiTransformed)
141166
})
142-
val callIntoLabel = (
143-
if (dd.tparams.isEmpty) ref(label)
144-
else ref(label).appliedToTypes(dd.tparams.map(_.tpe))
145-
).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol))))
167+
val callIntoLabel = ref(label).appliedToArgs(vparams0.map(x => ref(x.symbol)))
146168
Block(List(labelDef), callIntoLabel)
147169
}} else {
148170
if (mandatory) ctx.error(
@@ -163,13 +185,14 @@ class TailRec extends MiniPhase with FullParameterization {
163185

164186
}
165187

166-
class TailRecElimination(method: Symbol, methTparams: List[Tree], enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {
188+
class TailRecElimination(method: Symbol, enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap {
167189

168190
import dotty.tools.dotc.ast.tpd._
169191

170192
var rewrote = false
171193

172-
private val defaultReason = "it contains a recursive call not in tail position"
194+
/** Symbols of Labeled blocks that are in tail position. */
195+
private val tailPositionLabeledSyms = new collection.mutable.HashSet[Symbol]()
173196

174197
private[this] var ctx: TailContext = yesTailContext
175198

@@ -195,82 +218,60 @@ class TailRec extends MiniPhase with FullParameterization {
195218

196219
override def transform(tree: Tree)(implicit c: Context): Tree = {
197220
/* A possibly polymorphic apply to be considered for tail call transformation. */
198-
def rewriteApply(tree: Tree, sym: Symbol, required: Boolean = false): Tree = {
199-
def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil, accT: List[Tree] = Nil):
200-
(Tree, Tree, List[List[Tree]], List[Tree], Symbol) = t match {
201-
case TypeApply(fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs)
202-
case Apply(fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs, accT)
203-
case Select(qual, _) => (qual, t, accArgs, accT, t.symbol)
204-
case x: This => (x, x, accArgs, accT, x.symbol)
205-
case x: Ident if x.symbol eq method => (EmptyTree, x, accArgs, accT, x.symbol)
206-
case x => (x, x, accArgs, accT, x.symbol)
221+
def rewriteApply(tree: Tree, sym: Symbol): Tree = {
222+
def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil):
223+
(Tree, Tree, List[List[Tree]], Symbol) = t match {
224+
case Apply(fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs)
225+
case Select(qual, _) => (qual, t, accArgs, t.symbol)
226+
case x: This => (x, x, accArgs, x.symbol)
227+
case x: Ident if x.symbol eq method => (EmptyTree, x, accArgs, x.symbol)
228+
case x => (x, x, accArgs, x.symbol)
207229
}
208230

209-
val (prefix, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree)
210-
val hasConformingTargs = (typeArguments zip methTparams).forall{x => x._1.tpe <:< x._2.tpe}
231+
val (prefix, call, arguments, symbol) = receiverArgumentsAndSymbol(tree)
211232

212-
val targs = typeArguments.map(noTailTransform)
213233
val argumentss = arguments.map(noTailTransforms)
234+
assert(argumentss.size == 1, tree)
214235

215236
val isRecursiveCall = (method eq sym)
216237
val recvWiden = prefix.tpe.widenDealias
217238

218239
def continue = {
219240
val method = noTailTransform(call)
220-
val methodWithTargs = if (targs.nonEmpty) TypeApply(method, targs) else method
221-
if (methodWithTargs.tpe.widen.isParameterless) methodWithTargs
222-
else argumentss.foldLeft(methodWithTargs) {
241+
argumentss.foldLeft(method) {
223242
// case (method, args) => Apply(method, args) // Dotty deviation no auto-detupling yet. Interesting that one can do it in Scala2!
224243
(method, args) => Apply(method, args)
225244
}
226245
}
227246
def fail(reason: String) = {
247+
def required = tree.getAttachment(TailRecCallSiteKey).isDefined
228248
if (isMandatory || required) c.error(s"Cannot rewrite recursive call: $reason", tree.pos)
229249
else c.debuglog("Cannot rewrite recursive call at: " + tree.pos + " because: " + reason)
230250
continue
231251
}
232252

233253
if (isRecursiveCall) {
234254
if (ctx.tailPos) {
235-
val receiverIsSame =
236-
recvWiden <:< enclosingClass.appliedRef &&
237-
(sym.isEffectivelyFinal || enclosingClass.appliedRef <:< recvWiden)
238-
val receiverIsThis = prefix.tpe =:= thisType || prefix.tpe.widen =:= thisType
239-
240255
def rewriteTailCall(recv: Tree): Tree = {
241256
c.debuglog("Rewriting tail recursive call: " + tree.pos)
242257
rewrote = true
243258
val receiver = noTailTransform(recv)
244259

245-
val callTargs: List[tpd.Tree] =
246-
if (abstractOverClass) {
247-
val classTypeArgs = recv.tpe.baseType(enclosingClass).argInfos
248-
targs ::: classTypeArgs.map(x => ref(x.typeSymbol))
249-
} else targs
250-
251-
val method = if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef)
252-
val thisPassed =
260+
val method = Ident(label.termRef)
261+
val argumentsWithReceiver =
253262
if (this.method.owner.isClass)
254-
method.appliedTo(receiver.ensureConforms(method.tpe.widen.firstParamTypes.head))
255-
else method
256-
257-
val res =
258-
if (thisPassed.tpe.widen.isParameterless) thisPassed
259-
else argumentss.foldLeft(thisPassed) {
260-
(met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet.
261-
}
262-
res
263-
}
263+
receiver :: argumentss.head
264+
else
265+
argumentss.head
264266

265-
if (!hasConformingTargs) fail("it changes type arguments on a polymorphic recursive call")
266-
else {
267-
val recv = noTailTransform(prefix)
268-
if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass))
269-
else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv)
270-
else fail("it changes type of 'this' on a polymorphic recursive call")
267+
Apply(method, argumentsWithReceiver)
271268
}
269+
270+
val recv = noTailTransform(prefix)
271+
if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass))
272+
else rewriteTailCall(recv)
272273
}
273-
else fail(defaultReason)
274+
else fail("it is not in tail position")
274275
} else {
275276
val receiverIsSuper = (method.name eq sym) && enclosingClass.appliedRef.widen <:< recvWiden
276277

@@ -305,28 +306,21 @@ class TailRec extends MiniPhase with FullParameterization {
305306
else tree
306307

307308
case tree: Select =>
308-
val sym = tree.symbol
309-
if (sym == method && ctx.tailPos) rewriteApply(tree, sym)
310-
else tpd.cpy.Select(tree)(noTailTransform(tree.qualifier), tree.name)
309+
tpd.cpy.Select(tree)(noTailTransform(tree.qualifier), tree.name)
311310

312-
case Apply(fun, args) =>
311+
case Apply(fun, args) if !fun.isInstanceOf[TypeApply] =>
313312
val meth = fun.symbol
314313
if (meth == defn.Boolean_|| || meth == defn.Boolean_&&)
315314
tpd.cpy.Apply(tree)(fun, transform(args))
316315
else
317316
rewriteApply(tree, meth)
318317

319-
case TypeApply(fun, targs) =>
320-
val meth = fun.symbol
321-
rewriteApply(tree, meth)
322-
323318
case tree@Block(stats, expr) =>
324319
tpd.cpy.Block(tree)(
325320
noTailTransforms(stats),
326321
transform(expr)
327322
)
328-
case tree @ Typed(t: Apply, tpt) if tpt.tpe.hasAnnotation(defn.TailrecAnnot) =>
329-
tpd.Typed(rewriteApply(t, t.fun.symbol, required = true), tpt)
323+
330324
case tree@If(cond, thenp, elsep) =>
331325
tpd.cpy.If(tree)(
332326
noTailTransform(cond),
@@ -357,8 +351,15 @@ class TailRec extends MiniPhase with FullParameterization {
357351
Literal(_) | TypeTree() | TypeDef(_, _) =>
358352
tree
359353

354+
case Labeled(bind, expr) =>
355+
if (ctx.tailPos)
356+
tailPositionLabeledSyms += bind.symbol
357+
tpd.cpy.Labeled(tree)(bind, transform(expr))
358+
360359
case Return(expr, from) =>
361-
tpd.cpy.Return(tree)(noTailTransform(expr), from)
360+
val fromSym = from.symbol
361+
val tailPos = fromSym.is(Flags.Label) && tailPositionLabeledSyms.contains(fromSym)
362+
tpd.cpy.Return(tree)(transform(expr, new TailContext(tailPos)), from)
362363

363364
case _ =>
364365
super.transform(tree)
@@ -382,4 +383,6 @@ object TailRec {
382383

383384
final val noTailContext = new TailContext(false)
384385
final val yesTailContext = new TailContext(true)
386+
387+
object TailRecCallSiteKey extends Property.StickyKey[Unit]
385388
}

0 commit comments

Comments
 (0)