Skip to content

Commit 05fede9

Browse files
felixmulderliufengyun
authored andcommitted
Initial implementation of function specialization
Add phases and initial replacement for super Replace all existing combinations of Function1 with specialized version Do transformations on symbol level too Refactor transformations to be more idiomatic Add dispatch to specialized applys Add forwarding method for generic case Don't specialize Function1 tree when invalid to Write test to check for specialized apply Remove `DispatchToSpecializedApply` phase SpecializeFunction1: don't roll over parents, use mapConserve Rewrite to handle all specialized functions Don't remove parents not being specialized Add plain function tests to NameOps and Definitions Rewrite `SpecializeFunctions` from `DenotTransformer` to `InfoTransformer` Add `MiniPhaseTransform` to add specialized methods to FunctionN Add synthetic bridge when compiling FunctionN Fix ordering of specialized names and type parameterized apply Add parent types explicitly when specializing When a class directly extends a specialized function class, we need to replace the parent with the specialized interface. In other cases we don't replace it, even if the parent of a parent has a specialized apply - the symbols would propagate anyway. Make `ThisName` recursive on `self.ThisName` Make sure specialized functions get the correct name
1 parent 6a56706 commit 05fede9

File tree

8 files changed

+520
-7
lines changed

8 files changed

+520
-7
lines changed

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -276,16 +276,29 @@ object NameOps {
276276
case nme.clone_ => nme.clone_
277277
}
278278

279-
def specializedFor(classTargs: List[Type], classTargsNames: List[Name], methodTargs: List[Type], methodTarsNames: List[Name])(using Context): N = {
280-
281-
val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
282-
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
279+
/** This method is to be used on **type parameters** from a class, since
280+
* this method does sorting based on their names
281+
*/
282+
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): name.ThisName = {
283+
val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => typeToTag(x._1))
284+
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => typeToTag(x._1))
283285

284286
likeSpacedN(name ++ nme.specializedTypeNames.prefix ++
285287
methodTags.fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.separator ++
286288
classTags.fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix)
287289
}
288290

291+
/** Use for specializing function names ONLY and use it if you are **not**
292+
* creating specialized name from type parameters. The order of names will
293+
* be:
294+
*
295+
* `<return type><first type><second type><...>`
296+
*/
297+
def specializedFunction(ret: Types.Type, args: List[Types.Type])(implicit ctx: Context): name.ThisName =
298+
name ++ nme.specializedTypeNames.prefix ++
299+
nme.specializedTypeNames.separator ++ typeToTag(ret) ++
300+
args.map(typeToTag).fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix
301+
289302
/** If name length exceeds allowable limit, replace part of it by hash */
290303
def compactified(using Context): TermName = termName(compactify(name.toString))
291304

compiler/src/dotty/tools/dotc/core/Names.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ object Names {
3535
abstract class Name extends Designator, Showable derives CanEqual {
3636

3737
/** A type for names of the same kind as this name */
38-
type ThisName <: Name
38+
type ThisName <: Name { type ThisName = self.ThisName }
3939

4040
/** Is this name a type name? */
4141
def isTypeName: Boolean
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package dotty.tools
2+
package dotc
3+
package transform
4+
5+
import TreeTransforms.{ MiniPhaseTransform, TransformerInfo }
6+
import core._
7+
import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
8+
import Denotations._, SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
9+
10+
class DispatchToSpecializedApply extends MiniPhaseTransform {
11+
import ast.Trees._
12+
import ast.tpd
13+
14+
val phaseName = "dispatchToSpecializedApply"
15+
16+
override def transformApply(tree: tpd.Apply)(implicit ctx: Context, info: TransformerInfo) =
17+
tree match {
18+
case Apply(select @ Select(id, nme.apply), arg :: Nil) =>
19+
val params = List(arg.tpe, tree.tpe)
20+
val specializedApply = nme.apply.specializedFor(params, params.map(_.typeSymbol.name))
21+
val hasOverridenSpecializedApply = id.tpe.decls.iterator.exists { sym =>
22+
sym.is(Flags.Override) && (sym.name eq specializedApply)
23+
}
24+
25+
if (hasOverridenSpecializedApply) tpd.Apply(tpd.Select(id, specializedApply), arg :: Nil)
26+
else tree
27+
case _ => tree
28+
}
29+
}
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import TreeTransforms.{ MiniPhaseTransform, TransformerInfo }
5+
import ast.Trees._, ast.tpd, core._
6+
import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
7+
import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
8+
9+
import scala.collection.mutable
10+
11+
/** Specializes classes that inherit from `FunctionN` where there exists a
12+
* specialized form.
13+
*/
14+
class SpecializeFunctions extends MiniPhaseTransform with InfoTransformer {
15+
import ast.tpd._
16+
val phaseName = "specializeFunctions"
17+
18+
private[this] var _blacklistedSymbols: List[Symbol] = _
19+
20+
private def blacklistedSymbols(implicit ctx: Context): List[Symbol] = {
21+
if (_blacklistedSymbols eq null) _blacklistedSymbols = List(
22+
ctx.getClassIfDefined("scala.math.Ordering").asClass.membersNamed("Ops".toTypeName).first.symbol
23+
)
24+
25+
_blacklistedSymbols
26+
}
27+
28+
/** Transforms the type to include decls for specialized applys and replace
29+
* the class parents with specialized versions.
30+
*/
31+
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
32+
case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) => {
33+
var newApplys = Map.empty[Name, Symbol]
34+
35+
val newParents = tp.parents.mapConserve { parent =>
36+
List(0, 1, 2, 3).flatMap { arity =>
37+
val func = defn.FunctionClass(arity)
38+
if (!parent.derivesFrom(func)) Nil
39+
else {
40+
val typeParams = tp.typeRef.baseArgInfos(func)
41+
val interface = specInterface(typeParams)
42+
43+
if (interface.exists) {
44+
if (tp.decls.lookup(nme.apply).exists) {
45+
val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init)
46+
newApplys = newApplys + (specializedMethodName -> interface)
47+
}
48+
49+
if (parent.isRef(func)) List(interface.typeRef)
50+
else Nil
51+
}
52+
else Nil
53+
}
54+
}
55+
.headOption
56+
.getOrElse(parent)
57+
}
58+
59+
def newDecls =
60+
if (newApplys.isEmpty) tp.decls
61+
else
62+
newApplys.toList.map { case (name, interface) =>
63+
ctx.newSymbol(
64+
sym,
65+
name,
66+
Flags.Override | Flags.Method,
67+
interface.info.decls.lookup(name).info
68+
)
69+
}
70+
.foldLeft(tp.decls.cloneScope) {
71+
(scope, sym) => scope.enter(sym); scope
72+
}
73+
74+
tp.derivedClassInfo(
75+
classParents = newParents,
76+
decls = newDecls
77+
)
78+
}
79+
80+
case _ => tp
81+
}
82+
83+
/** Transforms the `Template` of the classes to contain forwarders from the
84+
* generic applys to the specialized ones. Also replaces parents of the
85+
* class on the tree level and inserts the specialized applys in the
86+
* template body.
87+
*/
88+
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = {
89+
val applyBuf = new mutable.ListBuffer[Tree]
90+
val newBody = tree.body.mapConserve {
91+
case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 => {
92+
val specName = nme.apply.specializedFunction(
93+
dt.tpe.widen.finalResultType,
94+
dt.vparamss.head.map(_.symbol.info)
95+
)
96+
97+
val specializedApply = tree.symbol.enclosingClass.info.decls.lookup(specName)//member(specName).symbol
98+
//val specializedApply = tree.symbol.enclosingClass.info.member(specName).symbol
99+
100+
if (false) {
101+
println(tree.symbol.enclosingClass.show)
102+
println("'" + specName.show + "'")
103+
println(specializedApply)
104+
println(specializedApply.exists)
105+
}
106+
107+
108+
if (specializedApply.exists) {
109+
val apply = specializedApply.asTerm
110+
val specializedDecl =
111+
polyDefDef(apply, trefs => vrefss => {
112+
dt.rhs
113+
.changeOwner(dt.symbol, apply)
114+
.subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
115+
})
116+
applyBuf += specializedDecl
117+
118+
// create a forwarding to the specialized apply
119+
cpy.DefDef(dt)(rhs = {
120+
tpd
121+
.ref(apply)
122+
.appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
123+
})
124+
} else dt
125+
}
126+
case x => x
127+
}
128+
129+
val missing: List[TypeTree] = List(0, 1, 2, 3).flatMap { arity =>
130+
val func = defn.FunctionClass(arity)
131+
val tr = tree.symbol.enclosingClass.typeRef
132+
133+
if (!tr.parents.exists(_.isRef(func))) Nil
134+
else {
135+
val typeParams = tr.baseArgInfos(func)
136+
val interface = specInterface(typeParams)
137+
138+
if (interface.exists) List(interface.info)
139+
else Nil
140+
}
141+
}.map(TypeTree)
142+
143+
cpy.Template(tree)(
144+
parents = tree.parents ++ missing,
145+
body = applyBuf.toList ++ newBody
146+
)
147+
}
148+
149+
/** Dispatch to specialized `apply`s in user code when available */
150+
override def transformApply(tree: Apply)(implicit ctx: Context, info: TransformerInfo) =
151+
tree match {
152+
case app @ Apply(fun, args)
153+
if fun.symbol.name == nme.apply &&
154+
fun.symbol.owner.derivesFrom(defn.FunctionClass(args.length))
155+
=> {
156+
val params = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
157+
val specializedApply = specializedName(nme.apply, params)
158+
159+
if (!params.exists(_.isInstanceOf[ExprType]) && fun.symbol.owner.info.decls.lookup(specializedApply).exists) {
160+
val newSel = fun match {
161+
case Select(qual, _) =>
162+
qual.select(specializedApply)
163+
case _ => {
164+
(fun.tpe: @unchecked) match {
165+
case TermRef(prefix: ThisType, name) =>
166+
tpd.This(prefix.cls).select(specializedApply)
167+
case TermRef(prefix: NamedType, name) =>
168+
tpd.ref(prefix).select(specializedApply)
169+
}
170+
}
171+
}
172+
173+
newSel.appliedToArgs(args)
174+
}
175+
else tree
176+
}
177+
case _ => tree
178+
}
179+
180+
@inline private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) =
181+
name.specializedFor(args, args.map(_.typeSymbol.name), Nil, Nil)
182+
183+
@inline private def specInterface(typeParams: List[Type])(implicit ctx: Context) = {
184+
val specName =
185+
("JFunction" + (typeParams.length - 1)).toTermName
186+
.specializedFunction(typeParams.last, typeParams.init)
187+
188+
ctx.getClassIfDefined("scala.compat.java8.".toTermName ++ specName)
189+
}
190+
}
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import TreeTransforms.{ MiniPhaseTransform, TransformerInfo }
5+
import ast.Trees._, ast.tpd, core._
6+
import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
7+
import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
8+
9+
/** This phase synthesizes specialized methods for FunctionN, this is done
10+
* since there are no scala signatures in the bytecode for the specialized
11+
* methods.
12+
*
13+
* We know which specializations exist for the different arities, therefore we
14+
* can hardcode them. This should, however be removed once we're using a
15+
* different standard library.
16+
*/
17+
class SpecializedApplyMethods extends MiniPhaseTransform with InfoTransformer {
18+
import ast.tpd._
19+
20+
val phaseName = "specializedApplyMethods"
21+
22+
private[this] var func0Applys: List[Symbol] = _
23+
private[this] var func1Applys: List[Symbol] = _
24+
private[this] var func2Applys: List[Symbol] = _
25+
private[this] var func0: Symbol = _
26+
private[this] var func1: Symbol = _
27+
private[this] var func2: Symbol = _
28+
29+
private def init()(implicit ctx: Context): Unit = if (func0Applys eq null) {
30+
val definitions = ctx.definitions
31+
import definitions._
32+
33+
def specApply(sym: Symbol, args: List[Type], ret: Type)(implicit ctx: Context) = {
34+
val name = nme.apply.specializedFunction(ret, args)
35+
ctx.newSymbol(sym, name, Flags.Method, MethodType(args, ret))
36+
}
37+
38+
func0 = FunctionClass(0)
39+
func0Applys = for (r <- ScalaValueTypes.toList) yield specApply(func0, Nil, r)
40+
41+
func1 = FunctionClass(1)
42+
func1Applys = for {
43+
r <- List(UnitType, BooleanType, IntType, FloatType, LongType, DoubleType)
44+
t1 <- List(IntType, LongType, FloatType, DoubleType)
45+
} yield specApply(func1, List(t1), r)
46+
47+
func2 = FunctionClass(2)
48+
func2Applys = for {
49+
r <- List(UnitType, BooleanType, IntType, FloatType, LongType, DoubleType)
50+
t1 <- List(IntType, LongType, DoubleType)
51+
t2 <- List(IntType, LongType, DoubleType)
52+
} yield specApply(func2, List(t1, t2), r)
53+
}
54+
55+
/** Add symbols for specialized methods to FunctionN */
56+
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
57+
case tp: ClassInfo if defn.isPlainFunctionClass(sym) => {
58+
init()
59+
val newDecls = sym.name.functionArity match {
60+
case 0 => func0Applys.foldLeft(tp.decls.cloneScope) {
61+
(decls, sym) => decls.enter(sym); decls
62+
}
63+
case 1 => func1Applys.foldLeft(tp.decls.cloneScope) {
64+
(decls, sym) => decls.enter(sym); decls
65+
}
66+
case 2 => func2Applys.foldLeft(tp.decls.cloneScope) {
67+
(decls, sym) => decls.enter(sym); decls
68+
}
69+
case _ => tp.decls
70+
}
71+
72+
tp.derivedClassInfo(decls = newDecls)
73+
}
74+
case _ => tp
75+
}
76+
77+
/** Create bridge methods for FunctionN with specialized applys */
78+
override def transformTemplate(tree: Template)(implicit ctx: Context, info: TransformerInfo) = {
79+
val owner = tree.symbol.owner
80+
val additionalSymbols =
81+
if (owner eq func0) func0Applys
82+
else if (owner eq func1) func1Applys
83+
else if (owner eq func2) func2Applys
84+
else Nil
85+
86+
if (additionalSymbols eq Nil) tree
87+
else cpy.Template(tree)(body = tree.body ++ additionalSymbols.map { apply =>
88+
DefDef(apply.asTerm, { vparamss =>
89+
This(owner.asClass)
90+
.select(nme.apply)
91+
.appliedToArgss(vparamss)
92+
.ensureConforms(apply.info.finalResultType)
93+
})
94+
})
95+
}
96+
}

0 commit comments

Comments
 (0)