Skip to content

Fix bytecode generation for Single Abstract Method lambdas #11839

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1414,7 +1414,7 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
def genLoadTry(tree: Try): BType

def genInvokeDynamicLambda(ctor: Symbol, lambdaTarget: Symbol, environmentSize: Int, functionalInterface: Symbol): BType = {
import java.lang.invoke.LambdaMetafactory.FLAG_SERIALIZABLE
import java.lang.invoke.LambdaMetafactory.{FLAG_BRIDGES, FLAG_SERIALIZABLE}

report.debuglog(s"Using invokedynamic rather than `new ${ctor.owner}`")
val generatedType = classBTypeFromSymbol(functionalInterface)
Expand Down Expand Up @@ -1445,9 +1445,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
val functionalInterfaceDesc: String = generatedType.descriptor
val desc = capturedParamsTypes.map(tpe => toTypeKind(tpe)).mkString(("("), "", ")") + functionalInterfaceDesc
// TODO specialization
val constrainedType = new MethodBType(lambdaParamTypes.map(p => toTypeKind(p)), toTypeKind(lambdaTarget.info.resultType)).toASMType
val instantiatedMethodType = new MethodBType(lambdaParamTypes.map(p => toTypeKind(p)), toTypeKind(lambdaTarget.info.resultType)).toASMType

val abstractMethod = atPhase(erasurePhase) {
val samMethod = atPhase(erasurePhase) {
val samMethods = toDenot(functionalInterface).info.possibleSamMethods.toList
samMethods match {
case x :: Nil => x.symbol
Expand All @@ -1457,21 +1457,38 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
}
}

val methodName = abstractMethod.javaSimpleName
val applyN = {
val mt = asmMethodType(abstractMethod)
mt.toASMType
val methodName = samMethod.javaSimpleName
val samMethodType = asmMethodType(samMethod).toASMType
val needsGenericBridge = samMethodType != instantiatedMethodType
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest adding the same comment than in Scala 2 for context:

Suggested change
val needsGenericBridge = samMethodType != instantiatedMethodType
// scala/bug#10334: make sure that a lambda object for `T => U` has a method `apply(T)U`, not only the `(Object)Object`
// version. Using the lambda a structural type `{def apply(t: T): U}` causes a reflective lookup for this method.
val needsGenericBridge = samMethodType != instantiatedMethodType

val bridgeMethods = atPhase(erasurePhase.prev){
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can actually be run at erasurePhase because when running at a certain phase we see the types as they are before the phase denotation transformer is run.

Suggested change
val bridgeMethods = atPhase(erasurePhase.prev){
val bridgeMethods = atPhase(erasurePhase) {

samMethod.allOverriddenSymbols.toList
}
val bsmArgs0 = Seq(applyN, targetHandle, constrainedType)
val bsmArgs =
if (isSerializable)
bsmArgs0 :+ Int.box(FLAG_SERIALIZABLE)
val overriddenMethodTypes = bridgeMethods.map(b => asmMethodType(b).toASMType)

// any methods which `samMethod` overrides need bridges made for them
// this is done automatically during erasure for classes we generate, but LMF needs to have them explicitly mentioned
// so we have to compute them at this relatively late point.
val bridgeTypes = (
if (needsGenericBridge)
instantiatedMethodType +: overriddenMethodTypes
else
bsmArgs0
overriddenMethodTypes
).distinct.filterNot(_ == samMethodType)

val needsBridges = bridgeTypes.nonEmpty

def flagIf(b: Boolean, flag: Int): Int = if (b) flag else 0
val flags = flagIf(isSerializable, FLAG_SERIALIZABLE) | flagIf(needsBridges, FLAG_BRIDGES)

val bsmArgs0 = Seq(samMethodType, targetHandle, instantiatedMethodType)
val bsmArgs1 = if (flags != 0) Seq(Int.box(flags)) else Seq.empty
val bsmArgs2 = if needsBridges then bridgeTypes.length +: bridgeTypes else Seq.empty

val bsmArgs = bsmArgs0 ++ bsmArgs1 ++ bsmArgs2

val metafactory =
if (isSerializable)
lambdaMetaFactoryAltMetafactoryHandle // altMetafactory needed to be able to pass the SERIALIZABLE flag
if (flags != 0)
lambdaMetaFactoryAltMetafactoryHandle // altMetafactory required to be able to pass the flags and additional arguments if needed
else
lambdaMetaFactoryMetafactoryHandle

Expand Down
12 changes: 7 additions & 5 deletions compiler/src/dotty/tools/backend/jvm/GenBCode.scala
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,13 @@ class GenBCodePipeline(val int: DottyBackendInterface, val primitives: DottyPrim
val insn = iter.next()
insn match {
case indy: InvokeDynamicInsnNode
// No need to check the exact bsmArgs because we only generate
// altMetafactory indy calls for serializable lambdas.
if indy.bsm == BCodeBodyBuilder.lambdaMetaFactoryAltMetafactoryHandle =>
val implMethod = indy.bsmArgs(1).asInstanceOf[Handle]
indyLambdaBodyMethods += implMethod
if indy.bsm == BCodeBodyBuilder.lambdaMetaFactoryAltMetafactoryHandle =>
import java.lang.invoke.LambdaMetafactory.FLAG_SERIALIZABLE
val metafactoryFlags = indy.bsmArgs(3).asInstanceOf[Integer].toInt
val isSerializable = (metafactoryFlags & FLAG_SERIALIZABLE) != 0
if isSerializable then
val implMethod = indy.bsmArgs(1).asInstanceOf[Handle]
indyLambdaBodyMethods += implMethod
case _ =>
}
}
Expand Down
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -927,8 +927,10 @@ object Types {
*/
final def possibleSamMethods(using Context): Seq[SingleDenotation] = {
record("possibleSamMethods")
abstractTermMembers.toList.filterConserve(m =>
!m.symbol.matchingMember(defn.ObjectType).exists && !m.symbol.isSuperAccessor)
atPhaseNoLater(erasurePhase.prev) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here:

Suggested change
atPhaseNoLater(erasurePhase.prev) {
atPhaseNoLater(erasurePhase) {

abstractTermMembers.toList.filterConserve(m =>
!m.symbol.matchingMember(defn.ObjectType).exists && !m.symbol.isSuperAccessor)
}
}

/** The set of abstract type members of this type. */
Expand Down
16 changes: 10 additions & 6 deletions compiler/src/dotty/tools/dotc/transform/Erasure.scala
Original file line number Diff line number Diff line change
Expand Up @@ -431,10 +431,14 @@ object Erasure {
val implParamTypes = implType.paramInfos
val implResultType = implType.resultType
val implReturnsUnit = implResultType.classSymbol eq defn.UnitClass
// The SAM that this closure should implement
val SAMType(sam) = lambdaType: @unchecked
val samParamTypes = sam.paramInfos
val samResultType = sam.resultType
// The SAM that this closure should implement.
// At this point it should be already guaranteed that there's only one method to implement
val Seq(unerasedSam) = lambdaType.possibleSamMethods
// We're now in erasure so the alternatives will have erased types
val Seq(erasedSam: MethodType) = unerasedSam.symbol.alternatives.map(_.info)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid this complication I suggest changing possibleSamMethods to do .map(_.current) at the end so the denotations always correspond to the current phase.


val samParamTypes = erasedSam.paramInfos
val samResultType = erasedSam.resultType

/** Can the implementation parameter type `tp` be auto-adapted to a different
* parameter type in the SAM?
Expand Down Expand Up @@ -498,12 +502,12 @@ object Erasure {
val bridgeType =
if paramAdaptationNeeded then
if resultAdaptationNeeded then
sam
erasedSam
else
implType.derivedLambdaType(paramInfos = samParamTypes)
else
implType.derivedLambdaType(resType = samResultType)
val bridge = newSymbol(ctx.owner, AdaptedClosureName(meth.symbol.name.asTermName), Flags.Synthetic | Flags.Method, bridgeType)
val bridge = newSymbol(ctx.owner, AdaptedClosureName(meth.symbol.name.asTermName), Flags.Synthetic | Flags.Method | Flags.Bridge, bridgeType)
Closure(bridge, bridgeParamss =>
inContext(ctx.withOwner(bridge)) {
val List(bridgeParams) = bridgeParamss
Expand Down
3 changes: 3 additions & 0 deletions tests/run/i10068a.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
42
Foo
Foo
31 changes: 31 additions & 0 deletions tests/run/i10068a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
sealed trait Partial
sealed trait Total extends Partial

case object Foo extends Total

trait P[A] {
def bar(a: A): Partial
}

trait T[A] extends P[A] {
def bar(a: A): Total
}

object T {
def make[A](x: Total): T[A] =
a => x
}

object Test {
def total[A](a: A)(ev: T[A]): Total = ev.bar(a)
def partial[A](a: A)(ev: P[A]): Partial = ev.bar(a)

def go[A](a: A)(ev: T[A]): Unit = {
println(a)
println(total(a)(ev))
println(partial(a)(ev))
}

def main(args: Array[String]): Unit =
go(42)(T.make(Foo))
}
23 changes: 23 additions & 0 deletions tests/run/i10068b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
trait Foo[A] {
def xxx(a1: A, a2: A): A
def xxx(a: A): A = xxx(a, a)
}

trait Bar[A] extends Foo[A] {
def yyy(a1: A, a2: A) = xxx(a1, a2)
}

trait Baz[A] extends Bar[A]

object Test:
def main(args: Array[String]): Unit =
val foo: Foo[String] = { (s1, s2) => s1 ++ s2 }
val bar: Bar[String] = { (s1, s2) => s1 ++ s2 }
val baz: Baz[String] = { (s1, s2) => s1 ++ s2 }

val s = "abc"
val ss = "abcabc"
assert(foo.xxx(s) == ss)
assert(bar.yyy(s, s) == ss)
assert(baz.xxx(s) == ss)
assert(baz.yyy(s, s) == ss)
45 changes: 45 additions & 0 deletions tests/run/i10068c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Taken from: https://github.com/scala/scala/pull/6087

trait JsonValue
class JsonObject extends JsonValue
class JsonString extends JsonValue

trait JsonEncoder[A] {
def encode(value: A): JsonValue
}

trait JsonObjectEncoder[A] extends JsonEncoder[A] {
def encode(value: A): JsonObject
}

object JsonEncoderInstances {

val seWorks: JsonEncoder[String] =
new JsonEncoder[String] {
def encode(value: String) = new JsonString
}

implicit val stringEncoder: JsonEncoder[String] =
s => new JsonString
//new JsonEncoder[String] {
// def encode(value: String) = new JsonString
//}

def leWorks[A](implicit encoder: JsonEncoder[A]): JsonObjectEncoder[List[A]] =
new JsonObjectEncoder[List[A]] {
def encode(value: List[A]) = new JsonObject
}

implicit def listEncoder[A](implicit encoder: JsonEncoder[A]): JsonObjectEncoder[List[A]] =
l => new JsonObject
// new JsonObjectEncoder[List[A]] {
// def encode(value: List[A]) = new JsonObject
// }

}

object Test extends App {
import JsonEncoderInstances._

implicitly[JsonEncoder[List[String]]].encode("" :: Nil)
}
56 changes: 56 additions & 0 deletions tests/run/i10068d.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Taken from: https://github.com/scala/scala/pull/6087

trait A
trait B extends A
trait C extends B
object it extends C

/* try as many weird diamondy things as I can think of */
trait SAM_A { def apply(): A }
trait SAM_A1 extends SAM_A { def apply(): A }
trait SAM_B extends SAM_A1 { def apply(): B }
trait SAM_B1 extends SAM_A1 { def apply(): B }
trait SAM_B2 extends SAM_B with SAM_B1
trait SAM_C extends SAM_B2 { def apply(): C }

trait SAM_F extends (() => A) with SAM_C
trait SAM_F1 extends (() => C) with SAM_F


object Test extends App {

val s1: SAM_A = () => it
val s2: SAM_A1 = () => it
val s3: SAM_B = () => it
val s4: SAM_B1 = () => it
val s5: SAM_B2 = () => it
val s6: SAM_C = () => it
val s7: SAM_F = () => it
val s8: SAM_F1 = () => it

(s1(): A)

(s2(): A)

(s3(): B)
(s3(): A)

(s4(): B)
(s4(): A)

(s5(): B)
(s5(): A)

(s6(): C)
(s6(): B)
(s6(): A)

(s7(): C)
(s7(): B)
(s7(): A)

(s8(): C)
(s8(): B)
(s8(): A)

}
25 changes: 25 additions & 0 deletions tests/run/i11676.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
sealed trait PartialOrdering
sealed trait Ordering extends PartialOrdering

object Ordering {
def fromCompare(n: Int): Ordering = new Ordering {}
}

trait PartialOrd[-A] {
def checkCompare(l: A, r: A): PartialOrdering
}

trait Ord[-A] extends PartialOrd[A] {
def checkCompare(l: A, r: A): Ordering
}

object Ord {
def fromScala[A](implicit ordering: scala.math.Ordering[A]): Ord[A] =
(l: A, r: A) => Ordering.fromCompare(ordering.compare(l, r))
}

object Test {
def main(args: Array[String]): Unit =
val intOrd = Ord.fromScala[Int]
intOrd.checkCompare(1, 3)
}