Skip to content

Support tuple specialisation #15060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/Compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class Compiler {
new InterceptedMethods, // Special handling of `==`, `|=`, `getClass` methods
new Getters, // Replace non-private vals and vars with getter defs (fields are added later)
new SpecializeFunctions, // Specialized Function{0,1,2} by replacing super with specialized super
new SpecializeTuples, // Specializes Tuples by replacing tuple construction and selection trees
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
new CollectNullableFields, // Collect fields that can be nulled out after use in lazy initialization
new ElimOuterSelect, // Expand outer selections
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/Run.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,8 +282,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint

private def printTree(last: PrintedTree)(using Context): PrintedTree = {
val unit = ctx.compilationUnit
val prevPhase = ctx.phase.prev // can be a mini-phase
val fusedPhase = ctx.base.fusedContaining(prevPhase)
val fusedPhase = ctx.phase.prevMega
val echoHeader = f"[[syntax trees at end of $fusedPhase%25s]] // ${unit.source}"
val tree = if ctx.isAfterTyper then unit.tpdTree else unit.untpdTree
val treeString = tree.show(using ctx.withProperty(XprintMode, Some(())))
Expand Down
27 changes: 27 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1329,6 +1329,12 @@ class Definitions {

@tu lazy val TupleType: Array[TypeRef | Null] = mkArityArray("scala.Tuple", MaxTupleArity, 1)

def isSpecializedTuple(cls: Symbol)(using Context): Boolean =
cls.isClass && TupleSpecializedClasses.exists(tupleCls => cls.name.isSpecializedNameOf(tupleCls.name))

def SpecializedTuple(base: Symbol, args: List[Type])(using Context): Symbol =
base.owner.requiredClass(base.name.specializedName(args))

private class FunType(prefix: String):
private var classRefs: Array[TypeRef | Null] = new Array(22)
def apply(n: Int): TypeRef =
Expand Down Expand Up @@ -1587,6 +1593,20 @@ class Definitions {
def isFunctionType(tp: Type)(using Context): Boolean =
isNonRefinedFunction(tp.dropDependentRefinement)

private def withSpecMethods(cls: ClassSymbol, bases: List[Name], paramTypes: Set[TypeRef]) =
for base <- bases; tp <- paramTypes do
cls.enter(newSymbol(cls, base.specializedName(List(tp)), Method, ExprType(tp)))
cls

@tu lazy val Tuple1: ClassSymbol = withSpecMethods(requiredClass("scala.Tuple1"), List(nme._1), Tuple1SpecializedParamTypes)
@tu lazy val Tuple2: ClassSymbol = withSpecMethods(requiredClass("scala.Tuple2"), List(nme._1, nme._2), Tuple2SpecializedParamTypes)

@tu lazy val TupleSpecializedClasses: Set[Symbol] = Set(Tuple1, Tuple2)
@tu lazy val Tuple1SpecializedParamTypes: Set[TypeRef] = Set(IntType, LongType, DoubleType)
@tu lazy val Tuple2SpecializedParamTypes: Set[TypeRef] = Set(IntType, LongType, DoubleType, CharType, BooleanType)
@tu lazy val Tuple1SpecializedParamClasses: PerRun[Set[Symbol]] = new PerRun(Tuple1SpecializedParamTypes.map(_.symbol))
@tu lazy val Tuple2SpecializedParamClasses: PerRun[Set[Symbol]] = new PerRun(Tuple2SpecializedParamTypes.map(_.symbol))

// Specialized type parameters defined for scala.Function{0,1,2}.
@tu lazy val Function1SpecializedParamTypes: collection.Set[TypeRef] =
Set(IntType, LongType, FloatType, DoubleType)
Expand All @@ -1610,6 +1630,13 @@ class Definitions {
@tu lazy val Function2SpecializedReturnClasses: PerRun[collection.Set[Symbol]] =
new PerRun(Function2SpecializedReturnTypes.map(_.symbol))

def isSpecializableTuple(base: Symbol, args: List[Type])(using Context): Boolean =
args.length <= 2 && base.isClass && TupleSpecializedClasses.exists(base.asClass.derivesFrom) && args.match
case List(x) => Tuple1SpecializedParamClasses().contains(x.classSymbol)
case List(x, y) => Tuple2SpecializedParamClasses().contains(x.classSymbol) && Tuple2SpecializedParamClasses().contains(y.classSymbol)
case _ => false
&& base.owner.denot.info.member(base.name.specializedName(args)).exists // when dotc compiles the stdlib there are no specialised classes

def isSpecializableFunction(cls: ClassSymbol, paramTypes: List[Type], retType: Type)(using Context): Boolean =
paramTypes.length <= 2
&& (cls.derivesFrom(FunctionClass(paramTypes.length)) || isByNameFunctionClass(cls))
Expand Down
24 changes: 24 additions & 0 deletions compiler/src/dotty/tools/dotc/core/NameOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import scala.io.Codec
import Int.MaxValue
import Names._, StdNames._, Contexts._, Symbols._, Flags._, NameKinds._, Types._
import util.Chars.{isOperatorPart, digit2int}
import Decorators.*
import Definitions._
import nme._

Expand Down Expand Up @@ -278,6 +279,29 @@ object NameOps {
classTags.fold(nme.EMPTY)(_ ++ _) ++ nme.specializedTypeNames.suffix)
}

/** Determines if the current name is the specialized name of the given base name.
* For example `typeName("Tuple2$mcII$sp").isSpecializedNameOf(tpnme.Tuple2) == true`
*/
def isSpecializedNameOf(base: N)(using Context): Boolean =
var i = 0
inline def nextString(str: String) = name.startsWith(str, i) && { i += str.length; true }
nextString(base.toString)
&& nextString(nme.specializedTypeNames.prefix.toString)
&& nextString(nme.specializedTypeNames.separator.toString)
&& name.endsWith(nme.specializedTypeNames.suffix.toString)

/** Returns the name of the class specialised to the provided types,
* in the given order. Used for the specialized tuple classes.
*/
def specializedName(args: List[Type])(using Context): N =
val sb = new StringBuilder
sb.append(name.toString)
sb.append(nme.specializedTypeNames.prefix.toString)
sb.append(nme.specializedTypeNames.separator)
args.foreach { arg => sb.append(defn.typeTag(arg)) }
sb.append(nme.specializedTypeNames.suffix)
likeSpacedN(termName(sb.toString))

/** Use for specializing function names ONLY and use it if you are **not**
* creating specialized name from type parameters. The order of names will
* be:
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,9 @@ object Phases {
final def prev: Phase =
if (id > FirstPhaseId) myBase.phases(start - 1) else NoPhase

final def prevMega(using Context): Phase =
ctx.base.fusedContaining(ctx.phase.prev)

final def next: Phase =
if (hasNext) myBase.phases(end + 1) else NoPhase

Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ object SymDenotations {
if (myFlags.is(Trait)) NoInitsInterface & bodyFlags // no parents are initialized from a trait
else NoInits & bodyFlags & parentFlags)

final def setStableConstructor()(using Context): Unit =
val ctorStable = if myFlags.is(Trait) then myFlags.is(NoInits) else isNoInitsRealClass
if ctorStable then primaryConstructor.setFlag(StableRealizable)

def isCurrent(fs: FlagSet)(using Context): Boolean =
def knownFlags(info: Type): FlagSet = info match
case _: SymbolLoader | _: ModuleCompleter => FromStartFlags
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,10 @@ class ClassfileParser(
else return unpickleTASTY(bytes)
}

if (scan(tpnme.ScalaATTR) && !scalaUnpickleWhitelist.contains(classRoot.name))
if scan(tpnme.ScalaATTR) && !scalaUnpickleWhitelist.contains(classRoot.name)
&& !(classRoot.name.startsWith("Tuple") && classRoot.name.endsWith("$sp"))
&& !(classRoot.name.startsWith("Product") && classRoot.name.endsWith("$sp"))
then
// To understand the situation, it's helpful to know that:
// - Scalac emits the `ScalaSig` attribute for classfiles with pickled information
// and the `Scala` attribute for everything else.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,9 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
// we need the checkNonCyclic call to insert LazyRefs for F-bounded cycles
else if (!denot.is(Param)) tp1.translateFromRepeated(toArray = false)
else tp1
if (denot.isConstructor) addConstructorTypeParams(denot)
if (denot.isConstructor)
denot.owner.setStableConstructor()
addConstructorTypeParams(denot)
if (atEnd)
assert(!denot.symbol.isSuperAccessor, denot)
else {
Expand Down
11 changes: 10 additions & 1 deletion compiler/src/dotty/tools/dotc/printing/Formatting.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,16 @@ object Formatting {
object ShowAny extends Show[Any]:
def show(x: Any): Shown = x

class ShowImplicits2:
class ShowImplicits3:
given Show[Product] = ShowAny

class ShowImplicits2 extends ShowImplicits3:
given Show[ParamInfo] with
def show(x: ParamInfo) = x match
case x: Symbol => Show[x.type].show(x)
case x: LambdaParam => Show[x.type].show(x)
case _ => ShowAny

class ShowImplicits1 extends ShowImplicits2:
given Show[ImplicitRef] = ShowAny
given Show[Names.Designator] = ShowAny
Expand Down Expand Up @@ -99,6 +106,8 @@ object Formatting {
val sep = StringContext.processEscapes(rawsep)
if (rest.nonEmpty) (arg.map(showArg).mkString(sep), rest.tail)
else (arg, suffix)
case arg: Seq[?] =>
(arg.map(showArg).mkString("[", ", ", "]"), suffix)
case _ =>
(showArg(arg), suffix)
}
Expand Down
53 changes: 53 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SpecializeTuples.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package dotty.tools
package dotc
package transform

import ast.Trees.*, ast.tpd, core.*
import Contexts.*, Types.*, Decorators.*, Symbols.*, DenotTransformers.*
import SymDenotations.*, Scopes.*, StdNames.*, NameOps.*, Names.*
import MegaPhase.MiniPhase
import typer.Inliner.isElideableExpr

/** Specializes Tuples by replacing tuple construction and selection trees.
*
* Specifically:
* 1. Replaces `(1, 1)` (which is `Tuple2.apply[Int, Int](1, 1)`) and
* `new Tuple2[Int, Int](1, 1)` with `new Tuple2$mcII$sp(1, 1)`.
* 2. Replaces `(_: Tuple2[Int, Int])._1` with `(_: Tuple2[Int, Int])._1$mcI$sp`
*/
class SpecializeTuples extends MiniPhase:
import tpd.*

override def phaseName: String = SpecializeTuples.name
override def description: String = SpecializeTuples.description
override def isEnabled(using Context): Boolean = !ctx.settings.scalajs.value

override def transformApply(tree: Apply)(using Context): Tree = tree match
case Apply(TypeApply(fun: NameTree, targs), args)
if fun.symbol.name == nme.apply && fun.symbol.exists && defn.isSpecializableTuple(fun.symbol.owner.companionClass, targs.map(_.tpe))
&& isElideableExpr(tree)
=>
cpy.Apply(tree)(Select(New(defn.SpecializedTuple(fun.symbol.owner.companionClass, targs.map(_.tpe)).typeRef), nme.CONSTRUCTOR), args).withType(tree.tpe)
case Apply(TypeApply(fun: NameTree, targs), args)
if fun.symbol.name == nme.CONSTRUCTOR && fun.symbol.exists && defn.isSpecializableTuple(fun.symbol.owner, targs.map(_.tpe))
&& isElideableExpr(tree)
=>
cpy.Apply(tree)(Select(New(defn.SpecializedTuple(fun.symbol.owner, targs.map(_.tpe)).typeRef), nme.CONSTRUCTOR), args).withType(tree.tpe)
case _ => tree
end transformApply

override def transformSelect(tree: Select)(using Context): Tree = tree match
case Select(qual, nme._1) if isAppliedSpecializableTuple(qual.tpe.widen) =>
Select(qual, nme._1.specializedName(qual.tpe.widen.argInfos.slice(0, 1)))
case Select(qual, nme._2) if isAppliedSpecializableTuple(qual.tpe.widen) =>
Select(qual, nme._2.specializedName(qual.tpe.widen.argInfos.slice(1, 2)))
case _ => tree

private def isAppliedSpecializableTuple(tp: Type)(using Context) = tp match
case AppliedType(tycon, args) => defn.isSpecializableTuple(tycon.classSymbol, args)
case _ => false
end SpecializeTuples

object SpecializeTuples:
val name: String = "specializeTuples"
val description: String = "replaces tuple construction and selection trees"
14 changes: 10 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/TreeChecker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class TreeChecker extends Phase with SymTransformer {
if (ctx.phaseId <= erasurePhase.id) {
val initial = symd.initial
assert(symd == initial || symd.signature == initial.signature,
i"""Signature of ${sym.showLocated} changed at phase ${ctx.base.fusedContaining(ctx.phase.prev)}
i"""Signature of ${sym.showLocated} changed at phase ${ctx.phase.prevMega}
|Initial info: ${initial.info}
|Initial sig : ${initial.signature}
|Current info: ${symd.info}
Expand Down Expand Up @@ -122,8 +122,7 @@ class TreeChecker extends Phase with SymTransformer {
}

def check(phasesToRun: Seq[Phase], ctx: Context): Tree = {
val prevPhase = ctx.phase.prev // can be a mini-phase
val fusedPhase = ctx.base.fusedContaining(prevPhase)
val fusedPhase = ctx.phase.prevMega(using ctx)
report.echo(s"checking ${ctx.compilationUnit} after phase ${fusedPhase}")(using ctx)

inContext(ctx) {
Expand All @@ -145,7 +144,7 @@ class TreeChecker extends Phase with SymTransformer {
catch {
case NonFatal(ex) => //TODO CHECK. Check that we are bootstrapped
inContext(checkingCtx) {
println(i"*** error while checking ${ctx.compilationUnit} after phase ${ctx.phase.prev} ***")
println(i"*** error while checking ${ctx.compilationUnit} after phase ${ctx.phase.prevMega(using ctx)} ***")
}
throw ex
}
Expand Down Expand Up @@ -422,6 +421,13 @@ class TreeChecker extends Phase with SymTransformer {
assert(tree.qual.typeOpt.isInstanceOf[ThisType], i"expect prefix of Super to be This, actual = ${tree.qual}")
super.typedSuper(tree, pt)

override def typedApply(tree: untpd.Apply, pt: Type)(using Context): Tree = tree match
case Apply(Select(qual, nme.CONSTRUCTOR), _)
if !ctx.phase.erasedTypes
&& defn.isSpecializedTuple(qual.typeOpt.typeSymbol) =>
promote(tree) // e.g. `new Tuple2$mcII$sp(7, 8)` should keep its `(7, 8)` type instead of `Tuple2$mcII$sp`
case _ => super.typedApply(tree, pt)

override def typedTyped(tree: untpd.Typed, pt: Type)(using Context): Tree =
val tpt1 = checkSimpleKinded(typedType(tree.tpt))
val expr1 = tree.expr match
Expand Down
Loading