Skip to content

Commit 9b43b28

Browse files
Duhemmliufengyun
authored andcommitted
Second version of function specialization
Fix compilation errors Don't change parents in `specializeFunctions` Cleanup Move type to JVM tag conversion to Definitions Add more tests for function specialization Pass by name should not introduce boxing Cleanup No specialization for Function3 Adapt to recent changes in transformers and phases Address review comments Optimize phase `SpecializeFunctions` - Stop immediately if the type doesn't derive from `Function{0,1,2}` - We don't need to check if any of the parents derives from `Function{0,1,2}`, we can just check if the type derives from it.
1 parent 05fede9 commit 9b43b28

File tree

5 files changed

+204
-151
lines changed

5 files changed

+204
-151
lines changed

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,9 @@ object NameOps {
279279
/** This method is to be used on **type parameters** from a class, since
280280
* this method does sorting based on their names
281281
*/
282-
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): name.ThisName = {
283-
val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => typeToTag(x._1))
284-
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => typeToTag(x._1))
282+
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): Name = {
283+
val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
284+
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
285285

286286
likeSpacedN(name ++ nme.specializedTypeNames.prefix ++
287287
methodTags.fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.separator ++
@@ -294,10 +294,10 @@ object NameOps {
294294
*
295295
* `<return type><first type><second type><...>`
296296
*/
297-
def specializedFunction(ret: Types.Type, args: List[Types.Type])(implicit ctx: Context): name.ThisName =
297+
def specializedFunction(ret: Types.Type, args: List[Types.Type])(implicit ctx: Context): Name =
298298
name ++ nme.specializedTypeNames.prefix ++
299-
nme.specializedTypeNames.separator ++ typeToTag(ret) ++
300-
args.map(typeToTag).fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix
299+
nme.specializedTypeNames.separator ++ defn.typeTag(ret) ++
300+
args.map(defn.typeTag).fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix
301301

302302
/** If name length exceeds allowable limit, replace part of it by hash */
303303
def compactified(using Context): TermName = termName(compactify(name.toString))

compiler/src/dotty/tools/dotc/core/Names.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ object Names {
3535
abstract class Name extends Designator, Showable derives CanEqual {
3636

3737
/** A type for names of the same kind as this name */
38-
type ThisName <: Name { type ThisName = self.ThisName }
38+
type ThisName <: Name
3939

4040
/** Is this name a type name? */
4141
def isTypeName: Boolean
Lines changed: 94 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -1,190 +1,156 @@
11
package dotty.tools.dotc
22
package transform
33

4-
import TreeTransforms.{ MiniPhaseTransform, TransformerInfo }
54
import ast.Trees._, ast.tpd, core._
65
import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
76
import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
7+
import MegaPhase.MiniPhase
88

99
import scala.collection.mutable
1010

1111
/** Specializes classes that inherit from `FunctionN` where there exists a
1212
* specialized form.
1313
*/
14-
class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
14+
class SpecializeFunctions extends MiniPhase with InfoTransformer {
1515
import ast.tpd._
1616
val phaseName = "specializeFunctions"
17+
override def runsAfter = Set(classOf[ElimByName])
1718

18-
private[this] var _blacklistedSymbols: List[Symbol] = _
19+
private val jFunction = "scala.compat.java8.JFunction".toTermName
1920

20-
private def blacklistedSymbols(implicit ctx: Context): List[Symbol] = {
21-
if (_blacklistedSymbols eq null) _blacklistedSymbols = List(
22-
ctx.getClassIfDefined("scala.math.Ordering").asClass.membersNamed("Ops".toTypeName).first.symbol
23-
)
24-
25-
_blacklistedSymbols
26-
}
27-
28-
/** Transforms the type to include decls for specialized applys and replace
29-
* the class parents with specialized versions.
30-
*/
31-
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
32-
case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) => {
21+
/** Transforms the type to include decls for specialized applys */
22+
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
23+
case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) && derivesFromFn012(sym) =>
3324
var newApplys = Map.empty[Name, Symbol]
3425

35-
val newParents = tp.parents.mapConserve { parent =>
36-
List(0, 1, 2, 3).flatMap { arity =>
37-
val func = defn.FunctionClass(arity)
38-
if (!parent.derivesFrom(func)) Nil
39-
else {
40-
val typeParams = tp.typeRef.baseArgInfos(func)
41-
val interface = specInterface(typeParams)
42-
43-
if (interface.exists) {
44-
if (tp.decls.lookup(nme.apply).exists) {
45-
val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init)
46-
newApplys = newApplys + (specializedMethodName -> interface)
47-
}
26+
var arity = 0
27+
while (arity < 3) {
28+
val func = defn.FunctionClass(arity)
29+
if (tp.derivesFrom(func)) {
30+
val typeParams = tp.cls.typeRef.baseType(func).argInfos
31+
val isSpecializable =
32+
defn.isSpecializableFunction(
33+
sym.asClass,
34+
typeParams.init,
35+
typeParams.last
36+
)
4837

49-
if (parent.isRef(func)) List(interface.typeRef)
50-
else Nil
51-
}
52-
else Nil
38+
if (isSpecializable && tp.decls.lookup(nme.apply).exists) {
39+
val interface = specInterface(typeParams)
40+
val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init)
41+
newApplys += (specializedMethodName -> interface)
5342
}
5443
}
55-
.headOption
56-
.getOrElse(parent)
44+
arity += 1
5745
}
5846

5947
def newDecls =
60-
if (newApplys.isEmpty) tp.decls
61-
else
62-
newApplys.toList.map { case (name, interface) =>
63-
ctx.newSymbol(
64-
sym,
65-
name,
66-
Flags.Override | Flags.Method,
67-
interface.info.decls.lookup(name).info
68-
)
69-
}
70-
.foldLeft(tp.decls.cloneScope) {
71-
(scope, sym) => scope.enter(sym); scope
72-
}
48+
newApplys.toList.map { case (name, interface) =>
49+
ctx.newSymbol(
50+
sym,
51+
name,
52+
Flags.Override | Flags.Method | Flags.Synthetic,
53+
interface.info.decls.lookup(name).info
54+
)
55+
}
56+
.foldLeft(tp.decls.cloneScope) {
57+
(scope, sym) => scope.enter(sym); scope
58+
}
7359

74-
tp.derivedClassInfo(
75-
classParents = newParents,
76-
decls = newDecls
77-
)
78-
}
60+
if (newApplys.isEmpty) tp
61+
else tp.derivedClassInfo(decls = newDecls)
7962

8063
case _ => tp
8164
}
8265

8366
/** Transforms the `Template` of the classes to contain forwarders from the
84-
* generic applys to the specialized ones. Also replaces parents of the
85-
* class on the tree level and inserts the specialized applys in the
86-
* template body.
67+
* generic applys to the specialized ones. Also inserts the specialized applys
68+
* in the template body.
8769
*/
88-
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = {
89-
val applyBuf = new mutable.ListBuffer[Tree]
90-
val newBody = tree.body.mapConserve {
91-
case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => {
92-
val specName = nme.apply.specializedFunction(
93-
dt.tpe.widen.finalResultType,
94-
dt.vparamss.head.map(_.symbol.info)
95-
)
96-
97-
val specializedApply = tree.symbol.enclosingClass.info.decls.lookup(specName)//member(specName).symbol
98-
//val specializedApply = tree.symbol.enclosingClass.info.member(specName).symbol
99-
100-
if (false) {
101-
println(tree.symbol.enclosingClass.show)
102-
println("'" + specName.show + "'")
103-
println(specializedApply)
104-
println(specializedApply.exists)
105-
}
106-
107-
108-
if (specializedApply.exists) {
109-
val apply = specializedApply.asTerm
110-
val specializedDecl =
111-
polyDefDef(apply, trefs => vrefss => {
112-
dt.rhs
113-
.changeOwner(dt.symbol, apply)
114-
.subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
70+
override def transformTemplate(tree: Template)(implicit ctx: Context) = {
71+
val cls = tree.symbol.enclosingClass.asClass
72+
if (derivesFromFn012(cls)) {
73+
val applyBuf = new mutable.ListBuffer[Tree]
74+
val newBody = tree.body.mapConserve {
75+
case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 =>
76+
val typeParams = dt.vparamss.head.map(_.symbol.info)
77+
val retType = dt.tpe.widen.finalResultType
78+
79+
val specName = specializedName(nme.apply, typeParams :+ retType)
80+
val specializedApply = cls.info.decls.lookup(specName)
81+
if (specializedApply.exists) {
82+
val apply = specializedApply.asTerm
83+
val specializedDecl =
84+
polyDefDef(apply, trefs => vrefss => {
85+
dt.rhs
86+
.changeOwner(dt.symbol, apply)
87+
.subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
88+
})
89+
applyBuf += specializedDecl
90+
91+
// create a forwarding to the specialized apply
92+
cpy.DefDef(dt)(rhs = {
93+
tpd
94+
.ref(apply)
95+
.appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
11596
})
116-
applyBuf += specializedDecl
117-
118-
// create a forwarding to the specialized apply
119-
cpy.DefDef(dt)(rhs = {
120-
tpd
121-
.ref(apply)
122-
.appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
123-
})
124-
} else dt
125-
}
126-
case x => x
127-
}
128-
129-
val missing: List[TypeTree] = List(0, 1, 2, 3).flatMap { arity =>
130-
val func = defn.FunctionClass(arity)
131-
val tr = tree.symbol.enclosingClass.typeRef
97+
} else dt
13298

133-
if (!tr.parents.exists(_.isRef(func))) Nil
134-
else {
135-
val typeParams = tr.baseArgInfos(func)
136-
val interface = specInterface(typeParams)
137-
138-
if (interface.exists) List(interface.info)
139-
else Nil
99+
case x => x
140100
}
141-
}.map(TypeTree)
142101

143-
cpy.Template(tree)(
144-
parents = tree.parents ++ missing,
145-
body = applyBuf.toList ++ newBody
146-
)
102+
cpy.Template(tree)(
103+
body = applyBuf.toList ::: newBody
104+
)
105+
} else tree
147106
}
148107

149108
/** Dispatch to specialized `apply`s in user code when available */
150-
override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) =
109+
override def transformApply(tree: Apply)(implicit ctx: Context) =
151110
tree match {
152-
case app @ Apply(fun, args)
111+
case Apply(fun, args)
153112
if fun.symbol.name == nme.apply &&
154113
fun.symbol.owner.derivesFrom(defn.FunctionClass(args.length))
155-
=> {
114+
=>
156115
val params = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
157-
val specializedApply = specializedName(nme.apply, params)
158-
159-
if (!params.exists(_.isInstanceOf[ExprType]) && fun.symbol.owner.info.decls.lookup(specializedApply).exists) {
116+
val isSpecializable =
117+
defn.isSpecializableFunction(
118+
fun.symbol.owner.asClass,
119+
params.init,
120+
params.last)
121+
122+
if (isSpecializable && !params.exists(_.isInstanceOf[ExprType])) {
123+
val specializedApply = specializedName(nme.apply, params)
160124
val newSel = fun match {
161125
case Select(qual, _) =>
162126
qual.select(specializedApply)
163-
case _ => {
127+
case _ =>
164128
(fun.tpe: @unchecked) match {
165129
case TermRef(prefix: ThisType, name) =>
166130
tpd.This(prefix.cls).select(specializedApply)
167131
case TermRef(prefix: NamedType, name) =>
168132
tpd.ref(prefix).select(specializedApply)
169133
}
170-
}
171134
}
172135

173136
newSel.appliedToArgs(args)
174137
}
175138
else tree
176-
}
139+
177140
case _ => tree
178141
}
179142

180-
@inline private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) =
181-
name.specializedFor(args, args.map(_.typeSymbol.name), Nil, Nil)
143+
private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) =
144+
name.specializedFunction(args.last, args.init)
182145

183-
@inline private def specInterface(typeParams: List[Type])(implicit ctx: Context) = {
184-
val specName =
185-
("JFunction" + (typeParams.length - 1)).toTermName
186-
.specializedFunction(typeParams.last, typeParams.init)
146+
private def functionName(typeParams: List[Type])(implicit ctx: Context) =
147+
jFunction ++ (typeParams.length - 1).toString
187148

188-
ctx.getClassIfDefined("scala.compat.java8.".toTermName ++ specName)
189-
}
149+
private def specInterface(typeParams: List[Type])(implicit ctx: Context) =
150+
ctx.getClassIfDefined(functionName(typeParams).specializedFunction(typeParams.last, typeParams.init))
151+
152+
private def derivesFromFn012(sym: Symbol)(implicit ctx: Context): Boolean =
153+
sym.derivesFrom(defn.FunctionClass(0)) ||
154+
sym.derivesFrom(defn.FunctionClass(1)) ||
155+
sym.derivesFrom(defn.FunctionClass(2))
190156
}

compiler/src/dotty/tools/dotc/transform/SpecializedApplyMethods.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
package dotty.tools.dotc
22
package transform
33

4-
import TreeTransforms.{ MiniPhaseTransform, TransformerInfo }
54
import ast.Trees._, ast.tpd, core._
65
import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
76
import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
7+
import MegaPhase.MiniPhase
88

99
/** This phase synthesizes specialized methods for FunctionN, this is done
1010
* since there are no scala signatures in the bytecode for the specialized
@@ -14,14 +14,14 @@ import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
1414
* can hardcode them. This should, however be removed once we're using a
1515
* different standard library.
1616
*/
17-
class SpecializedApplyMethods extends MiniPhaseTransform with InfoTransformer {
17+
class SpecializedApplyMethods extends MiniPhase with InfoTransformer {
1818
import ast.tpd._
1919

2020
val phaseName = "specializedApplyMethods"
2121

22-
private[this] var func0Applys: List[Symbol] = _
23-
private[this] var func1Applys: List[Symbol] = _
24-
private[this] var func2Applys: List[Symbol] = _
22+
private[this] var func0Applys: collection.Set[Symbol] = _
23+
private[this] var func1Applys: collection.Set[Symbol] = _
24+
private[this] var func2Applys: collection.Set[Symbol] = _
2525
private[this] var func0: Symbol = _
2626
private[this] var func1: Symbol = _
2727
private[this] var func2: Symbol = _
@@ -30,30 +30,30 @@ class SpecializedApplyMethods extends MiniPhaseTransform with InfoTransformer {
3030
val definitions = ctx.definitions
3131
import definitions._
3232

33-
def specApply(sym: Symbol, args: List[Type], ret: Type)(implicit ctx: Context) = {
33+
def specApply(sym: Symbol, args: List[Type], ret: Type)(implicit ctx: Context): Symbol = {
3434
val name = nme.apply.specializedFunction(ret, args)
3535
ctx.newSymbol(sym, name, Flags.Method, MethodType(args, ret))
3636
}
3737

3838
func0 = FunctionClass(0)
39-
func0Applys = for (r <- ScalaValueTypes.toList) yield specApply(func0, Nil, r)
39+
func0Applys = for (r <- defn.Function0SpecializedReturns) yield specApply(func0, Nil, r)
4040

4141
func1 = FunctionClass(1)
4242
func1Applys = for {
43-
r <- List(UnitType, BooleanType, IntType, FloatType, LongType, DoubleType)
44-
t1 <- List(IntType, LongType, FloatType, DoubleType)
43+
r <- defn.Function1SpecializedReturns
44+
t1 <- defn.Function1SpecializedParams
4545
} yield specApply(func1, List(t1), r)
4646

4747
func2 = FunctionClass(2)
4848
func2Applys = for {
49-
r <- List(UnitType, BooleanType, IntType, FloatType, LongType, DoubleType)
50-
t1 <- List(IntType, LongType, DoubleType)
51-
t2 <- List(IntType, LongType, DoubleType)
49+
r <- Function2SpecializedReturns
50+
t1 <- Function2SpecializedParams
51+
t2 <- Function2SpecializedReturns
5252
} yield specApply(func2, List(t1, t2), r)
5353
}
5454

5555
/** Add symbols for specialized methods to FunctionN */
56-
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
56+
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
5757
case tp: ClassInfo if defn.isPlainFunctionClass(sym) => {
5858
init()
5959
val newDecls = sym.name.functionArity match {
@@ -75,7 +75,7 @@ class SpecializedApplyMethods extends MiniPhaseTransform with InfoTransformer {
7575
}
7676

7777
/** Create bridge methods for FunctionN with specialized applys */
78-
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = {
78+
override def transformTemplate(tree: Template)(implicit ctx: Context) = {
7979
val owner = tree.symbol.owner
8080
val additionalSymbols =
8181
if (owner eq func0) func0Applys

0 commit comments

Comments
 (0)