Skip to content

Commit 7fbec1c

Browse files
committed
Third version of function specialization
- Fix isPlainFunctionClass Previous implementation is incorrect, as scala.Function1$ would qualify. - Create the symbol at the next phase Create the symbol at the next phase, so that it is a valid member of the corresponding function for all valid periods of its SymDenotations. Otherwise, the valid period will offset by 1, which causes a stale symbol in compiling stdlib. - Handle abstract apply and multiple applys - Don't specialize abstract apply - Fast specialization We avoid going through InfoTransformer, which will cause all symbols to be checked. The reason why it works is that the specialized base classes, i.e. Function0-2 already have all the relevant definitions. - Use StringBuilder instead of StringBuffer (thanks @smarter) StringBuffer is synchronized thus is slower.
1 parent 9b43b28 commit 7fbec1c

15 files changed

+314
-296
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,21 +69,23 @@ class Compiler {
6969
new CacheAliasImplicits, // Cache RHS of parameterless alias implicits
7070
new ByNameClosures, // Expand arguments to by-name parameters to closures
7171
new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope
72+
new SpecializeApplyMethods, // Adds specialized methods to FunctionN
7273
new RefChecks) :: // Various checks mostly related to abstract members and overriding
7374
List(new ElimOpaque, // Turn opaque into normal aliases
7475
new TryCatchPatterns, // Compile cases in try/catch
7576
new PatternMatcher, // Compile pattern matches
7677
new sjs.ExplicitJSClasses, // Make all JS classes explicit (Scala.js only)
7778
new ExplicitOuter, // Add accessors to outer classes from nested ones.
7879
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
80+
new ElimByName, // Expand by-name parameter references
7981
new StringInterpolatorOpt) :: // Optimizes raw and s string interpolators by rewriting them to string concatentations
8082
List(new PruneErasedDefs, // Drop erased definitions from scopes and simplify erased expressions
8183
new InlinePatterns, // Remove placeholders of inlined patterns
8284
new VCInlineMethods, // Inlines calls to value class methods
8385
new SeqLiterals, // Express vararg arguments as arrays
8486
new InterceptedMethods, // Special handling of `==`, `|=`, `getClass` methods
8587
new Getters, // Replace non-private vals and vars with getter defs (fields are added later)
86-
new ElimByName, // Expand by-name parameter references
88+
new SpecializeFunctions, // Specialized Function{0,1,2} by replacing super with specialized super
8789
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
8890
new CollectNullableFields, // Collect fields that can be nulled out after use in lazy initialization
8991
new ElimOuterSelect, // Expand outer selections

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -784,8 +784,8 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
784784
def tupleArgs(tree: Tree)(using Context): List[Tree] = tree match {
785785
case Block(Nil, expr) => tupleArgs(expr)
786786
case Inlined(_, Nil, expr) => tupleArgs(expr)
787-
case Apply(fn, args)
788-
if fn.symbol.name == nme.apply &&
787+
case Apply(fn: NameTree, args)
788+
if fn.name == nme.apply &&
789789
fn.symbol.owner.is(Module) &&
790790
defn.isTupleClass(fn.symbol.owner.companionClass) => args
791791
case _ => Nil

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1209,7 +1209,11 @@ class Definitions {
12091209
else funType(n)
12101210
).symbol.asClass
12111211

1212-
@tu lazy val Function0_apply: Symbol = FunctionClass(0).requiredMethod(nme.apply)
1212+
@tu lazy val Function0_apply: Symbol = Function0.requiredMethod(nme.apply)
1213+
1214+
@tu lazy val Function0: Symbol = FunctionClass(0)
1215+
@tu lazy val Function1: Symbol = FunctionClass(1)
1216+
@tu lazy val Function2: Symbol = FunctionClass(2)
12131217

12141218
def FunctionType(n: Int, isContextual: Boolean = false, isErased: Boolean = false)(using Context): TypeRef =
12151219
FunctionClass(n, isContextual && !ctx.erasedTypes, isErased).typeRef
@@ -1244,7 +1248,7 @@ class Definitions {
12441248

12451249
def isBottomClassAfterErasure(cls: Symbol): Boolean = cls == NothingClass || cls == NullClass
12461250

1247-
/** Is a function class.
1251+
/** Is any function class where
12481252
* - FunctionXXL
12491253
* - FunctionN for N >= 0
12501254
* - ContextFunctionN for N >= 0
@@ -1253,6 +1257,11 @@ class Definitions {
12531257
*/
12541258
def isFunctionClass(cls: Symbol): Boolean = scalaClassName(cls).isFunction
12551259

1260+
/** Is a function class where
1261+
* - FunctionN for N >= 0 and N != XXL
1262+
*/
1263+
def isPlainFunctionClass(cls: Symbol) = isVarArityClass(cls, str.Function)
1264+
12561265
/** Is an context function class.
12571266
* - ContextFunctionN for N >= 0
12581267
* - ErasedContextFunctionN for N > 0
@@ -1488,6 +1497,25 @@ class Definitions {
14881497
false
14891498
})
14901499

1500+
@tu lazy val Function0SpecializedApplyNames: collection.Set[TermName] =
1501+
for r <- Function0SpecializedReturnTypes
1502+
yield nme.apply.specializedFunction(r, Nil).asTermName
1503+
1504+
@tu lazy val Function1SpecializedApplyNames: collection.Set[TermName] =
1505+
for
1506+
r <- Function1SpecializedReturnTypes
1507+
t1 <- Function1SpecializedParamTypes
1508+
yield
1509+
nme.apply.specializedFunction(r, List(t1)).asTermName
1510+
1511+
@tu lazy val Function2SpecializedApplyNames: collection.Set[TermName] =
1512+
for
1513+
r <- Function2SpecializedReturnTypes
1514+
t1 <- Function2SpecializedParamTypes
1515+
t2 <- Function2SpecializedParamTypes
1516+
yield
1517+
nme.apply.specializedFunction(r, List(t1, t2)).asTermName
1518+
14911519
def functionArity(tp: Type)(using Context): Int = tp.dropDependentRefinement.dealias.argInfos.length - 1
14921520

14931521
/** Return underlying context function type (i.e. instance of an ContextFunctionN class)

compiler/src/dotty/tools/dotc/core/NameOps.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ object NameOps {
231231
def isFunction: Boolean =
232232
(name eq tpnme.FunctionXXL) || checkedFunArity(functionSuffixStart) >= 0
233233

234+
/** Is a function name
235+
* - FunctionN for N >= 0
236+
*/
237+
def isPlainFunction: Boolean = functionArity >= 0
238+
234239
/** Is an context function name, i.e one of ContextFunctionN or ErasedContextFunctionN for N >= 0
235240
*/
236241
def isContextFunction: Boolean =
@@ -279,7 +284,7 @@ object NameOps {
279284
/** This method is to be used on **type parameters** from a class, since
280285
* this method does sorting based on their names
281286
*/
282-
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): Name = {
287+
def specializedFor(classTargs: List[Type], classTargsNames: List[Name], methodTargs: List[Type], methodTarsNames: List[Name])(using Context): N = {
283288
val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
284289
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
285290

@@ -294,10 +299,15 @@ object NameOps {
294299
*
295300
* `<return type><first type><second type><...>`
296301
*/
297-
def specializedFunction(ret: Types.Type, args: List[Types.Type])(implicit ctx: Context): Name =
298-
name ++ nme.specializedTypeNames.prefix ++
299-
nme.specializedTypeNames.separator ++ defn.typeTag(ret) ++
300-
args.map(defn.typeTag).fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix
302+
def specializedFunction(ret: Type, args: List[Type])(using Context): Name =
303+
val sb = new StringBuilder
304+
sb.append(name.toString)
305+
sb.append(nme.specializedTypeNames.prefix.toString)
306+
sb.append(nme.specializedTypeNames.separator)
307+
sb.append(defn.typeTag(ret).toString)
308+
args.foreach { arg => sb.append(defn.typeTag(arg)) }
309+
sb.append(nme.specializedTypeNames.suffix)
310+
termName(sb.toString)
301311

302312
/** If name length exceeds allowable limit, replace part of it by hash */
303313
def compactified(using Context): TermName = termName(compactify(name.toString))

compiler/src/dotty/tools/dotc/transform/DispatchToSpecializedApply.scala

Lines changed: 0 additions & 29 deletions
This file was deleted.

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ class ExpandSAMs extends MiniPhase {
162162
cpy.Block(tree)(pfDef :: Nil, New(pfSym.typeRef, Nil))
163163

164164
case _ =>
165-
val found = tpe.baseType(defn.FunctionClass(1))
165+
val found = tpe.baseType(defn.Function1)
166166
report.error(TypeMismatch(found, tpe), tree.srcPos)
167167
tree
168168
}

compiler/src/dotty/tools/dotc/transform/FunctionXXLForwarders.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,18 +39,20 @@ class FunctionXXLForwarders extends MiniPhase with IdentityDenotTransformer {
3939
ref(receiver.symbol).appliedToArgss(argss).cast(defn.ObjectType)
4040
}
4141

42+
if impl.symbol.owner.is(Trait) then return impl
43+
4244
val forwarders =
4345
for {
44-
tree <- if (impl.symbol.owner.is(Trait)) Nil else impl.body
45-
if tree.symbol.is(Method) && tree.symbol.name == nme.apply &&
46-
tree.symbol.signature.paramsSig.size > MaxImplementedFunctionArity &&
47-
tree.symbol.allOverriddenSymbols.exists(sym => defn.isXXLFunctionClass(sym.owner))
46+
(ddef: DefDef) <- impl.body
47+
if ddef.name == nme.apply && ddef.symbol.is(Method) &&
48+
ddef.symbol.signature.paramsSig.size > MaxImplementedFunctionArity &&
49+
ddef.symbol.allOverriddenSymbols.exists(sym => defn.isXXLFunctionClass(sym.owner))
4850
}
4951
yield {
5052
val xsType = defn.ArrayType.appliedTo(List(defn.ObjectType))
5153
val methType = MethodType(List(nme.args))(_ => List(xsType), _ => defn.ObjectType)
52-
val meth = newSymbol(tree.symbol.owner, nme.apply, Synthetic | Method, methType)
53-
DefDef(meth, paramss => forwarderRhs(tree, paramss.head.head))
54+
val meth = newSymbol(ddef.symbol.owner, nme.apply, Synthetic | Method, methType)
55+
DefDef(meth, paramss => forwarderRhs(ddef, paramss.head.head))
5456
}
5557

5658
cpy.Template(impl)(body = forwarders ::: impl.body)
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import ast.Trees._, ast.tpd, core._
5+
import Contexts._, Types._, Decorators._, Symbols._, DenotTransformers._
6+
import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
7+
import MegaPhase.MiniPhase
8+
9+
import scala.collection.mutable
10+
11+
12+
/** This phase synthesizes specialized methods for FunctionN, this is done
13+
* since there are no scala signatures in the bytecode for the specialized
14+
* methods.
15+
*
16+
* We know which specializations exist for the different arities, therefore we
17+
* can hardcode them. This should, however be removed once we're using a
18+
* different standard library.
19+
*/
20+
class SpecializeApplyMethods extends MiniPhase with InfoTransformer {
21+
import ast.tpd._
22+
23+
val phaseName = "specializeApplyMethods"
24+
25+
override def isEnabled(using Context): Boolean =
26+
!ctx.settings.scalajs.value
27+
28+
private def specApplySymbol(sym: Symbol, args: List[Type], ret: Type)(using Context): Symbol = {
29+
val name = nme.apply.specializedFunction(ret, args)
30+
// Create the symbol at the next phase, so that it is a valid member of the
31+
// corresponding function for all valid periods of its SymDenotations.
32+
// Otherwise, the valid period will offset by 1, which causes a stale symbol
33+
// in compiling stdlib.
34+
atNextPhase(newSymbol(sym, name, Flags.Method, MethodType(args, ret)))
35+
}
36+
37+
private inline def specFun0(inline op: Type => Unit)(using Context): Unit = {
38+
for (r <- defn.Function0SpecializedReturnTypes) do
39+
op(r)
40+
}
41+
42+
private inline def specFun1(inline op: (Type, Type) => Unit)(using Context): Unit = {
43+
for
44+
r <- defn.Function1SpecializedReturnTypes
45+
t1 <- defn.Function1SpecializedParamTypes
46+
do
47+
op(t1, r)
48+
}
49+
50+
private inline def specFun2(inline op: (Type, Type, Type) => Unit)(using Context): Unit = {
51+
for
52+
r <- defn.Function2SpecializedReturnTypes
53+
t1 <- defn.Function2SpecializedParamTypes
54+
t2 <- defn.Function2SpecializedParamTypes
55+
do
56+
op(t1, t2, r)
57+
}
58+
59+
override def infoMayChange(sym: Symbol)(using Context) =
60+
sym == defn.Function0
61+
|| sym == defn.Function1
62+
|| sym == defn.Function2
63+
64+
/** Add symbols for specialized methods to FunctionN */
65+
override def transformInfo(tp: Type, sym: Symbol)(using Context) = tp match {
66+
case tp: ClassInfo =>
67+
if sym == defn.Function0 then
68+
val scope = tp.decls.cloneScope
69+
specFun0 { r => scope.enter(specApplySymbol(sym, Nil, r)) }
70+
tp.derivedClassInfo(decls = scope)
71+
72+
else if sym == defn.Function1 then
73+
val scope = tp.decls.cloneScope
74+
specFun1 { (t1, r) => scope.enter(specApplySymbol(sym, t1 :: Nil, r)) }
75+
tp.derivedClassInfo(decls = scope)
76+
77+
else if sym == defn.Function2 then
78+
val scope = tp.decls.cloneScope
79+
specFun2 { (t1, t2, r) => scope.enter(specApplySymbol(sym, t1 :: t2 :: Nil, r)) }
80+
tp.derivedClassInfo(decls = scope)
81+
82+
else tp
83+
84+
case _ => tp
85+
}
86+
87+
/** Create bridge methods for FunctionN with specialized applys */
88+
override def transformTemplate(tree: Template)(using Context) = {
89+
val cls = tree.symbol.owner.asClass
90+
91+
def synthesizeApply(names: collection.Set[TermName]): Tree = {
92+
val applyBuf = new mutable.ListBuffer[DefDef]
93+
names.foreach { name =>
94+
val applySym = cls.info.decls.lookup(name)
95+
val ddef = DefDef(
96+
applySym.asTerm,
97+
{ vparamss =>
98+
This(cls)
99+
.select(nme.apply)
100+
.appliedToArgss(vparamss)
101+
.ensureConforms(applySym.info.finalResultType)
102+
}
103+
)
104+
applyBuf += ddef
105+
}
106+
cpy.Template(tree)(body = tree.body ++ applyBuf)
107+
}
108+
109+
if cls == defn.Function0 then
110+
synthesizeApply(defn.Function0SpecializedApplyNames)
111+
else if cls == defn.Function1 then
112+
synthesizeApply(defn.Function1SpecializedApplyNames)
113+
else if cls == defn.Function2 then
114+
synthesizeApply(defn.Function2SpecializedApplyNames)
115+
else
116+
tree
117+
}
118+
}

0 commit comments

Comments
 (0)