diff --git a/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala b/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala new file mode 100644 index 0000000..f9609d1 --- /dev/null +++ b/src/main/java/scala/compat/java8/runtime/LambdaDeserializer.scala @@ -0,0 +1,132 @@ +package scala.compat.java8.runtime + +import java.lang.invoke._ + +/** + * This class is only intended to be called by synthetic `$deserializeLambda$` method that the Scala 2.12 + * compiler will add to classes hosting lambdas. + * + * It is not intended to be consumed directly. + */ +object LambdaDeserializer { + /** + * Deserialize a lambda by calling `LambdaMetafactory.altMetafactory` to spin up a lambda class + * and instantiating this class with the captured arguments. + * + * A cache may be provided to ensure that subsequent deserialization of the same lambda expression + * is cheap, it amounts to a reflective call to the constructor of the previously created class. + * However, deserialization of the same lambda expression is not guaranteed to use the same class, + * concurrent deserialization of the same lambda expression may spin up more than one class. + * + * Assumptions: + * - No additional marker interfaces are required beyond `{java.io,scala.}Serializable`. These are + * not stored in `SerializedLambda`, so we can't reconstitute them. + * - No additional bridge methods are passed to `altMetafactory`. Again, these are not stored. + * + * @param lookup The factory for method handles. Must have access to the implementation method, the + * functional interface class, and `java.io.Serializable` or `scala.Serializable` as + * required. + * @param cache A cache used to avoid spinning up a class for each deserialization of a given lambda. May be `null` + * @param serialized The lambda to deserialize. Note that this is typically created by the `readResolve` + * member of the anonymous class created by `LambdaMetaFactory`. + * @return An instance of the functional interface + */ + def deserializeLambda(lookup: MethodHandles.Lookup, cache: java.util.Map[String, MethodHandle], serialized: SerializedLambda): AnyRef = { + def slashDot(name: String) = name.replaceAll("/", ".") + val loader = lookup.lookupClass().getClassLoader + val implClass = loader.loadClass(slashDot(serialized.getImplClass)) + + def makeCallSite: CallSite = { + import serialized._ + def parseDescriptor(s: String) = + MethodType.fromMethodDescriptorString(s, loader) + + val funcInterfaceSignature = parseDescriptor(getFunctionalInterfaceMethodSignature) + val instantiated = parseDescriptor(getInstantiatedMethodType) + val functionalInterfaceClass = loader.loadClass(slashDot(getFunctionalInterfaceClass)) + + val implMethodSig = parseDescriptor(getImplMethodSignature) + // Construct the invoked type from the impl method type. This is the type of a factory + // that will be generated by the meta-factory. It is a method type, with param types + // coming form the types of the captures, and return type being the functional interface. + val invokedType: MethodType = { + // 1. Add receiver for non-static impl methods + val withReceiver = getImplMethodKind match { + case MethodHandleInfo.REF_invokeStatic | MethodHandleInfo.REF_newInvokeSpecial => + implMethodSig + case _ => + implMethodSig.insertParameterTypes(0, implClass) + } + // 2. Remove lambda parameters, leaving only captures. Note: the receiver may be a lambda parameter, + // such as in `Function s = Object::toString` + val lambdaArity = funcInterfaceSignature.parameterCount() + val from = withReceiver.parameterCount() - lambdaArity + val to = withReceiver.parameterCount() + + // 3. Drop the lambda return type and replace with the functional interface. + withReceiver.dropParameterTypes(from, to).changeReturnType(functionalInterfaceClass) + } + + // Lookup the implementation method + val implMethod: MethodHandle = try { + findMember(lookup, getImplMethodKind, implClass, getImplMethodName, implMethodSig) + } catch { + case e: ReflectiveOperationException => throw new IllegalArgumentException("Illegal lambda deserialization", e) + } + + val flags: Int = LambdaMetafactory.FLAG_SERIALIZABLE | LambdaMetafactory.FLAG_MARKERS + val isScalaFunction = functionalInterfaceClass.getName.startsWith("scala.Function") + val markerInterface: Class[_] = loader.loadClass(if (isScalaFunction) ScalaSerializable else JavaIOSerializable) + + LambdaMetafactory.altMetafactory( + lookup, getFunctionalInterfaceMethodName, invokedType, + + /* samMethodType = */ funcInterfaceSignature, + /* implMethod = */ implMethod, + /* instantiatedMethodType = */ instantiated, + /* flags = */ flags.asInstanceOf[AnyRef], + /* markerInterfaceCount = */ 1.asInstanceOf[AnyRef], + /* markerInterfaces[0] = */ markerInterface, + /* bridgeCount = */ 0.asInstanceOf[AnyRef] + ) + } + + val key = serialized.getImplMethodName + " : " + serialized.getImplMethodSignature + val factory: MethodHandle = if (cache == null) { + makeCallSite.getTarget + } else cache.get(key) match { + case null => + val callSite = makeCallSite + val temp = callSite.getTarget + cache.put(key, temp) + temp + case target => target + } + + val captures = Array.tabulate(serialized.getCapturedArgCount)(n => serialized.getCapturedArg(n)) + factory.invokeWithArguments(captures: _*) + } + + private val ScalaSerializable = "scala.Serializable" + + private val JavaIOSerializable = { + // We could actually omit this marker interface as LambdaMetaFactory will add it if + // the FLAG_SERIALIZABLE is set and of the provided markers extend it. But the code + // is cleaner if we uniformly add a single marker, so I'm leaving it in place. + "java.io.Serializable" + } + + private def findMember(lookup: MethodHandles.Lookup, kind: Int, owner: Class[_], + name: String, signature: MethodType): MethodHandle = { + kind match { + case MethodHandleInfo.REF_invokeStatic => + lookup.findStatic(owner, name, signature) + case MethodHandleInfo.REF_newInvokeSpecial => + lookup.findConstructor(owner, signature) + case MethodHandleInfo.REF_invokeVirtual | MethodHandleInfo.REF_invokeInterface => + lookup.findVirtual(owner, name, signature) + case MethodHandleInfo.REF_invokeSpecial => + lookup.findSpecial(owner, name, signature, owner) + } + } +} diff --git a/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java new file mode 100644 index 0000000..723e56c --- /dev/null +++ b/src/test/java/scala/compat/java8/runtime/LambdaDeserializerTest.java @@ -0,0 +1,193 @@ +package scala.compat.java8.runtime; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.Serializable; +import java.lang.invoke.MethodHandle; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.SerializedLambda; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.HashMap; + +public final class LambdaDeserializerTest { + private LambdaHost lambdaHost = new LambdaHost(); + + @Test + public void serializationPrivate() { + F1 f1 = lambdaHost.lambdaBackedByPrivateImplMethod(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationStatic() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationVirtualMethodReference() { + F1 f1 = lambdaHost.lambdaBackedByVirtualMethodReference(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationInterfaceMethodReference() { + F1 f1 = lambdaHost.lambdaBackedByInterfaceMethodReference(); + I i = new I() { + }; + Assert.assertEquals(f1.apply(i), reconstitute(f1).apply(i)); + } + + @Test + public void serializationStaticMethodReference() { + F1 f1 = lambdaHost.lambdaBackedByStaticMethodReference(); + Assert.assertEquals(f1.apply(true), reconstitute(f1).apply(true)); + } + + @Test + public void serializationNewInvokeSpecial() { + F0 f1 = lambdaHost.lambdaBackedByConstructorCall(); + Assert.assertEquals(f1.apply(), reconstitute(f1).apply()); + } + + @Test + public void uncached() { + F0 f1 = lambdaHost.lambdaBackedByConstructorCall(); + F0 reconstituted1 = reconstitute(f1); + F0 reconstituted2 = reconstitute(f1); + Assert.assertNotEquals(reconstituted1.getClass(), reconstituted2.getClass()); + } + + @Test + public void cached() { + HashMap cache = new HashMap<>(); + F0 f1 = lambdaHost.lambdaBackedByConstructorCall(); + F0 reconstituted1 = reconstitute(f1, cache); + F0 reconstituted2 = reconstitute(f1, cache); + Assert.assertEquals(reconstituted1.getClass(), reconstituted2.getClass()); + } + + @Test + public void cachedStatic() { + HashMap cache = new HashMap<>(); + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + // Check that deserialization of a static lambda always returns the + // same instance. + Assert.assertSame(reconstitute(f1, cache), reconstitute(f1, cache)); + + // (as is the case with regular invocation.) + Assert.assertSame(f1, lambdaHost.lambdaBackedByStaticImplMethod()); + } + + @Test + public void implMethodNameChanged() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName() + "___", sl.getImplMethodSignature())); + } + + @Test + public void implMethodSignatureChanged() { + F1 f1 = lambdaHost.lambdaBackedByStaticImplMethod(); + SerializedLambda sl = writeReplace(f1); + checkIllegalAccess(copySerializedLambda(sl, sl.getImplMethodName(), sl.getImplMethodSignature().replace("Boolean", "Integer"))); + } + + private void checkIllegalAccess(SerializedLambda serialized) { + try { + LambdaDeserializer.deserializeLambda(MethodHandles.lookup(), null, serialized); + throw new AssertionError(); + } catch (IllegalArgumentException iae) { + if (!iae.getMessage().contains("Illegal lambda deserialization")) { + Assert.fail("Unexpected message: " + iae.getMessage()); + } + } + } + + private SerializedLambda copySerializedLambda(SerializedLambda sl, String implMethodName, String implMethodSignature) { + Object[] captures = new Object[sl.getCapturedArgCount()]; + for (int i = 0; i < captures.length; i++) { + captures[i] = sl.getCapturedArg(i); + } + return new SerializedLambda(loadClass(sl.getCapturingClass()), sl.getFunctionalInterfaceClass(), sl.getFunctionalInterfaceMethodName(), + sl.getFunctionalInterfaceMethodSignature(), sl.getImplMethodKind(), sl.getImplClass(), implMethodName, implMethodSignature, + sl.getInstantiatedMethodType(), captures); + } + + private Class loadClass(String className) { + try { + return Class.forName(className.replace('/', '.')); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + } + private A reconstitute(A f1) { + return reconstitute(f1, null); + } + + @SuppressWarnings("unchecked") + private A reconstitute(A f1, java.util.HashMap cache) { + try { + return (A) LambdaDeserializer.deserializeLambda(LambdaHost.lookup(), cache, writeReplace(f1)); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private SerializedLambda writeReplace(A f1) { + try { + Method writeReplace = f1.getClass().getDeclaredMethod("writeReplace"); + writeReplace.setAccessible(true); + return (SerializedLambda) writeReplace.invoke(f1); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} + + +interface F1 extends Serializable { + B apply(A a); +} + +interface F0 extends Serializable { + A apply(); +} + +class LambdaHost { + public F1 lambdaBackedByPrivateImplMethod() { + int local = 42; + return (b) -> Arrays.asList(local, b ? "true" : "false", LambdaHost.this).toString(); + } + + @SuppressWarnings("Convert2MethodRef") + public F1 lambdaBackedByStaticImplMethod() { + return (b) -> String.valueOf(b); + } + + public F1 lambdaBackedByStaticMethodReference() { + return String::valueOf; + } + + public F1 lambdaBackedByVirtualMethodReference() { + return Object::toString; + } + + public F1 lambdaBackedByInterfaceMethodReference() { + return I::i; + } + + public F0 lambdaBackedByConstructorCall() { + return String::new; + } + + public static MethodHandles.Lookup lookup() { + return MethodHandles.lookup(); + } +} + +interface I { + default String i() { return "i"; }; +}