Skip to content

Commit 6ad9b44

Browse files
committed
[indylambda] Relieve LambdaMetafactory of boxing duties
`LambdaMetafactory` generates code to perform a limited number of type adaptations when delegating from its implementation of the functional interface method to the lambda target method. These adaptations are: numeric widening, casting, boxing and unboxing. However, the semantics of unboxing numerics in Java differs to Scala: they treat `UNBOX(null)` as cause to raise a `NullPointerException`, Scala (in `BoxesRuntime.unboxTo{Byte,Short,...}`) reinterprets the null as zero. Furthermore, Java has no idea how to adapt between a value class and its wrapped type, nor from a void return to `BoxedUnit`. This commit detects when the lambda target method would require such adaptation. If it does, an extra method, `$anonfun$1$adapted` is created to perform the adaptation, and this is used as the target of the lambda. This obviates the use of `JProcedureN` for `Unit` returning lambdas, we know use `JFunctionN` as the functional interface and bind this to an `$adapted` method that summons the instance of `BoxedUnit` after calling the `void` returning lambda target. The enclosed test cases fail without boxing changes. They don't execute with indylambda enabled under regular partest runs yet, you need to add scala-java8-compat to scala-library and pass the SCALAC_OPTS to partest manually to try this out, as described in scala#4463. Once we enable indylambda by default, however, this test will exercise the code in this patch all the time. It is also possible to run the tests with: ``` % curl https://oss.sonatype.org/content/repositories/releases/org/scala-lang/modules/scala-java8-compat_2.11/0.4.0/scala-java8-compat_2.11-0.4.0.jar > scala-java8-compat_2.11-0.4.0.jar % export INDYLAMBDA="-Ydelambdafy:method -Ybackend:GenBCode -target:jvm-1.8 -classpath .:scala-java8-compat_2.11-0.4.0.jar" qscalac $INDYLAMBDA test/files/run/indylambda-boxing/*.scala && qscala $INDYLAMBDA Test ```
1 parent 99d3ab3 commit 6ad9b44

File tree

6 files changed

+157
-31
lines changed

6 files changed

+157
-31
lines changed

src/compiler/scala/tools/nsc/transform/Delambdafy.scala

Lines changed: 115 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
8888
case class DelambdafyAnonClass(lambdaClassDef: ClassDef, newExpr: Tree) extends TransformedFunction
8989
case class InvokeDynamicLambda(tree: Apply) extends TransformedFunction
9090

91+
private val boxingBridgeMethods = mutable.ArrayBuffer[Tree]()
92+
9193
// here's the main entry point of the transform
9294
override def transform(tree: Tree): Tree = tree match {
9395
// the main thing we care about is lambdas
@@ -105,6 +107,12 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
105107
// ... or an invokedynamic call
106108
super.transform(apply)
107109
}
110+
case Template(_, _, _) =>
111+
try {
112+
// during this call boxingBridgeMethods will be populated from the Function case
113+
val Template(parents, self, body) = super.transform(tree)
114+
Template(parents, self, body ++ boxingBridgeMethods)
115+
} finally boxingBridgeMethods.clear()
108116
case _ => super.transform(tree)
109117
}
110118

@@ -137,6 +145,61 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
137145

138146
val isStatic = target.hasFlag(STATIC)
139147

148+
def createBoxingBridgeMethod(functionParamTypes: List[Type], functionResultType: Type): Tree = {
149+
val methSym = oldClass.newMethod(target.name.append("$adapted").toTermName, target.pos, target.flags | FINAL | ARTIFACT)
150+
var neededAdaptation = false
151+
def boxedType(tpe: Type): Type = {
152+
if (isPrimitiveValueClass(tpe.typeSymbol)) {neededAdaptation = true; ObjectTpe}
153+
else if (enteringErasure(tpe.typeSymbol.isDerivedValueClass)) {neededAdaptation = true; ObjectTpe}
154+
else tpe
155+
}
156+
val targetParams: List[Symbol] = target.paramss.head
157+
val numCaptures = targetParams.length - functionParamTypes.length
158+
val (targetCaptureParams, targetFunctionParams) = targetParams.splitAt(numCaptures)
159+
val bridgeParams: List[Symbol] =
160+
targetCaptureParams.map(param => methSym.newSyntheticValueParam(param.tpe, param.name.toTermName)) :::
161+
map2(targetFunctionParams, functionParamTypes)((param, tp) => methSym.newSyntheticValueParam(boxedType(tp), param.name.toTermName))
162+
163+
val bridgeResultType: Type = {
164+
if (target.info.resultType == UnitTpe && functionResultType != UnitTpe) {
165+
neededAdaptation = true
166+
ObjectTpe
167+
} else
168+
boxedType(functionResultType)
169+
}
170+
val methodType = MethodType(bridgeParams, bridgeResultType)
171+
methSym setInfo methodType
172+
if (!neededAdaptation)
173+
EmptyTree
174+
else {
175+
val bridgeParamTrees = bridgeParams.map(ValDef(_))
176+
177+
oldClass.info.decls enter methSym
178+
179+
val body = localTyper.typedPos(originalFunction.pos) {
180+
val newTarget = Select(gen.mkAttributedThis(oldClass), target)
181+
val args: List[Tree] = mapWithIndex(bridgeParams) { (param, i) =>
182+
if (i < numCaptures) {
183+
gen.mkAttributedRef(param)
184+
} else {
185+
val functionParam = functionParamTypes(i - numCaptures)
186+
val targetParam = targetParams(i)
187+
if (enteringErasure(functionParam.typeSymbol.isDerivedValueClass)) {
188+
val casted = cast(gen.mkAttributedRef(param), functionParam)
189+
val unboxed = unbox(casted, ErasedValueType(functionParam.typeSymbol, targetParam.tpe)).modifyType(postErasure.elimErasedValueType)
190+
unboxed
191+
} else adaptToType(gen.mkAttributedRef(param), targetParam.tpe)
192+
}
193+
}
194+
gen.mkMethodCall(newTarget, args)
195+
}
196+
val body1 = if (enteringErasure(functionResultType.typeSymbol.isDerivedValueClass))
197+
adaptToType(box(body.setType(ErasedValueType(functionResultType.typeSymbol, body.tpe)), "boxing lambda target"), bridgeResultType)
198+
else adaptToType(body, bridgeResultType)
199+
val methDef0 = DefDef(methSym, List(bridgeParamTrees), body1)
200+
postErasure.newTransformer(unit).transform(methDef0).asInstanceOf[DefDef]
201+
}
202+
}
140203
/**
141204
* Creates the apply method for the anonymous subclass of FunctionN
142205
*/
@@ -292,22 +355,56 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
292355
thisArg ::: captureArgs
293356
}
294357

295-
val functionalInterface = java8CompatFunctionalInterface(target, originalFunction.tpe)
358+
val arity = originalFunction.vparams.length
359+
360+
// Reconstruct the type of the function entering erasure.
361+
// We do this by taking the type after erasure, and re-boxing `ErasedValueType`.
362+
//
363+
// Unfortunately, the more obvious `enteringErasure(target.info)` doesn't work
364+
// as we would like, value classes in parameter position show up as the unboxed types.
365+
val (functionParamTypes, functionResultType) = exitingErasure {
366+
def boxed(tp: Type) = tp match {
367+
case ErasedValueType(valueClazz, _) => TypeRef(NoPrefix, valueClazz, Nil)
368+
case _ => tp
369+
}
370+
// We don't need to deeply map `boxedValueClassType` over the infos as `ErasedValueType`
371+
// will only appear directly as a parameter type in a method signature, as shown
372+
// https://gist.github.com/retronym/ba81dbd462282c504ff8
373+
val info = target.info
374+
val boxedParamTypes = info.paramTypes.takeRight(arity).map(boxed)
375+
(boxedParamTypes, boxed(info.resultType))
376+
}
377+
val functionType = definitions.functionType(functionParamTypes, functionResultType)
378+
379+
val (functionalInterface, isSpecialized) = java8CompatFunctionalInterface(target, functionType)
296380
if (functionalInterface.exists) {
297381
// Create a symbol representing a fictional lambda factory method that accepts the captured
298382
// arguments and returns a Function.
299-
val msym = currentOwner.newMethod(nme.ANON_FUN_NAME, originalFunction.pos, ARTIFACT)
383+
val msym = currentOwner.newMethod(nme.ANON_FUN_NAME, originalFunction.pos, ARTIFACT)
300384
val argTypes: List[Type] = allCaptureArgs.map(_.tpe)
301385
val params = msym.newSyntheticValueParams(argTypes)
302-
msym.setInfo(MethodType(params, originalFunction.tpe))
386+
msym.setInfo(MethodType(params, functionType))
303387
val arity = originalFunction.vparams.length
304388

389+
val lambdaTarget =
390+
if (isSpecialized)
391+
target
392+
else {
393+
createBoxingBridgeMethod(functionParamTypes, functionResultType) match {
394+
case EmptyTree =>
395+
target
396+
case bridge =>
397+
boxingBridgeMethods += bridge
398+
bridge.symbol
399+
}
400+
}
401+
305402
// We then apply this symbol to the captures.
306403
val apply = localTyper.typedPos(originalFunction.pos)(Apply(Ident(msym), allCaptureArgs)).asInstanceOf[Apply]
307404

308405
// The backend needs to know the target of the lambda and the functional interface in order
309406
// to emit the invokedynamic instruction. We pass this information as tree attachment.
310-
apply.updateAttachment(LambdaMetaFactoryCapable(target, arity, functionalInterface))
407+
apply.updateAttachment(LambdaMetaFactoryCapable(lambdaTarget, arity, functionalInterface))
311408
InvokeDynamicLambda(apply)
312409
} else {
313410
val anonymousClassDef = makeAnonymousClass
@@ -469,34 +566,24 @@ abstract class Delambdafy extends Transform with TypingTransformers with ast.Tre
469566
final case class LambdaMetaFactoryCapable(target: Symbol, arity: Int, functionalInterface: Symbol)
470567

471568
// The functional interface that can be used to adapt the lambda target method `target` to the
472-
// given function type. Returns `NoSymbol` if the compiler settings are unsuitable, or `LambdaMetaFactory`
473-
// would be unable to generate the correct implementation (e.g. functions referring to derived value classes)
474-
private def java8CompatFunctionalInterface(target: Symbol, functionType: Type): Symbol = {
569+
// given function type. Returns `NoSymbol` if the compiler settings are unsuitable.
570+
private def java8CompatFunctionalInterface(target: Symbol, functionType: Type): (Symbol, Boolean) = {
475571
val canUseLambdaMetafactory: Boolean = {
476-
val hasValueClass = exitingErasure {
477-
val methodType: Type = target.info
478-
methodType.exists(_.isInstanceOf[ErasedValueType])
479-
}
480572
val isTarget18 = settings.target.value.contains("jvm-1.8")
481-
settings.isBCodeActive && isTarget18 && !hasValueClass
573+
settings.isBCodeActive && isTarget18
482574
}
483575

484-
def functionalInterface: Symbol = {
485-
val sym = functionType.typeSymbol
486-
val pack = currentRun.runDefinitions.Scala_Java8_CompatPackage
487-
val name1 = specializeTypes.specializedFunctionName(sym, functionType.typeArgs)
488-
val paramTps :+ restpe = functionType.typeArgs
489-
val arity = paramTps.length
490-
if (name1.toTypeName == sym.name) {
491-
val returnUnit = restpe.typeSymbol == UnitClass
492-
val functionInterfaceArray =
493-
if (returnUnit) currentRun.runDefinitions.Scala_Java8_CompatPackage_JProcedure
494-
else currentRun.runDefinitions.Scala_Java8_CompatPackage_JFunction
495-
functionInterfaceArray.apply(arity)
496-
} else {
497-
pack.info.decl(name1.toTypeName.prepend("J"))
498-
}
576+
val sym = functionType.typeSymbol
577+
val pack = currentRun.runDefinitions.Scala_Java8_CompatPackage
578+
val name1 = specializeTypes.specializedFunctionName(sym, functionType.typeArgs)
579+
val paramTps :+ restpe = functionType.typeArgs
580+
val arity = paramTps.length
581+
val isSpecialized = name1.toTypeName != sym.name
582+
val functionalInterface = if (!isSpecialized) {
583+
currentRun.runDefinitions.Scala_Java8_CompatPackage_JFunction(arity)
584+
} else {
585+
pack.info.decl(name1.toTypeName.prepend("J"))
499586
}
500-
if (canUseLambdaMetafactory) functionalInterface else NoSymbol
587+
(if (canUseLambdaMetafactory) functionalInterface else NoSymbol, isSpecialized)
501588
}
502589
}

src/reflect/scala/reflect/internal/Definitions.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,8 +1515,7 @@ trait Definitions extends api.StandardDefinitions {
15151515
private lazy val PolySigMethods: Set[Symbol] = Set[Symbol](MethodHandle.info.decl(sn.Invoke), MethodHandle.info.decl(sn.InvokeExact)).filter(_.exists)
15161516

15171517
lazy val Scala_Java8_CompatPackage = rootMirror.getPackageIfDefined("scala.compat.java8")
1518-
lazy val Scala_Java8_CompatPackage_JFunction = (0 to MaxTupleArity).toArray map (i => getMemberIfDefined(Scala_Java8_CompatPackage.moduleClass, TypeName("JFunction" + i)))
1519-
lazy val Scala_Java8_CompatPackage_JProcedure = (0 to MaxTupleArity).toArray map (i => getMemberIfDefined(Scala_Java8_CompatPackage.moduleClass, TypeName("JProcedure" + i)))
1518+
lazy val Scala_Java8_CompatPackage_JFunction = (0 to MaxFunctionArity).toArray map (i => getMemberIfDefined(Scala_Java8_CompatPackage.moduleClass, TypeName("JFunction" + i)))
15201519
}
15211520
}
15221521
}

src/reflect/scala/reflect/internal/transform/PostErasure.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ trait PostErasure {
99
object elimErasedValueType extends TypeMap {
1010
def apply(tp: Type) = tp match {
1111
case ConstantType(Constant(tp: Type)) => ConstantType(Constant(apply(tp)))
12-
case ErasedValueType(_, underlying) => underlying
12+
case ErasedValueType(_, underlying) =>
13+
underlying
1314
case _ => mapOver(tp)
1415
}
1516
}
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
object Test {
2+
def main(args: Array[String]): Unit = {
3+
val i2s = (x: Int) => ""
4+
assert(i2s.asInstanceOf[AnyRef => String].apply(null) == "")
5+
val i2i = (x: Int) => x + 1
6+
assert(i2i.asInstanceOf[AnyRef => Int].apply(null) == 1)
7+
}
8+
}
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
class VC(private val i: Int) extends AnyVal
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
class Capture
2+
class Test {
3+
def test1 = (i: Int) => ""
4+
def test2 = (i: VC) => i
5+
def test3 = (i: Int) => i
6+
7+
def test4 = {val c = new Capture; (i: Int) => {(c, Test.this.toString); 42} }
8+
def test5 = {val c = new Capture; (i: VC) => (c, Test.this.toString) }
9+
def test6 = {val c = new Capture; (i: Int) => (c, Test.this.toString) }
10+
11+
def test7 = {val vc = new Capture; (i: Int) => vc }
12+
def test8 = {val c = 42; (s: String) => (s, c)}
13+
def test9 = {val c = 42; (s: String) => ()}
14+
}
15+
16+
object Test {
17+
def main(args: Array[String]): Unit = {
18+
val t = new Test
19+
assert(t.test1.apply(42) == "")
20+
assert(t.test2.apply(new VC(42)) == new VC(42))
21+
assert(t.test3.apply(-1) == -1)
22+
t.test4.apply(0)
23+
t.test5.apply(new VC(42))
24+
t.test6.apply(42)
25+
t.test7.apply(0)
26+
t.test8.apply("")
27+
t.test9.apply("")
28+
}
29+
}

0 commit comments

Comments
 (0)