Skip to content

Commit 276c8f8

Browse files
committed
Fix #502: Optimize Array.apply([...]) to [...]
1 parent cba756d commit 276c8f8

File tree

6 files changed

+177
-0
lines changed

6 files changed

+177
-0
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class Compiler {
9292
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
9393
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
9494
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
95+
new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
9596
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
9697
new TailRec, // Rewrite tail recursion to loops
9798
new Mixin, // Expand trait fields and trait initializers
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import core._
5+
import MegaPhase._
6+
import Contexts.Context
7+
import Symbols._
8+
import Types._
9+
import StdNames._
10+
import ast.Trees._
11+
import dotty.tools.dotc.ast.tpd
12+
13+
14+
/** This phase rewrites calls to `Array.apply` to primitive array instantion.
15+
*
16+
* Transforms `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
17+
*/
18+
class ArrayApply extends MiniPhase {
19+
import tpd._
20+
21+
override def phaseName: String = "arrayApply"
22+
23+
override def transformApply(tree: tpd.Apply)(implicit ctx: Context): tpd.Tree = {
24+
if (tree.symbol.name == nme.apply && tree.symbol.owner == defn.ArrayModule) { // Is `Array.apply`
25+
tree.args match {
26+
case CleanTree(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
27+
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
28+
seqLit
29+
30+
case elem0 :: CleanTree(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil
31+
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) =>
32+
tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
33+
34+
case _ =>
35+
tree
36+
}
37+
38+
} else tree
39+
}
40+
41+
// Only optimize when classtag is `ClassTag.apply` or `ClassTag.{Byte, Boolean, ...}`
42+
private def elideClassTag(ct: Tree)(implicit ctx: Context): Boolean = {
43+
ct.symbol.maybeOwner.companionModule == defn.ClassTagModule
44+
}
45+
46+
object CleanTree {
47+
def unapply(tree: Tree)(implicit ctx: Context): Some[Tree] = tree match {
48+
case Block(Nil, expr) => unapply(expr)
49+
case Typed(expr, _) => unapply(expr)
50+
case _ => Some(tree)
51+
}
52+
}
53+
}
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package dotty.tools.backend.jvm
2+
3+
import org.junit.Test
4+
import org.junit.Assert._
5+
6+
import scala.tools.asm.Opcodes._
7+
8+
class ArrayApplyOptTest extends DottyBytecodeTest {
9+
import ASMConverters._
10+
11+
@Test def testArrayEmptyGenericApply= {
12+
test("Array[String]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "java/lang/String"), Op(POP), Op(RETURN)))
13+
test("Array[Unit]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "scala/runtime/BoxedUnit"), Op(POP), Op(RETURN)))
14+
test("Array[Object]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "java/lang/Object"), Op(POP), Op(RETURN)))
15+
test("Array[Boolean]()", List(Op(ICONST_0), IntOp(NEWARRAY, 4), Op(POP), Op(RETURN)))
16+
test("Array[Char]()", List(Op(ICONST_0), IntOp(NEWARRAY, 5), Op(POP), Op(RETURN)))
17+
test("Array[Float]()", List(Op(ICONST_0), IntOp(NEWARRAY, 6), Op(POP), Op(RETURN)))
18+
test("Array[Double]()", List(Op(ICONST_0), IntOp(NEWARRAY, 7), Op(POP), Op(RETURN)))
19+
test("Array[Byte]()", List(Op(ICONST_0), IntOp(NEWARRAY, 8), Op(POP), Op(RETURN)))
20+
test("Array[Short]()", List(Op(ICONST_0), IntOp(NEWARRAY, 9), Op(POP), Op(RETURN)))
21+
test("Array[Int]()", List(Op(ICONST_0), IntOp(NEWARRAY, 10), Op(POP), Op(RETURN)))
22+
test("Array[Long]()", List(Op(ICONST_0), IntOp(NEWARRAY, 11), Op(POP), Op(RETURN)))
23+
test("Array[T]()", List(Op(ICONST_0), IntOp(NEWARRAY, 10), Op(POP), Op(RETURN)))
24+
}
25+
26+
@Test def testArrayGenericApply= {
27+
test("""Array("a", "b")""", List(Op(ICONST_2), TypeOp(ANEWARRAY, "java/lang/String"), Op(DUP), Op(ICONST_0), Ldc(LDC, "a"), Op(AASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, "b"), Op(AASTORE), Op(POP), Op(RETURN)))
28+
test("""Array[Object]("a", "b")""", List(Op(ICONST_2), TypeOp(ANEWARRAY, "java/lang/Object"), Op(DUP), Op(ICONST_0), Ldc(LDC, "a"), Op(AASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, "b"), Op(AASTORE), Op(POP), Op(RETURN)))
29+
}
30+
31+
@Test def testArrayApplyBoolean =
32+
test("Array(true, false)", List(Op(ICONST_2), IntOp(NEWARRAY, 4), Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(BASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_0), Op(BASTORE), Op(POP), Op(RETURN)))
33+
34+
@Test def testArrayApplyByte =
35+
test("Array[Byte](1, 2)", List(Op(ICONST_2), IntOp(NEWARRAY, 8), Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(BASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(BASTORE), Op(POP), Op(RETURN)))
36+
37+
@Test def testArrayApplyShort =
38+
test("Array[Short](1, 2)", List(Op(ICONST_2), IntOp(NEWARRAY, 9), Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(SASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(SASTORE), Op(POP), Op(RETURN)))
39+
40+
@Test def testArrayApplyInt = {
41+
test("Array(1, 2)", List(Op(ICONST_2), IntOp(NEWARRAY, 10), Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE), Op(POP), Op(RETURN)))
42+
test("""Array[T](t, t)""", List(Op(ICONST_2), IntOp(NEWARRAY, 10), Op(DUP), Op(ICONST_0), Field(GETSTATIC, "Foo$", "MODULE$", "LFoo$;"), Invoke(INVOKEVIRTUAL, "Foo$", "t", "()I", false), Op(IASTORE), Op(DUP), Op(ICONST_1), Field(GETSTATIC, "Foo$", "MODULE$", "LFoo$;"), Invoke(INVOKEVIRTUAL, "Foo$", "t", "()I", false), Op(IASTORE), Op(POP), Op(RETURN)))
43+
}
44+
45+
@Test def testArrayApplyLong =
46+
test("Array(2L, 3L)", List(Op(ICONST_2), IntOp(NEWARRAY, 11), Op(DUP), Op(ICONST_0), Ldc(LDC, 2), Op(LASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3), Op(LASTORE), Op(POP), Op(RETURN)))
47+
48+
@Test def testArrayApplyFloat =
49+
test("Array(2.1f, 3.1f)", List(Op(ICONST_2), IntOp(NEWARRAY, 6), Op(DUP), Op(ICONST_0), Ldc(LDC, 2.1f), Op(FASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3.1f), Op(FASTORE), Op(POP), Op(RETURN)))
50+
51+
@Test def testArrayApplyDouble =
52+
test("Array(2.2d, 3.2d)", List(Op(ICONST_2), IntOp(NEWARRAY, 7), Op(DUP), Op(ICONST_0), Ldc(LDC, 2.2d), Op(DASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3.2d), Op(DASTORE), Op(POP), Op(RETURN)))
53+
54+
@Test def testArrayApplyChar =
55+
test("Array('x', 'y')", List(Op(ICONST_2), IntOp(NEWARRAY, 5), Op(DUP), Op(ICONST_0), IntOp(BIPUSH, 120), Op(CASTORE), Op(DUP), Op(ICONST_1), IntOp(BIPUSH, 121), Op(CASTORE), Op(POP), Op(RETURN)))
56+
57+
@Test def testArrayApplyUnit =
58+
test("Array[Unit]((), ())", List(Op(ICONST_2), TypeOp(ANEWARRAY, "scala/runtime/BoxedUnit"), Op(DUP),
59+
Op(ICONST_0), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), Op(AASTORE), Op(DUP),
60+
Op(ICONST_1), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), Op(AASTORE), Op(POP), Op(RETURN)))
61+
62+
@Test def testArrayInlined = test(
63+
"""{
64+
| inline def array(xs: =>Int*): Array[Int] = Array(xs: _*)
65+
| array(1, 2)
66+
|}""".stripMargin,
67+
List(Op(ICONST_2), IntOp(NEWARRAY, 10), Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE), TypeOp(CHECKCAST, "[I"), Op(POP), Op(RETURN))
68+
)
69+
70+
@Test def testArrayInlined2 = test(
71+
"""{
72+
| inline def array(x: =>Int, xs: =>Int*): Array[Int] = Array(x, xs: _*)
73+
| array(1, 2)
74+
|}""".stripMargin,
75+
List(Op(ICONST_2), IntOp(NEWARRAY, 10), Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE), Op(POP), Op(RETURN))
76+
)
77+
78+
private def test(code: String, expectedInstructions: List[Any])= {
79+
val source =
80+
s"""class Foo {
81+
| import Foo._
82+
| def test: Unit = $code
83+
|}
84+
|object Foo {
85+
| opaque type T = Int
86+
| def t: T = 1
87+
|}
88+
""".stripMargin
89+
90+
checkBCode(source) { dir =>
91+
val clsIn = dir.lookupName("Foo.class", directory = false).input
92+
val clsNode = loadClassNode(clsIn)
93+
val meth = getMethod(clsNode, "test")
94+
95+
val instructions = instructionsFromMethod(meth)
96+
97+
assertEquals(expectedInstructions, instructions)
98+
}
99+
}
100+
101+
}

tests/run/i502.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Ok
2+
foo

tests/run/i502.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import scala.reflect.ClassTag
2+
3+
object Test extends App {
4+
Array[Int](1, 2)
5+
6+
try {
7+
Array[Int](1, 2)(null)
8+
???
9+
} catch {
10+
case _: NullPointerException => println("Ok")
11+
}
12+
13+
Array[Int](1, 2)({println("foo"); the[ClassTag[Int]]})
14+
}

tests/run/t6611b.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
object Test extends App {
2+
val a = Array("1")
3+
val a2 = Array(a: _*)
4+
a2(0) = "2"
5+
assert(a(0) == "1")
6+
}

0 commit comments

Comments
 (0)