Skip to content

Commit 58234f3

Browse files
committed
Fix #4446: Inline implementation of PF methods into its anonymous class
1 parent 80cd7e2 commit 58234f3

File tree

2 files changed

+33
-12
lines changed

2 files changed

+33
-12
lines changed

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,18 @@ class ExpandSAMs extends MiniPhase {
7171
case PartialFunctionRHS(pf) =>
7272
val anonSym = anon.symbol
7373

74+
val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType)
75+
val pfSym = ctx.newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.pos)
76+
7477
def overrideSym(sym: Symbol) = sym.copy(
75-
owner = anonSym.owner,
76-
flags = Synthetic | Method | Final,
78+
owner = pfSym,
79+
flags = Synthetic | Method | Final | Override,
7780
info = tpe.memberInfo(sym),
78-
coord = tree.pos).asTerm
81+
coord = tree.pos).asTerm.entered
7982
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
8083
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
8184

82-
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree) = {
85+
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(implicit ctx: Context) = {
8386
val selector = tree.selector
8487
val selectorTpe = selector.tpe.widen
8588
val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD, Synthetic, selectorTpe)
@@ -96,7 +99,7 @@ class ExpandSAMs extends MiniPhase {
9699
// And we need to update all references to 'param'
97100
}
98101

99-
def isDefinedAtRhs(paramRefss: List[List[Tree]]) = {
102+
def isDefinedAtRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = {
100103
val tru = Literal(Constant(true))
101104
def translateCase(cdef: CaseDef) =
102105
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
@@ -105,20 +108,19 @@ class ExpandSAMs extends MiniPhase {
105108
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
106109
}
107110

108-
def applyOrElseRhs(paramRefss: List[List[Tree]]) = {
111+
def applyOrElseRhs(paramRefss: List[List[Tree]])(implicit ctx: Context) = {
109112
val List(paramRef, defaultRef) = paramRefss.head
110113
def translateCase(cdef: CaseDef) =
111114
cdef.changeOwner(anonSym, applyOrElseFn)
112115
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
113116
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
114117
}
115118

116-
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)))
117-
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)))
118-
119-
val parents = List(defn.AbstractPartialFunctionType.appliedTo(tpe.argInfos), defn.SerializableType)
120-
val anonCls = AnonClass(parents, List(isDefinedAtFn, applyOrElseFn), List(nme.isDefinedAt, nme.applyOrElse))
121-
cpy.Block(tree)(List(isDefinedAtDef, applyOrElseDef), anonCls)
119+
val constr = ctx.newConstructor(pfSym, Synthetic, Nil, Nil).entered
120+
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(ctx.withOwner(isDefinedAtFn))))
121+
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(ctx.withOwner(applyOrElseFn))))
122+
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
123+
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))
122124

123125
case _ =>
124126
val found = tpe.baseType(defn.FunctionClass(1))

tests/run/i4446.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
class Foo {
2+
def foo: PartialFunction[Int, Int] = { case x => x + 1 }
3+
}
4+
5+
object Test {
6+
def serializeDeserialize[T <: AnyRef](obj: T): T = {
7+
import java.io._
8+
val buffer = new ByteArrayOutputStream
9+
val out = new ObjectOutputStream(buffer)
10+
out.writeObject(obj)
11+
val in = new ObjectInputStream(new ByteArrayInputStream(buffer.toByteArray))
12+
in.readObject.asInstanceOf[T]
13+
}
14+
15+
def main(args: Array[String]): Unit = {
16+
val adder = serializeDeserialize((new Foo).foo)
17+
assert(adder(1) == 2)
18+
}
19+
}

0 commit comments

Comments
 (0)