Skip to content

Commit ef48dc0

Browse files
committed
Refactor SpecializeFunctions
1 parent 8114182 commit ef48dc0

File tree

2 files changed

+56
-69
lines changed

2 files changed

+56
-69
lines changed

compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala

Lines changed: 55 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,49 +16,45 @@ class SpecializeFunctions extends MiniPhase with InfoTransformer {
1616
val phaseName = "specializeFunctions"
1717
override def runsAfter = Set(ElimByName.name)
1818

19-
private val jFunction = "scala.compat.java8.JFunction".toTermName
20-
2119
/** Transforms the type to include decls for specialized applys */
2220
override def transformInfo(tp: Type, sym: Symbol)(using Context) = tp match {
23-
case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) && derivesFromFn012(sym) =>
24-
var newApplys = Map.empty[Name, Symbol]
21+
case tp: ClassInfo if !sym.is(Flags.Package) && derivesFromFn012(sym) =>
22+
val newApplys = new mutable.ListBuffer[Symbol]
2523

2624
var arity = 0
2725
while (arity < 3) {
2826
val func = defn.FunctionClass(arity)
2927
if (tp.derivesFrom(func)) {
30-
val typeParams = tp.cls.typeRef.baseType(func).argInfos
28+
val paramTypes = tp.cls.typeRef.baseType(func).argInfos
3129
val isSpecializable =
3230
defn.isSpecializableFunction(
3331
sym.asClass,
34-
typeParams.init,
35-
typeParams.last
32+
paramTypes.init,
33+
paramTypes.last
34+
)
35+
36+
val apply = tp.decls.lookup(nme.apply)
37+
if (isSpecializable && apply.exists) {
38+
val specializedMethodName = specializedName(nme.apply, paramTypes)
39+
val applySpecialized = newSymbol(
40+
sym,
41+
specializedMethodName,
42+
Flags.Override | Flags.Method | Flags.Synthetic,
43+
apply.info
3644
)
3745

38-
if (isSpecializable && tp.decls.lookup(nme.apply).exists) {
39-
val interface = specInterface(typeParams)
40-
val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init)
41-
newApplys += (specializedMethodName -> interface)
46+
newApplys += applySpecialized
4247
}
4348
}
4449
arity += 1
4550
}
4651

47-
def newDecls =
48-
newApplys.toList.map { case (name, interface) =>
49-
newSymbol(
50-
sym,
51-
name,
52-
Flags.Override | Flags.Method | Flags.Synthetic,
53-
interface.info.decls.lookup(name).info
54-
)
55-
}
56-
.foldLeft(tp.decls.cloneScope) {
57-
(scope, sym) => scope.enter(sym); scope
58-
}
59-
6052
if (newApplys.isEmpty) tp
61-
else tp.derivedClassInfo(decls = newDecls)
53+
else {
54+
val scope = tp.decls.cloneScope
55+
newApplys.toList.foreach { sym => scope.enter(sym) }
56+
tp.derivedClassInfo(decls = scope)
57+
}
6258

6359
case _ => tp
6460
}
@@ -69,58 +65,55 @@ class SpecializeFunctions extends MiniPhase with InfoTransformer {
6965
*/
7066
override def transformTemplate(tree: Template)(using Context) = {
7167
val cls = tree.symbol.enclosingClass.asClass
72-
if (derivesFromFn012(cls)) {
73-
val applyBuf = new mutable.ListBuffer[Tree]
74-
val newBody = tree.body.mapConserve {
75-
case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 =>
76-
val typeParams = dt.vparamss.head.map(_.symbol.info)
77-
val retType = dt.tpe.widen.finalResultType
78-
79-
val specName = specializedName(nme.apply, typeParams :+ retType)
80-
val specializedApply = cls.info.decls.lookup(specName)
81-
if (specializedApply.exists) {
82-
val apply = specializedApply.asTerm
83-
val specializedDecl =
84-
polyDefDef(apply, trefs => vrefss => {
85-
dt.rhs
86-
.changeOwner(dt.symbol, apply)
87-
.subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
88-
})
89-
applyBuf += specializedDecl
90-
91-
// create a forwarding to the specialized apply
92-
cpy.DefDef(dt)(rhs = {
93-
tpd
94-
.ref(apply)
95-
.appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
68+
69+
if (!derivesFromFn012(cls)) return tree
70+
71+
val applyBuf = new mutable.ListBuffer[Tree]
72+
val newBody = tree.body.mapConserve {
73+
case ddef: DefDef if ddef.name == nme.apply && ddef.vparamss.length == 1 =>
74+
val paramTypes = ddef.vparamss.head.map(_.symbol.info)
75+
val retType = ddef.tpe.widen.finalResultType
76+
77+
val specName = specializedName(nme.apply, paramTypes :+ retType)
78+
val specializedApply = cls.info.decls.lookup(specName)
79+
if (specializedApply.exists) {
80+
val specializedDecl =
81+
DefDef(specializedApply.asTerm, vparamss => {
82+
ddef.rhs
83+
.changeOwner(ddef.symbol, specializedApply)
84+
.subst(ddef.vparamss.head.map(_.symbol), vparamss.head.map(_.symbol))
9685
})
97-
} else dt
86+
applyBuf += specializedDecl
9887

99-
case x => x
100-
}
88+
// create a forwarding to the specialized apply
89+
val args = ddef.vparamss.head.map(vparam => ref(vparam.symbol))
90+
val rhs = This(cls).select(specializedApply).appliedToArgs(args)
91+
cpy.DefDef(ddef)(rhs = rhs)
92+
} else ddef
93+
94+
case x => x
95+
}
10196

102-
cpy.Template(tree)(
103-
body = applyBuf.toList ::: newBody
104-
)
105-
} else tree
97+
cpy.Template(tree)(body = applyBuf.toList ::: newBody)
10698
}
10799

108100
/** Dispatch to specialized `apply`s in user code when available */
109101
override def transformApply(tree: Apply)(using Context) =
110102
tree match {
111103
case Apply(fun, args)
112104
if fun.symbol.name == nme.apply &&
113-
fun.symbol.owner.derivesFrom(defn.FunctionClass(args.length))
105+
derivesFromFn012(fun.symbol.owner)
114106
=>
115-
val params = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
107+
val paramTypes = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
116108
val isSpecializable =
117109
defn.isSpecializableFunction(
118110
fun.symbol.owner.asClass,
119-
params.init,
120-
params.last)
111+
paramTypes.init,
112+
paramTypes.last
113+
)
121114

122-
if (isSpecializable && !params.exists(_.isInstanceOf[ExprType])) {
123-
val specializedApply = specializedName(nme.apply, params)
115+
if (isSpecializable && !paramTypes.exists(_.isInstanceOf[ExprType])) {
116+
val specializedApply = specializedName(nme.apply, paramTypes)
124117
val newSel = fun match {
125118
case Select(qual, _) =>
126119
qual.select(specializedApply)
@@ -143,12 +136,6 @@ class SpecializeFunctions extends MiniPhase with InfoTransformer {
143136
private def specializedName(name: Name, args: List[Type])(using Context) =
144137
name.specializedFunction(args.last, args.init)
145138

146-
private def functionName(typeParams: List[Type])(using Context) =
147-
jFunction ++ (typeParams.length - 1).toString
148-
149-
private def specInterface(typeParams: List[Type])(using Context) =
150-
getClassIfDefined(functionName(typeParams).specializedFunction(typeParams.last, typeParams.init))
151-
152139
private def derivesFromFn012(sym: Symbol)(using Context): Boolean =
153140
sym.derivesFrom(defn.FunctionClass(0)) ||
154141
sym.derivesFrom(defn.FunctionClass(1)) ||

compiler/src/dotty/tools/dotc/transform/SpecializedApplyMethods.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class SpecializedApplyMethods extends MiniPhase with InfoTransformer {
7777

7878
/** Create bridge methods for FunctionN with specialized applys */
7979
override def transformTemplate(tree: Template)(using Context) = {
80-
val cls = tree.symbol.owner.asInstanceOf[ClassSymbol]
80+
val cls = tree.symbol.owner.asClass
8181

8282
if (!defn.isPlainFunctionClass(cls)) return tree
8383

0 commit comments

Comments
 (0)