Skip to content

Commit 641f84e

Browse files
committed
Fix #4177: Generate optimised applyOrElse implementation for partial function literals
1 parent 14187e2 commit 641f84e

File tree

5 files changed

+100
-26
lines changed

5 files changed

+100
-26
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,11 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
273273
coord = fns.map(_.pos).reduceLeft(_ union _))
274274
val constr = ctx.newConstructor(cls, Synthetic, Nil, Nil).entered
275275
def forwarder(fn: TermSymbol, name: TermName) = {
276-
val fwdMeth = fn.copy(cls, name, Synthetic | Method).entered.asTerm
277-
DefDef(fwdMeth, prefss => ref(fn).appliedToArgss(prefss))
276+
var flags = Synthetic | Method
277+
val isOverride = parents.exists(_.member(name).hasAltWith(_.info == fn.info))
278+
if (isOverride) flags = flags | Override
279+
val fwdMeth = fn.copy(cls, name, flags).entered.asTerm
280+
polyDefDef(fwdMeth, tprefs => prefss => ref(fn).appliedToTypes(tprefs).appliedToArgss(prefss))
278281
}
279282
val forwarders = (fns, methNames).zipped.map(forwarder)
280283
val cdef = ClassDef(cls, DefDef(constr), forwarders)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,10 @@ class Definitions {
585585

586586
lazy val PartialFunctionType: TypeRef = ctx.requiredClassRef("scala.PartialFunction")
587587
def PartialFunctionClass(implicit ctx: Context) = PartialFunctionType.symbol.asClass
588+
589+
lazy val PartialFunction_applyOrElseR = PartialFunctionClass.requiredMethodRef(nme.applyOrElse)
590+
def PartialFunction_applyOrElse(implicit ctx: Context) = PartialFunction_applyOrElseR.symbol
591+
588592
lazy val AbstractPartialFunctionType: TypeRef = ctx.requiredClassRef("scala.runtime.AbstractPartialFunction")
589593
def AbstractPartialFunctionClass(implicit ctx: Context) = AbstractPartialFunctionType.symbol.asClass
590594
lazy val FunctionXXLType: TypeRef = ctx.requiredClassRef("scala.FunctionXXL")

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import dotty.tools.dotc.util.Positions.Position
1313
/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
1414
* These fall into five categories
1515
*
16-
* 1. Partial function closures, we need to generate a isDefinedAt method for these.
16+
* 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these.
1717
* 2. Closures implementing non-trait classes.
1818
* 3. Closures implementing classes that inherit from a class other than Object
1919
* (a lambda cannot not be a run-time subtype of such a class)
@@ -54,38 +54,72 @@ class ExpandSAMs extends MiniPhase {
5454
val Block(
5555
(applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil,
5656
Closure(_, _, tpt)) = tree
57-
val applyRhs: Tree = applyDef.rhs
57+
val applyRhs = applyDef.rhs
5858
val applyFn = applyDef.symbol.asTerm
5959

6060
val MethodTpe(paramNames, paramTypes, _) = applyFn.info
6161
val isDefinedAtFn = applyFn.copy(
6262
name = nme.isDefinedAt,
6363
flags = Synthetic | Method,
6464
info = MethodType(paramNames, paramTypes, defn.BooleanType)).asTerm
65-
val tru = Literal(Constant(true))
66-
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = applyRhs match {
67-
case Match(selector, cases) =>
68-
assert(selector.symbol == param.symbol)
69-
val paramRef = paramRefss.head.head
70-
// Again, the alternative
71-
// val List(List(paramRef)) = paramRefs
72-
// fails with a similar self instantiation error
73-
def translateCase(cdef: CaseDef): CaseDef =
74-
cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn)
75-
val defaultSym = ctx.newSymbol(isDefinedAtFn, nme.WILDCARD, Synthetic, selector.tpe.widen)
76-
val defaultCase =
77-
CaseDef(
78-
Bind(defaultSym, Underscore(selector.tpe.widen)),
79-
EmptyTree,
80-
Literal(Constant(false)))
81-
val annotated = Annotated(paramRef, New(ref(defn.UncheckedAnnotType)))
82-
cpy.Match(applyRhs)(annotated, cases.map(translateCase) :+ defaultCase)
83-
case _ =>
84-
tru
65+
66+
val applyOrElseFn = applyFn.copy(
67+
name = nme.applyOrElse,
68+
flags = Synthetic | Method,
69+
info = tpt.tpe.memberInfo(defn.PartialFunction_applyOrElse)).asTerm
70+
71+
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
72+
val tru = Literal(Constant(true))
73+
applyRhs match {
74+
case Match(selector, cases) =>
75+
assert(selector.symbol == param.symbol)
76+
val paramRef = paramRefss.head.head
77+
def translateCase(cdef: CaseDef)=
78+
cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn)
79+
val defaultSym = ctx.newSymbol(isDefinedAtFn, nme.WILDCARD, Synthetic, selector.tpe.widen)
80+
val defaultCase =
81+
CaseDef(
82+
Bind(defaultSym, Underscore(selector.tpe.widen)),
83+
EmptyTree,
84+
Literal(Constant(false)))
85+
val annotated = Annotated(paramRef, New(ref(defn.UncheckedAnnotType)))
86+
cpy.Match(applyRhs)(annotated, cases.map(translateCase) :+ defaultCase)
87+
.subst(param.symbol :: Nil, paramRef.symbol :: Nil)
88+
// Needed because a partial function can be written as:
89+
// x => x match { case "foo" if foo(x) => x }
90+
// And we need to update all references to 'x'
91+
case _ =>
92+
tru
93+
}
8594
}
95+
96+
def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
97+
val List(paramRef, defaultRef) = paramRefss.head
98+
applyRhs match {
99+
case Match(selector, cases) =>
100+
assert(selector.symbol == param.symbol)
101+
def translateCase(cdef: CaseDef) =
102+
cdef.changeOwner(applyFn, applyOrElseFn)
103+
val defaultSym = ctx.newSymbol(applyOrElseFn, nme.WILDCARD, Synthetic, selector.tpe.widen)
104+
val defaultCase =
105+
CaseDef(
106+
Bind(defaultSym, Underscore(selector.tpe.widen)),
107+
EmptyTree,
108+
defaultRef.select(nme.apply).appliedTo(paramRef))
109+
val annotated = Annotated(paramRef, New(ref(defn.UncheckedAnnotType)))
110+
cpy.Match(applyRhs)(annotated, cases.map(translateCase) :+ defaultCase)
111+
.subst(param.symbol :: Nil, paramRef.symbol :: Nil)
112+
// Same as for isDefinedAtRhs. See comment above
113+
case _ =>
114+
ref(applyFn).appliedTo(paramRef)
115+
}
116+
}
117+
86118
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
87-
val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn), List(nme.apply, nme.isDefinedAt))
88-
cpy.Block(tree)(List(applyDef, isDefinedAtDef), anonCls)
119+
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))
120+
121+
val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn, applyOrElseFn), List(nme.apply, nme.isDefinedAt, nme.applyOrElse))
122+
cpy.Block(tree)(List(applyDef, isDefinedAtDef, applyOrElseDef), anonCls)
89123
}
90124

91125
private def checkRefinements(tpe: Type, pos: Position)(implicit ctx: Context): Type = tpe match {

tests/pos/i4177.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
class Test {
2+
3+
object Foo { def unapply(x: Int) = if (x == 2) Some(x.toString) else None }
4+
5+
def test: Unit = {
6+
val a: PartialFunction[Int, String] = { case Foo(x) => x }
7+
val b: PartialFunction[Int, String] = { case x => x.toString }
8+
val c: PartialFunction[Int, String] = { x => x.toString }
9+
val d: PartialFunction[Int, String] = x => x.toString
10+
11+
val e: PartialFunction[String, String] = { case x @ "abc" => x }
12+
val f: PartialFunction[String, String] = x => x match { case "abc" => x }
13+
val g: PartialFunction[String, String] = x => x match { case "abc" if x.isEmpty => x }
14+
}
15+
}

tests/run/i4177.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
object Test {
2+
private[this] var count = 0
3+
4+
def test(x: Int) = { count += 1; true }
5+
6+
object Foo {
7+
def unapply(x: Int): Option[Int] = { count += 1; Some(x) }
8+
}
9+
10+
def main(args: Array[String]): Unit = {
11+
val res = List(1, 2).collect { case x if test(x) => x }
12+
assert(count == 2)
13+
14+
count = 0
15+
val res2 = List(1, 2).collect { case Foo(x) => x }
16+
assert(count == 2)
17+
}
18+
}

0 commit comments

Comments
 (0)