Skip to content

Commit ccad792

Browse files
Merge pull request #6821 from dotty-staging/array-apply-opt
Fix #502: Optimize `Array.apply([...])` to `[...]`
2 parents 786dfd9 + 1541817 commit ccad792

File tree

7 files changed

+206
-0
lines changed

7 files changed

+206
-0
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class Compiler {
9292
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
9393
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
9494
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
95+
new ArrayApply, // Optimize `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
9596
new ElimPolyFunction, // Rewrite PolyFunction subclasses to FunctionN subclasses
9697
new TailRec, // Rewrite tail recursion to loops
9798
new Mixin, // Expand trait fields and trait initializers

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,8 @@ class Definitions {
755755
@threadUnsafe lazy val ClassTagType: TypeRef = ctx.requiredClassRef("scala.reflect.ClassTag")
756756
def ClassTagClass(implicit ctx: Context): ClassSymbol = ClassTagType.symbol.asClass
757757
def ClassTagModule(implicit ctx: Context): Symbol = ClassTagClass.companionModule
758+
@threadUnsafe lazy val ClassTagModule_applyR: TermRef = ClassTagModule.requiredMethodRef(nme.apply)
759+
def ClassTagModule_apply(implicit ctx: Context): Symbol = ClassTagModule_applyR.symbol
758760

759761
@threadUnsafe lazy val QuotedExprType: TypeRef = ctx.requiredClassRef("scala.quoted.Expr")
760762
def QuotedExprClass(implicit ctx: Context): ClassSymbol = QuotedExprType.symbol.asClass
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import core._
5+
import MegaPhase._
6+
import Contexts.Context
7+
import Symbols._
8+
import Types._
9+
import StdNames._
10+
import ast.Trees._
11+
import dotty.tools.dotc.ast.tpd
12+
13+
import scala.reflect.ClassTag
14+
15+
16+
/** This phase rewrites calls to `Array.apply` to a direct instantiation of the array in the bytecode.
17+
*
18+
* Transforms `scala.Array.apply([....])` and `scala.Array.apply(..., [....])` into `[...]`
19+
*/
20+
class ArrayApply extends MiniPhase {
21+
import tpd._
22+
23+
override def phaseName: String = "arrayApply"
24+
25+
override def transformApply(tree: tpd.Apply)(implicit ctx: Context): tpd.Tree = {
26+
if (tree.symbol.name == nme.apply && tree.symbol.owner == defn.ArrayModule) { // Is `Array.apply`
27+
tree.args match {
28+
case StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: ct :: Nil
29+
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) && elideClassTag(ct) =>
30+
seqLit
31+
32+
case elem0 :: StripAscription(Apply(wrapRefArrayMeth, (seqLit: tpd.JavaSeqLiteral) :: Nil)) :: Nil
33+
if defn.WrapArrayMethods().contains(wrapRefArrayMeth.symbol) =>
34+
tpd.JavaSeqLiteral(elem0 :: seqLit.elems, seqLit.elemtpt)
35+
36+
case _ =>
37+
tree
38+
}
39+
40+
} else tree
41+
}
42+
43+
/** Only optimize when classtag if it is one of
44+
* - `ClassTag.apply(classOf[XYZ])`
45+
* - `ClassTag.apply(java.lang.XYZ.Type)` for boxed primitives `XYZ``
46+
* - `ClassTag.XYZ` for primitive types
47+
*/
48+
private def elideClassTag(ct: Tree)(implicit ctx: Context): Boolean = ct match {
49+
case Apply(_, rc :: Nil) if ct.symbol == defn.ClassTagModule_apply =>
50+
rc match {
51+
case _: Literal => true // ClassTag.apply(classOf[XYZ])
52+
case rc: RefTree if rc.name == nme.TYPE_ =>
53+
// ClassTag.apply(java.lang.XYZ.Type)
54+
defn.ScalaBoxedClasses().contains(rc.symbol.maybeOwner.companionClass)
55+
case _ => false
56+
}
57+
case Apply(ctm: RefTree, _) if ctm.symbol.maybeOwner.companionModule == defn.ClassTagModule =>
58+
// ClassTag.XYZ
59+
nme.ScalaValueNames.contains(ctm.name)
60+
case _ => false
61+
}
62+
63+
object StripAscription {
64+
def unapply(tree: Tree)(implicit ctx: Context): Some[Tree] = tree match {
65+
case Typed(expr, _) => unapply(expr)
66+
case _ => Some(tree)
67+
}
68+
}
69+
}
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package dotty.tools.backend.jvm
2+
3+
import org.junit.Test
4+
import org.junit.Assert._
5+
6+
import scala.tools.asm.Opcodes._
7+
8+
class ArrayApplyOptTest extends DottyBytecodeTest {
9+
import ASMConverters._
10+
11+
@Test def testArrayEmptyGenericApply= {
12+
test("Array[String]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "java/lang/String"), Op(POP), Op(RETURN)))
13+
test("Array[Unit]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "scala/runtime/BoxedUnit"), Op(POP), Op(RETURN)))
14+
test("Array[Object]()", List(Op(ICONST_0), TypeOp(ANEWARRAY, "java/lang/Object"), Op(POP), Op(RETURN)))
15+
test("Array[Boolean]()", newArray0Opcodes(T_BOOLEAN))
16+
test("Array[Byte]()", newArray0Opcodes(T_BYTE))
17+
test("Array[Short]()", newArray0Opcodes(T_SHORT))
18+
test("Array[Int]()", newArray0Opcodes(T_INT))
19+
test("Array[Long]()", newArray0Opcodes(T_LONG))
20+
test("Array[Float]()", newArray0Opcodes(T_FLOAT))
21+
test("Array[Double]()", newArray0Opcodes(T_DOUBLE))
22+
test("Array[Char]()", newArray0Opcodes(T_CHAR))
23+
test("Array[T]()", newArray0Opcodes(T_INT))
24+
}
25+
26+
@Test def testArrayGenericApply= {
27+
def opCodes(tpe: String) =
28+
List(Op(ICONST_2), TypeOp(ANEWARRAY, tpe), Op(DUP), Op(ICONST_0), Ldc(LDC, "a"), Op(AASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, "b"), Op(AASTORE), Op(POP), Op(RETURN))
29+
test("""Array("a", "b")""", opCodes("java/lang/String"))
30+
test("""Array[Object]("a", "b")""", opCodes("java/lang/Object"))
31+
}
32+
33+
@Test def testArrayApplyBoolean =
34+
test("Array(true, false)", newArray2Opcodes(T_BOOLEAN, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(BASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_0), Op(BASTORE))))
35+
36+
@Test def testArrayApplyByte =
37+
test("Array[Byte](1, 2)", newArray2Opcodes(T_BYTE, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(BASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(BASTORE))))
38+
39+
@Test def testArrayApplyShort =
40+
test("Array[Short](1, 2)", newArray2Opcodes(T_SHORT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(SASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(SASTORE))))
41+
42+
@Test def testArrayApplyInt = {
43+
test("Array(1, 2)", newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE))))
44+
test("""Array[T](t, t)""", newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Field(GETSTATIC, "Foo$", "MODULE$", "LFoo$;"), Invoke(INVOKEVIRTUAL, "Foo$", "t", "()I", false), Op(IASTORE), Op(DUP), Op(ICONST_1), Field(GETSTATIC, "Foo$", "MODULE$", "LFoo$;"), Invoke(INVOKEVIRTUAL, "Foo$", "t", "()I", false), Op(IASTORE))))
45+
}
46+
47+
@Test def testArrayApplyLong =
48+
test("Array(2L, 3L)", newArray2Opcodes(T_LONG, List(Op(DUP), Op(ICONST_0), Ldc(LDC, 2), Op(LASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3), Op(LASTORE))))
49+
50+
@Test def testArrayApplyFloat =
51+
test("Array(2.1f, 3.1f)", newArray2Opcodes(T_FLOAT, List(Op(DUP), Op(ICONST_0), Ldc(LDC, 2.1f), Op(FASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3.1f), Op(FASTORE))))
52+
53+
@Test def testArrayApplyDouble =
54+
test("Array(2.2d, 3.2d)", newArray2Opcodes(T_DOUBLE, List(Op(DUP), Op(ICONST_0), Ldc(LDC, 2.2d), Op(DASTORE), Op(DUP), Op(ICONST_1), Ldc(LDC, 3.2d), Op(DASTORE))))
55+
56+
@Test def testArrayApplyChar =
57+
test("Array('x', 'y')", newArray2Opcodes(T_CHAR, List(Op(DUP), Op(ICONST_0), IntOp(BIPUSH, 120), Op(CASTORE), Op(DUP), Op(ICONST_1), IntOp(BIPUSH, 121), Op(CASTORE))))
58+
59+
@Test def testArrayApplyUnit =
60+
test("Array[Unit]((), ())", List(Op(ICONST_2), TypeOp(ANEWARRAY, "scala/runtime/BoxedUnit"), Op(DUP),
61+
Op(ICONST_0), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), Op(AASTORE), Op(DUP),
62+
Op(ICONST_1), Field(GETSTATIC, "scala/runtime/BoxedUnit", "UNIT", "Lscala/runtime/BoxedUnit;"), Op(AASTORE), Op(POP), Op(RETURN)))
63+
64+
@Test def testArrayInlined = test(
65+
"""{
66+
| inline def array(xs: =>Int*): Array[Int] = Array(xs: _*)
67+
| array(1, 2)
68+
|}""".stripMargin,
69+
newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE), TypeOp(CHECKCAST, "[I")))
70+
)
71+
72+
@Test def testArrayInlined2 = test(
73+
"""{
74+
| inline def array(x: =>Int, xs: =>Int*): Array[Int] = Array(x, xs: _*)
75+
| array(1, 2)
76+
|}""".stripMargin,
77+
newArray2Opcodes(T_INT, List(Op(DUP), Op(ICONST_0), Op(ICONST_1), Op(IASTORE), Op(DUP), Op(ICONST_1), Op(ICONST_2), Op(IASTORE)))
78+
)
79+
80+
private def newArray0Opcodes(tpe: Int, init: List[Any] = Nil): List[Any] =
81+
Op(ICONST_0) :: IntOp(NEWARRAY, tpe) :: init ::: Op(POP) :: Op(RETURN) :: Nil
82+
83+
private def newArray2Opcodes(tpe: Int, init: List[Any] = Nil): List[Any] =
84+
Op(ICONST_2) :: IntOp(NEWARRAY, tpe) :: init ::: Op(POP) :: Op(RETURN) :: Nil
85+
86+
private def test(code: String, expectedInstructions: List[Any])= {
87+
val source =
88+
s"""class Foo {
89+
| import Foo._
90+
| def test: Unit = $code
91+
|}
92+
|object Foo {
93+
| opaque type T = Int
94+
| def t: T = 1
95+
|}
96+
""".stripMargin
97+
98+
checkBCode(source) { dir =>
99+
val clsIn = dir.lookupName("Foo.class", directory = false).input
100+
val clsNode = loadClassNode(clsIn)
101+
val meth = getMethod(clsNode, "test")
102+
103+
val instructions = instructionsFromMethod(meth)
104+
105+
assertEquals(expectedInstructions, instructions)
106+
}
107+
}
108+
109+
}

tests/run/i502.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Ok
2+
foo
3+
bar

tests/run/i502.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.reflect.ClassTag
2+
3+
object Test extends App {
4+
Array[Int](1, 2)
5+
6+
try {
7+
Array[Int](1, 2)(null)
8+
???
9+
} catch {
10+
case _: NullPointerException => println("Ok")
11+
}
12+
13+
Array[Int](1, 2)({println("foo"); the[ClassTag[Int]]})
14+
15+
Array[Int](1, 2)(ClassTag.apply({ println("bar"); classOf[Int]}))
16+
}

tests/run/t6611b.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
object Test extends App {
2+
val a = Array("1")
3+
val a2 = Array(a: _*)
4+
a2(0) = "2"
5+
assert(a(0) == "1")
6+
}

0 commit comments

Comments
 (0)