@@ -8,12 +8,13 @@ import MegaPhase._
8
8
import SymUtils ._
9
9
import ast .untpd
10
10
import ast .Trees ._
11
+ import dotty .tools .dotc .reporting .diagnostic .messages .TypeMismatch
11
12
import dotty .tools .dotc .util .Positions .Position
12
13
13
14
/** Expand SAM closures that cannot be represented by the JVM as lambdas to anonymous classes.
14
15
* These fall into five categories
15
16
*
16
- * 1. Partial function closures, we need to generate a isDefinedAt method for these.
17
+ * 1. Partial function closures, we need to generate isDefinedAt and applyOrElse methods for these.
17
18
* 2. Closures implementing non-trait classes.
18
19
* 3. Closures implementing classes that inherit from a class other than Object
19
20
* (a lambda cannot not be a run-time subtype of such a class)
@@ -35,8 +36,8 @@ class ExpandSAMs extends MiniPhase {
35
36
tpt.tpe match {
36
37
case NoType => tree // it's a plain function
37
38
case tpe @ SAMType (_) if tpe.isRef(defn.PartialFunctionClass ) =>
38
- checkRefinements(tpe, fn.pos)
39
- toPartialFunction(tree)
39
+ val tpe1 = checkRefinements(tpe, fn.pos)
40
+ toPartialFunction(tree, tpe1 )
40
41
case tpe @ SAMType (_) if isPlatformSam(tpe.classSymbol.asClass) =>
41
42
checkRefinements(tpe, fn.pos)
42
43
tree
@@ -50,50 +51,83 @@ class ExpandSAMs extends MiniPhase {
50
51
tree
51
52
}
52
53
53
- private def toPartialFunction (tree : Block )(implicit ctx : Context ): Tree = {
54
- val Block (
55
- (applyDef @ DefDef (nme.ANON_FUN , Nil , List (List (param)), _, _)) :: Nil ,
56
- Closure (_, _, tpt)) = tree
57
- val applyRhs : Tree = applyDef.rhs
58
- val applyFn = applyDef.symbol.asTerm
59
-
60
- val MethodTpe (paramNames, paramTypes, _) = applyFn.info
61
- val isDefinedAtFn = applyFn.copy(
62
- name = nme.isDefinedAt,
63
- flags = Synthetic | Method ,
64
- info = MethodType (paramNames, paramTypes, defn.BooleanType )).asTerm
65
- val tru = Literal (Constant (true ))
66
- def isDefinedAtRhs (paramRefss : List [List [Tree ]]) = applyRhs match {
67
- case Match (selector, cases) =>
68
- assert(selector.symbol == param.symbol)
69
- val paramRef = paramRefss.head.head
70
- // Again, the alternative
71
- // val List(List(paramRef)) = paramRefs
72
- // fails with a similar self instantiation error
73
- def translateCase (cdef : CaseDef ): CaseDef =
74
- cpy.CaseDef (cdef)(body = tru).changeOwner(applyFn, isDefinedAtFn)
75
- val defaultSym = ctx.newSymbol(isDefinedAtFn, nme.WILDCARD , Synthetic , selector.tpe.widen)
76
- val defaultCase =
77
- CaseDef (
78
- Bind (defaultSym, Underscore (selector.tpe.widen)),
79
- EmptyTree ,
80
- Literal (Constant (false )))
81
- val annotated = Annotated (paramRef, New (ref(defn.UncheckedAnnotType )))
82
- cpy.Match (applyRhs)(annotated, cases.map(translateCase) :+ defaultCase)
54
+ private def toPartialFunction (tree : Block , tpe : Type )(implicit ctx : Context ): Tree = {
55
+ // /** An extractor for match, either contained in a block or standalone. */
56
+ object PartialFunctionRHS {
57
+ def unapply (tree : Tree ): Option [Match ] = tree match {
58
+ case Block (Nil , expr) => unapply(expr)
59
+ case m : Match => Some (m)
60
+ case _ => None
61
+ }
62
+ }
63
+
64
+ val closureDef(anon @ DefDef (_, _, List (List (param)), _, _)) = tree
65
+ anon.rhs match {
66
+ case PartialFunctionRHS (pf) =>
67
+ val anonSym = anon.symbol
68
+
69
+ def overrideSym (sym : Symbol ) = sym.copy(
70
+ owner = anonSym.owner,
71
+ flags = Synthetic | Method | Final ,
72
+ info = tpe.memberInfo(sym),
73
+ coord = tree.pos).asTerm
74
+ val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt )
75
+ val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse )
76
+
77
+ def translateMatch (tree : Match , pfParam : Symbol , cases : List [CaseDef ], defaultValue : Tree ) = {
78
+ val selector = tree.selector
79
+ val selectorTpe = selector.tpe.widen
80
+ val defaultSym = ctx.newSymbol(pfParam.owner, nme.WILDCARD , Synthetic , selectorTpe)
81
+ val defaultCase =
82
+ CaseDef (
83
+ Bind (defaultSym, Underscore (selectorTpe)),
84
+ EmptyTree ,
85
+ defaultValue)
86
+ val unchecked = Annotated (selector, New (ref(defn.UncheckedAnnotType )))
87
+ cpy.Match (tree)(unchecked, cases :+ defaultCase)
88
+ .subst(param.symbol :: Nil , pfParam :: Nil )
89
+ // Needed because a partial function can be written as:
90
+ // param => param match { case "foo" if foo(param) => param }
91
+ // And we need to update all references to 'param'
92
+ }
93
+
94
+ def isDefinedAtRhs (paramRefss : List [List [Tree ]]) = {
95
+ val tru = Literal (Constant (true ))
96
+ def translateCase (cdef : CaseDef ) =
97
+ cpy.CaseDef (cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
98
+ val paramRef = paramRefss.head.head
99
+ val defaultValue = Literal (Constant (false ))
100
+ translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
101
+ }
102
+
103
+ def applyOrElseRhs (paramRefss : List [List [Tree ]]) = {
104
+ val List (paramRef, defaultRef) = paramRefss.head
105
+ def translateCase (cdef : CaseDef ) =
106
+ cdef.changeOwner(anonSym, applyOrElseFn)
107
+ val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
108
+ translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
109
+ }
110
+
111
+ val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)))
112
+ val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)))
113
+
114
+ val parent = defn.AbstractPartialFunctionType .appliedTo(tpe.argInfos)
115
+ val anonCls = AnonClass (parent :: Nil , List (isDefinedAtFn, applyOrElseFn), List (nme.isDefinedAt, nme.applyOrElse))
116
+ cpy.Block (tree)(List (isDefinedAtDef, applyOrElseDef), anonCls)
117
+
83
118
case _ =>
84
- tru
119
+ val found = tpe.baseType(defn.FunctionClass (1 ))
120
+ ctx.error(TypeMismatch (found, tpe), tree.pos)
121
+ tree
85
122
}
86
- val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)))
87
- val anonCls = AnonClass (tpt.tpe :: Nil , List (applyFn, isDefinedAtFn), List (nme.apply, nme.isDefinedAt))
88
- cpy.Block (tree)(List (applyDef, isDefinedAtDef), anonCls)
89
123
}
90
124
91
125
private def checkRefinements (tpe : Type , pos : Position )(implicit ctx : Context ): Type = tpe.dealias match {
92
126
case RefinedType (parent, name, _) =>
93
127
if (name.isTermName && tpe.member(name).symbol.ownersIterator.isEmpty) // if member defined in the refinement
94
128
ctx.error(" Lambda does not define " + name, pos)
95
129
checkRefinements(parent, pos)
96
- case _ =>
130
+ case tpe =>
97
131
tpe
98
132
}
99
133
0 commit comments