Skip to content

Commit 378acae

Browse files
Duhemmliufengyun
authored andcommitted
Second version of function specialization
Fix compilation errors Don't change parents in `specializeFunctions` Cleanup Move type to JVM tag conversion to Definitions Add more tests for function specialization Pass by name should not introduce boxing Cleanup No specialization for Function3 Adapt to recent changes in transformers and phases Address review comments Optimize phase `SpecializeFunctions` - Stop immediately if the type doesn't derive from `Function{0,1,2}` - We don't need to check if any of the parents derives from `Function{0,1,2}`, we can just check if the type derives from it.
1 parent caca442 commit 378acae

File tree

5 files changed

+205
-170
lines changed

5 files changed

+205
-170
lines changed

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -274,30 +274,12 @@ object NameOps {
274274
case nme.clone_ => nme.clone_
275275
}
276276

277-
<<<<<<< HEAD
278-
def specializedFor(classTargs: List[Type], classTargsNames: List[Name], methodTargs: List[Type], methodTarsNames: List[Name])(using Context): N = {
279-
=======
280-
private def typeToTag(tp: Types.Type)(implicit ctx: Context): Name =
281-
tp.classSymbol match {
282-
case t if t eq defn.IntClass => nme.specializedTypeNames.Int
283-
case t if t eq defn.BooleanClass => nme.specializedTypeNames.Boolean
284-
case t if t eq defn.ByteClass => nme.specializedTypeNames.Byte
285-
case t if t eq defn.LongClass => nme.specializedTypeNames.Long
286-
case t if t eq defn.ShortClass => nme.specializedTypeNames.Short
287-
case t if t eq defn.FloatClass => nme.specializedTypeNames.Float
288-
case t if t eq defn.UnitClass => nme.specializedTypeNames.Void
289-
case t if t eq defn.DoubleClass => nme.specializedTypeNames.Double
290-
case t if t eq defn.CharClass => nme.specializedTypeNames.Char
291-
case _ => nme.specializedTypeNames.Object
292-
}
293-
>>>>>>> Fix ordering of specialized names and type parameterized apply
294-
295277
/** This method is to be used on **type parameters** from a class, since
296278
* this method does sorting based on their names
297279
*/
298-
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): name.ThisName = {
299-
val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => typeToTag(x._1))
300-
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => typeToTag(x._1))
280+
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): Name = {
281+
val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
282+
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
301283

302284
likeSpacedN(name ++ nme.specializedTypeNames.prefix ++
303285
methodTags.fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.separator ++
@@ -310,10 +292,10 @@ object NameOps {
310292
*
311293
* `<return type><first type><second type><...>`
312294
*/
313-
def specializedFunction(ret: Types.Type, args: List[Types.Type])(implicit ctx: Context): name.ThisName =
295+
def specializedFunction(ret: Types.Type, args: List[Types.Type])(implicit ctx: Context): Name =
314296
name ++ nme.specializedTypeNames.prefix ++
315-
nme.specializedTypeNames.separator ++ typeToTag(ret) ++
316-
args.map(typeToTag).fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix
297+
nme.specializedTypeNames.separator ++ defn.typeTag(ret) ++
298+
args.map(defn.typeTag).fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix
317299

318300
/** If name length exceeds allowable limit, replace part of it by hash */
319301
def compactified(using Context): TermName = termName(compactify(name.toString))

compiler/src/dotty/tools/dotc/core/Names.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ object Names {
3232
* in a name table. A derived term name adds a tag, and possibly a number
3333
* or a further simple name to some other name.
3434
*/
35-
abstract class Name extends Designator, Showable derives Eql { self =>
35+
abstract class Name extends Designator with PreName {
3636

3737
/** A type for names of the same kind as this name */
38-
type ThisName <: Name { type ThisName = self.ThisName }
38+
type ThisName <: Name
3939

4040
/** Is this name a type name? */
4141
def isTypeName: Boolean
Lines changed: 94 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,190 +1,156 @@
11
package dotty.tools.dotc
22
package transform
33

4-
import TreeTransforms.{ MiniPhaseTransform, TransformerInfo }
54
import ast.Trees._, ast.tpd, core._
65
import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
76
import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
7+
import MegaPhase.MiniPhase
88

99
import scala.collection.mutable
1010

1111
/** Specializes classes that inherit from `FunctionN` where there exists a
1212
* specialized form.
1313
*/
14-
class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
14+
class SpecializeFunctions extends MiniPhase with InfoTransformer {
1515
import ast.tpd._
1616
val phaseName = "specializeFunctions"
17+
override def runsAfter = Set(classOf[ElimByName])
1718

18-
private[this] var _blacklistedSymbols: List[Symbol] = _
19+
private val jFunction = "scala.compat.java8.JFunction".toTermName
1920

20-
private def blacklistedSymbols(implicit ctx: Context): List[Symbol] = {
21-
if (_blacklistedSymbols eq null) _blacklistedSymbols = List(
22-
ctx.getClassIfDefined("scala.math.Ordering").asClass.membersNamed("Ops".toTypeName).first.symbol
23-
)
24-
25-
_blacklistedSymbols
26-
}
27-
28-
/** Transforms the type to include decls for specialized applys and replace
29-
* the class parents with specialized versions.
30-
*/
31-
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
32-
case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) => {
21+
/** Transforms the type to include decls for specialized applys */
22+
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
23+
case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) && derivesFromFn012(sym) =>
3324
var newApplys = Map.empty[Name, Symbol]
3425

35-
val newParents = tp.parents.mapConserve { parent =>
36-
List(0, 1, 2, 3).flatMap { arity =>
37-
val func = defn.FunctionClass(arity)
38-
if (!parent.derivesFrom(func)) Nil
39-
else {
40-
val typeParams = tp.typeRef.baseArgInfos(func)
41-
val interface = specInterface(typeParams)
42-
43-
if (interface.exists) {
44-
if (tp.decls.lookup(nme.apply).exists) {
45-
val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init)
46-
newApplys = newApplys + (specializedMethodName -> interface)
47-
}
26+
var arity = 0
27+
while (arity < 3) {
28+
val func = defn.FunctionClass(arity)
29+
if (tp.derivesFrom(func)) {
30+
val typeParams = tp.cls.typeRef.baseType(func).argInfos
31+
val isSpecializable =
32+
defn.isSpecializableFunction(
33+
sym.asClass,
34+
typeParams.init,
35+
typeParams.last
36+
)
4837

49-
if (parent.isRef(func)) List(interface.typeRef)
50-
else Nil
51-
}
52-
else Nil
38+
if (isSpecializable && tp.decls.lookup(nme.apply).exists) {
39+
val interface = specInterface(typeParams)
40+
val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init)
41+
newApplys += (specializedMethodName -> interface)
5342
}
5443
}
55-
.headOption
56-
.getOrElse(parent)
44+
arity += 1
5745
}
5846

5947
def newDecls =
60-
if (newApplys.isEmpty) tp.decls
61-
else
62-
newApplys.toList.map { case (name, interface) =>
63-
ctx.newSymbol(
64-
sym,
65-
name,
66-
Flags.Override | Flags.Method,
67-
interface.info.decls.lookup(name).info
68-
)
69-
}
70-
.foldLeft(tp.decls.cloneScope) {
71-
(scope, sym) => scope.enter(sym); scope
72-
}
48+
newApplys.toList.map { case (name, interface) =>
49+
ctx.newSymbol(
50+
sym,
51+
name,
52+
Flags.Override | Flags.Method | Flags.Synthetic,
53+
interface.info.decls.lookup(name).info
54+
)
55+
}
56+
.foldLeft(tp.decls.cloneScope) {
57+
(scope, sym) => scope.enter(sym); scope
58+
}
7359

74-
tp.derivedClassInfo(
75-
classParents = newParents,
76-
decls = newDecls
77-
)
78-
}
60+
if (newApplys.isEmpty) tp
61+
else tp.derivedClassInfo(decls = newDecls)
7962

8063
case _ => tp
8164
}
8265

8366
/** Transforms the `Template` of the classes to contain forwarders from the
84-
* generic applys to the specialized ones. Also replaces parents of the
85-
* class on the tree level and inserts the specialized applys in the
86-
* template body.
67+
* generic applys to the specialized ones. Also inserts the specialized applys
68+
* in the template body.
8769
*/
88-
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = {
89-
val applyBuf = new mutable.ListBuffer[Tree]
90-
val newBody = tree.body.mapConserve {
91-
case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => {
92-
val specName = nme.apply.specializedFunction(
93-
dt.tpe.widen.finalResultType,
94-
dt.vparamss.head.map(_.symbol.info)
95-
)
96-
97-
val specializedApply = tree.symbol.enclosingClass.info.decls.lookup(specName)//member(specName).symbol
98-
//val specializedApply = tree.symbol.enclosingClass.info.member(specName).symbol
99-
100-
if (false) {
101-
println(tree.symbol.enclosingClass.show)
102-
println("'" + specName.show + "'")
103-
println(specializedApply)
104-
println(specializedApply.exists)
105-
}
106-
107-
108-
if (specializedApply.exists) {
109-
val apply = specializedApply.asTerm
110-
val specializedDecl =
111-
polyDefDef(apply, trefs => vrefss => {
112-
dt.rhs
113-
.changeOwner(dt.symbol, apply)
114-
.subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
70+
override def transformTemplate(tree: Template)(implicit ctx: Context) = {
71+
val cls = tree.symbol.enclosingClass.asClass
72+
if (derivesFromFn012(cls)) {
73+
val applyBuf = new mutable.ListBuffer[Tree]
74+
val newBody = tree.body.mapConserve {
75+
case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 =>
76+
val typeParams = dt.vparamss.head.map(_.symbol.info)
77+
val retType = dt.tpe.widen.finalResultType
78+
79+
val specName = specializedName(nme.apply, typeParams :+ retType)
80+
val specializedApply = cls.info.decls.lookup(specName)
81+
if (specializedApply.exists) {
82+
val apply = specializedApply.asTerm
83+
val specializedDecl =
84+
polyDefDef(apply, trefs => vrefss => {
85+
dt.rhs
86+
.changeOwner(dt.symbol, apply)
87+
.subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
88+
})
89+
applyBuf += specializedDecl
90+
91+
// create a forwarding to the specialized apply
92+
cpy.DefDef(dt)(rhs = {
93+
tpd
94+
.ref(apply)
95+
.appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
11596
})
116-
applyBuf += specializedDecl
117-
118-
// create a forwarding to the specialized apply
119-
cpy.DefDef(dt)(rhs = {
120-
tpd
121-
.ref(apply)
122-
.appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
123-
})
124-
} else dt
125-
}
126-
case x => x
127-
}
128-
129-
val missing: List[TypeTree] = List(0, 1, 2, 3).flatMap { arity =>
130-
val func = defn.FunctionClass(arity)
131-
val tr = tree.symbol.enclosingClass.typeRef
97+
} else dt
13298

133-
if (!tr.parents.exists(_.isRef(func))) Nil
134-
else {
135-
val typeParams = tr.baseArgInfos(func)
136-
val interface = specInterface(typeParams)
137-
138-
if (interface.exists) List(interface.info)
139-
else Nil
99+
case x => x
140100
}
141-
}.map(TypeTree)
142101

143-
cpy.Template(tree)(
144-
parents = tree.parents ++ missing,
145-
body = applyBuf.toList ++ newBody
146-
)
102+
cpy.Template(tree)(
103+
body = applyBuf.toList ::: newBody
104+
)
105+
} else tree
147106
}
148107

149108
/** Dispatch to specialized `apply`s in user code when available */
150-
override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) =
109+
override def transformApply(tree: Apply)(implicit ctx: Context) =
151110
tree match {
152-
case app @ Apply(fun, args)
111+
case Apply(fun, args)
153112
if fun.symbol.name == nme.apply &&
154113
fun.symbol.owner.derivesFrom(defn.FunctionClass(args.length))
155-
=> {
114+
=>
156115
val params = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
157-
val specializedApply = specializedName(nme.apply, params)
158-
159-
if (!params.exists(_.isInstanceOf[ExprType]) && fun.symbol.owner.info.decls.lookup(specializedApply).exists) {
116+
val isSpecializable =
117+
defn.isSpecializableFunction(
118+
fun.symbol.owner.asClass,
119+
params.init,
120+
params.last)
121+
122+
if (isSpecializable && !params.exists(_.isInstanceOf[ExprType])) {
123+
val specializedApply = specializedName(nme.apply, params)
160124
val newSel = fun match {
161125
case Select(qual, _) =>
162126
qual.select(specializedApply)
163-
case _ => {
127+
case _ =>
164128
(fun.tpe: @unchecked) match {
165129
case TermRef(prefix: ThisType, name) =>
166130
tpd.This(prefix.cls).select(specializedApply)
167131
case TermRef(prefix: NamedType, name) =>
168132
tpd.ref(prefix).select(specializedApply)
169133
}
170-
}
171134
}
172135

173136
newSel.appliedToArgs(args)
174137
}
175138
else tree
176-
}
139+
177140
case _ => tree
178141
}
179142

180-
@inline private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) =
181-
name.specializedFor(args, args.map(_.typeSymbol.name), Nil, Nil)
143+
private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) =
144+
name.specializedFunction(args.last, args.init)
182145

183-
@inline private def specInterface(typeParams: List[Type])(implicit ctx: Context) = {
184-
val specName =
185-
("JFunction" + (typeParams.length - 1)).toTermName
186-
.specializedFunction(typeParams.last, typeParams.init)
146+
private def functionName(typeParams: List[Type])(implicit ctx: Context) =
147+
jFunction ++ (typeParams.length - 1).toString
187148

188-
ctx.getClassIfDefined("scala.compat.java8.".toTermName ++ specName)
189-
}
149+
private def specInterface(typeParams: List[Type])(implicit ctx: Context) =
150+
ctx.getClassIfDefined(functionName(typeParams).specializedFunction(typeParams.last, typeParams.init))
151+
152+
private def derivesFromFn012(sym: Symbol)(implicit ctx: Context): Boolean =
153+
sym.derivesFrom(defn.FunctionClass(0)) ||
154+
sym.derivesFrom(defn.FunctionClass(1)) ||
155+
sym.derivesFrom(defn.FunctionClass(2))
190156
}

0 commit comments

Comments
 (0)