diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 042e6609927c..564646b446c7 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -67,7 +67,6 @@ class Compiler { new ProtectedAccessors, // Add accessors for protected members new ExtensionMethods, // Expand methods of value classes with extension methods new ShortcutImplicits, // Allow implicit functions without creating closures - new TailRec, // Rewrite tail recursion to loops new ByNameClosures, // Expand arguments to by-name parameters to closures new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope @@ -96,6 +95,7 @@ class Compiler { List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations + new TailRec, // Rewrite tail recursion to loops new Mixin, // Expand trait fields and trait initializers new LazyVals, // Expand lazy vals new Memoize, // Add private fields to getters and setters diff --git a/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala b/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala index d854482537ec..856597b4b7ff 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala @@ -116,6 +116,11 @@ class TreeTypeMap( val guard1 = tmap.transform(guard) val rhs1 = tmap.transform(rhs) cpy.CaseDef(cdef)(pat1, guard1, rhs1) + case labeled @ Labeled(bind, expr) => + val tmap = withMappedSyms(bind.symbol :: Nil) + val bind1 = tmap.transformSub(bind) + val expr1 = tmap.transform(expr) + cpy.Labeled(labeled)(bind1, expr1) case Hole(n, args) => Hole(n, args.mapConserve(transform)).withPos(tree.pos).withType(mapType(tree.tpe)) case tree1 => diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index d1d1f6df7314..546eef756466 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -208,17 +208,17 @@ object Types { def loop(tp: Type): Boolean = tp match { case tp: TypeRef => val sym = tp.symbol - if (sym.isClass) sym.derivesFrom(cls) else loop(tp.superType): @tailrec + if (sym.isClass) sym.derivesFrom(cls) else loop(tp.superType) case tp: AppliedType => tp.superType.derivesFrom(cls) case tp: MatchType => tp.bound.derivesFrom(cls) || tp.reduced.derivesFrom(cls) case tp: TypeProxy => - loop(tp.underlying): @tailrec + loop(tp.underlying) case tp: AndType => - loop(tp.tp1) || loop(tp.tp2): @tailrec + loop(tp.tp1) || loop(tp.tp2) case tp: OrType => - loop(tp.tp1) && loop(tp.tp2): @tailrec + loop(tp.tp1) && loop(tp.tp2) case tp: JavaArrayType => cls == defn.ObjectClass case _ => @@ -403,16 +403,16 @@ object Types { */ final def classSymbol(implicit ctx: Context): Symbol = this match { case ConstantType(constant) => - constant.tpe.classSymbol: @tailrec + constant.tpe.classSymbol case tp: TypeRef => val sym = tp.symbol - if (sym.isClass) sym else tp.superType.classSymbol: @tailrec + if (sym.isClass) sym else tp.superType.classSymbol case tp: ClassInfo => tp.cls case tp: SingletonType => NoSymbol case tp: TypeProxy => - tp.underlying.classSymbol: @tailrec + tp.underlying.classSymbol case AndType(l, r) => val lsym = l.classSymbol val rsym = r.classSymbol @@ -436,9 +436,9 @@ object Types { tp.cls :: Nil case tp: TypeRef => val sym = tp.symbol - if (sym.isClass) sym.asClass :: Nil else tp.superType.classSymbols: @tailrec + if (sym.isClass) sym.asClass :: Nil else tp.superType.classSymbols case tp: TypeProxy => - tp.underlying.classSymbols: @tailrec + tp.underlying.classSymbols case AndType(l, r) => l.classSymbols union r.classSymbols case OrType(l, r) => @@ -479,7 +479,7 @@ object Types { case tp: ClassInfo => tp.decls case tp: TypeProxy => - tp.underlying.decls: @tailrec + tp.underlying.decls case _ => EmptyScope } @@ -725,7 +725,7 @@ object Types { val ns = tp.parent.memberNames(keepOnly, pre) if (keepOnly(pre, tp.refinedName)) ns + tp.refinedName else ns case tp: TypeProxy => - tp.underlying.memberNames(keepOnly, pre): @tailrec + tp.underlying.memberNames(keepOnly, pre) case tp: AndType => tp.tp1.memberNames(keepOnly, pre) | tp.tp2.memberNames(keepOnly, pre) case tp: OrType => @@ -1042,21 +1042,21 @@ object Types { case tp: TypeRef => if (tp.symbol.isClass) tp else tp.info match { - case TypeAlias(alias) => alias.dealias1(keep): @tailrec + case TypeAlias(alias) => alias.dealias1(keep) case _ => tp } case app @ AppliedType(tycon, args) => val tycon1 = tycon.dealias1(keep) - if (tycon1 ne tycon) app.superType.dealias1(keep): @tailrec + if (tycon1 ne tycon) app.superType.dealias1(keep) else this case tp: TypeVar => val tp1 = tp.instanceOpt - if (tp1.exists) tp1.dealias1(keep): @tailrec else tp + if (tp1.exists) tp1.dealias1(keep) else tp case tp: AnnotatedType => val tp1 = tp.parent.dealias1(keep) if (keep(tp)(ctx)) tp.derivedAnnotatedType(tp1, tp.annot) else tp1 case tp: LazyRef => - tp.ref.dealias1(keep): @tailrec + tp.ref.dealias1(keep) case _ => this } diff --git a/compiler/src/dotty/tools/dotc/transform/MixinOps.scala b/compiler/src/dotty/tools/dotc/transform/MixinOps.scala index 921e3762f186..16dbd0ad3ba6 100644 --- a/compiler/src/dotty/tools/dotc/transform/MixinOps.scala +++ b/compiler/src/dotty/tools/dotc/transform/MixinOps.scala @@ -24,7 +24,7 @@ class MixinOps(cls: ClassSymbol, thisPhase: DenotTransformer)(implicit ctx: Cont name = member.name.stripScala2LocalSuffix, flags = member.flags &~ Deferred, info = cls.thisType.memberInfo(member)).enteredAfter(thisPhase).asTerm - res.addAnnotations(member.annotations) + res.addAnnotations(member.annotations.filter(_.symbol != defn.TailrecAnnot)) res } diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 98f614ae5a18..a04c7329710a 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -27,7 +27,6 @@ class PatternMatcher extends MiniPhase { override def phaseName = PatternMatcher.name override def runsAfter = Set(ElimRepeated.name) - override def runsAfterGroupsOf = Set(TailRec.name) // tailrec is not capable of reversing the patmat tranformation made for tree override def transformMatch(tree: Match)(implicit ctx: Context): Tree = { val translated = new Translator(tree.tpe, this).translateMatch(tree) diff --git a/compiler/src/dotty/tools/dotc/transform/TailRec.scala b/compiler/src/dotty/tools/dotc/transform/TailRec.scala index fb8160c5a146..d91bf1d13dfd 100644 --- a/compiler/src/dotty/tools/dotc/transform/TailRec.scala +++ b/compiler/src/dotty/tools/dotc/transform/TailRec.scala @@ -7,25 +7,30 @@ import core._ import Contexts.Context import Decorators._ import Symbols._ +import StdNames.nme import Types._ import NameKinds.TailLabelName import MegaPhase.MiniPhase import reporting.diagnostic.messages.TailrecNotApplicable +import util.Property /** * A Tail Rec Transformer * @author Erik Stenman, Iulian Dragos, * ported and heavily modified for dotty by Dmitry Petrashko + * moved after erasure by Sébastien Doeraene * @version 1.1 * * What it does: *

* Finds method calls in tail-position and replaces them with jumps. * A call is in a tail-position if it is the last instruction to be - * executed in the body of a method. This is done by recursing over + * executed in the body of a method. This includes being in + * tail-position of a `return` from a `Labeled` block which is itself + * in tail-position (which is critical for tail-recursive calls in the + * cases of a `match`). To identify tail positions, we recurse over * the trees that may contain calls in tail-position (trees that can't - * contain such calls are not transformed). However, they are not that - * many. + * contain such calls are not transformed). *

*

* Self-recursive calls in tail-position are replaced by jumps to a @@ -37,16 +42,14 @@ import reporting.diagnostic.messages.TailrecNotApplicable * A method call is self-recursive if it calls the current method and * the method is final (otherwise, it could * be a call to an overridden method in a subclass). - * - * Recursive calls on a different instance - * are optimized. Since 'this' is not a local variable it s added as - * a label parameter. + * Recursive calls on a different instance are optimized. Since 'this' + * is not a local variable it is added as a label parameter. *

*

- * This phase has been moved before pattern matching to catch more - * of the common cases of tail recursive functions. This means that - * more cases should be taken into account (like nested function, and - * pattern cases). + * This phase has been moved after erasure to allow the use of vars + * for the parameters combined with a `WhileDo` (upcoming change). + * This is also beneficial to support polymorphic tail-recursive + * calls. *

*

* If a method contains self-recursive calls, a label is added to at @@ -54,59 +57,52 @@ import reporting.diagnostic.messages.TailrecNotApplicable * that label. *

*

- * - * In scalac, If the method had type parameters, the call must contain same - * parameters as type arguments. This is no longer case in dotc. + * In scalac, if the method had type parameters, the call must contain + * the same parameters as type arguments. This is no longer the case in + * dotc thanks to being located after erasure. * In scalac, this is named tailCall but it does only provide optimization for * self recursive functions, that's why it's renamed to tailrec *

*/ -class TailRec extends MiniPhase with FullParameterization { +class TailRec extends MiniPhase { import TailRec._ import dotty.tools.dotc.ast.tpd._ override def phaseName: String = TailRec.name - override def runsAfter = Set(ShortcutImplicits.name) // Replaces non-tail calls by tail calls + override def runsAfter = Set(Erasure.name) // tailrec assumes erased types final val labelFlags = Flags.Synthetic | Flags.Label | Flags.Method - /** Symbols of methods that have @tailrec annotatios inside */ - private val methodsWithInnerAnnots = new collection.mutable.HashSet[Symbol]() + private def mkLabel(method: Symbol)(implicit ctx: Context): TermSymbol = { + val name = TailLabelName.fresh() - override def transformUnit(tree: Tree)(implicit ctx: Context): Tree = { - methodsWithInnerAnnots.clear() - tree - } + if (method.owner.isClass) { + val MethodTpe(paramNames, paramInfos, resultType) = method.info - override def transformTyped(tree: Typed)(implicit ctx: Context): Tree = { - if (tree.tpt.tpe.hasAnnotation(defn.TailrecAnnot)) - methodsWithInnerAnnots += ctx.owner.enclosingMethod - tree - } + val enclosingClass = method.enclosingClass.asClass + val thisParamType = + if (enclosingClass.is(Flags.Module)) enclosingClass.thisType + else enclosingClass.classInfo.selfType - private def mkLabel(method: Symbol, abstractOverClass: Boolean)(implicit ctx: Context): TermSymbol = { - val name = TailLabelName.fresh() - - if (method.owner.isClass) - ctx.newSymbol(method, name.toTermName, labelFlags, fullyParameterizedType(method.info, method.enclosingClass.asClass, abstractOverClass, liftThisType = false)) + ctx.newSymbol(method, name.toTermName, labelFlags, + MethodType(nme.SELF :: paramNames, thisParamType :: paramInfos, resultType)) + } else ctx.newSymbol(method, name.toTermName, labelFlags, method.info) } override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree = { val sym = tree.symbol tree match { - case dd@DefDef(name, tparams, vparamss0, tpt, _) + case dd@DefDef(name, Nil, vparams :: Nil, tpt, _) if (sym.isEffectivelyFinal) && !((sym is Flags.Accessor) || (dd.rhs eq EmptyTree) || (sym is Flags.Label)) => val mandatory = sym.hasAnnotation(defn.TailrecAnnot) cpy.DefDef(dd)(rhs = { - val defIsTopLevel = sym.owner.isClass val origMeth = sym - val label = mkLabel(sym, abstractOverClass = defIsTopLevel) + val label = mkLabel(sym) val owner = ctx.owner.enclosingClass.asClass - val thisTpe = owner.thisType.widen var rewrote = false @@ -115,34 +111,50 @@ class TailRec extends MiniPhase with FullParameterization { // and second one will actually apply, // now this speculatively transforms tree and throws away result in many cases val rhsSemiTransformed = { - val transformer = new TailRecElimination(origMeth, dd.tparams, owner, thisTpe, mandatory, label, abstractOverClass = defIsTopLevel) + val transformer = new TailRecElimination(origMeth, owner, mandatory, label) val rhs = transformer.transform(dd.rhs) rewrote = transformer.rewrote rhs } if (rewrote) { - val dummyDefDef = cpy.DefDef(tree)(rhs = rhsSemiTransformed) if (tree.symbol.owner.isClass) { - val labelDef = fullyParameterizedDef(label, dummyDefDef, abstractOverClass = defIsTopLevel) - val call = forwarder(label, dd, abstractOverClass = defIsTopLevel, liftThisType = true) - Block(List(labelDef), call) + val classSym = tree.symbol.owner.asClass + + val labelDef = DefDef(label, vrefss => { + assert(vrefss.size == 1, vrefss) + val vrefs = vrefss.head + val thisRef = vrefs.head + val origMeth = tree.symbol + val origVParams = vparams.map(_.symbol) + new TreeTypeMap( + typeMap = identity(_) + .substThisUnlessStatic(classSym, thisRef.tpe) + .subst(origVParams, vrefs.tail.map(_.tpe)), + treeMap = { + case tree: This if tree.symbol == classSym => thisRef + case tree => tree + }, + oldOwners = origMeth :: Nil, + newOwners = label :: Nil + ).transform(rhsSemiTransformed) + }) + val callIntoLabel = ref(label).appliedToArgs(This(classSym) :: vparams.map(x => ref(x.symbol))) + Block(List(labelDef), callIntoLabel) } else { // inner method. Tail recursion does not change `this` - val labelDef = polyDefDef(label, trefs => vrefss => { + val labelDef = DefDef(label, vrefss => { + assert(vrefss.size == 1, vrefss) + val vrefs = vrefss.head val origMeth = tree.symbol - val origTParams = tree.tparams.map(_.symbol) - val origVParams = tree.vparamss.flatten map (_.symbol) + val origVParams = vparams.map(_.symbol) new TreeTypeMap( typeMap = identity(_) - .subst(origTParams ++ origVParams, trefs ++ vrefss.flatten.map(_.tpe)), + .subst(origVParams, vrefs.map(_.tpe)), oldOwners = origMeth :: Nil, newOwners = label :: Nil ).transform(rhsSemiTransformed) }) - val callIntoLabel = ( - if (dd.tparams.isEmpty) ref(label) - else ref(label).appliedToTypes(dd.tparams.map(_.tpe)) - ).appliedToArgss(vparamss0.map(_.map(x=> ref(x.symbol)))) + val callIntoLabel = ref(label).appliedToArgs(vparams.map(x => ref(x.symbol))) Block(List(labelDef), callIntoLabel) }} else { if (mandatory) ctx.error( @@ -155,7 +167,7 @@ class TailRec extends MiniPhase with FullParameterization { dd.rhs } }) - case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot) || methodsWithInnerAnnots.contains(d.symbol) => + case d: DefDef if d.symbol.hasAnnotation(defn.TailrecAnnot) => ctx.error(TailrecNotApplicable(sym), sym.pos) d case _ => tree @@ -163,13 +175,14 @@ class TailRec extends MiniPhase with FullParameterization { } - class TailRecElimination(method: Symbol, methTparams: List[Tree], enclosingClass: Symbol, thisType: Type, isMandatory: Boolean, label: Symbol, abstractOverClass: Boolean) extends tpd.TreeMap { + class TailRecElimination(method: Symbol, enclosingClass: Symbol, isMandatory: Boolean, label: Symbol) extends tpd.TreeMap { import dotty.tools.dotc.ast.tpd._ var rewrote = false - private val defaultReason = "it contains a recursive call not in tail position" + /** Symbols of Labeled blocks that are in tail position. */ + private val tailPositionLabeledSyms = new collection.mutable.HashSet[Symbol]() private[this] var ctx: TailContext = yesTailContext @@ -191,88 +204,50 @@ class TailRec extends MiniPhase with FullParameterization { transform(tree, noTailContext) def noTailTransforms[Tr <: Tree](trees: List[Tr])(implicit c: Context): List[Tr] = - trees.map(noTailTransform).asInstanceOf[List[Tr]] + trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]] override def transform(tree: Tree)(implicit c: Context): Tree = { - /* A possibly polymorphic apply to be considered for tail call transformation. */ - def rewriteApply(tree: Tree, sym: Symbol, required: Boolean = false): Tree = { - def receiverArgumentsAndSymbol(t: Tree, accArgs: List[List[Tree]] = Nil, accT: List[Tree] = Nil): - (Tree, Tree, List[List[Tree]], List[Tree], Symbol) = t match { - case TypeApply(fun, targs) if fun.symbol eq t.symbol => receiverArgumentsAndSymbol(fun, accArgs, targs) - case Apply(fn, args) if fn.symbol == t.symbol => receiverArgumentsAndSymbol(fn, args :: accArgs, accT) - case Select(qual, _) => (qual, t, accArgs, accT, t.symbol) - case x: This => (x, x, accArgs, accT, x.symbol) - case x: Ident if x.symbol eq method => (EmptyTree, x, accArgs, accT, x.symbol) - case x => (x, x, accArgs, accT, x.symbol) + /* Rewrite an Apply to be considered for tail call transformation. */ + def rewriteApply(tree: Apply): Tree = { + val call = tree.fun + val sym = call.symbol + val arguments = noTailTransforms(tree.args) + + val prefix = call match { + case Select(qual, _) => qual + case x: Ident if x.symbol eq method => EmptyTree + case x => x } - val (prefix, call, arguments, typeArguments, symbol) = receiverArgumentsAndSymbol(tree) - val hasConformingTargs = (typeArguments zip methTparams).forall{x => x._1.tpe <:< x._2.tpe} + val isRecursiveCall = (method eq sym) - val targs = typeArguments.map(noTailTransform) - val argumentss = arguments.map(noTailTransforms) + def continue = + tpd.cpy.Apply(tree)(noTailTransform(call), arguments) - val isRecursiveCall = (method eq sym) - val recvWiden = prefix.tpe.widenDealias - - def continue = { - val method = noTailTransform(call) - val methodWithTargs = if (targs.nonEmpty) TypeApply(method, targs) else method - if (methodWithTargs.tpe.widen.isParameterless) methodWithTargs - else argumentss.foldLeft(methodWithTargs) { - // case (method, args) => Apply(method, args) // Dotty deviation no auto-detupling yet. Interesting that one can do it in Scala2! - (method, args) => Apply(method, args) - } - } def fail(reason: String) = { - if (isMandatory || required) c.error(s"Cannot rewrite recursive call: $reason", tree.pos) + if (isMandatory) c.error(s"Cannot rewrite recursive call: $reason", tree.pos) else c.debuglog("Cannot rewrite recursive call at: " + tree.pos + " because: " + reason) continue } if (isRecursiveCall) { if (ctx.tailPos) { - val receiverIsSame = - recvWiden <:< enclosingClass.appliedRef && - (sym.isEffectivelyFinal || enclosingClass.appliedRef <:< recvWiden) - val receiverIsThis = prefix.tpe =:= thisType || prefix.tpe.widen =:= thisType - - def rewriteTailCall(recv: Tree): Tree = { - c.debuglog("Rewriting tail recursive call: " + tree.pos) - rewrote = true - val receiver = noTailTransform(recv) - - val callTargs: List[tpd.Tree] = - if (abstractOverClass) { - val classTypeArgs = recv.tpe.baseType(enclosingClass).argInfos - targs ::: classTypeArgs.map(x => ref(x.typeSymbol)) - } else targs - - val method = if (callTargs.nonEmpty) TypeApply(Ident(label.termRef), callTargs) else Ident(label.termRef) - val thisPassed = - if (this.method.owner.isClass) - method.appliedTo(receiver.ensureConforms(method.tpe.widen.firstParamTypes.head)) - else method - - val res = - if (thisPassed.tpe.widen.isParameterless) thisPassed - else argumentss.foldLeft(thisPassed) { - (met, ar) => Apply(met, ar) // Dotty deviation no auto-detupling yet. - } - res - } - - if (!hasConformingTargs) fail("it changes type arguments on a polymorphic recursive call") - else { - val recv = noTailTransform(prefix) - if (recv eq EmptyTree) rewriteTailCall(This(enclosingClass.asClass)) - else if (receiverIsSame || receiverIsThis) rewriteTailCall(recv) - else fail("it changes type of 'this' on a polymorphic recursive call") - } + c.debuglog("Rewriting tail recursive call: " + tree.pos) + rewrote = true + def receiver = + if (prefix eq EmptyTree) This(enclosingClass.asClass) + else noTailTransform(prefix) + + val argumentsWithReceiver = + if (this.method.owner.isClass) receiver :: arguments + else arguments + + tpd.cpy.Apply(tree)(ref(label), argumentsWithReceiver) } - else fail(defaultReason) + else fail("it is not in tail position") } else { - val receiverIsSuper = (method.name eq sym) && enclosingClass.appliedRef.widen <:< recvWiden + // FIXME `(method.name eq sym)` is always false (Name vs Symbol). What is this trying to do? + val receiverIsSuper = (method.name eq sym) && enclosingClass.appliedRef.widen <:< prefix.tpe.widenDealias if (receiverIsSuper) fail("it contains a recursive call targeting a supertype") else continue @@ -298,35 +273,22 @@ class TailRec extends MiniPhase with FullParameterization { } val res: Tree = tree match { - - case Ident(qual) => - val sym = tree.symbol - if (sym == method && ctx.tailPos) rewriteApply(tree, sym) - else tree - - case tree: Select => - val sym = tree.symbol - if (sym == method && ctx.tailPos) rewriteApply(tree, sym) - else tpd.cpy.Select(tree)(noTailTransform(tree.qualifier), tree.name) - - case Apply(fun, args) => + case tree@Apply(fun, args) => val meth = fun.symbol if (meth == defn.Boolean_|| || meth == defn.Boolean_&&) tpd.cpy.Apply(tree)(fun, transform(args)) else - rewriteApply(tree, meth) + rewriteApply(tree) - case TypeApply(fun, targs) => - val meth = fun.symbol - rewriteApply(tree, meth) + case tree: Select => + tpd.cpy.Select(tree)(noTailTransform(tree.qualifier), tree.name) case tree@Block(stats, expr) => tpd.cpy.Block(tree)( noTailTransforms(stats), transform(expr) ) - case tree @ Typed(t: Apply, tpt) if tpt.tpe.hasAnnotation(defn.TailrecAnnot) => - tpd.Typed(rewriteApply(t, t.fun.symbol, required = true), tpt) + case tree@If(cond, thenp, elsep) => tpd.cpy.If(tree)( noTailTransform(cond), @@ -357,8 +319,15 @@ class TailRec extends MiniPhase with FullParameterization { Literal(_) | TypeTree() | TypeDef(_, _) => tree + case Labeled(bind, expr) => + if (ctx.tailPos) + tailPositionLabeledSyms += bind.symbol + tpd.cpy.Labeled(tree)(bind, transform(expr)) + case Return(expr, from) => - tpd.cpy.Return(tree)(noTailTransform(expr), from) + val fromSym = from.symbol + val tailPos = fromSym.is(Flags.Label) && tailPositionLabeledSyms.contains(fromSym) + tpd.cpy.Return(tree)(transform(expr, new TailContext(tailPos)), from) case _ => super.transform(tree) @@ -367,12 +336,6 @@ class TailRec extends MiniPhase with FullParameterization { res } } - - /** If references to original `target` from fully parameterized method `derived` should be - * rewired to some fully parameterized method, that method symbol, - * otherwise NoSymbol. - */ - override protected def rewiredTarget(target: Symbol, derived: Symbol)(implicit ctx: Context): Symbol = NoSymbol } object TailRec { diff --git a/tests/neg-tailcall/i1221.scala b/tests/neg-tailcall/i1221.scala deleted file mode 100644 index 7cf9312f5f03..000000000000 --- a/tests/neg-tailcall/i1221.scala +++ /dev/null @@ -1,10 +0,0 @@ -import annotation.tailrec - -object I1221{ - final def foo(a: Int): Int = { - if ((foo(a - 1): @tailrec) > 0) // error: not in tail position - foo(a - 1): @tailrec - else - foo(a - 2): @tailrec - } -} diff --git a/tests/neg-tailcall/i1221b.scala b/tests/neg-tailcall/i1221b.scala deleted file mode 100644 index f8e2add9a7a0..000000000000 --- a/tests/neg-tailcall/i1221b.scala +++ /dev/null @@ -1,10 +0,0 @@ -import annotation.tailrec - -class Test { - def foo(a: Int): Int = { // error: method is not final - if ((foo(a - 1): @tailrec) > 0) - foo(a - 1): @tailrec - else - foo(a - 2): @tailrec - } -} diff --git a/tests/neg-tailcall/t1672b.scala b/tests/neg-tailcall/t1672b.scala index 294eebdef248..20185c430116 100644 --- a/tests/neg-tailcall/t1672b.scala +++ b/tests/neg-tailcall/t1672b.scala @@ -4,9 +4,9 @@ object Test1772B { try { throw new RuntimeException } catch { - case _: Throwable => bar + case _: Throwable => bar // error: it is not in tail position } finally { - bar + bar // error: it is not in tail position } } @@ -15,7 +15,7 @@ object Test1772B { try { throw new RuntimeException } catch { - case _: Throwable => baz + case _: Throwable => baz // error: it is not in tail position } finally { ??? } @@ -26,14 +26,14 @@ object Test1772B { try { throw new RuntimeException } catch { - case _: Throwable => boz; ??? + case _: Throwable => boz; ??? // error: it is not in tail position } } @annotation.tailrec def bez : Nothing = { // error: TailRec optimisation not applicable try { - bez + bez // error: it is not in tail position } finally { ??? } diff --git a/tests/neg-tailcall/t6574.scala b/tests/neg-tailcall/t6574.scala index 462ef800f435..dfae77953234 100644 --- a/tests/neg-tailcall/t6574.scala +++ b/tests/neg-tailcall/t6574.scala @@ -4,7 +4,7 @@ class Bad[X, Y](val v: Int) extends AnyVal { println("tail") } - @annotation.tailrec final def differentTypeArgs: Unit = { // error - {(); new Bad[String, Unit](0)}.differentTypeArgs // error + @annotation.tailrec final def differentTypeArgs: Unit = { + {(); new Bad[String, Unit](0)}.differentTypeArgs } } diff --git a/tests/neg-tailcall/tailrec-3.scala b/tests/neg-tailcall/tailrec-3.scala index b0c56061527d..391f39c90dab 100644 --- a/tests/neg-tailcall/tailrec-3.scala +++ b/tests/neg-tailcall/tailrec-3.scala @@ -7,7 +7,7 @@ object Test { case _ => Nil } @tailrec private def quux3(xs: List[String]): Boolean = xs match { - case x :: xs if quux3(List("abc")) => quux3(xs) + case x :: xs if quux3(List("abc")) => quux3(xs) // error case _ => false } } diff --git a/tests/neg-tailcall/tailrec.scala b/tests/neg-tailcall/tailrec.scala index f2047cdaf1be..159b2c82962a 100644 --- a/tests/neg-tailcall/tailrec.scala +++ b/tests/neg-tailcall/tailrec.scala @@ -36,6 +36,12 @@ class Winners { case Nil => Nil case x :: xs => succ3(xs, x :: acc) } + + @tailrec final def fail3[T](x: Int): Int = fail3(x - 1) + + class Tom[T](x: Int) { + @tailrec final def fail4[U](other: Tom[U], x: Int): Int = other.fail4[U](other, x - 1) + } } object Failures { @@ -54,12 +60,4 @@ class Failures { case Nil => Nil case x :: xs => x :: fail2[T](xs) // error } - - // unsafe - @tailrec final def fail3[T](x: Int): Int = fail3(x - 1) // error // error: recursive application has different type arguments - - // unsafe - class Tom[T](x: Int) { - @tailrec final def fail4[U](other: Tom[U], x: Int): Int = other.fail4[U](other, x - 1) // error // error - } } diff --git a/tests/pos/tailcall/i1221.scala b/tests/pos/tailcall/i1221.scala deleted file mode 100644 index ba1dbc4a96c9..000000000000 --- a/tests/pos/tailcall/i1221.scala +++ /dev/null @@ -1,10 +0,0 @@ -import annotation.tailrec - -object i1221{ - final def foo(a: Int): Int = { - if (foo(a - 1) > 0) - foo(a - 1): @tailrec - else - foo(a - 2): @tailrec - } -}