1
- package dotty .tools .dotc
1
+ package dotty .tools
2
+ package dotc
2
3
package transform
3
4
4
5
import core ._
@@ -7,6 +8,7 @@ import MegaPhase._
7
8
import SymUtils ._
8
9
import NullOpsDecorator ._
9
10
import ast .Trees ._
11
+ import ast .untpd
10
12
import reporting ._
11
13
import dotty .tools .dotc .util .Spans .Span
12
14
@@ -113,68 +115,72 @@ class ExpandSAMs extends MiniPhase:
113
115
}
114
116
115
117
val closureDef(anon @ DefDef (_, List (List (param)), _, _)) = tree
116
- anon.rhs match {
117
- case PartialFunctionRHS (pf) =>
118
- val anonSym = anon.symbol
119
- val anonTpe = anon.tpe.widen
120
- val parents = List (
121
- defn.AbstractPartialFunctionClass .typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
122
- defn.SerializableType )
123
- val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS , Synthetic | Final , parents, coord = tree.span)
124
-
125
- def overrideSym (sym : Symbol ) = sym.copy(
126
- owner = pfSym,
127
- flags = Synthetic | Method | Final | Override ,
128
- info = tpe.memberInfo(sym),
129
- coord = tree.span).asTerm.entered
130
- val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt )
131
- val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse )
132
-
133
- def translateMatch (tree : Match , pfParam : Symbol , cases : List [CaseDef ], defaultValue : Tree )(using Context ) = {
134
- val selector = tree.selector
135
- val selectorTpe = selector.tpe.widen
136
- val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD , Synthetic | Case , selectorTpe)
137
- val defaultCase =
138
- CaseDef (
139
- Bind (defaultSym, Underscore (selectorTpe)),
140
- EmptyTree ,
141
- defaultValue)
142
- val unchecked = selector.annotated(New (ref(defn.UncheckedAnnot .typeRef)))
143
- cpy.Match (tree)(unchecked, cases :+ defaultCase)
144
- .subst(param.symbol :: Nil , pfParam :: Nil )
145
- // Needed because a partial function can be written as:
146
- // param => param match { case "foo" if foo(param) => param }
147
- // And we need to update all references to 'param'
148
- }
149
-
150
- def isDefinedAtRhs (paramRefss : List [List [Tree ]])(using Context ) = {
151
- val tru = Literal (Constant (true ))
152
- def translateCase (cdef : CaseDef ) =
153
- cpy.CaseDef (cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
154
- val paramRef = paramRefss.head.head
155
- val defaultValue = Literal (Constant (false ))
156
- translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
157
- }
158
-
159
- def applyOrElseRhs (paramRefss : List [List [Tree ]])(using Context ) = {
160
- val List (paramRef, defaultRef) = paramRefss(1 )
161
- def translateCase (cdef : CaseDef ) =
162
- cdef.changeOwner(anonSym, applyOrElseFn)
163
- val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
164
- translateMatch(pf, paramRef.symbol, pf.cases.map(translateCase), defaultValue)
165
- }
166
-
167
- val constr = newConstructor(pfSym, Synthetic , Nil , Nil ).entered
168
- val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
169
- val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
170
- val pfDef = ClassDef (pfSym, DefDef (constr), List (isDefinedAtDef, applyOrElseDef))
171
- cpy.Block (tree)(pfDef :: Nil , New (pfSym.typeRef, Nil ))
172
-
118
+
119
+ // The right hand side from which to construct the partial function. This is always a Match.
120
+ // If the original rhs is already a Match (possibly in braces), return that.
121
+ // Otherwise construct a match `x match case _ => rhs` where `x` is the parameter of the closure.
122
+ def partialFunRHS (tree : Tree ): Match = tree match
123
+ case m : Match => m
124
+ case Block (Nil , expr) => partialFunRHS(expr)
173
125
case _ =>
174
- val found = tpe.baseType(defn.Function1 )
175
- report.error(TypeMismatch (found, tpe), tree.srcPos)
176
- tree
126
+ Match (ref(param.symbol),
127
+ CaseDef (untpd.Ident (nme.WILDCARD ).withType(param.symbol.info), EmptyTree , tree) :: Nil )
128
+
129
+ val pfRHS = partialFunRHS(anon.rhs)
130
+ val anonSym = anon.symbol
131
+ val anonTpe = anon.tpe.widen
132
+ val parents = List (
133
+ defn.AbstractPartialFunctionClass .typeRef.appliedTo(anonTpe.firstParamTypes.head, anonTpe.resultType),
134
+ defn.SerializableType )
135
+ val pfSym = newNormalizedClassSymbol(anonSym.owner, tpnme.ANON_CLASS , Synthetic | Final , parents, coord = tree.span)
136
+
137
+ def overrideSym (sym : Symbol ) = sym.copy(
138
+ owner = pfSym,
139
+ flags = Synthetic | Method | Final | Override ,
140
+ info = tpe.memberInfo(sym),
141
+ coord = tree.span).asTerm.entered
142
+ val isDefinedAtFn = overrideSym(defn.PartialFunction_isDefinedAt )
143
+ val applyOrElseFn = overrideSym(defn.PartialFunction_applyOrElse )
144
+
145
+ def translateMatch (tree : Match , pfParam : Symbol , cases : List [CaseDef ], defaultValue : Tree )(using Context ) = {
146
+ val selector = tree.selector
147
+ val selectorTpe = selector.tpe.widen
148
+ val defaultSym = newSymbol(pfParam.owner, nme.WILDCARD , Synthetic | Case , selectorTpe)
149
+ val defaultCase =
150
+ CaseDef (
151
+ Bind (defaultSym, Underscore (selectorTpe)),
152
+ EmptyTree ,
153
+ defaultValue)
154
+ val unchecked = selector.annotated(New (ref(defn.UncheckedAnnot .typeRef)))
155
+ cpy.Match (tree)(unchecked, cases :+ defaultCase)
156
+ .subst(param.symbol :: Nil , pfParam :: Nil )
157
+ // Needed because a partial function can be written as:
158
+ // param => param match { case "foo" if foo(param) => param }
159
+ // And we need to update all references to 'param'
160
+ }
161
+
162
+ def isDefinedAtRhs (paramRefss : List [List [Tree ]])(using Context ) = {
163
+ val tru = Literal (Constant (true ))
164
+ def translateCase (cdef : CaseDef ) =
165
+ cpy.CaseDef (cdef)(body = tru).changeOwner(anonSym, isDefinedAtFn)
166
+ val paramRef = paramRefss.head.head
167
+ val defaultValue = Literal (Constant (false ))
168
+ translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
177
169
}
170
+
171
+ def applyOrElseRhs (paramRefss : List [List [Tree ]])(using Context ) = {
172
+ val List (paramRef, defaultRef) = paramRefss(1 )
173
+ def translateCase (cdef : CaseDef ) =
174
+ cdef.changeOwner(anonSym, applyOrElseFn)
175
+ val defaultValue = defaultRef.select(nme.apply).appliedTo(paramRef)
176
+ translateMatch(pfRHS, paramRef.symbol, pfRHS.cases.map(translateCase), defaultValue)
177
+ }
178
+
179
+ val constr = newConstructor(pfSym, Synthetic , Nil , Nil ).entered
180
+ val isDefinedAtDef = transformFollowingDeep(DefDef (isDefinedAtFn, isDefinedAtRhs(_)(using ctx.withOwner(isDefinedAtFn))))
181
+ val applyOrElseDef = transformFollowingDeep(DefDef (applyOrElseFn, applyOrElseRhs(_)(using ctx.withOwner(applyOrElseFn))))
182
+ val pfDef = ClassDef (pfSym, DefDef (constr), List (isDefinedAtDef, applyOrElseDef))
183
+ cpy.Block (tree)(pfDef :: Nil , New (pfSym.typeRef, Nil ))
178
184
}
179
185
180
186
private def checkRefinements (tpe : Type , tree : Tree )(using Context ): Type = tpe.dealias match {
0 commit comments