From a91294476a71cd0ef9b6a3ecb94a87a19b539416 Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Tue, 12 May 2015 15:26:59 +1000 Subject: [PATCH 1/2] Add a generic deserializer for Java/Scala 2.12 lambdas Java support serialization of lambdas by using the serialization proxy pattern. Deserialization of a lambda uses `LambdaMetafactory` to create a new anonymous subclass. More details of the scheme are documented: https://docs.oracle.com/javase/8/docs/api/java/lang/invoke/SerializedLambda.html From those docs: > SerializedLambda has a readResolve method that looks for a > (possibly private) static method called $deserializeLambda$ > in the capturing class, invokes that with itself as the first > argument, and returns the result. Lambda classes implementing > $deserializeLambda$ are responsible for validating that the > properties of the SerializedLambda are consistent with a lambda > actually captured by that class. The Java compiler generates code in `$deserializeLambda$` that switches on the implementation method name and signature to locate an invokedynamic instruction generated for the particular lambda expression. Then, the `SerializedLambda` is further unpacked, validating that this implementation method still represents the same functional interface as it did when it was serialized. (The source may have been recompiled in the interim.) In Java, serializable lambda expressions are the exception rather than the rule. In Scala, however, the serializability of `FunctionN` means that we would end up generating a large amount of code to support deserialization. Instead, we are pursuing an alternative approach in which the `$deserializeLambda$` method is a simple forwarder to the generic deserializer added here. This is capable of deserializing lambdas created by the Java compiler, although this is not its intended use case. The enclosed tests use Java lambdas. This generic deserializer also works by calling `LambdaMetafactory`, but it does so explicitly, rather than implicitly during linkage of the `invokedynamic` instruction. We have to mimic the caching property of `invokedynamic` instruction to ensure we reuse the classes when constructing. I originally tried using a central cache, but wasn't able to come up with a scheme to avoid potential classloader memory leaks. Instead, I now allow the caller to provide a cache. The scala compiler will host an instance of this cache in each class that hosts a lambda. This is analagous the the `MethodCache` used by reflective calls. If the name or signature of the implementation method has changed, we fail during deserialization with an `IllegalArgumentError.` However, we do not fail fast in a few cases that Java would, as we cannot reflect on the "current" functional interface supported by this implementation method. We just instantiate using the "previous" functional interface class/method. This might: 1. fail inside `LambdaMetafactory` if the new implementation method is not compatible with the old functional interface. 2. pass through `LambdaMetafactory` by chance, but fail when instantiating the class in other cases. For example: ``` % tail sandbox/test{1,2}.scala ==> sandbox/test1.scala <== class C { def test: (String => String) = { val s: String = "" (t) => s + t } } ==> sandbox/test2.scala <== class C { def test: (String, String) => String = { (s, t) => s + t } } % (for i in 1 2; do scalac -Ydelambdafy:method -Xprint:delambdafy sandbox/test$i.scala 2>&1 ; done) | grep 'def $anon' final private[this] def $anonfun$1(t: String, s$1: String): String = s$1.+(t); final private[this] def $anonfun$1(s: String, t: String): String = s.+(t); ``` 3. Silently create an instance of the old functional interface. For example, imagine switching from `FuncInterface1` to `FuncInterface2` where these were identical other than the name. I don't believe that these are showstoppers. Failing test case demonstrating overly weak cache --- .../java8/runtime/LambdaDeserializer.scala | 132 +++++++++++++ .../java8/runtime/LambdaDeserializerTest.java | 181 ++++++++++++++++++ 2 files changed, 313 insertions(+) create mode 100644 src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala create mode 100644 src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java diff --git a/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala b/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala new file mode 100644 index 0000000..f9609d1 --- /dev/null +++ b/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala @@ -0,0 +1,132 @@ +package scala.compat.java8.runtime + +import java.lang.invoke._ + +/** + * This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12 + * compiler will add to classes hosting lambdas. + * + * It is not intended to be consumed directly. + */ +object LambdaDeserializer { + /** + * Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class + * and instantiating this class with the captured arguments. + * + * A cache may be provided to ensure that subsequent deserialization of the same lambda expression + * is cheap, it amounts to a reflective call to the constructor of the previously created class. + * However, deserialization of the same lambda expression is not guaranteed to use the same class, + * concurrent deserialization of the same lambda expression may spin up more than one class. + * + * Assumptions: + * - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are + * not stored in `SerializedLambda`, so we can't reconstitute them. + * - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored. + * + * @param lookup The factory for method handles. Must have access to the implementation method, the + * functional interface class, and `java.io.Serializable` or `scala.Serializable` as + * required. + * @param cache A cache used to avoid spinning up a class for each deserialization of a given lambda. May be `null` + * @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve` + * member of the anonymous class created by `LambdaMetaFactory`. + * @return An instance of the functional interface + */ + def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = { + def slashDot(name: String) = name.replaceAll("/", ".") + val loader = lookup.lookupClass().getClassLoader + val implClass = loader.loadClass(slashDot(serialized.getImplClass)) + + def makeCallSite: CallSite = { + import serialized._ + def parseDescriptor(s: String) = + MethodType.fromMethodDescriptorString(s, loader) + + val funcInterfaceSignature = parseDescriptor(getFunctionalInterfaceMethodSignature) + val instantiated = parseDescriptor(getInstantiatedMethodType) + val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass)) + + val implMethodSig = parseDescriptor(getImplMethodSignature) + // Construct the invoked type from the impl method type. This is the type of a factory + // that will be generated by the meta-factory. It is a method type, with param types + // coming form the types of the captures, and return type being the functional interface. + val invokedType: MethodType = { + // 1. Add receiver for non-static impl methods + val withReceiver = getImplMethodKind match { + case MethodHandleInfo.REF_invokeStatic | MethodHandleInfo.REF_newInvokeSpecial => + implMethodSig + case _ => + implMethodSig.insertParameterTypes(0, implClass) + } + // 2. Remove lambda parameters, leaving only captures. Note: the receiver may be a lambda parameter, + // such as in `Function s = Object::toString` + val lambdaArity = funcInterfaceSignature.parameterCount() + val from = withReceiver.parameterCount() - lambdaArity + val to = withReceiver.parameterCount() + + // 3. Drop the lambda return type and replace with the functional interface. + withReceiver.dropParameterTypes(from, to).changeReturnType(functionalInterfaceClass) + } + + // Lookup the implementation method + val implMethod: MethodHandle = try { + findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig) + } catch { + case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e) + } + + val flags: Int = LambdaMetafactory.FLAG_SERIALIZABLE | LambdaMetafactory.FLAG_MARKERS + val isScalaFunction = functionalInterfaceClass.getName.startsWith("scala.Function") + val markerInterface: Class[_] = loader.loadClass(if (isScalaFunction) ScalaSerializable else JavaIOSerializable) + + LambdaMetafactory.altMetafactory( + lookup, getFunctionalInterfaceMethodName, invokedType, + + /* samMethodType = */ funcInterfaceSignature, + /* implMethod = */ implMethod, + /* instantiatedMethodType = */ instantiated, + /* flags = */ flags.asInstanceOf[AnyRef], + /* markerInterfaceCount = */ 1.asInstanceOf[AnyRef], + /* markerInterfaces[0] = */ markerInterface, + /* bridgeCount = */ 0.asInstanceOf[AnyRef] + ) + } + + val key = serialized.getImplMethodName + " : " + serialized.getImplMethodSignature + val factory: MethodHandle = if (cache == null) { + makeCallSite.getTarget + } else cache.get(key) match { + case null => + val callSite = makeCallSite + val temp = callSite.getTarget + cache.put(key, temp) + temp + case target => target + } + + val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n)) + factory.invokeWithArguments(captures: _*) + } + + private val ScalaSerializable = "scala.Serializable" + + private val JavaIOSerializable = { + // We could actually omit this marker interface as LambdaMetaFactory will add it if + // the FLAG_SERIALIZABLE is set and of the provided markers extend it. But the code + // is cleaner if we uniformly add a single marker, so I'm leaving it in place. + "java.io.Serializable" + } + + private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_], + name: String, signature: MethodType): MethodHandle = { + kind match { + case MethodHandleInfo.REF_invokeStatic => + lookup.findStatic(owner, name, signature) + case MethodHandleInfo.REF_newInvokeSpecial => + lookup.findConstructor(owner, signature) + case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface => + lookup.findVirtual(owner, name, signature) + case MethodHandleInfo.REF_invokeSpecial => + lookup.findSpecial(owner, name, signature, owner) + } + } +} diff --git a/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java new file mode 100644 index 0000000..3a03750 --- /dev/null +++ b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java @@ -0,0 +1,181 @@ +package scala.compat.java8.runtime; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.Serializable; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.SerializedLambda; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; + +public final class LambdaDeserializerTest { + private LambdaHost lambdaHost = new LambdaHost(); + + @Test + public void serializationPrivate() { + F1 f1 = lambdaHost.lambdaBackedByPrivateImplMethod(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationStatic() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationVirtualMethodReference() { + F1 f1 = lambdaHost.lambdaBackedByVirtualMethodReference(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationInterfaceMethodReference() { + F1 f1 = lambdaHost.lambdaBackedByInterfaceMethodReference(); + I i = new I() { + }; + Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i)); + } + + @Test + public void serializationStaticMethodReference() { + F1 f1 = lambdaHost.lambdaBackedByStaticMethodReference(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationNewInvokeSpecial() { + F0 f1 = lambdaHost.lambdaBackedByConstructorCall(); + Assert.assertEquals(f1.apply(), reconstitute(f1).apply()); + } + + @Test + public void uncached() { + F0 f1 = lambdaHost.lambdaBackedByConstructorCall(); + F0 reconstituted1 = reconstitute(f1); + F0 reconstituted2 = reconstitute(f1); + Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass()); + } + + @Test + public void cached() { + HashMap cache = new HashMap<>(); + F0 f1 = lambdaHost.lambdaBackedByConstructorCall(); + F0 reconstituted1 = reconstitute(f1, cache); + F0 reconstituted2 = reconstitute(f1, cache); + Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass()); + } + + @Test + public void implMethodNameChanged() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); + } + + @Test + public void implMethodSignatureChanged() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); + } + + private void checkIllegalAccess(SerializedLambda serialized) { + try { + LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized); + throw new AssertionError(); + } catch (IllegalArgumentException iae) { + if (!iae.getMessage().contains("Illegal lambda deserialization")) { + Assert.fail("Unexpected message: " + iae.getMessage()); + } + } + } + + private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) { + Object[] captures = new Object[sl.getCapturedArgCount()]; + for (int i = 0; i < captures.length; i++) { + captures[i] = sl.getCapturedArg(i); + } + return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(), + sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature, + sl.getInstantiatedMethodType(), captures); + } + + private Class loadClass(String className) { + try { + return Class.forName(className.replace('/', '.')); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + private A reconstitute(A f1) { + return reconstitute(f1, null); + } + + @SuppressWarnings("unchecked") + private A reconstitute(A f1, java.util.HashMap cache) { + try { + return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private SerializedLambda writeReplace(A f1) { + try { + Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace"); + writeReplace.setAccessible(true); + return (SerializedLambda) writeReplace.invoke(f1); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} + + +interface F1 extends Serializable { + B apply(A a); +} + +interface F0 extends Serializable { + A apply(); +} + +class LambdaHost { + public F1 lambdaBackedByPrivateImplMethod() { + int local = 42; + return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString(); + } + + @SuppressWarnings("Convert2MethodRef") + public F1 lambdaBackedByStaticImplMethod() { + return (b) -> String.valueOf(b); + } + + public F1 lambdaBackedByStaticMethodReference() { + return String::valueOf; + } + + public F1 lambdaBackedByVirtualMethodReference() { + return Object::toString; + } + + public F1 lambdaBackedByInterfaceMethodReference() { + return I::i; + } + + public F0 lambdaBackedByConstructorCall() { + return String::new; + } + + public static MethodHandles.Lookup lookup() { + return MethodHandles.lookup(); + } +} + +interface I { + default String i() { return "i"; }; +} From 921b212b609bc8aa08ccac46e3c76048e2e6978e Mon Sep 17 00:00:00 2001 From: Jason Zaugg Date: Sun, 17 May 2015 18:41:29 +1000 Subject: [PATCH 2/2] Test static lambda hoisting works via LambdaDeserializer LambdaMetafactory returns a ConstantCallSite bound to a shared instance of a lambda, rather than a reference to the no-arg constructor. This is a technique to avoid unnecessary allocations. This test checks that we preserve this property when deserializing. --- .../compat/java8/runtime/LambdaDeserializerTest.java | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java index 3a03750..723e56c 100644 --- a/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java +++ b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java @@ -69,6 +69,18 @@ public void cached() { Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass()); } + @Test + public void cachedStatic() { + HashMap cache = new HashMap<>(); + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + // Check that deserialization of a static lambda always returns the + // same instance. + Assert.assertSame(reconstitute(f1, cache), reconstitute(f1, cache)); + + // (as is the case with regular invocation.) + Assert.assertSame(f1, lambdaHost.lambdaBackedByStaticImplMethod()); + } + @Test public void implMethodNameChanged() { F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod();