Skip to content

Commit 6d59b1c

Browse files
committed
Miniphase for adding constructor parameters to Java enums
Miniphase for adding constructor parameters if a Scala enum extends a Java enum.
1 parent 747f927 commit 6d59b1c

File tree

4 files changed

+169
-7
lines changed

4 files changed

+169
-7
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class Compiler {
5656
List(new FirstTransform, // Some transformations to put trees into a canonical form
5757
new CheckReentrant, // Internal use only: Check that compiled program has no data races involving global vars
5858
new ElimPackagePrefixes, // Eliminate references to package prefixes in Select nodes
59-
new CookComments) :: // Cook the comments: expand variables, doc, etc.
59+
new CookComments, // Cook the comments: expand variables, doc, etc.
60+
new CompleteJavaEnums) :: // Fill in constructors for Java enums
6061
List(new CheckStatic, // Check restrictions that apply to @static members
6162
new ElimRepeated, // Rewrite vararg parameters and arguments
6263
new ExpandSAMs, // Expand single abstract method closures to anonymous classes

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,15 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
596596
loop(tree, Nil, Nil)
597597
}
598598

599+
/** Decompose a template body into parameters and other statements */
600+
def decomposeTemplateBody(body: List[Tree])(implicit ctx: Context): (List[Tree], List[Tree]) =
601+
body.partition {
602+
case stat: TypeDef => stat.symbol is Flags.Param
603+
case stat: ValOrDefDef =>
604+
stat.symbol.is(Flags.ParamAccessor) && !stat.symbol.isSetter
605+
case _ => false
606+
}
607+
599608
/** An extractor for closures, either contained in a block or standalone.
600609
*/
601610
object closure {

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -509,12 +509,7 @@ class TreePickler(pickler: TastyPickler) {
509509
case tree: Template =>
510510
registerDef(tree.symbol)
511511
writeByte(TEMPLATE)
512-
val (params, rest) = tree.body partition {
513-
case stat: TypeDef => stat.symbol is Flags.Param
514-
case stat: ValOrDefDef =>
515-
stat.symbol.is(Flags.ParamAccessor) && !stat.symbol.isSetter
516-
case _ => false
517-
}
512+
val (params, rest) = decomposeTemplateBody(tree.body)
518513
withLength {
519514
pickleParams(params)
520515
tree.parents.foreach(pickleTree)
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import core._
5+
import Names._
6+
import StdNames.{nme, tpnme}
7+
import Types._
8+
import dotty.tools.dotc.transform.MegaPhase._
9+
import Flags._
10+
import Contexts.Context
11+
import Symbols._
12+
import Constants._
13+
import Decorators._
14+
import DenotTransformers._
15+
16+
object CompleteJavaEnums {
17+
val name: String = "completeJavaEnums"
18+
19+
private val nameParamName: TermName = "$name".toTermName
20+
private val ordinalParamName: TermName = "$ordinal".toTermName
21+
}
22+
23+
/** For Scala enums that inherit from java.lang.Enum:
24+
* Add constructor parameters for `name` and `ordinal` to pass from each
25+
* case to the java.lang.Enum class.
26+
*/
27+
class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
28+
import CompleteJavaEnums._
29+
import ast.tpd._
30+
31+
override def phaseName: String = CompleteJavaEnums.name
32+
33+
override def relaxedTypingInGroup: Boolean = true
34+
// Because it adds additional parameters to some constructors
35+
36+
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type =
37+
if (sym.isConstructor && (
38+
sym == defn.JavaEnumClass.primaryConstructor ||
39+
derivesFromJavaEnum(sym.owner)))
40+
addConstrParams(sym.info)
41+
else tp
42+
43+
/** Is `sym` a Scala enum class that derives from `java.lang.Enum`?
44+
*/
45+
private def derivesFromJavaEnum(sym: Symbol)(implicit ctx: Context) =
46+
sym.is(Enum, butNot = Case) && sym.derivesFrom(defn.JavaEnumClass)
47+
48+
/** Add constructor parameters `$name: String` and `$ordinal: Int` to the end of
49+
* the last parameter list of (method- or poly-) type `tp`.
50+
*/
51+
private def addConstrParams(tp: Type)(implicit ctx: Context): Type = tp match {
52+
case tp: PolyType =>
53+
tp.derivedLambdaType(resType = addConstrParams(tp.resType))
54+
case tp: MethodType =>
55+
tp.resType match {
56+
case restpe: MethodType =>
57+
tp.derivedLambdaType(resType = addConstrParams(restpe))
58+
case _ =>
59+
tp.derivedLambdaType(
60+
paramNames = tp.paramNames ++ List(nameParamName, ordinalParamName),
61+
paramInfos = tp.paramInfos ++ List(defn.StringType, defn.IntType))
62+
}
63+
}
64+
65+
/** The list of parameter definitions `$name: String, $ordinal: Int`, in given `owner`
66+
* with given flags (either `Param` or `ParamAccessor`)
67+
*/
68+
private def addedParams(owner: Symbol, flag: FlagSet)(implicit ctx: Context): List[ValDef] = {
69+
val nameParam = ctx.newSymbol(owner, nameParamName, flag | Synthetic, defn.StringType, coord = owner.span)
70+
val ordinalParam = ctx.newSymbol(owner, ordinalParamName, flag | Synthetic, defn.IntType, coord = owner.span)
71+
List(ValDef(nameParam), ValDef(ordinalParam))
72+
}
73+
74+
/** Add arguments `args` to the parent constructor application in `parents` that invokes
75+
* a constructor of `targetCls`,
76+
*/
77+
private def addEnumConstrArgs(targetCls: Symbol, parents: List[Tree], args: List[Tree])(implicit ctx: Context): List[Tree] =
78+
parents.map {
79+
case app @ Apply(fn, args0) if fn.symbol.owner == targetCls => cpy.Apply(app)(fn, args0 ++ args)
80+
case p => p
81+
}
82+
83+
/** 1. If this is a constructor of a enum class that extends, add $name and $ordinal parameters to it.
84+
*
85+
* 2. If this is a $new method that creates simple cases, pass $name and $ordinal parameters
86+
* to the enum superclass. The $new method looks like this:
87+
*
88+
* def $new(..., enumTag: Int, name: String) = {
89+
* class $anon extends E(...) { ... }
90+
* new $anon
91+
* }
92+
*
93+
* After the transform it is expanded to
94+
*
95+
* def $new(..., enumTag: Int, name: String) = {
96+
* class $anon extends E(..., name, enumTag) { ... }
97+
* new $anon
98+
* }
99+
*/
100+
override def transformDefDef(tree: DefDef)(implicit ctx: Context): DefDef = {
101+
val sym = tree.symbol
102+
if (sym.isConstructor && derivesFromJavaEnum(sym.owner))
103+
cpy.DefDef(tree)(
104+
vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, Param)))
105+
else if (sym.name == nme.DOLLAR_NEW && derivesFromJavaEnum(sym.owner.linkedClass)) {
106+
val Block((tdef @ TypeDef(tpnme.ANON_CLASS, templ: Template)) :: Nil, call) = tree.rhs
107+
val args = tree.vparamss.last.takeRight(2).map(param => ref(param.symbol)).reverse
108+
val templ1 = cpy.Template(templ)(
109+
parents = addEnumConstrArgs(sym.owner.linkedClass, templ.parents, args))
110+
cpy.DefDef(tree)(
111+
rhs = cpy.Block(tree.rhs)(cpy.TypeDef(tdef)(tdef.name, templ1) :: Nil, call))
112+
}
113+
else tree
114+
}
115+
116+
/** 1. If this is an enum class, add $name and $ordinal parameters to its
117+
* parameter accessors and pass them on to the java.lang.Enum constructor.
118+
*
119+
* 2. If this is an anonymous class that implement a value enum case,
120+
* pass $name and $ordinal parameters to the enum superclass. The class
121+
* looks like this:
122+
*
123+
* class $anon extends E(...) {
124+
* ...
125+
* def enumTag = N
126+
* def toString = S
127+
* ...
128+
* }
129+
*
130+
* After the transform it is expanded to
131+
*
132+
* class $anon extends E(..., N, S) {
133+
* "same as before"
134+
* }
135+
*/
136+
override def transformTemplate(templ: Template)(implicit ctx: Context): Template = {
137+
val cls = templ.symbol.owner
138+
if (derivesFromJavaEnum(cls)) {
139+
val (params, rest) = decomposeTemplateBody(templ.body)
140+
val addedDefs = addedParams(cls, ParamAccessor)
141+
val addedSyms = addedDefs.map(_.symbol.entered)
142+
cpy.Template(templ)(
143+
parents = addEnumConstrArgs(defn.JavaEnumClass, templ.parents, addedSyms.map(ref)),
144+
body = params ++ addedDefs ++ rest)
145+
}
146+
else if (cls.isAnonymousClass && cls.owner.is(EnumCase) && derivesFromJavaEnum(cls.owner.owner.linkedClass)) {
147+
def rhsOf(name: TermName) =
148+
templ.body.collect {
149+
case mdef: DefDef if mdef.name == name => mdef.rhs
150+
}.head
151+
val args = List(rhsOf(nme.toString_), rhsOf(nme.enumTag))
152+
cpy.Template(templ)(
153+
parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args))
154+
}
155+
else templ
156+
}
157+
}

0 commit comments

Comments
 (0)