Skip to content

Commit 89e0f2a

Browse files
committed
Always generate a partial function from a lambda
`scalac` no longer complains, neither should `dotc`. I verified that the output of `i4241.scala` is the same for both `scalac` and `dotc`. Fixes #12661
1 parent 56abade commit 89e0f2a

File tree

3 files changed

+91
-73
lines changed

3 files changed

+91
-73
lines changed

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 67 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
package dotty.tools.dotc
1+
package dotty.tools
2+
package dotc
23
package transform
34

45
import core._
@@ -7,6 +8,7 @@ import MegaPhase._
78
import SymUtils._
89
import NullOpsDecorator._
910
import ast.Trees._
11+
import ast.untpd
1012
import reporting._
1113
import dotty.tools.dotc.util.Spans.Span
1214

@@ -113,68 +115,72 @@ class ExpandSAMs extends MiniPhase:
113115
}
114116

115117
val closureDef(anon @ DefDef(_, List(List(param)), _, _)) = tree
116-
anon.rhs match {
117-
case PartialFunctionRHS(pf) =>
118-
val anonSym = anon.symbol
119-
val anonTpe = anon.tpe.widen
120-
val parents = List(
121-
defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
122-
defn.SerializableType)
123-
val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.span)
124-
125-
def overrideSym(sym: Symbol) = sym.copy(
126-
owner = pfSym,
127-
flags = Synthetic | Method | Final | Override,
128-
info = tpe.memberInfo(sym),
129-
coord = tree.span).asTerm.entered
130-
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
131-
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
132-
133-
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = {
134-
val selector = tree.selector
135-
val selectorTpe = selector.tpe.widen
136-
val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD, Synthetic | Case, selectorTpe)
137-
val defaultCase =
138-
CaseDef(
139-
Bind(defaultSym, Underscore(selectorTpe)),
140-
EmptyTree,
141-
defaultValue)
142-
val unchecked = selector.annotated(New(ref(defn.UncheckedAnnot.typeRef)))
143-
cpy.Match(tree)(unchecked, cases :+ defaultCase)
144-
.subst(param.symbol :: Nil, pfParam :: Nil)
145-
// Needed because a partial function can be written as:
146-
// param => param match { case "foo" if foo(param) => param }
147-
// And we need to update all references to 'param'
148-
}
149-
150-
def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = {
151-
val tru = Literal(Constant(true))
152-
def translateCase(cdef: CaseDef) =
153-
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
154-
val paramRef = paramRefss.head.head
155-
val defaultValue = Literal(Constant(false))
156-
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
157-
}
158-
159-
def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = {
160-
val List(paramRef, defaultRef) = paramRefss(1)
161-
def translateCase(cdef: CaseDef) =
162-
cdef.changeOwner(anonSym, applyOrElseFn)
163-
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
164-
translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
165-
}
166-
167-
val constr = newConstructor(pfSym, Synthetic, Nil, Nil).entered
168-
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
169-
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
170-
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
171-
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))
172-
118+
119+
// The right hand side from which to construct the partial function. This is always a Match.
120+
// If the original rhs is already a Match (possibly in braces), return that.
121+
// Otherwise construct a match `x match case _ => rhs` where `x` is the parameter of the closure.
122+
def partialFunRHS(tree: Tree): Match = tree match
123+
case m: Match => m
124+
case Block(Nil, expr) => partialFunRHS(expr)
173125
case _ =>
174-
val found = tpe.baseType(defn.Function1)
175-
report.error(TypeMismatch(found, tpe), tree.srcPos)
176-
tree
126+
Match(ref(param.symbol),
127+
CaseDef(untpd.Ident(nme.WILDCARD).withType(param.symbol.info), EmptyTree, tree) :: Nil)
128+
129+
val pfRHS = partialFunRHS(anon.rhs)
130+
val anonSym = anon.symbol
131+
val anonTpe = anon.tpe.widen
132+
val parents = List(
133+
defn.AbstractPartialFunctionClass.typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
134+
defn.SerializableType)
135+
val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS, Synthetic | Final, parents, coord = tree.span)
136+
137+
def overrideSym(sym: Symbol) = sym.copy(
138+
owner = pfSym,
139+
flags = Synthetic | Method | Final | Override,
140+
info = tpe.memberInfo(sym),
141+
coord = tree.span).asTerm.entered
142+
val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt)
143+
val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse)
144+
145+
def translateMatch(tree: Match, pfParam: Symbol, cases: List[CaseDef], defaultValue: Tree)(using Context) = {
146+
val selector = tree.selector
147+
val selectorTpe = selector.tpe.widen
148+
val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD, Synthetic | Case, selectorTpe)
149+
val defaultCase =
150+
CaseDef(
151+
Bind(defaultSym, Underscore(selectorTpe)),
152+
EmptyTree,
153+
defaultValue)
154+
val unchecked = selector.annotated(New(ref(defn.UncheckedAnnot.typeRef)))
155+
cpy.Match(tree)(unchecked, cases :+ defaultCase)
156+
.subst(param.symbol :: Nil, pfParam :: Nil)
157+
// Needed because a partial function can be written as:
158+
// param => param match { case "foo" if foo(param) => param }
159+
// And we need to update all references to 'param'
160+
}
161+
162+
def isDefinedAtRhs(paramRefss: List[List[Tree]])(using Context) = {
163+
val tru = Literal(Constant(true))
164+
def translateCase(cdef: CaseDef) =
165+
cpy.CaseDef(cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
166+
val paramRef = paramRefss.head.head
167+
val defaultValue = Literal(Constant(false))
168+
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
177169
}
170+
171+
def applyOrElseRhs(paramRefss: List[List[Tree]])(using Context) = {
172+
val List(paramRef, defaultRef) = paramRefss(1)
173+
def translateCase(cdef: CaseDef) =
174+
cdef.changeOwner(anonSym, applyOrElseFn)
175+
val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
176+
translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
177+
}
178+
179+
val constr = newConstructor(pfSym, Synthetic, Nil, Nil).entered
180+
val isDefinedAtDef = transformFollowingDeep(DefDef(isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
181+
val applyOrElseDef = transformFollowingDeep(DefDef(applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
182+
val pfDef = ClassDef(pfSym, DefDef(constr), List(isDefinedAtDef, applyOrElseDef))
183+
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))
178184
}
179185

180186
private def checkRefinements(tpe: Type, tree: Tree)(using Context): Type = tpe.dealias match {

tests/neg/i4241.scala

Lines changed: 0 additions & 12 deletions
This file was deleted.

tests/run/i4241.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
object Text extends App {
2+
val a: PartialFunction[Int, Int] = { case x => x }
3+
val b: PartialFunction[Int, Int] = x => x match { case 1 => 1; case 2 => 2 }
4+
val c: PartialFunction[Int, Int] = x => { x match { case 1 => 1 } }
5+
val d: PartialFunction[Int, Int] = x => { { x match { case 1 => 1 } } }
6+
7+
val e: PartialFunction[Int, Int] = x => { println("foo"); x match { case 1 => 1 } }
8+
val f: PartialFunction[Int, Int] = x => x
9+
val g: PartialFunction[Int, String] = { x => x.toString }
10+
val h: PartialFunction[Int, String] = _.toString
11+
assert(a.isDefinedAt(2))
12+
assert(b.isDefinedAt(2))
13+
assert(!b.isDefinedAt(3))
14+
assert(c.isDefinedAt(1))
15+
assert(!c.isDefinedAt(2))
16+
assert(d.isDefinedAt(1))
17+
assert(!d.isDefinedAt(2))
18+
assert(e.isDefinedAt(2))
19+
assert(f.isDefinedAt(2))
20+
assert(g.isDefinedAt(2))
21+
assert(h.isDefinedAt(2))
22+
}
23+
24+

0 commit comments

Comments
 (0)