|
1 | 1 | package dotty.tools.dotc
|
2 | 2 | package transform
|
3 | 3 |
|
4 |
| -import TreeTransforms.{ MiniPhaseTransform, TransformerInfo } |
5 | 4 | import ast.Trees._, ast.tpd, core._
|
6 | 5 | import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
|
7 | 6 | import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
|
| 7 | +import MegaPhase.MiniPhase |
8 | 8 |
|
9 | 9 | import scala.collection.mutable
|
10 | 10 |
|
11 | 11 | /** Specializes classes that inherit from `FunctionN` where there exists a
|
12 | 12 | * specialized form.
|
13 | 13 | */
|
14 |
| -class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer { |
| 14 | +class SpecializeFunctions extends MiniPhase with InfoTransformer { |
15 | 15 | import ast.tpd._
|
16 | 16 | val phaseName = "specializeFunctions"
|
| 17 | + override def runsAfter = Set(classOf[ElimByName]) |
17 | 18 |
|
18 |
| - private[this] var _blacklistedSymbols: List[Symbol] = _ |
| 19 | + private val jFunction = "scala.compat.java8.JFunction".toTermName |
19 | 20 |
|
20 |
| - private def blacklistedSymbols(implicit ctx: Context): List[Symbol] = { |
21 |
| - if (_blacklistedSymbols eq null) _blacklistedSymbols = List( |
22 |
| - ctx.getClassIfDefined("scala.math.Ordering").asClass.membersNamed("Ops".toTypeName).first.symbol |
23 |
| - ) |
24 |
| - |
25 |
| - _blacklistedSymbols |
26 |
| - } |
27 |
| - |
28 |
| - /** Transforms the type to include decls for specialized applys and replace |
29 |
| - * the class parents with specialized versions. |
30 |
| - */ |
31 |
| - def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match { |
32 |
| - case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) => { |
| 21 | + /** Transforms the type to include decls for specialized applys */ |
| 22 | + override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match { |
| 23 | + case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) && derivesFromFn012(sym) => |
33 | 24 | var newApplys = Map.empty[Name, Symbol]
|
34 | 25 |
|
35 |
| - val newParents = tp.parents.mapConserve { parent => |
36 |
| - List(0, 1, 2, 3).flatMap { arity => |
37 |
| - val func = defn.FunctionClass(arity) |
38 |
| - if (!parent.derivesFrom(func)) Nil |
39 |
| - else { |
40 |
| - val typeParams = tp.typeRef.baseArgInfos(func) |
41 |
| - val interface = specInterface(typeParams) |
42 |
| - |
43 |
| - if (interface.exists) { |
44 |
| - if (tp.decls.lookup(nme.apply).exists) { |
45 |
| - val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init) |
46 |
| - newApplys = newApplys + (specializedMethodName -> interface) |
47 |
| - } |
| 26 | + var arity = 0 |
| 27 | + while (arity < 3) { |
| 28 | + val func = defn.FunctionClass(arity) |
| 29 | + if (tp.derivesFrom(func)) { |
| 30 | + val typeParams = tp.cls.typeRef.baseType(func).argInfos |
| 31 | + val isSpecializable = |
| 32 | + defn.isSpecializableFunction( |
| 33 | + sym.asClass, |
| 34 | + typeParams.init, |
| 35 | + typeParams.last |
| 36 | + ) |
48 | 37 |
|
49 |
| - if (parent.isRef(func)) List(interface.typeRef) |
50 |
| - else Nil |
51 |
| - } |
52 |
| - else Nil |
| 38 | + if (isSpecializable && tp.decls.lookup(nme.apply).exists) { |
| 39 | + val interface = specInterface(typeParams) |
| 40 | + val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init) |
| 41 | + newApplys += (specializedMethodName -> interface) |
53 | 42 | }
|
54 | 43 | }
|
55 |
| - .headOption |
56 |
| - .getOrElse(parent) |
| 44 | + arity += 1 |
57 | 45 | }
|
58 | 46 |
|
59 | 47 | def newDecls =
|
60 |
| - if (newApplys.isEmpty) tp.decls |
61 |
| - else |
62 |
| - newApplys.toList.map { case (name, interface) => |
63 |
| - ctx.newSymbol( |
64 |
| - sym, |
65 |
| - name, |
66 |
| - Flags.Override | Flags.Method, |
67 |
| - interface.info.decls.lookup(name).info |
68 |
| - ) |
69 |
| - } |
70 |
| - .foldLeft(tp.decls.cloneScope) { |
71 |
| - (scope, sym) => scope.enter(sym); scope |
72 |
| - } |
| 48 | + newApplys.toList.map { case (name, interface) => |
| 49 | + ctx.newSymbol( |
| 50 | + sym, |
| 51 | + name, |
| 52 | + Flags.Override | Flags.Method | Flags.Synthetic, |
| 53 | + interface.info.decls.lookup(name).info |
| 54 | + ) |
| 55 | + } |
| 56 | + .foldLeft(tp.decls.cloneScope) { |
| 57 | + (scope, sym) => scope.enter(sym); scope |
| 58 | + } |
73 | 59 |
|
74 |
| - tp.derivedClassInfo( |
75 |
| - classParents = newParents, |
76 |
| - decls = newDecls |
77 |
| - ) |
78 |
| - } |
| 60 | + if (newApplys.isEmpty) tp |
| 61 | + else tp.derivedClassInfo(decls = newDecls) |
79 | 62 |
|
80 | 63 | case _ => tp
|
81 | 64 | }
|
82 | 65 |
|
83 | 66 | /** Transforms the `Template` of the classes to contain forwarders from the
|
84 |
| - * generic applys to the specialized ones. Also replaces parents of the |
85 |
| - * class on the tree level and inserts the specialized applys in the |
86 |
| - * template body. |
| 67 | + * generic applys to the specialized ones. Also inserts the specialized applys |
| 68 | + * in the template body. |
87 | 69 | */
|
88 |
| - override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = { |
89 |
| - val applyBuf = new mutable.ListBuffer[Tree] |
90 |
| - val newBody = tree.body.mapConserve { |
91 |
| - case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => { |
92 |
| - val specName = nme.apply.specializedFunction( |
93 |
| - dt.tpe.widen.finalResultType, |
94 |
| - dt.vparamss.head.map(_.symbol.info) |
95 |
| - ) |
96 |
| - |
97 |
| - val specializedApply = tree.symbol.enclosingClass.info.decls.lookup(specName)//member(specName).symbol |
98 |
| - //val specializedApply = tree.symbol.enclosingClass.info.member(specName).symbol |
99 |
| - |
100 |
| - if (false) { |
101 |
| - println(tree.symbol.enclosingClass.show) |
102 |
| - println("'" + specName.show + "'") |
103 |
| - println(specializedApply) |
104 |
| - println(specializedApply.exists) |
105 |
| - } |
106 |
| - |
107 |
| - |
108 |
| - if (specializedApply.exists) { |
109 |
| - val apply = specializedApply.asTerm |
110 |
| - val specializedDecl = |
111 |
| - polyDefDef(apply, trefs => vrefss => { |
112 |
| - dt.rhs |
113 |
| - .changeOwner(dt.symbol, apply) |
114 |
| - .subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol)) |
| 70 | + override def transformTemplate(tree: Template)(implicit ctx: Context) = { |
| 71 | + val cls = tree.symbol.enclosingClass.asClass |
| 72 | + if (derivesFromFn012(cls)) { |
| 73 | + val applyBuf = new mutable.ListBuffer[Tree] |
| 74 | + val newBody = tree.body.mapConserve { |
| 75 | + case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => |
| 76 | + val typeParams = dt.vparamss.head.map(_.symbol.info) |
| 77 | + val retType = dt.tpe.widen.finalResultType |
| 78 | + |
| 79 | + val specName = specializedName(nme.apply, typeParams :+ retType) |
| 80 | + val specializedApply = cls.info.decls.lookup(specName) |
| 81 | + if (specializedApply.exists) { |
| 82 | + val apply = specializedApply.asTerm |
| 83 | + val specializedDecl = |
| 84 | + polyDefDef(apply, trefs => vrefss => { |
| 85 | + dt.rhs |
| 86 | + .changeOwner(dt.symbol, apply) |
| 87 | + .subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol)) |
| 88 | + }) |
| 89 | + applyBuf += specializedDecl |
| 90 | + |
| 91 | + // create a forwarding to the specialized apply |
| 92 | + cpy.DefDef(dt)(rhs = { |
| 93 | + tpd |
| 94 | + .ref(apply) |
| 95 | + .appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol))) |
115 | 96 | })
|
116 |
| - applyBuf += specializedDecl |
117 |
| - |
118 |
| - // create a forwarding to the specialized apply |
119 |
| - cpy.DefDef(dt)(rhs = { |
120 |
| - tpd |
121 |
| - .ref(apply) |
122 |
| - .appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol))) |
123 |
| - }) |
124 |
| - } else dt |
125 |
| - } |
126 |
| - case x => x |
127 |
| - } |
128 |
| - |
129 |
| - val missing: List[TypeTree] = List(0, 1, 2, 3).flatMap { arity => |
130 |
| - val func = defn.FunctionClass(arity) |
131 |
| - val tr = tree.symbol.enclosingClass.typeRef |
| 97 | + } else dt |
132 | 98 |
|
133 |
| - if (!tr.parents.exists(_.isRef(func))) Nil |
134 |
| - else { |
135 |
| - val typeParams = tr.baseArgInfos(func) |
136 |
| - val interface = specInterface(typeParams) |
137 |
| - |
138 |
| - if (interface.exists) List(interface.info) |
139 |
| - else Nil |
| 99 | + case x => x |
140 | 100 | }
|
141 |
| - }.map(TypeTree) |
142 | 101 |
|
143 |
| - cpy.Template(tree)( |
144 |
| - parents = tree.parents ++ missing, |
145 |
| - body = applyBuf.toList ++ newBody |
146 |
| - ) |
| 102 | + cpy.Template(tree)( |
| 103 | + body = applyBuf.toList ::: newBody |
| 104 | + ) |
| 105 | + } else tree |
147 | 106 | }
|
148 | 107 |
|
149 | 108 | /** Dispatch to specialized `apply`s in user code when available */
|
150 |
| - override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) = |
| 109 | + override def transformApply(tree: Apply)(implicit ctx: Context) = |
151 | 110 | tree match {
|
152 |
| - case app @ Apply(fun, args) |
| 111 | + case Apply(fun, args) |
153 | 112 | if fun.symbol.name == nme.apply &&
|
154 | 113 | fun.symbol.owner.derivesFrom(defn.FunctionClass(args.length))
|
155 |
| - => { |
| 114 | + => |
156 | 115 | val params = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
|
157 |
| - val specializedApply = specializedName(nme.apply, params) |
158 |
| - |
159 |
| - if (!params.exists(_.isInstanceOf[ExprType]) && fun.symbol.owner.info.decls.lookup(specializedApply).exists) { |
| 116 | + val isSpecializable = |
| 117 | + defn.isSpecializableFunction( |
| 118 | + fun.symbol.owner.asClass, |
| 119 | + params.init, |
| 120 | + params.last) |
| 121 | + |
| 122 | + if (isSpecializable && !params.exists(_.isInstanceOf[ExprType])) { |
| 123 | + val specializedApply = specializedName(nme.apply, params) |
160 | 124 | val newSel = fun match {
|
161 | 125 | case Select(qual, _) =>
|
162 | 126 | qual.select(specializedApply)
|
163 |
| - case _ => { |
| 127 | + case _ => |
164 | 128 | (fun.tpe: @unchecked) match {
|
165 | 129 | case TermRef(prefix: ThisType, name) =>
|
166 | 130 | tpd.This(prefix.cls).select(specializedApply)
|
167 | 131 | case TermRef(prefix: NamedType, name) =>
|
168 | 132 | tpd.ref(prefix).select(specializedApply)
|
169 | 133 | }
|
170 |
| - } |
171 | 134 | }
|
172 | 135 |
|
173 | 136 | newSel.appliedToArgs(args)
|
174 | 137 | }
|
175 | 138 | else tree
|
176 |
| - } |
| 139 | + |
177 | 140 | case _ => tree
|
178 | 141 | }
|
179 | 142 |
|
180 |
| - @inline private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) = |
181 |
| - name.specializedFor(args, args.map(_.typeSymbol.name), Nil, Nil) |
| 143 | + private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) = |
| 144 | + name.specializedFunction(args.last, args.init) |
182 | 145 |
|
183 |
| - @inline private def specInterface(typeParams: List[Type])(implicit ctx: Context) = { |
184 |
| - val specName = |
185 |
| - ("JFunction" + (typeParams.length - 1)).toTermName |
186 |
| - .specializedFunction(typeParams.last, typeParams.init) |
| 146 | + private def functionName(typeParams: List[Type])(implicit ctx: Context) = |
| 147 | + jFunction ++ (typeParams.length - 1).toString |
187 | 148 |
|
188 |
| - ctx.getClassIfDefined("scala.compat.java8.".toTermName ++ specName) |
189 |
| - } |
| 149 | + private def specInterface(typeParams: List[Type])(implicit ctx: Context) = |
| 150 | + ctx.getClassIfDefined(functionName(typeParams).specializedFunction(typeParams.last, typeParams.init)) |
| 151 | + |
| 152 | + private def derivesFromFn012(sym: Symbol)(implicit ctx: Context): Boolean = |
| 153 | + sym.derivesFrom(defn.FunctionClass(0)) || |
| 154 | + sym.derivesFrom(defn.FunctionClass(1)) || |
| 155 | + sym.derivesFrom(defn.FunctionClass(2)) |
190 | 156 | }
|
0 commit comments