diff --git a/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala b/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala index f668a1b68a5f..824a93ed506f 100644 --- a/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala +++ b/compiler/src/dotty/tools/backend/jvm/BCodeSkelBuilder.scala @@ -4,6 +4,8 @@ package jvm import scala.language.unsafeNulls +import scala.annotation.tailrec + import scala.collection.{ mutable, immutable } import scala.tools.asm @@ -484,6 +486,14 @@ trait BCodeSkelBuilder extends BCodeHelpers { slots.getOrElse(locSym, makeLocal(locSym)) } + def reuseLocal(sym: Symbol, loc: Local): Unit = + val existing = slots.put(sym, loc) + if (existing.isDefined) + report.error("attempt to create duplicate local var.", ctx.source.atSpan(sym.span)) + + def reuseThisSlot(sym: Symbol): Unit = + reuseLocal(sym, Local(symInfoTK(sym), sym.javaSimpleName, 0, sym.is(Synthetic))) + private def makeLocal(sym: Symbol, tk: BType): Local = { assert(nxtIdx != -1, "not a valid start index") val loc = Local(tk, sym.javaSimpleName, nxtIdx, sym.is(Synthetic)) @@ -753,18 +763,47 @@ trait BCodeSkelBuilder extends BCodeHelpers { .addFlagIf(isNative, asm.Opcodes.ACC_NATIVE) // native methods of objects are generated in mirror classes // TODO needed? for(ann <- m.symbol.annotations) { ann.symbol.initialize } - initJMethod(flags, params.map(_.symbol)) + val paramSyms = params.map(_.symbol) + initJMethod(flags, paramSyms) if (!isAbstractMethod && !isNative) { + // #14773 Reuse locals slots for tailrec-generated mutable vars + val trimmedRhs: Tree = + @tailrec def loop(stats: List[Tree]): List[Tree] = + stats match + case (tree @ ValDef(TailLocalName(_, _), _, _)) :: rest if tree.symbol.isAllOf(Mutable | Synthetic) => + tree.rhs match + case This(_) => + locals.reuseThisSlot(tree.symbol) + loop(rest) + case rhs: Ident if paramSyms.contains(rhs.symbol) => + locals.reuseLocal(tree.symbol, locals(rhs.symbol)) + loop(rest) + case _ => + stats + case _ => + stats + end loop + + rhs match + case Block(stats, expr) => + val trimmedStats = loop(stats) + if trimmedStats eq stats then + rhs + else + Block(trimmedStats, expr) + case _ => + rhs + end trimmedRhs def emitNormalMethodBody(): Unit = { val veryFirstProgramPoint = currProgramPoint() - genLoad(rhs, returnType) + genLoad(trimmedRhs, returnType) - rhs match { + trimmedRhs match { case (_: Return) | Block(_, (_: Return)) => () - case (_: Apply) | Block(_, (_: Apply)) if rhs.symbol eq defn.throwMethod => () + case (_: Apply) | Block(_, (_: Apply)) if trimmedRhs.symbol eq defn.throwMethod => () case tpd.EmptyTree => report.error("Concrete method has no definition: " + dd + ( if (ctx.settings.Ydebug.value) "(found: " + methSymbol.owner.info.decls.toList.mkString(", ") + ")" diff --git a/compiler/src/dotty/tools/dotc/transform/TailRec.scala b/compiler/src/dotty/tools/dotc/transform/TailRec.scala index 9ee3aa5b1a8c..c2555c3d5b95 100644 --- a/compiler/src/dotty/tools/dotc/transform/TailRec.scala +++ b/compiler/src/dotty/tools/dotc/transform/TailRec.scala @@ -253,7 +253,7 @@ class TailRec extends MiniPhase { val tpe = if (enclosingClass.is(Module)) enclosingClass.thisType else enclosingClass.classInfo.selfType - val sym = newSymbol(method, nme.SELF, Synthetic | Mutable, tpe) + val sym = newSymbol(method, TailLocalName.fresh(nme.SELF), Synthetic | Mutable, tpe) varForRewrittenThis = Some(sym) sym } diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTest.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTest.scala index 6c796c280723..ce887ec56ba9 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTest.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTest.scala @@ -122,7 +122,7 @@ trait DottyBytecodeTest { def assertSameCode(method: MethodNode, expected: List[Instruction]): Unit = assertSameCode(instructionsFromMethod(method).dropNonOp, expected) def assertSameCode(actual: List[Instruction], expected: List[Instruction]): Unit = { - assert(actual === expected, s"\nExpected: $expected\nActual : $actual") + assert(actual === expected, "\n" + diffInstructions(actual, expected)) } def assertInvoke(m: MethodNode, receiver: String, method: String): Unit = @@ -296,4 +296,3 @@ trait DottyBytecodeTest { object DottyBytecodeTest { extension [T](l: List[T]) def stringLines = l.mkString("\n") } - diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index 051dd414e28f..9c5f9c167bf9 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -946,6 +946,103 @@ class TestBCode extends DottyBytecodeTest { } } + @Test def i14773TailRecReuseParamSlots(): Unit = { + val source = + s"""class Foo { + | @scala.annotation.tailrec // explicit @tailrec here + | final def fact(n: Int, acc: Int): Int = + | if n == 0 then acc + | else fact(n - 1, acc * n) + |} + | + |class IntList(head: Int, tail: IntList | Null) { + | // implicit @tailrec + | final def sum(acc: Int): Int = + | val t = tail + | if t == null then acc + head + | else t.sum(acc + head) + |} + """.stripMargin + + checkBCode(source) { dir => + // The mutable local vars for n and acc reuse the slots of the params n and acc + + val fooClass = loadClassNode(dir.lookupName("Foo.class", directory = false).input) + val factMeth = getMethod(fooClass, "fact") + + assertSameCode(factMeth, List( + Label(0), + VarOp(ILOAD, 1), + Op(ICONST_0), + Jump(IF_ICMPNE, Label(7)), + VarOp(ILOAD, 2), + Jump(GOTO, Label(26)), + Label(7), + VarOp(ALOAD, 0), + VarOp(ASTORE, 3), + VarOp(ILOAD, 1), + Op(ICONST_1), + Op(ISUB), + VarOp(ISTORE, 4), + VarOp(ILOAD, 2), + VarOp(ILOAD, 1), + Op(IMUL), + VarOp(ISTORE, 5), + VarOp(ALOAD, 3), + VarOp(ASTORE, 0), + VarOp(ILOAD, 4), + VarOp(ISTORE, 1), + VarOp(ILOAD, 5), + VarOp(ISTORE, 2), + Jump(GOTO, Label(29)), + Label(26), + Op(IRETURN), + Label(29), + Jump(GOTO, Label(0)), + Op(NOP), + Op(ATHROW), + )) + + // The mutable local vars for this and acc reuse the slots of `this` and of the param acc + + val intListClass = loadClassNode(dir.lookupName("IntList.class", directory = false).input) + val sumMeth = getMethod(intListClass, "sum") + + assertSameCode(sumMeth, List( + Label(0), + VarOp(ALOAD, 0), + Field(GETFIELD, "IntList", "tail", "LIntList;"), + VarOp(ASTORE, 2), + VarOp(ALOAD, 2), + Jump(IFNONNULL, Label(12)), + VarOp(ILOAD, 1), + VarOp(ALOAD, 0), + Field(GETFIELD, "IntList", "head", "I"), + Op(IADD), + Jump(GOTO, Label(26)), + Label(12), + VarOp(ALOAD, 2), + VarOp(ASTORE, 3), + VarOp(ILOAD, 1), + VarOp(ALOAD, 0), + Field(GETFIELD, "IntList", "head", "I"), + Op(IADD), + VarOp(ISTORE, 4), + VarOp(ALOAD, 3), + VarOp(ASTORE, 0), + VarOp(ILOAD, 4), + VarOp(ISTORE, 1), + Jump(GOTO, Label(29)), + Label(26), + Op(IRETURN), + Label(29), + Jump(GOTO, Label(0)), + Op(NOP), + Op(ATHROW), + )) + } + } + @Test def getClazz: Unit = { val source = """