@@ -16,49 +16,45 @@ class SpecializeFunctions extends MiniPhase with InfoTransformer {
16
16
val phaseName = " specializeFunctions"
17
17
override def runsAfter = Set (ElimByName .name)
18
18
19
- private val jFunction = " scala.compat.java8.JFunction" .toTermName
20
-
21
19
/** Transforms the type to include decls for specialized applys */
22
20
override def transformInfo (tp : Type , sym : Symbol )(using Context ) = tp match {
23
- case tp : ClassInfo if ! sym.is(Flags .Package ) && (tp.decls ne EmptyScope ) && derivesFromFn012(sym) =>
24
- var newApplys = Map .empty[ Name , Symbol ]
21
+ case tp : ClassInfo if ! sym.is(Flags .Package ) && derivesFromFn012(sym) =>
22
+ val newApplys = new mutable. ListBuffer [ Symbol ]
25
23
26
24
var arity = 0
27
25
while (arity < 3 ) {
28
26
val func = defn.FunctionClass (arity)
29
27
if (tp.derivesFrom(func)) {
30
- val typeParams = tp.cls.typeRef.baseType(func).argInfos
28
+ val paramTypes = tp.cls.typeRef.baseType(func).argInfos
31
29
val isSpecializable =
32
30
defn.isSpecializableFunction(
33
31
sym.asClass,
34
- typeParams.init,
35
- typeParams.last
32
+ paramTypes.init,
33
+ paramTypes.last
34
+ )
35
+
36
+ val apply = tp.decls.lookup(nme.apply)
37
+ if (isSpecializable && apply.exists) {
38
+ val specializedMethodName = specializedName(nme.apply, paramTypes)
39
+ val applySpecialized = newSymbol(
40
+ sym,
41
+ specializedMethodName,
42
+ Flags .Override | Flags .Method | Flags .Synthetic ,
43
+ apply.info
36
44
)
37
45
38
- if (isSpecializable && tp.decls.lookup(nme.apply).exists) {
39
- val interface = specInterface(typeParams)
40
- val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init)
41
- newApplys += (specializedMethodName -> interface)
46
+ newApplys += applySpecialized
42
47
}
43
48
}
44
49
arity += 1
45
50
}
46
51
47
- def newDecls =
48
- newApplys.toList.map { case (name, interface) =>
49
- newSymbol(
50
- sym,
51
- name,
52
- Flags .Override | Flags .Method | Flags .Synthetic ,
53
- interface.info.decls.lookup(name).info
54
- )
55
- }
56
- .foldLeft(tp.decls.cloneScope) {
57
- (scope, sym) => scope.enter(sym); scope
58
- }
59
-
60
52
if (newApplys.isEmpty) tp
61
- else tp.derivedClassInfo(decls = newDecls)
53
+ else {
54
+ val scope = tp.decls.cloneScope
55
+ newApplys.toList.foreach { sym => scope.enter(sym) }
56
+ tp.derivedClassInfo(decls = scope)
57
+ }
62
58
63
59
case _ => tp
64
60
}
@@ -69,58 +65,55 @@ class SpecializeFunctions extends MiniPhase with InfoTransformer {
69
65
*/
70
66
override def transformTemplate (tree : Template )(using Context ) = {
71
67
val cls = tree.symbol.enclosingClass.asClass
72
- if (derivesFromFn012(cls)) {
73
- val applyBuf = new mutable.ListBuffer [Tree ]
74
- val newBody = tree.body.mapConserve {
75
- case dt : DefDef if dt.name == nme.apply && dt.vparamss.length == 1 =>
76
- val typeParams = dt.vparamss.head.map(_.symbol.info)
77
- val retType = dt.tpe.widen.finalResultType
78
-
79
- val specName = specializedName(nme.apply, typeParams :+ retType)
80
- val specializedApply = cls.info.decls.lookup(specName)
81
- if (specializedApply.exists) {
82
- val apply = specializedApply.asTerm
83
- val specializedDecl =
84
- polyDefDef(apply, trefs => vrefss => {
85
- dt.rhs
86
- .changeOwner(dt.symbol, apply)
87
- .subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
88
- })
89
- applyBuf += specializedDecl
90
-
91
- // create a forwarding to the specialized apply
92
- cpy.DefDef (dt)(rhs = {
93
- tpd
94
- .ref(apply)
95
- .appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
68
+
69
+ if (! derivesFromFn012(cls)) return tree
70
+
71
+ val applyBuf = new mutable.ListBuffer [Tree ]
72
+ val newBody = tree.body.mapConserve {
73
+ case ddef : DefDef if ddef.name == nme.apply && ddef.vparamss.length == 1 =>
74
+ val paramTypes = ddef.vparamss.head.map(_.symbol.info)
75
+ val retType = ddef.tpe.widen.finalResultType
76
+
77
+ val specName = specializedName(nme.apply, paramTypes :+ retType)
78
+ val specializedApply = cls.info.decls.lookup(specName)
79
+ if (specializedApply.exists) {
80
+ val specializedDecl =
81
+ DefDef (specializedApply.asTerm, vparamss => {
82
+ ddef.rhs
83
+ .changeOwner(ddef.symbol, specializedApply)
84
+ .subst(ddef.vparamss.head.map(_.symbol), vparamss.head.map(_.symbol))
96
85
})
97
- } else dt
86
+ applyBuf += specializedDecl
98
87
99
- case x => x
100
- }
88
+ // create a forwarding to the specialized apply
89
+ val args = ddef.vparamss.head.map(vparam => ref(vparam.symbol))
90
+ val rhs = This (cls).select(specializedApply).appliedToArgs(args)
91
+ cpy.DefDef (ddef)(rhs = rhs)
92
+ } else ddef
93
+
94
+ case x => x
95
+ }
101
96
102
- cpy.Template (tree)(
103
- body = applyBuf.toList ::: newBody
104
- )
105
- } else tree
97
+ cpy.Template (tree)(body = applyBuf.toList ::: newBody)
106
98
}
107
99
108
100
/** Dispatch to specialized `apply`s in user code when available */
109
101
override def transformApply (tree : Apply )(using Context ) =
110
102
tree match {
111
103
case Apply (fun, args)
112
104
if fun.symbol.name == nme.apply &&
113
- fun.symbol.owner.derivesFrom(defn. FunctionClass (args.length) )
105
+ derivesFromFn012( fun.symbol.owner)
114
106
=>
115
- val params = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
107
+ val paramTypes = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
116
108
val isSpecializable =
117
109
defn.isSpecializableFunction(
118
110
fun.symbol.owner.asClass,
119
- params.init,
120
- params.last)
111
+ paramTypes.init,
112
+ paramTypes.last
113
+ )
121
114
122
- if (isSpecializable && ! params .exists(_.isInstanceOf [ExprType ])) {
123
- val specializedApply = specializedName(nme.apply, params )
115
+ if (isSpecializable && ! paramTypes .exists(_.isInstanceOf [ExprType ])) {
116
+ val specializedApply = specializedName(nme.apply, paramTypes )
124
117
val newSel = fun match {
125
118
case Select (qual, _) =>
126
119
qual.select(specializedApply)
@@ -143,12 +136,6 @@ class SpecializeFunctions extends MiniPhase with InfoTransformer {
143
136
private def specializedName (name : Name , args : List [Type ])(using Context ) =
144
137
name.specializedFunction(args.last, args.init)
145
138
146
- private def functionName (typeParams : List [Type ])(using Context ) =
147
- jFunction ++ (typeParams.length - 1 ).toString
148
-
149
- private def specInterface (typeParams : List [Type ])(using Context ) =
150
- getClassIfDefined(functionName(typeParams).specializedFunction(typeParams.last, typeParams.init))
151
-
152
139
private def derivesFromFn012 (sym : Symbol )(using Context ): Boolean =
153
140
sym.derivesFrom(defn.FunctionClass (0 )) ||
154
141
sym.derivesFrom(defn.FunctionClass (1 )) ||
0 commit comments