Skip to content

Specialize Functions #3306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
caa0a6b
Add phases and initial replacement for super
felixmulder Dec 2, 2016
04024a1
Replace all existing combinations of Function1 with specialized version
felixmulder Dec 2, 2016
c5a5a7e
Do transformations on symbol level too
felixmulder Dec 13, 2016
34836c7
Refactor transformations to be more idiomatic
felixmulder Dec 14, 2016
f315b8d
Add dispatch to specialized applys
felixmulder Dec 14, 2016
1d82065
Add forwarding method for generic case
felixmulder Dec 14, 2016
a932097
Don't specialize Function1 tree when invalid to
felixmulder Dec 14, 2016
1016194
Write test to check for specialized apply
felixmulder Dec 14, 2016
55a35c4
Remove `DispatchToSpecializedApply` phase
felixmulder Dec 23, 2016
9db4d27
SpecializeFunction1: don't roll over parents, use mapConserve
felixmulder Dec 30, 2016
71005ae
Rewrite to handle all specialized functions
felixmulder Feb 14, 2017
9bab6a9
Don't remove parents not being specialized
felixmulder Feb 16, 2017
b0e175a
Add plain function tests to NameOps and Definitions
felixmulder Feb 16, 2017
6a9eabb
Rewrite `SpecializeFunctions` from `DenotTransformer` to `InfoTransfo…
felixmulder Feb 16, 2017
29986fc
Add `MiniPhaseTransform` to add specialized methods to FunctionN
felixmulder Feb 21, 2017
4e845a4
Add synthetic bridge when compiling FunctionN
felixmulder Feb 21, 2017
3e32c22
Fix ordering of specialized names and type parameterized apply
felixmulder Feb 21, 2017
d4f83d0
Add parent types explicitly when specializing
felixmulder Feb 22, 2017
6e79132
Make `ThisName` recursive on `self.ThisName`
felixmulder Apr 12, 2017
bcf83cf
Make sure specialized functions get the correct name
felixmulder Apr 12, 2017
8b6bb63
Fix compilation errors
Duhemm Oct 4, 2017
de7b5f4
Don't change parents in `specializeFunctions`
Duhemm Oct 6, 2017
aa3f588
Cleanup
Duhemm Oct 9, 2017
b9187d2
Move type to JVM tag conversion to Definitions
Duhemm Oct 10, 2017
a607938
Add more tests for function specialization
Duhemm Oct 11, 2017
efdc516
Pass by name should not introduce boxing
Duhemm Oct 11, 2017
654fb14
Cleanup
Duhemm Oct 11, 2017
ed751bd
No specialization for Function3
Duhemm Oct 11, 2017
a1c7109
Adapt to recent changes in transformers and phases
Duhemm Oct 28, 2017
16b881d
Un-split phase group that I previously split
Duhemm Oct 28, 2017
8851312
Address review comments
Duhemm Nov 6, 2017
a32b06f
Optimize phase `SpecializeFunctions`
Duhemm Nov 19, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,22 @@ class Compiler {
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope
new ClassOf, // Expand `Predef.classOf` calls.
new SpecializedApplyMethods, // Adds specialized methods to FunctionN
new RefChecks), // Various checks mostly related to abstract members and overriding
List(new TryCatchPatterns, // Compile cases in try/catch
new PatternMatcher, // Compile pattern matches
new ExplicitOuter, // Add accessors to outer classes from nested ones.
new ExplicitSelf, // Make references to non-trivial self types explicit as casts
new ShortcutImplicits, // Allow implicit functions without creating closures
new CrossCastAnd, // Normalize selections involving intersection types.
new Splitter), // Expand selections involving union types into conditionals
new Splitter, // Expand selections involving union types into conditionals
new ElimByName), // Expand by-name parameter references
List(new PhantomArgLift, // Extracts the evaluation of phantom arguments placing them before the call.
new VCInlineMethods, // Inlines calls to value class methods
new SeqLiterals, // Express vararg arguments as arrays
new InterceptedMethods, // Special handling of `==`, `|=`, `getClass` methods
new Getters, // Replace non-private vals and vars with getter defs (fields are added later)
new ElimByName, // Expand by-name parameter references
new SpecializeFunctions, // Specialized Function{0,1,2} by replacing super with specialized super
new ElimOuterSelect, // Expand outer selections
new AugmentScala2Traits, // Expand traits defined in Scala 2.x to simulate old-style rewritings
new ResolveSuper, // Implement super accessors and add forwarders to trait methods
Expand Down
42 changes: 26 additions & 16 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -801,12 +801,17 @@ class Definitions {
def isBottomType(tp: Type) =
tp.derivesFrom(NothingClass) || tp.derivesFrom(NullClass)

/** Is a function class.
/** Is any function class that satisfies:
* - FunctionN for N >= 0
* - ImplicitFunctionN for N > 0
*/
def isFunctionClass(cls: Symbol) = scalaClassName(cls).isFunction

/** Is a function class where
* - FunctionN for N >= 0
*/
def isPlainFunctionClass(cls: Symbol) = scalaClassName(cls).isPlainFunction

/** Is an implicit function class.
* - ImplicitFunctionN for N > 0
*/
Expand Down Expand Up @@ -922,28 +927,33 @@ class Definitions {
}

// Specialized type parameters defined for scala.Function{0,1,2}.
private lazy val Function1SpecializedParams: collection.Set[Type] =
lazy val Function1SpecializedParams: collection.Set[Type] =
Set(IntType, LongType, FloatType, DoubleType)
private lazy val Function2SpecializedParams: collection.Set[Type] =
lazy val Function2SpecializedParams: collection.Set[Type] =
Set(IntType, LongType, DoubleType)
private lazy val Function0SpecializedReturns: collection.Set[Type] =
lazy val Function0SpecializedReturns: collection.Set[Type] =
ScalaNumericValueTypeList.toSet[Type] + UnitType + BooleanType
private lazy val Function1SpecializedReturns: collection.Set[Type] =
lazy val Function1SpecializedReturns: collection.Set[Type] =
Set(UnitType, BooleanType, IntType, FloatType, LongType, DoubleType)
private lazy val Function2SpecializedReturns: collection.Set[Type] =
lazy val Function2SpecializedReturns: collection.Set[Type] =
Function1SpecializedReturns

def isSpecializableFunction(cls: ClassSymbol, paramTypes: List[Type], retType: Type)(implicit ctx: Context) =
isFunctionClass(cls) && (paramTypes match {
cls.derivesFrom(FunctionClass(paramTypes.length)) && (paramTypes match {
case Nil =>
Function0SpecializedReturns.contains(retType)
val specializedReturns = Function0SpecializedReturns.map(_.typeSymbol)
specializedReturns.contains(retType.typeSymbol)
case List(paramType0) =>
Function1SpecializedParams.contains(paramType0) &&
Function1SpecializedReturns.contains(retType)
val specializedParams = Function1SpecializedParams.map(_.typeSymbol)
lazy val specializedReturns = Function1SpecializedReturns.map(_.typeSymbol)
specializedParams.contains(paramType0.typeSymbol) &&
specializedReturns.contains(retType.typeSymbol)
case List(paramType0, paramType1) =>
Function2SpecializedParams.contains(paramType0) &&
Function2SpecializedParams.contains(paramType1) &&
Function2SpecializedReturns.contains(retType)
val specializedParams = Function2SpecializedParams.map(_.typeSymbol)
lazy val specializedReturns = Function2SpecializedReturns.map(_.typeSymbol)
specializedParams.contains(paramType0.typeSymbol) &&
specializedParams.contains(paramType1.typeSymbol) &&
specializedReturns.contains(retType.typeSymbol)
case _ =>
false
})
Expand Down Expand Up @@ -987,9 +997,9 @@ class Definitions {
lazy val ScalaNumericValueTypeList = List(
ByteType, ShortType, CharType, IntType, LongType, FloatType, DoubleType)

private lazy val ScalaNumericValueTypes: collection.Set[TypeRef] = ScalaNumericValueTypeList.toSet
private lazy val ScalaValueTypes: collection.Set[TypeRef] = ScalaNumericValueTypes + UnitType + BooleanType
private lazy val ScalaBoxedTypes = ScalaValueTypes map (t => boxedTypes(t.name))
lazy val ScalaNumericValueTypes: collection.Set[TypeRef] = ScalaNumericValueTypeList.toSet
lazy val ScalaValueTypes: collection.Set[TypeRef] = ScalaNumericValueTypes + UnitType + BooleanType
lazy val ScalaBoxedTypes = ScalaValueTypes map (t => boxedTypes(t.name))

val ScalaNumericValueClasses = new PerRun[collection.Set[Symbol]](implicit ctx => ScalaNumericValueTypes.map(_.symbol))
val ScalaValueClasses = new PerRun[collection.Set[Symbol]](implicit ctx => ScalaValueTypes.map(_.symbol))
Expand Down
25 changes: 21 additions & 4 deletions compiler/src/dotty/tools/dotc/core/NameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,17 @@ object NameOps {
if (n == 0) -1 else n
}

/** Is a function name
/** Is any function name that satisfies
* - FunctionN for N >= 0
* - ImplicitFunctionN for N >= 1
* - false otherwise
*/
def isFunction: Boolean = functionArity >= 0

/** Is a function name
* - FunctionN for N >= 0
*/
def isPlainFunction: Boolean = functionArityFor(str.Function) >= 0

/** Is a implicit function name
* - ImplicitFunctionN for N >= 1
* - false otherwise
Expand Down Expand Up @@ -227,8 +231,10 @@ object NameOps {
case nme.clone_ => nme.clone_
}

def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): name.ThisName = {

/** This method is to be used on **type parameters** from a class, since
* this method does sorting based on their names
*/
def specializedFor(classTargs: List[Types.Type], classTargsNames: List[Name], methodTargs: List[Types.Type], methodTarsNames: List[Name])(implicit ctx: Context): Name = {
val methodTags: Seq[Name] = (methodTargs zip methodTarsNames).sortBy(_._2).map(x => defn.typeTag(x._1))
val classTags: Seq[Name] = (classTargs zip classTargsNames).sortBy(_._2).map(x => defn.typeTag(x._1))

Expand All @@ -237,6 +243,17 @@ object NameOps {
classTags.fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix)
}

/** Use for specializing function names ONLY and use it if you are **not**
* creating specialized name from type parameters. The order of names will
* be:
*
* `<return type><first type><second type><...>`
*/
def specializedFunction(ret: Types.Type, args: List[Types.Type])(implicit ctx: Context): Name =
name ++ nme.specializedTypeNames.prefix ++
nme.specializedTypeNames.separator ++ defn.typeTag(ret) ++
args.map(defn.typeTag).fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix

/** If name length exceeds allowable limit, replace part of it by hash */
def compactified(implicit ctx: Context): TermName = termName(compactify(name.toString))

Expand Down
3 changes: 0 additions & 3 deletions compiler/src/dotty/tools/dotc/transform/ElimByName.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ class ElimByName extends TransformByNameApply with InfoTransformer {

override def phaseName: String = "elimByName"

override def runsAfterGroupsOf = Set(classOf[Splitter])
// I got errors running this phase in an earlier group, but I did not track them down.

/** Map `tree` to `tree.apply()` is `ftree` was of ExprType and becomes now a function */
private def applyIfFunction(tree: Tree, ftree: Tree)(implicit ctx: Context) =
if (isByNameRef(ftree))
Expand Down
156 changes: 156 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SpecializeFunctions.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
package dotty.tools.dotc
package transform

import ast.Trees._, ast.tpd, core._
import Contexts.Context, Types._, Decorators._, Symbols._, DenotTransformers._
import SymDenotations._, Scopes._, StdNames._, NameOps._, Names._
import MegaPhase.MiniPhase

import scala.collection.mutable

/** Specializes classes that inherit from `FunctionN` where there exists a
* specialized form.
*/
class SpecializeFunctions extends MiniPhase with InfoTransformer {
import ast.tpd._
val phaseName = "specializeFunctions"
override def runsAfter = Set(classOf[ElimByName])

private val jFunction = "scala.compat.java8.JFunction".toTermName

/** Transforms the type to include decls for specialized applys */
override def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context) = tp match {
case tp: ClassInfo if !sym.is(Flags.Package) && (tp.decls ne EmptyScope) && derivesFromFn012(sym) =>
var newApplys = Map.empty[Name, Symbol]

var arity = 0
while (arity < 3) {
val func = defn.FunctionClass(arity)
if (tp.derivesFrom(func)) {
val typeParams = tp.cls.typeRef.baseType(func).argInfos
val isSpecializable =
defn.isSpecializableFunction(
sym.asClass,
typeParams.init,
typeParams.last
)

if (isSpecializable && tp.decls.lookup(nme.apply).exists) {
val interface = specInterface(typeParams)
val specializedMethodName = nme.apply.specializedFunction(typeParams.last, typeParams.init)
newApplys += (specializedMethodName -> interface)
}
}
arity += 1
}

def newDecls =
newApplys.toList.map { case (name, interface) =>
ctx.newSymbol(
sym,
name,
Flags.Override | Flags.Method | Flags.Synthetic,
interface.info.decls.lookup(name).info
)
}
.foldLeft(tp.decls.cloneScope) {
(scope, sym) => scope.enter(sym); scope
}

if (newApplys.isEmpty) tp
else tp.derivedClassInfo(decls = newDecls)

case _ => tp
}

/** Transforms the `Template` of the classes to contain forwarders from the
* generic applys to the specialized ones. Also inserts the specialized applys
* in the template body.
*/
override def transformTemplate(tree: Template)(implicit ctx: Context) = {
val cls = tree.symbol.enclosingClass.asClass
if (derivesFromFn012(cls)) {
val applyBuf = new mutable.ListBuffer[Tree]
val newBody = tree.body.mapConserve {
case dt: DefDef if dt.name == nme.apply && dt.vparamss.length == 1 =>
val typeParams = dt.vparamss.head.map(_.symbol.info)
val retType = dt.tpe.widen.finalResultType

val specName = specializedName(nme.apply, typeParams :+ retType)
val specializedApply = cls.info.decls.lookup(specName)
if (specializedApply.exists) {
val apply = specializedApply.asTerm
val specializedDecl =
polyDefDef(apply, trefs => vrefss => {
dt.rhs
.changeOwner(dt.symbol, apply)
.subst(dt.vparamss.flatten.map(_.symbol), vrefss.flatten.map(_.symbol))
})
applyBuf += specializedDecl

// create a forwarding to the specialized apply
cpy.DefDef(dt)(rhs = {
tpd
.ref(apply)
.appliedToArgs(dt.vparamss.head.map(vparam => ref(vparam.symbol)))
})
} else dt

case x => x
}

cpy.Template(tree)(
body = applyBuf.toList ::: newBody
)
} else tree
}

/** Dispatch to specialized `apply`s in user code when available */
override def transformApply(tree: Apply)(implicit ctx: Context) =
tree match {
case Apply(fun, args)
if fun.symbol.name == nme.apply &&
fun.symbol.owner.derivesFrom(defn.FunctionClass(args.length))
=>
val params = (fun.tpe.widen.firstParamTypes :+ tree.tpe).map(_.widenSingleton.dealias)
val isSpecializable =
defn.isSpecializableFunction(
fun.symbol.owner.asClass,
params.init,
params.last)

if (isSpecializable && !params.exists(_.isInstanceOf[ExprType])) {
val specializedApply = specializedName(nme.apply, params)
val newSel = fun match {
case Select(qual, _) =>
qual.select(specializedApply)
case _ =>
(fun.tpe: @unchecked) match {
case TermRef(prefix: ThisType, name) =>
tpd.This(prefix.cls).select(specializedApply)
case TermRef(prefix: NamedType, name) =>
tpd.ref(prefix).select(specializedApply)
}
}

newSel.appliedToArgs(args)
}
else tree

case _ => tree
}

private def specializedName(name: Name, args: List[Type])(implicit ctx: Context) =
name.specializedFunction(args.last, args.init)

private def functionName(typeParams: List[Type])(implicit ctx: Context) =
jFunction ++ (typeParams.length - 1).toString

private def specInterface(typeParams: List[Type])(implicit ctx: Context) =
ctx.getClassIfDefined(functionName(typeParams).specializedFunction(typeParams.last, typeParams.init))

private def derivesFromFn012(sym: Symbol)(implicit ctx: Context): Boolean =
sym.derivesFrom(defn.FunctionClass(0)) ||
sym.derivesFrom(defn.FunctionClass(1)) ||
sym.derivesFrom(defn.FunctionClass(2))
}
Loading