Skip to content

Commit 770320b

Browse files
authored
Merge pull request #4245 from dotty-staging/fix-4177
Fix #4177: Generate optimised applyOrElse implementation for partial function literals
2 parents d603361 + 8a91774 commit 770320b

File tree

7 files changed

+141
-44
lines changed

7 files changed

+141
-44
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -282,12 +282,16 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
282282
val parents1 =
283283
if (parents.head.classSymbol.is(Trait)) parents.head.parents.head :: parents
284284
else parents
285-
val cls = ctx.newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic, parents1,
285+
val cls = ctx.newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic | Final, parents1,
286286
coord = fns.map(_.pos).reduceLeft(_ union _))
287287
val constr = ctx.newConstructor(cls, Synthetic, Nil, Nil).entered
288288
def forwarder(fn: TermSymbol, name: TermName) = {
289-
val fwdMeth = fn.copy(cls, name, Synthetic | Method).entered.asTerm
290-
DefDef(fwdMeth, prefss => ref(fn).appliedToArgss(prefss))
289+
var flags = Synthetic | Method | Final
290+
def isOverriden(denot: SingleDenotation) = fn.info.overrides(denot.info, matchLoosely = true)
291+
val isOverride = parents.exists(_.member(name).hasAltWith(isOverriden))
292+
if (isOverride) flags = flags | Override
293+
val fwdMeth = fn.copy(cls, name, flags).entered.asTerm
294+
polyDefDef(fwdMeth, tprefs => prefss => ref(fn).appliedToTypes(tprefs).appliedToArgss(prefss))
291295
}
292296
val forwarders = (fns, methNames).zipped.map(forwarder)
293297
val cdef = ClassDef(cls, DefDef(constr), forwarders)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,14 @@ class Definitions {
575575

576576
lazy val PartialFunctionType: TypeRef = ctx.requiredClassRef("scala.PartialFunction")
577577
def PartialFunctionClass(implicit ctx: Context) = PartialFunctionType.symbol.asClass
578+
lazy val PartialFunction_isDefinedAtR = PartialFunctionClass.requiredMethodRef(nme.isDefinedAt)
579+
def PartialFunction_isDefinedAt(implicit ctx: Context) = PartialFunction_isDefinedAtR.symbol
580+
lazy val PartialFunction_applyOrElseR = PartialFunctionClass.requiredMethodRef(nme.applyOrElse)
581+
def PartialFunction_applyOrElse(implicit ctx: Context) = PartialFunction_applyOrElseR.symbol
582+
578583
lazy val AbstractPartialFunctionType: TypeRef = ctx.requiredClassRef("scala.runtime.AbstractPartialFunction")
579584
def AbstractPartialFunctionClass(implicit ctx: Context) = AbstractPartialFunctionType.symbol.asClass
585+
580586
lazy val FunctionXXLType: TypeRef = ctx.requiredClassRef("scala.FunctionXXL")
581587
def FunctionXXLClass(implicit ctx: Context) = FunctionXXLType.symbol.asClass
582588

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 72 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@ import MegaPhase._
88
import SymUtils._
99
import ast.untpd
1010
import ast.Trees._
11+
import dotty.tools.dotc.reporting.diagnostic.messages.TypeMismatch
1112
import dotty.tools.dotc.util.Positions.Position
1213

1314
/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
1415
* These fall into five categories
1516
*
16-
* 1. Partial function closures, we need to generate a isDefinedAt method for these.
17+
* 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these.
1718
* 2. Closures implementing non-trait classes.
1819
* 3. Closures implementing classes that inherit from a class other than Object
1920
* (a lambda cannot not be a run-time subtype of such a class)
@@ -35,8 +36,8 @@ class ExpandSAMs extends MiniPhase {
3536
tpt.tpe match {
3637
case NoType => tree // it's a plain function
3738
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
38-
checkRefinements(tpe, fn.pos)
39-
toPartialFunction(tree)
39+
val tpe1 = checkRefinements(tpe, fn.pos)
40+
toPartialFunction(tree, tpe1)
4041
case tpe @ SAMType(_) if isPlatformSam(tpe.classSymbol.asClass) =>
4142
checkRefinements(tpe, fn.pos)
4243
tree
@@ -50,50 +51,83 @@ class ExpandSAMs extends MiniPhase {
5051
tree
5152
}
5253

53-
private def toPartialFunction(tree: Block)(implicit ctx: Context): Tree = {
54-
val Block(
55-
(applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil,
56-
Closure(_, _, tpt)) = tree
57-
val applyRhs: Tree = applyDef.rhs
58-
val applyFn = applyDef.symbol.asTerm
59-
60-
val MethodTpe(paramNames, paramTypes, _) = applyFn.info
61-
val isDefinedAtFn = applyFn.copy(
62-
name = nme.isDefinedAt,
63-
flags = Synthetic | Method,
64-
info = MethodType(paramNames, paramTypes, defn.BooleanType)).asTerm
65-
val tru = Literal(Constant(true))
66-
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = applyRhs match {
67-
case Match(selector, cases) =>
68-
assert(selector.symbol == param.symbol)
69-
val paramRef = paramRefss.head.head
70-
// Again, the alternative
71-
// val List(List(paramRef)) = paramRefs
72-
// fails with a similar self instantiation error
73-
def translateCase(cdef: CaseDef): CaseDef =
74-
cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn)
75-
val defaultSym = ctx.newSymbol(isDefinedAtFn, nme.WILDCARD, Synthetic, selector.tpe.widen)
76-
val defaultCase =
77-
CaseDef(
78-
Bind(defaultSym, Underscore(selector.tpe.widen)),
79-
EmptyTree,
80-
Literal(Constant(false)))
81-
val annotated = Annotated(paramRef, New(ref(defn.UncheckedAnnotType)))
82-
cpy.Match(applyRhs)(annotated, cases.map(translateCase) :+ defaultCase)
54+
private def toPartialFunction(tree: Block, tpe: Type)(implicit ctx: Context): Tree = {
55+
// /** An extractor for match, either contained in a block or standalone. */
56+
object PartialFunctionRHS {
57+
def unapply(tree: Tree): Option[Match] = tree match {
58+
case Block(Nil, expr) => unapply(expr)
59+
case m: Match => Some(m)
60+
case _ => None
61+
}
62+
}
63+
64+
val closureDef(anon @ DefDef(_, _, List(List(param)), _, _)) = tree
65+
anon.rhs match {
66+
case PartialFunctionRHS(pf) =>
67+
val anonSym = anon.symbol
68+
69+
def overrideSym(sym: Symbol) = sym.copy(
70+
owner = anonSym.owner,
71+
flags = Synthetic | Method | Final,
72+
info = tpe.memberInfo(sym),
73+
coord = tree.pos).asTerm
74+
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
75+
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
76+
77+
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = {
78+
val selector = tree.selector
79+
val selectorTpe = selector.tpe.widen
80+
val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe)
81+
val defaultCase =
82+
CaseDef(
83+
Bind(defaultSym, Underscore(selectorTpe)),
84+
EmptyTree,
85+
defaultValue)
86+
val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType)))
87+
cpy.Match(tree)(unchecked, cases :+ defaultCase)
88+
.subst(param.symbol :: Nil, pfParam :: Nil)
89+
// Needed because a partial function can be written as:
90+
// param => param match { case "foo" if foo(param) => param }
91+
// And we need to update all references to 'param'
92+
}
93+
94+
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
95+
val tru = Literal(Constant(true))
96+
def translateCase(cdef: CaseDef) =
97+
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
98+
val paramRef = paramRefss.head.head
99+
val defaultValue = Literal(Constant(false))
100+
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
101+
}
102+
103+
def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
104+
val List(paramRef, defaultRef) = paramRefss.head
105+
def translateCase(cdef: CaseDef) =
106+
cdef.changeOwner(anonSym, applyOrElseFn)
107+
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
108+
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
109+
}
110+
111+
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
112+
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))
113+
114+
val parent = defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos)
115+
val anonCls = AnonClass(parent :: Nil, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse))
116+
cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls)
117+
83118
case _ =>
84-
tru
119+
val found = tpe.baseType(defn.FunctionClass(1))
120+
ctx.error(TypeMismatch(found, tpe), tree.pos)
121+
tree
85122
}
86-
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
87-
val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn), List(nme.apply, nme.isDefinedAt))
88-
cpy.Block(tree)(List(applyDef, isDefinedAtDef), anonCls)
89123
}
90124

91125
private def checkRefinements(tpe: Type, pos: Position)(implicit ctx: Context): Type = tpe.dealias match {
92126
case RefinedType(parent, name, _) =>
93127
if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement
94128
ctx.error("Lambda does not define " + name, pos)
95129
checkRefinements(parent, pos)
96-
case _ =>
130+
case tpe =>
97131
tpe
98132
}
99133

tests/neg/i4241.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
class Test {
2+
def test: Unit = {
3+
val a: PartialFunction[Int, Int] = { case x => x }
4+
val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case _ => 2 }
5+
val c: PartialFunction[Int, Int] = x => { x match { case y => y } }
6+
val d: PartialFunction[Int, Int] = x => { { x match { case y => y } } }
7+
8+
val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case y => y } } // error
9+
val f: PartialFunction[Int, Int] = x => x // error
10+
val g: PartialFunction[Int, String] = { x => x.toString } // error
11+
}
12+
}

tests/pos/i4177.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
class Test {
2+
3+
object Foo { def unapply(x: Int) = if (x == 2) Some(x.toString) else None }
4+
5+
def test: Unit = {
6+
val a: PartialFunction[Int, String] = { case Foo(x) => x }
7+
val b: PartialFunction[Int, String] = { case x => x.toString }
8+
9+
val e: PartialFunction[String, String] = { case x @ "abc" => x }
10+
val f: PartialFunction[String, String] = x => x match { case "abc" => x }
11+
val g: PartialFunction[String, String] = x => x match { case "abc" if x.isEmpty => x }
12+
13+
type P = PartialFunction[String,String]
14+
val h: P = { case x => x.toString }
15+
16+
val i: PartialFunction[Int, Int] = { x => x match { case x => x } }
17+
}
18+
}

tests/run/i4177.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
object Test {
2+
private[this] var count = 0
3+
4+
def test(x: Int) = { count += 1; true }
5+
6+
object Foo {
7+
def unapply(x: Int): Option[Int] = { count += 1; Some(x) }
8+
}
9+
10+
def main(args: Array[String]): Unit = {
11+
val res = List(1, 2).collect { case x if test(x) => x }
12+
assert(count == 2)
13+
14+
count = 0
15+
val res2 = List(1, 2).collect { case Foo(x) => x }
16+
assert(count == 2)
17+
}
18+
}

tests/run/partialFunctions.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
object Test {
22

3-
def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1)
3+
def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1)
4+
class Foo(val field: Option[Int])
45

56
def main(args: Array[String]): Unit = {
6-
val partialFunction: PartialFunction[Int, Int] = {case a: Int => a}
7+
val p1: PartialFunction[Int, Int] = { case a: Int => a }
8+
assert(takesPartialFunction(p1) == 1)
79

8-
assert(takesPartialFunction(partialFunction) == 1)
10+
val p2: PartialFunction[Foo, Int] =
11+
foo => foo.field match { case Some(x) => x }
12+
assert(p2.isDefinedAt(new Foo(Some(1))))
13+
assert(!p2.isDefinedAt(new Foo(None)))
914
}
1015
}

0 commit comments

Comments
 (0)