Skip to content

Fix #4177: Generate optimised applyOrElse implementation for partial function literals #4245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,16 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
val parents1 =
if (parents.head.classSymbol.is(Trait)) parents.head.parents.head :: parents
else parents
val cls = ctx.newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic, parents1,
val cls = ctx.newNormalizedClassSymbol(owner, tpnme.ANON_CLASS, Synthetic | Final, parents1,
coord = fns.map(_.pos).reduceLeft(_ union _))
val constr = ctx.newConstructor(cls, Synthetic, Nil, Nil).entered
def forwarder(fn: TermSymbol, name: TermName) = {
val fwdMeth = fn.copy(cls, name, Synthetic | Method).entered.asTerm
DefDef(fwdMeth, prefss => ref(fn).appliedToArgss(prefss))
var flags = Synthetic | Method | Final
def isOverriden(denot: SingleDenotation) = fn.info.overrides(denot.info, matchLoosely = true)
val isOverride = parents.exists(_.member(name).hasAltWith(isOverriden))
if (isOverride) flags = flags | Override
val fwdMeth = fn.copy(cls, name, flags).entered.asTerm
polyDefDef(fwdMeth, tprefs => prefss => ref(fn).appliedToTypes(tprefs).appliedToArgss(prefss))
}
val forwarders = (fns, methNames).zipped.map(forwarder)
val cdef = ClassDef(cls, DefDef(constr), forwarders)
Expand Down
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -585,8 +585,14 @@ class Definitions {

lazy val PartialFunctionType: TypeRef = ctx.requiredClassRef("scala.PartialFunction")
def PartialFunctionClass(implicit ctx: Context) = PartialFunctionType.symbol.asClass
lazy val PartialFunction_isDefinedAtR = PartialFunctionClass.requiredMethodRef(nme.isDefinedAt)
def PartialFunction_isDefinedAt(implicit ctx: Context) = PartialFunction_isDefinedAtR.symbol
lazy val PartialFunction_applyOrElseR = PartialFunctionClass.requiredMethodRef(nme.applyOrElse)
def PartialFunction_applyOrElse(implicit ctx: Context) = PartialFunction_applyOrElseR.symbol

lazy val AbstractPartialFunctionType: TypeRef = ctx.requiredClassRef("scala.runtime.AbstractPartialFunction")
def AbstractPartialFunctionClass(implicit ctx: Context) = AbstractPartialFunctionType.symbol.asClass

lazy val FunctionXXLType: TypeRef = ctx.requiredClassRef("scala.FunctionXXL")
def FunctionXXLClass(implicit ctx: Context) = FunctionXXLType.symbol.asClass

Expand Down
110 changes: 72 additions & 38 deletions compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ import MegaPhase._
import SymUtils._
import ast.untpd
import ast.Trees._
import dotty.tools.dotc.reporting.diagnostic.messages.TypeMismatch
import dotty.tools.dotc.util.Positions.Position

/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
* These fall into five categories
*
* 1. Partial function closures, we need to generate a isDefinedAt method for these.
* 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these.
* 2. Closures implementing non-trait classes.
* 3. Closures implementing classes that inherit from a class other than Object
* (a lambda cannot not be a run-time subtype of such a class)
Expand All @@ -35,8 +36,8 @@ class ExpandSAMs extends MiniPhase {
tpt.tpe match {
case NoType => tree // it's a plain function
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
checkRefinements(tpe, fn.pos)
toPartialFunction(tree)
val tpe1 = checkRefinements(tpe, fn.pos)
toPartialFunction(tree, tpe1)
case tpe @ SAMType(_) if isPlatformSam(tpe.classSymbol.asClass) =>
checkRefinements(tpe, fn.pos)
tree
Expand All @@ -50,50 +51,83 @@ class ExpandSAMs extends MiniPhase {
tree
}

private def toPartialFunction(tree: Block)(implicit ctx: Context): Tree = {
val Block(
(applyDef @ DefDef(nme.ANON_FUN, Nil, List(List(param)), _, _)) :: Nil,
Closure(_, _, tpt)) = tree
val applyRhs: Tree = applyDef.rhs
val applyFn = applyDef.symbol.asTerm

val MethodTpe(paramNames, paramTypes, _) = applyFn.info
val isDefinedAtFn = applyFn.copy(
name = nme.isDefinedAt,
flags = Synthetic | Method,
info = MethodType(paramNames, paramTypes, defn.BooleanType)).asTerm
val tru = Literal(Constant(true))
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = applyRhs match {
case Match(selector, cases) =>
assert(selector.symbol == param.symbol)
val paramRef = paramRefss.head.head
// Again, the alternative
// val List(List(paramRef)) = paramRefs
// fails with a similar self instantiation error
def translateCase(cdef: CaseDef): CaseDef =
cpy.CaseDef(cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn)
val defaultSym = ctx.newSymbol(isDefinedAtFn, nme.WILDCARD, Synthetic, selector.tpe.widen)
val defaultCase =
CaseDef(
Bind(defaultSym, Underscore(selector.tpe.widen)),
EmptyTree,
Literal(Constant(false)))
val annotated = Annotated(paramRef, New(ref(defn.UncheckedAnnotType)))
cpy.Match(applyRhs)(annotated, cases.map(translateCase) :+ defaultCase)
private def toPartialFunction(tree: Block, tpe: Type)(implicit ctx: Context): Tree = {
// /** An extractor for match, either contained in a block or standalone. */
object PartialFunctionRHS {
def unapply(tree: Tree): Option[Match] = tree match {
case Block(Nil, expr) => unapply(expr)
case m: Match => Some(m)
case _ => None
}
}

val closureDef(anon @ DefDef(_, _, List(List(param)), _, _)) = tree
anon.rhs match {
case PartialFunctionRHS(pf) =>
val anonSym = anon.symbol

def overrideSym(sym: Symbol) = sym.copy(
owner = anonSym.owner,
flags = Synthetic | Method | Final,
info = tpe.memberInfo(sym),
coord = tree.pos).asTerm
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)

def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = {
val selector = tree.selector
val selectorTpe = selector.tpe.widen
val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe)
val defaultCase =
CaseDef(
Bind(defaultSym, Underscore(selectorTpe)),
EmptyTree,
defaultValue)
val unchecked = Annotated(selector, New(ref(defn.UncheckedAnnotType)))
cpy.Match(tree)(unchecked, cases :+ defaultCase)
.subst(param.symbol :: Nil, pfParam :: Nil)
// Needed because a partial function can be written as:
// param => param match { case "foo" if foo(param) => param }
// And we need to update all references to 'param'
}

def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
val tru = Literal(Constant(true))
def translateCase(cdef: CaseDef) =
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
val paramRef = paramRefss.head.head
val defaultValue = Literal(Constant(false))
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
}

def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
val List(paramRef, defaultRef) = paramRefss.head
def translateCase(cdef: CaseDef) =
cdef.changeOwner(anonSym, applyOrElseFn)
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
}

val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))

val parent = defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos)
val anonCls = AnonClass(parent :: Nil, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse))
cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls)

case _ =>
tru
val found = tpe.baseType(defn.FunctionClass(1))
ctx.error(TypeMismatch(found, tpe), tree.pos)
tree
}
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
val anonCls = AnonClass(tpt.tpe :: Nil, List(applyFn, isDefinedAtFn), List(nme.apply, nme.isDefinedAt))
cpy.Block(tree)(List(applyDef, isDefinedAtDef), anonCls)
}

private def checkRefinements(tpe: Type, pos: Position)(implicit ctx: Context): Type = tpe.dealias match {
case RefinedType(parent, name, _) =>
if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement
ctx.error("Lambda does not define " + name, pos)
checkRefinements(parent, pos)
case _ =>
case tpe =>
tpe
}

Expand Down
12 changes: 12 additions & 0 deletions tests/neg/i4241.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
class Test {
def test: Unit = {
val a: PartialFunction[Int, Int] = { case x => x }
val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case _ => 2 }
val c: PartialFunction[Int, Int] = x => { x match { case y => y } }
val d: PartialFunction[Int, Int] = x => { { x match { case y => y } } }

val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case y => y } } // error
val f: PartialFunction[Int, Int] = x => x // error
val g: PartialFunction[Int, String] = { x => x.toString } // error
}
}
18 changes: 18 additions & 0 deletions tests/pos/i4177.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
class Test {

object Foo { def unapply(x: Int) = if (x == 2) Some(x.toString) else None }

def test: Unit = {
val a: PartialFunction[Int, String] = { case Foo(x) => x }
val b: PartialFunction[Int, String] = { case x => x.toString }

val e: PartialFunction[String, String] = { case x @ "abc" => x }
val f: PartialFunction[String, String] = x => x match { case "abc" => x }
val g: PartialFunction[String, String] = x => x match { case "abc" if x.isEmpty => x }

type P = PartialFunction[String,String]
val h: P = { case x => x.toString }

val i: PartialFunction[Int, Int] = { x => x match { case x => x } }
}
}
18 changes: 18 additions & 0 deletions tests/run/i4177.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
object Test {
private[this] var count = 0

def test(x: Int) = { count += 1; true }

object Foo {
def unapply(x: Int): Option[Int] = { count += 1; Some(x) }
}

def main(args: Array[String]): Unit = {
val res = List(1, 2).collect { case x if test(x) => x }
assert(count == 2)

count = 0
val res2 = List(1, 2).collect { case Foo(x) => x }
assert(count == 2)
}
}
11 changes: 8 additions & 3 deletions tests/run/partialFunctions.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
object Test {

def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1)
def takesPartialFunction(a: PartialFunction[Int, Int]) = a(1)
class Foo(val field: Option[Int])

def main(args: Array[String]): Unit = {
val partialFunction: PartialFunction[Int, Int] = {case a: Int => a}
val p1: PartialFunction[Int, Int] = { case a: Int => a }
assert(takesPartialFunction(p1) == 1)

assert(takesPartialFunction(partialFunction) == 1)
val p2: PartialFunction[Foo, Int] =
foo => foo.field match { case Some(x) => x }
assert(p2.isDefinedAt(new Foo(Some(1))))
assert(!p2.isDefinedAt(new Foo(None)))
}
}