Skip to content

Specialize Any.{==, !=} when inlining #12157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1383,7 +1383,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
val expanded = expandMacro(res.args.head, tree.srcPos)
typedExpr(expanded) // Inline calls and constant fold code generated by the macro
case res =>
inlineIfNeeded(res)
specializeEq(inlineIfNeeded(res))
}
if res.symbol == defn.QuotedRuntime_exprQuote then
ctx.compilationUnit.needsQuotePickling = true
Expand Down Expand Up @@ -1465,6 +1465,21 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
case tree => tree
}

def specializeEq(tree: Tree): Tree =
tree match
case Apply(sel @ Select(arg1, opName), arg2 :: Nil)
if sel.symbol == defn.Any_== || sel.symbol == defn.Any_!= =>
defn.ScalaValueClasses().find { cls =>
arg1.tpe.derivesFrom(cls) && arg2.tpe.derivesFrom(cls)
} match {
case Some(cls) =>
val newOp = cls.requiredMethod(opName, List(cls.typeRef))
arg1.select(newOp).withSpan(sel.span).appliedTo(arg2).withSpan(tree.span)
case None => tree
}
case _ =>
tree
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the criteria for deciding what methods should be specialized? Is it possible to generalize the specialization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The criteria is that it is an == or != on Any

sel.symbol == defn.Any_== || sel.symbol == defn.Any_!=

and that the types of the compared values are the same value class

defn.ScalaValueClasses().find { cls =>
          arg1.tpe.derivesFrom(cls) && arg2.tpe.derivesFrom(cls)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might generalize it for different classes. For example Short == Int.

The other potential generalization is on Numeric ops but that is much more complex.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking why even specialize == and !=. This seems to related to the discussion about whether inlining should re-resolve selections. Such specialization helps very little in that perspective, I'm thinking whether it's worthwhile to complicate the compiler.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, in this case it is about the partial evaluation of the operation. If we have (2: Any) == (3: Any) we know that this will end calling 2 == 3 and return that result.

All the testing frameworks will benefit from this as the variants of assertEquals will generate this code. That is already a large enough use case.


/** Drop any side-effect-free bindings that are unused in expansion or other reachable bindings.
* Inline def bindings that are used only once.
*/
Expand Down
104 changes: 104 additions & 0 deletions compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -615,4 +615,108 @@ class InlineBytecodeTests extends DottyBytecodeTest {

}
}

@Test def any_eq_specialization = {
val source = """class Test:
| inline def eql(x: Any, y: Any) = x == y
|
| def testAny(x: Any, y: Any) = eql(x, y)
| def testAnyExpected(x: Any, y: Any) = x == y
|
| def testBoolean(x: Boolean, y: Boolean) = eql(x, y)
| def testBooleanExpected(x: Boolean, y: Boolean) = x == y
|
| def testByte(x: Byte, y: Byte) = eql(x, y)
| def testByteExpected(x: Byte, y: Byte) = x == y
|
| def testShort(x: Short, y: Short) = eql(x, y)
| def testShortExpected(x: Short, y: Short) = x == y
|
| def testInt(x: Int, y: Int) = eql(x, y)
| def testIntExpected(x: Int, y: Int) = x == y
|
| def testLong(x: Long, y: Long) = eql(x, y)
| def testLongExpected(x: Long, y: Long) = x == y
|
| def testFloat(x: Float, y: Float) = eql(x, y)
| def testFloatExpected(x: Float, y: Float) = x == y
|
| def testDouble(x: Double, y: Double) = eql(x, y)
| def testDoubleExpected(x: Double, y: Double) = x == y
|
| def testChar(x: Char, y: Char) = eql(x, y)
| def testCharExpected(x: Char, y: Char) = x == y
|
| def testUnit(x: Unit, y: Unit) = eql(x, y)
| def testUnitExpected(x: Unit, y: Unit) = x == y
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)

for cls <- List("Boolean", "Byte", "Short", "Int", "Long", "Float", "Double", "Char", "Unit") do
val meth1 = getMethod(clsNode, s"test$cls")
val meth2 = getMethod(clsNode, s"test${cls}Expected")

val instructions1 = instructionsFromMethod(meth1)
val instructions2 = instructionsFromMethod(meth2)

assert(instructions1 == instructions2,
s"`==` was not properly specialized when inlined in `test$cls`\n" +
diffInstructions(instructions1, instructions2))
}
}

@Test def any_neq_specialization = {
val source = """class Test:
| inline def neql(x: Any, y: Any) = x != y
|
| def testAny(x: Any, y: Any) = neql(x, y)
| def testAnyExpected(x: Any, y: Any) = x != y
|
| def testBoolean(x: Boolean, y: Boolean) = neql(x, y)
| def testBooleanExpected(x: Boolean, y: Boolean) = x != y
|
| def testByte(x: Byte, y: Byte) = neql(x, y)
| def testByteExpected(x: Byte, y: Byte) = x != y
|
| def testShort(x: Short, y: Short) = neql(x, y)
| def testShortExpected(x: Short, y: Short) = x != y
|
| def testInt(x: Int, y: Int) = neql(x, y)
| def testIntExpected(x: Int, y: Int) = x != y
|
| def testLong(x: Long, y: Long) = neql(x, y)
| def testLongExpected(x: Long, y: Long) = x != y
|
| def testFloat(x: Float, y: Float) = neql(x, y)
| def testFloatExpected(x: Float, y: Float) = x != y
|
| def testDouble(x: Double, y: Double) = neql(x, y)
| def testDoubleExpected(x: Double, y: Double) = x != y
|
| def testChar(x: Char, y: Char) = neql(x, y)
| def testCharExpected(x: Char, y: Char) = x != y
|
| def testUnit(x: Unit, y: Unit) = neql(x, y)
| def testUnitExpected(x: Unit, y: Unit) = x != y
""".stripMargin

checkBCode(source) { dir =>
val clsIn = dir.lookupName("Test.class", directory = false).input
val clsNode = loadClassNode(clsIn)

for cls <- List("Boolean", "Byte", "Short", "Int", "Long", "Float", "Double", "Char", "Unit") do
val meth1 = getMethod(clsNode, s"test$cls")
val meth2 = getMethod(clsNode, s"test${cls}Expected")

val instructions1 = instructionsFromMethod(meth1)
val instructions2 = instructionsFromMethod(meth2)

assert(instructions1 == instructions2,
s"`!=` was not properly specialized when inlined in `test$cls`\n" +
diffInstructions(instructions1, instructions2))
}
}
}