Skip to content

Commit be3c312

Browse files
committed
Fix scala#4442: Make lambdas serializable
1 parent 6ffe218 commit be3c312

File tree

7 files changed

+443
-6
lines changed

7 files changed

+443
-6
lines changed

compiler/src/dotty/tools/backend/jvm/DottyBackendInterface.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -706,6 +706,7 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
706706
def isJavaEntryPoint: Boolean = CollectEntryPoints.isJavaEntryPoint(sym)
707707

708708
def isClassConstructor: Boolean = toDenot(sym).isClassConstructor
709+
def isSerializable: Boolean = toDenot(sym).isSerializable
709710

710711
/**
711712
* True for module classes of modules that are top-level or owned only by objects. Module classes
@@ -855,6 +856,9 @@ class DottyBackendInterface(outputDirectory: AbstractFile, val superCallsMap: Ma
855856

856857
def samMethod(): Symbol =
857858
toDenot(sym).info.abstractTermMembers.headOption.getOrElse(toDenot(sym).info.member(nme.apply)).symbol
859+
860+
def isFunctionClass: Boolean =
861+
defn.isFunctionClass(sym)
858862
}
859863

860864

compiler/src/dotty/tools/backend/jvm/GenBCode.scala

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import dotty.tools.dotc.ast.tpd
66
import dotty.tools.dotc.core.Phases.Phase
77

88
import scala.collection.mutable
9+
import scala.collection.JavaConverters._
910
import scala.tools.asm.CustomAttr
1011
import scala.tools.nsc.backend.jvm._
1112
import dotty.tools.dotc.transform.SymUtils._
@@ -23,6 +24,7 @@ import java.io.DataOutputStream
2324

2425

2526
import scala.tools.asm
27+
import scala.tools.asm.Handle
2628
import scala.tools.asm.tree._
2729
import tpd._
2830
import StdNames._
@@ -308,6 +310,93 @@ class GenBCodePipeline(val entryPoints: List[Symbol], val int: DottyBackendInter
308310
// BackendStats.timed(BackendStats.methodOptTimer)(localOpt.methodOptimizations(classNode))
309311
}
310312

313+
/*
314+
* Add:
315+
*
316+
* private static Object $deserializeLambda$(SerializedLambda l) {
317+
* try return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup$0](l)
318+
* catch {
319+
* case i: IllegalArgumentException =>
320+
* try return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup$1](l)
321+
* catch {
322+
* case i: IllegalArgumentException =>
323+
* ...
324+
* return indy[scala.runtime.LambdaDeserialize.bootstrap, targetMethodGroup${NUM_GROUPS-1}](l)
325+
* }
326+
*
327+
* We use invokedynamic here to enable caching within the deserializer without needing to
328+
* host a static field in the enclosing class. This allows us to add this method to interfaces
329+
* that define lambdas in default methods.
330+
*
331+
* SI-10232 we can't pass arbitrary number of method handles to the final varargs parameter of the bootstrap
332+
* method due to a limitation in the JVM. Instead, we emit a separate invokedynamic bytecode for each group of target
333+
* methods.
334+
*/
335+
def addLambdaDeserialize(classNode: ClassNode, implMethodsArray: Array[Handle]): Unit = {
336+
import asm.Opcodes._
337+
import BCodeBodyBuilder._
338+
import bTypes._
339+
import coreBTypes._
340+
341+
val cw = classNode
342+
343+
// Make sure to reference the ClassBTypes of all types that are used in the code generated
344+
// here (e.g. java/util/Map) are initialized. Initializing a ClassBType adds it to
345+
// `classBTypeFromInternalNameMap`. When writing the classfile, the asm ClassWriter computes
346+
// stack map frames and invokes the `getCommonSuperClass` method. This method expects all
347+
// ClassBTypes mentioned in the source code to exist in the map.
348+
349+
val serlamObjDesc = MethodBType(jliSerializedLambdaRef :: Nil, ObjectReference).descriptor
350+
351+
val mv = cw.visitMethod(ACC_PRIVATE + ACC_STATIC + ACC_SYNTHETIC, "$deserializeLambda$", serlamObjDesc, null, null)
352+
def emitLambdaDeserializeIndy(targetMethods: Seq[Handle]): Unit = {
353+
mv.visitVarInsn(ALOAD, 0)
354+
mv.visitInvokeDynamicInsn("lambdaDeserialize", serlamObjDesc, lambdaDeserializeBootstrapHandle, targetMethods: _*)
355+
}
356+
357+
val targetMethodGroupLimit = 255 - 1 - 3 // JVM limit. See See MAX_MH_ARITY in CallSite.java
358+
val groups: Array[Array[Handle]] = implMethodsArray.grouped(targetMethodGroupLimit).toArray
359+
val numGroups = groups.length
360+
361+
import scala.tools.asm.Label
362+
val initialLabels = Array.fill(numGroups - 1)(new Label())
363+
val terminalLabel = new Label
364+
def nextLabel(i: Int) = if (i == numGroups - 2) terminalLabel else initialLabels(i + 1)
365+
366+
for ((label, i) <- initialLabels.iterator.zipWithIndex) {
367+
mv.visitTryCatchBlock(label, nextLabel(i), nextLabel(i), jlIllegalArgExceptionRef.internalName)
368+
}
369+
for ((label, i) <- initialLabels.iterator.zipWithIndex) {
370+
mv.visitLabel(label)
371+
emitLambdaDeserializeIndy(groups(i))
372+
mv.visitInsn(ARETURN)
373+
}
374+
mv.visitLabel(terminalLabel)
375+
emitLambdaDeserializeIndy(groups(numGroups - 1))
376+
mv.visitInsn(ARETURN)
377+
}
378+
379+
/* Support deserialization of lambdas defined in this class */
380+
def addLambdaDeserializeIfNeeded(classNode: ClassNode): Unit = {
381+
val indyLambdaBodyMethods = new mutable.ArrayBuffer[Handle]
382+
for (m <- classNode.methods.asScala) {
383+
val iter = m.instructions.iterator
384+
while (iter.hasNext) {
385+
val insn = iter.next()
386+
insn match {
387+
case indy: InvokeDynamicInsnNode
388+
if indy.bsm == BCodeBodyBuilder.lambdaMetaFactoryMetafactoryHandle ||
389+
indy.bsm == BCodeBodyBuilder.lambdaMetaFactoryAltMetafactoryHandle =>
390+
val implMethod = indy.bsmArgs(1).asInstanceOf[Handle]
391+
indyLambdaBodyMethods += implMethod
392+
case _ =>
393+
}
394+
}
395+
}
396+
if (indyLambdaBodyMethods.nonEmpty)
397+
addLambdaDeserialize(classNode, indyLambdaBodyMethods.toArray)
398+
}
399+
311400
def run(): Unit = {
312401
while (true) {
313402
val item = q2.poll
@@ -317,7 +406,9 @@ class GenBCodePipeline(val entryPoints: List[Symbol], val int: DottyBackendInter
317406
}
318407
else {
319408
try {
320-
localOptimizations(item.plain.classNode)
409+
val plainNode = item.plain.classNode
410+
addLambdaDeserializeIfNeeded(plainNode)
411+
localOptimizations(plainNode)
321412
addToQ3(item)
322413
} catch {
323414
case ex: Throwable =>

scala-backend

tests/run/lambda-serialization.scala

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream, PrintWriter, StringWriter}
2+
import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
3+
4+
class C extends java.io.Serializable {
5+
val fs = List(
6+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
7+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
8+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
9+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
10+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
11+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
12+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
13+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
14+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
15+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
16+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
17+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
18+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
19+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => (),
20+
() => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => (), () => () ,() => (), () => (), () => (), () => (), () => ()
21+
)
22+
private def foo(): Unit = {
23+
assert(false, "should not be called!!!")
24+
}
25+
}
26+
27+
trait FakeSam { def apply(): Unit }
28+
29+
object Test {
30+
def main(args: Array[String]): Unit = {
31+
allRealLambdasRoundTrip()
32+
fakeLambdaFailsToDeserialize()
33+
}
34+
35+
def allRealLambdasRoundTrip(): Unit = {
36+
new C().fs.map(x => serializeDeserialize(x).apply())
37+
}
38+
39+
def fakeLambdaFailsToDeserialize(): Unit = {
40+
val fake = new SerializedLambda(classOf[C], classOf[FakeSam].getName, "apply", "()V",
41+
MethodHandleInfo.REF_invokeVirtual, classOf[C].getName, "foo", "()V", "()V", Array(new C))
42+
try {
43+
serializeDeserialize(fake).asInstanceOf[FakeSam].apply()
44+
assert(false)
45+
} catch {
46+
case ex: Exception =>
47+
val stackTrace = stackTraceString(ex)
48+
assert(stackTrace.contains("Illegal lambda deserialization"), stackTrace)
49+
}
50+
}
51+
52+
def serializeDeserialize[T <: AnyRef](obj: T) = {
53+
val buffer = new ByteArrayOutputStream
54+
val out = new ObjectOutputStream(buffer)
55+
out.writeObject(obj)
56+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
57+
in.readObject.asInstanceOf[T]
58+
}
59+
60+
def stackTraceString(ex: Throwable): String = {
61+
val writer = new StringWriter
62+
ex.printStackTrace(new PrintWriter(writer))
63+
writer.toString
64+
}
65+
}
66+

0 commit comments

Comments
 (0)