diff --git a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala index 8fd690aa8633..22350ca5dfef 100644 --- a/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala +++ b/compiler/test/dotty/tools/backend/jvm/InlineBytecodeTests.scala @@ -104,6 +104,30 @@ class InlineBytecodeTests extends DottyBytecodeTest { } } + @Test def inlineNn = { + val source = + s""" + |class Foo { + | def meth1(x: Int | Null): Int = x.nn + | def meth2(x: Int | Null): Int = scala.runtime.Scala3RunTime.nn(x) + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + "`nn` was not properly inlined in `meth1`\n" + + diffInstructions(instructions1, instructions2)) + } + } + @Test def i4947 = { val source = """class Foo { | transparent inline def track[T](inline f: T): T = { diff --git a/library/src/dotty/DottyPredef.scala b/library/src/dotty/DottyPredef.scala index 5e2512461386..72412c8c417b 100644 --- a/library/src/dotty/DottyPredef.scala +++ b/library/src/dotty/DottyPredef.scala @@ -79,7 +79,6 @@ object DottyPredef { * * Note that `.nn` performs a checked cast, so if invoked on a null value it'll throw an NPE. */ - extension [T](x: T | Null) def nn: x.type & T = - if (x == null) throw new NullPointerException("tried to cast away nullability, but value is null") - else x.asInstanceOf[x.type & T] + extension [T](x: T | Null) inline def nn: x.type & T = + scala.runtime.Scala3RunTime.nn(x) } diff --git a/library/src/scala/runtime/Scala3RunTime.scala b/library/src/scala/runtime/Scala3RunTime.scala index 592e4e21396c..fa141628ceec 100644 --- a/library/src/scala/runtime/Scala3RunTime.scala +++ b/library/src/scala/runtime/Scala3RunTime.scala @@ -10,4 +10,12 @@ object Scala3RunTime: def assertFailed(): Nothing = throw new java.lang.AssertionError("assertion failed") + /** Called by the inline extension def `nn`. + * + * Extracted to minimize the bytecode size at call site. + */ + def nn[T](x: T | Null): x.type & T = + if (x == null) throw new NullPointerException("tried to cast away nullability, but value is null") + else x.asInstanceOf[x.type & T] + end Scala3RunTime