diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 10bf1c2032db..5f8865ceb451 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -149,12 +149,37 @@ object Inliner { /** Replace `Inlined` node by a block that contains its bindings and expansion */ def dropInlined(inlined: tpd.Inlined)(implicit ctx: Context): Tree = { - val reposition = new TreeMap { - override def transform(tree: Tree)(implicit ctx: Context): Tree = { - super.transform(tree).withPos(inlined.call.pos) + if (enclosingInlineds.nonEmpty) inlined // Remove in the outer most inlined call + else { + val inlinedAtPos = inlined.call.pos + val callSourceFile = ctx.source.file + + /** Removes all Inlined trees, replacing them with blocks. + * Repositions all trees directly inside an inlined expansion of a non empty call to the position of the call. + * Any tree directly inside an empty call (inlined in the inlined code) retains their position. + */ + class Reposition extends TreeMap { + override def transform(tree: Tree)(implicit ctx: Context): Tree = { + tree match { + case tree: Inlined => transformInline(tree) + case _ => + val transformed = super.transform(tree) + enclosingInlineds match { + case call :: _ if call.symbol.sourceFile != callSourceFile => + // Until we implement JSR-45, we cannot represent in output positions in other source files. + // So, reposition inlined code from other files with the call position: + transformed.withPos(inlinedAtPos) + case _ => transformed + } + } + } + def transformInline(tree: tpd.Inlined)(implicit ctx: Context): Tree = { + tpd.seq(transformSub(tree.bindings), transform(tree.expansion)(inlineContext(tree.call))) + } } + + (new Reposition).transformInline(inlined) } - tpd.seq(inlined.bindings, reposition.transform(inlined.expansion)) } } diff --git a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala index d46a91bd1b74..6d6385497949 100644 --- a/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala @@ -335,4 +335,5 @@ class TestBCode extends DottyBytecodeTest { assert(!fooInvoke, "foo should not be called\n") } } + } diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index b639adbfc66b..2ab838ac271e 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -3,6 +3,10 @@ package dotty.tools.backend.jvm import org.junit.Assert._ import org.junit.Test +import scala.tools.asm.Opcodes._ + +import scala.collection.JavaConverters._ + class InlineBytecodeTests extends DottyBytecodeTest { import ASMConverters._ @Test def inlineUnit = { @@ -37,4 +41,244 @@ class InlineBytecodeTests extends DottyBytecodeTest { diffInstructions(instructions2, instructions3)) } } + + @Test def i4947 = { + val source = """class Foo { + | transparent def track[T](f: => T): T = { + | foo("tracking") // line 3 + | f // line 4 + | } + | def main(args: Array[String]): Unit = { // line 6 + | track { // line 7 + | foo("abc") // line 8 + | track { // line 9 + | foo("inner") // line 10 + | } + | } // line 11 + | } + | def foo(str: String): Unit = () + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn, skipDebugInfo = false) + + val track = clsNode.methods.asScala.find(_.name == "track") + assert(track.isEmpty, "method `track` should have been erased") + + val main = getMethod(clsNode, "main") + val instructions = instructionsFromMethod(main) + val expected = + List( + Label(0), + LineNumber(6, Label(0)), + LineNumber(3, Label(0)), + VarOp(ALOAD, 0), + Ldc(LDC, "tracking"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Label(6), + LineNumber(8, Label(6)), + VarOp(ALOAD, 0), + Ldc(LDC, "abc"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Label(11), + LineNumber(3, Label(11)), + VarOp(ALOAD, 0), + Ldc(LDC, "tracking"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Label(16), + LineNumber(10, Label(16)), + VarOp(ALOAD, 0), + Ldc(LDC, "inner"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Op(RETURN), + Label(22) + ) + assert(instructions == expected, + "`track` was not properly inlined in `main`\n" + diffInstructions(instructions, expected)) + + } + } + + @Test def i4947b = { + val source = """class Foo { + | transparent def track2[T](f: => T): T = { + | foo("tracking2") // line 3 + | f // line 4 + | } + | transparent def track[T](f: => T): T = { + | foo("tracking") // line 7 + | track2 { // line 8 + | f // line 9 + | } + | } + | def main(args: Array[String]): Unit = { // line 12 + | track { // line 13 + | foo("abc") // line 14 + | } + | } + | def foo(str: String): Unit = () + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn, skipDebugInfo = false) + + val track = clsNode.methods.asScala.find(_.name == "track") + assert(track.isEmpty, "method `track` should have been erased") + + val track2 = clsNode.methods.asScala.find(_.name == "track2") + assert(track2.isEmpty, "method `track2` should have been erased") + + val main = getMethod(clsNode, "main") + val instructions = instructionsFromMethod(main) + val expected = + List( + Label(0), + LineNumber(12, Label(0)), + LineNumber(7, Label(0)), + VarOp(ALOAD, 0), + Ldc(LDC, "tracking"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Label(6), + LineNumber(3, Label(6)), + VarOp(ALOAD, 0), + Ldc(LDC, "tracking2"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Label(11), + LineNumber(14, Label(11)), + VarOp(ALOAD, 0), + Ldc(LDC, "abc"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Op(RETURN), + Label(17) + ) + assert(instructions == expected, + "`track` was not properly inlined in `main`\n" + diffInstructions(instructions, expected)) + + } + } + + @Test def i4947c = { + val source = """class Foo { + | transparent def track2[T](f: => T): T = { + | foo("tracking2") // line 3 + | f // line 4 + | } + | transparent def track[T](f: => T): T = { + | track2 { // line 7 + | foo("fgh") // line 8 + | f // line 9 + | } + | } + | def main(args: Array[String]): Unit = { // line 12 + | track { // line 13 + | foo("abc") // line 14 + | } + | } + | def foo(str: String): Unit = () + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn, skipDebugInfo = false) + + val track = clsNode.methods.asScala.find(_.name == "track") + assert(track.isEmpty, "method `track` should have been erased") + + val track2 = clsNode.methods.asScala.find(_.name == "track2") + assert(track2.isEmpty, "method `track2` should have been erased") + + val main = getMethod(clsNode, "main") + val instructions = instructionsFromMethod(main) + val expected = + List( + Label(0), + LineNumber(12, Label(0)), + LineNumber(3, Label(0)), + VarOp(ALOAD, 0), + Ldc(LDC, "tracking2"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Label(6), + LineNumber(8, Label(6)), + VarOp(ALOAD, 0), + Ldc(LDC, "fgh"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Label(11), + LineNumber(14, Label(11)), + VarOp(ALOAD, 0), + Ldc(LDC, "abc"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Op(RETURN), + Label(17) + ) + assert(instructions == expected, + "`track` was not properly inlined in `main`\n" + diffInstructions(instructions, expected)) + + } + } + + @Test def i4947d = { + val source = """class Foo { + | transparent def track2[T](f: => T): T = { + | foo("tracking2") // line 3 + | f // line 4 + | } + | transparent def track[T](f: => T): T = { + | track2 { // line 7 + | track2 { // line 8 + | f // line 9 + | } + | } + | } + | def main(args: Array[String]): Unit = { // line 13 + | track { // line 14 + | foo("abc") // line 15 + | } + | } + | def foo(str: String): Unit = () + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn, skipDebugInfo = false) + + val track = clsNode.methods.asScala.find(_.name == "track") + assert(track.isEmpty, "method `track` should have been erased") + + val track2 = clsNode.methods.asScala.find(_.name == "track2") + assert(track2.isEmpty, "method `track2` should have been erased") + + val main = getMethod(clsNode, "main") + val instructions = instructionsFromMethod(main) + val expected = + List( + Label(0), + LineNumber(13, Label(0)), + LineNumber(3, Label(0)), + VarOp(ALOAD, 0), + Ldc(LDC, "tracking2"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Label(6), + LineNumber(3, Label(6)), + VarOp(ALOAD, 0), + Ldc(LDC, "tracking2"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Label(11), + LineNumber(15, Label(11)), + VarOp(ALOAD, 0), + Ldc(LDC, "abc"), + Invoke(INVOKEVIRTUAL, "Foo", "foo", "(Ljava/lang/String;)V", false), + Op(RETURN), + Label(17) + ) + assert(instructions == expected, + "`track` was not properly inlined in `main`\n" + diffInstructions(instructions, expected)) + + } + } } diff --git a/tests/run/i4947.check b/tests/run/i4947.check new file mode 100644 index 000000000000..35c8075066c3 --- /dev/null +++ b/tests/run/i4947.check @@ -0,0 +1,4 @@ +track: Test$.main(i4947.scala:4) +track: Test$.main(i4947.scala:5) +main1: Test$.main(i4947.scala:15) +main2: Test$.main(i4947.scala:16) diff --git a/tests/run/i4947.scala b/tests/run/i4947.scala new file mode 100644 index 000000000000..08648487f7fa --- /dev/null +++ b/tests/run/i4947.scala @@ -0,0 +1,20 @@ +object Test { + + transparent def track[T](f: => T): T = { + printStack("track") + printStack("track") + f + } + + def printStack(tag: String): Unit = { + println(tag + ": "+ new Exception().getStackTrace().apply(1)) + } + + def main(args: Array[String]): Unit = { + track { + printStack("main1") + printStack("main2") + } + } + +} diff --git a/tests/run/i4947a.check b/tests/run/i4947a.check new file mode 100644 index 000000000000..0ff7f8edcf9f --- /dev/null +++ b/tests/run/i4947a.check @@ -0,0 +1,14 @@ +track (i = 0): Test$.main(i4947a.scala:4) +track (i = 0): Test$.main(i4947a.scala:5) +track (i = 2): Test$.main(i4947a.scala:4) +track (i = 2): Test$.main(i4947a.scala:5) +main1 (i = -1): Test$.main(i4947a.scala:21) +main2 (i = -1): Test$.main(i4947a.scala:22) +track (i = 1): Test$.main(i4947a.scala:4) +track (i = 1): Test$.main(i4947a.scala:5) +main1 (i = -1): Test$.main(i4947a.scala:21) +main2 (i = -1): Test$.main(i4947a.scala:22) +track (i = 0): Test$.main(i4947a.scala:4) +track (i = 0): Test$.main(i4947a.scala:5) +main1 (i = -1): Test$.main(i4947a.scala:21) +main2 (i = -1): Test$.main(i4947a.scala:22) diff --git a/tests/run/i4947a.scala b/tests/run/i4947a.scala new file mode 100644 index 000000000000..aa0484b04e29 --- /dev/null +++ b/tests/run/i4947a.scala @@ -0,0 +1,27 @@ +object Test { + + transparent def fact[T](transparent i: Int)(f: => T): Int = { + printStack("track", i) + printStack("track", i) + f + if (i == 0) + 1 + else { + i * fact(i-1)(f) + } + } + + def printStack(tag: String, i: Int): Unit = { + println(s"$tag (i = $i): ${new Exception().getStackTrace().apply(1)}") + } + + def main(args: Array[String]): Unit = { + fact(0) { + fact(2) { + printStack("main1", -1) + printStack("main2", -1) + } + } + } + +} diff --git a/tests/run/i4947b.check b/tests/run/i4947b.check new file mode 100644 index 000000000000..183c31bdfc5d --- /dev/null +++ b/tests/run/i4947b.check @@ -0,0 +1,36 @@ +track: Test$.main(Test_2.scala:3) +track: Test$.main(Test_2.scala:3) +main1: Test$.main(Test_2.scala:4) +main2: Test$.main(Test_2.scala:5) +track: Test$.main(Test_2.scala:7) +track: Test$.main(Test_2.scala:7) +track: Test$.main(Test_2.scala:8) +track: Test$.main(Test_2.scala:8) +main3: Test$.main(Test_2.scala:9) +main4: Test$.main(Test_2.scala:10) +track (i = 0): Test$.main(Test_2.scala:13) +track (i = 0): Test$.main(Test_2.scala:13) +track: Test$.main(Test_2.scala:13) +track: Test$.main(Test_2.scala:13) +fact: Test$.main(Test_2.scala:13) +track (i = 2): Test$.main(Test_2.scala:14) +track (i = 2): Test$.main(Test_2.scala:14) +track: Test$.main(Test_2.scala:14) +track: Test$.main(Test_2.scala:14) +fact: Test$.main(Test_2.scala:14) +main1 (i = -1): Test$.main(Test_2.scala:15) +main2 (i = -1): Test$.main(Test_2.scala:16) +track (i = 1): Test$.main(Test_2.scala:14) +track (i = 1): Test$.main(Test_2.scala:14) +track: Test$.main(Test_2.scala:14) +track: Test$.main(Test_2.scala:14) +fact: Test$.main(Test_2.scala:14) +main1 (i = -1): Test$.main(Test_2.scala:15) +main2 (i = -1): Test$.main(Test_2.scala:16) +track (i = 0): Test$.main(Test_2.scala:14) +track (i = 0): Test$.main(Test_2.scala:14) +track: Test$.main(Test_2.scala:14) +track: Test$.main(Test_2.scala:14) +fact: Test$.main(Test_2.scala:14) +main1 (i = -1): Test$.main(Test_2.scala:15) +main2 (i = -1): Test$.main(Test_2.scala:16) diff --git a/tests/run/i4947b/Lib_1.scala b/tests/run/i4947b/Lib_1.scala new file mode 100644 index 000000000000..8f0799b9cc18 --- /dev/null +++ b/tests/run/i4947b/Lib_1.scala @@ -0,0 +1,28 @@ +object Lib { + transparent def track[T](f: => T): T = { + printStack("track") + printStack("track") + f + } + def printStack(tag: String): Unit = { + println(tag + ": "+ new Exception().getStackTrace().apply(1)) + } + + def printStack(tag: String, i: Int): Unit = { + println(s"$tag (i = $i): ${new Exception().getStackTrace().apply(1)}") + } + + transparent def fact[T](transparent i: Int)(f: => T): Int = { + printStack("track", i) + printStack("track", i) + track { + printStack("fact") + } + f + if (i == 0) + 1 + else { + i * fact(i-1)(f) + } + } +} diff --git a/tests/run/i4947b/Test_2.scala b/tests/run/i4947b/Test_2.scala new file mode 100644 index 000000000000..2ca1cf2e77f1 --- /dev/null +++ b/tests/run/i4947b/Test_2.scala @@ -0,0 +1,21 @@ +object Test { + def main(args: Array[String]): Unit = { + Lib.track { + Lib.printStack("main1") + Lib.printStack("main2") + } + Lib.track { + Lib.track { + Lib.printStack("main3") + Lib.printStack("main4") + } + } + Lib.fact(0) { + Lib.fact(2) { + Lib.printStack("main1", -1) + Lib.printStack("main2", -1) + } + } + } + +} diff --git a/tests/run/i4947c.check b/tests/run/i4947c.check new file mode 100644 index 000000000000..56a401547ae3 --- /dev/null +++ b/tests/run/i4947c.check @@ -0,0 +1,4 @@ +track: Test$.main(i4947c.scala:4) +track: Test$.main(i4947c.scala:5) +main1: Test$.main(i4947c.scala:18) +main2: Test$.main(i4947c.scala:19) diff --git a/tests/run/i4947c.scala b/tests/run/i4947c.scala new file mode 100644 index 000000000000..be1efe8cdb8e --- /dev/null +++ b/tests/run/i4947c.scala @@ -0,0 +1,23 @@ +object Aux { + + transparent def track[T](f: => T): T = { + printStack("track") + printStack("track") + f + } + + def printStack(tag: String): Unit = { + println(tag + ": "+ new Exception().getStackTrace().apply(1)) + } +} + +object Test { + import Aux._ + def main(args: Array[String]): Unit = { + track { + printStack("main1") + printStack("main2") + } + } + +}