Skip to content

Commit 48f0e89

Browse files
committed
Specialize Any.{==, !=} when inlining
Their specilized versions are known to have the same semantics as the generic `==` or `!=`. Improvement related to #11998
1 parent c3c7cd3 commit 48f0e89

File tree

2 files changed

+120
-1
lines changed

2 files changed

+120
-1
lines changed

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
13831383
val expanded = expandMacro(res.args.head, tree.srcPos)
13841384
typedExpr(expanded) // Inline calls and constant fold code generated by the macro
13851385
case res =>
1386-
inlineIfNeeded(res)
1386+
specializeEq(inlineIfNeeded(res))
13871387
}
13881388
if res.symbol == defn.QuotedRuntime_exprQuote then
13891389
ctx.compilationUnit.needsQuotePickling = true
@@ -1465,6 +1465,21 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
14651465
case tree => tree
14661466
}
14671467

1468+
def specializeEq(tree: Tree): Tree =
1469+
tree match
1470+
case Apply(sel @ Select(arg1, opName), arg2 :: Nil)
1471+
if sel.symbol == defn.Any_== || sel.symbol == defn.Any_!= =>
1472+
defn.ScalaValueClasses().find { cls =>
1473+
arg1.tpe.derivesFrom(cls) && arg2.tpe.derivesFrom(cls)
1474+
} match {
1475+
case Some(cls) =>
1476+
val newOp = cls.requiredMethod(opName, List(cls.typeRef))
1477+
arg1.select(newOp).withSpan(sel.span).appliedTo(arg2).withSpan(tree.span)
1478+
case None => tree
1479+
}
1480+
case _ =>
1481+
tree
1482+
14681483
/** Drop any side-effect-free bindings that are unused in expansion or other reachable bindings.
14691484
* Inline def bindings that are used only once.
14701485
*/

compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,4 +615,108 @@ class InlineBytecodeTests extends DottyBytecodeTest {
615615

616616
}
617617
}
618+
619+
@Test def any_eq_specialization = {
620+
val source = """class Test:
621+
| inline def eql(x: Any, y: Any) = x == y
622+
|
623+
| def testAny(x: Any, y: Any) = eql(x, y)
624+
| def testAnyExpected(x: Any, y: Any) = x == y
625+
|
626+
| def testBoolean(x: Boolean, y: Boolean) = eql(x, y)
627+
| def testBooleanExpected(x: Boolean, y: Boolean) = x == y
628+
|
629+
| def testByte(x: Byte, y: Byte) = eql(x, y)
630+
| def testByteExpected(x: Byte, y: Byte) = x == y
631+
|
632+
| def testShort(x: Short, y: Short) = eql(x, y)
633+
| def testShortExpected(x: Short, y: Short) = x == y
634+
|
635+
| def testInt(x: Int, y: Int) = eql(x, y)
636+
| def testIntExpected(x: Int, y: Int) = x == y
637+
|
638+
| def testLong(x: Long, y: Long) = eql(x, y)
639+
| def testLongExpected(x: Long, y: Long) = x == y
640+
|
641+
| def testFloat(x: Float, y: Float) = eql(x, y)
642+
| def testFloatExpected(x: Float, y: Float) = x == y
643+
|
644+
| def testDouble(x: Double, y: Double) = eql(x, y)
645+
| def testDoubleExpected(x: Double, y: Double) = x == y
646+
|
647+
| def testChar(x: Char, y: Char) = eql(x, y)
648+
| def testCharExpected(x: Char, y: Char) = x == y
649+
|
650+
| def testUnit(x: Unit, y: Unit) = eql(x, y)
651+
| def testUnitExpected(x: Unit, y: Unit) = x == y
652+
""".stripMargin
653+
654+
checkBCode(source) { dir =>
655+
val clsIn = dir.lookupName("Test.class", directory = false).input
656+
val clsNode = loadClassNode(clsIn)
657+
658+
for cls <- List("Boolean", "Byte", "Short", "Int", "Long", "Float", "Double", "Char", "Unit") do
659+
val meth1 = getMethod(clsNode, s"test$cls")
660+
val meth2 = getMethod(clsNode, s"test${cls}Expected")
661+
662+
val instructions1 = instructionsFromMethod(meth1)
663+
val instructions2 = instructionsFromMethod(meth2)
664+
665+
assert(instructions1 == instructions2,
666+
s"`==` was not properly specialized when inlined in `test$cls`\n" +
667+
diffInstructions(instructions1, instructions2))
668+
}
669+
}
670+
671+
@Test def any_neq_specialization = {
672+
val source = """class Test:
673+
| inline def neql(x: Any, y: Any) = x != y
674+
|
675+
| def testAny(x: Any, y: Any) = neql(x, y)
676+
| def testAnyExpected(x: Any, y: Any) = x != y
677+
|
678+
| def testBoolean(x: Boolean, y: Boolean) = neql(x, y)
679+
| def testBooleanExpected(x: Boolean, y: Boolean) = x != y
680+
|
681+
| def testByte(x: Byte, y: Byte) = neql(x, y)
682+
| def testByteExpected(x: Byte, y: Byte) = x != y
683+
|
684+
| def testShort(x: Short, y: Short) = neql(x, y)
685+
| def testShortExpected(x: Short, y: Short) = x != y
686+
|
687+
| def testInt(x: Int, y: Int) = neql(x, y)
688+
| def testIntExpected(x: Int, y: Int) = x != y
689+
|
690+
| def testLong(x: Long, y: Long) = neql(x, y)
691+
| def testLongExpected(x: Long, y: Long) = x != y
692+
|
693+
| def testFloat(x: Float, y: Float) = neql(x, y)
694+
| def testFloatExpected(x: Float, y: Float) = x != y
695+
|
696+
| def testDouble(x: Double, y: Double) = neql(x, y)
697+
| def testDoubleExpected(x: Double, y: Double) = x != y
698+
|
699+
| def testChar(x: Char, y: Char) = neql(x, y)
700+
| def testCharExpected(x: Char, y: Char) = x != y
701+
|
702+
| def testUnit(x: Unit, y: Unit) = neql(x, y)
703+
| def testUnitExpected(x: Unit, y: Unit) = x != y
704+
""".stripMargin
705+
706+
checkBCode(source) { dir =>
707+
val clsIn = dir.lookupName("Test.class", directory = false).input
708+
val clsNode = loadClassNode(clsIn)
709+
710+
for cls <- List("Boolean", "Byte", "Short", "Int", "Long", "Float", "Double", "Char", "Unit") do
711+
val meth1 = getMethod(clsNode, s"test$cls")
712+
val meth2 = getMethod(clsNode, s"test${cls}Expected")
713+
714+
val instructions1 = instructionsFromMethod(meth1)
715+
val instructions2 = instructionsFromMethod(meth2)
716+
717+
assert(instructions1 == instructions2,
718+
s"`!=` was not properly specialized when inlined in `test$cls`\n" +
719+
diffInstructions(instructions1, instructions2))
720+
}
721+
}
618722
}

0 commit comments

Comments
 (0)