Skip to content

Commit 1821d22

Browse files
authored
Merge pull request #9539 from dotty-staging/topic/enum-reduce-bytecode
remove dollar ordinal from Enum
2 parents f46c030 + 58ca371 commit 1821d22

File tree

12 files changed

+177
-54
lines changed

12 files changed

+177
-54
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,20 +125,24 @@ object DesugarEnums {
125125
/** A creation method for a value of enum type `E`, which is defined as follows:
126126
*
127127
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
128-
* def $ordinal = $tag
129-
* override def toString = $name
128+
* def ordinal = _$ordinal // if `E` does not derive from jl.Enum
129+
* override def toString = $name // if `E` does not derive from jl.Enum
130130
* $values.register(this)
131131
* }
132132
*/
133133
private def enumValueCreator(using Context) = {
134-
val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_))
135-
val toStringDef = toStringMeth(Ident(nme.nameDollar))
134+
val fieldMethods =
135+
if isJavaEnum then Nil
136+
else
137+
val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_))
138+
val toStringDef = toStringMeth(Ident(nme.nameDollar))
139+
List(ordinalDef, toStringDef)
136140
val creator = New(Template(
137141
constr = emptyConstructor,
138142
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
139143
derived = Nil,
140144
self = EmptyValDef,
141-
body = ordinalDef :: toStringDef :: registerCall :: Nil
145+
body = fieldMethods ::: registerCall :: Nil
142146
).withAttachment(ExtendsSingletonMirror, ()))
143147
DefDef(nme.DOLLAR_NEW, Nil,
144148
List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))),
@@ -264,8 +268,10 @@ object DesugarEnums {
264268
def param(name: TermName, typ: Type)(using Context) =
265269
ValDef(name, TypeTree(typ), EmptyTree).withFlags(Param)
266270

271+
private def isJavaEnum(using Context): Boolean = ctx.owner.linkedClass.derivesFrom(defn.JavaEnumClass)
272+
267273
def ordinalMeth(body: Tree)(using Context): DefDef =
268-
DefDef(nme.ordinalDollar, Nil, Nil, TypeTree(defn.IntType), body)
274+
DefDef(nme.ordinal, Nil, Nil, TypeTree(defn.IntType), body)
269275

270276
def toStringMeth(body: Tree)(using Context): DefDef =
271277
DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), body).withFlags(Override)
@@ -284,12 +290,16 @@ object DesugarEnums {
284290
expandSimpleEnumCase(name, mods, span)
285291
else {
286292
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
287-
val ordinalDef = ordinalMethLit(tag)
288-
val toStringDef = toStringMethLit(name.toString)
293+
val fieldMethods =
294+
if isJavaEnum then Nil
295+
else
296+
val ordinalDef = ordinalMethLit(tag)
297+
val toStringDef = toStringMethLit(name.toString)
298+
List(ordinalDef, toStringDef)
289299
val impl1 = cpy.Template(impl)(
290300
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
291-
body = ordinalDef :: toStringDef :: registerCall :: Nil
292-
).withAttachment(ExtendsSingletonMirror, ())
301+
body = fieldMethods ::: registerCall :: Nil)
302+
.withAttachment(ExtendsSingletonMirror, ())
293303
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
294304
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
295305
}

compiler/src/dotty/tools/dotc/transform/CompleteJavaEnums.scala

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import DenotTransformers._
1515
import dotty.tools.dotc.ast.Trees._
1616
import SymUtils._
1717

18+
import annotation.threadUnsafe
19+
1820
object CompleteJavaEnums {
1921
val name: String = "completeJavaEnums"
2022

@@ -62,9 +64,10 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
6264
/** The list of parameter definitions `$name: String, $ordinal: Int`, in given `owner`
6365
* with given flags (either `Param` or `ParamAccessor`)
6466
*/
65-
private def addedParams(owner: Symbol, flag: FlagSet)(using Context): List[ValDef] = {
66-
val nameParam = newSymbol(owner, nameParamName, flag | Synthetic, defn.StringType, coord = owner.span)
67-
val ordinalParam = newSymbol(owner, ordinalParamName, flag | Synthetic, defn.IntType, coord = owner.span)
67+
private def addedParams(owner: Symbol, isLocal: Boolean, flag: FlagSet)(using Context): List[ValDef] = {
68+
val flags = flag | Synthetic | (if isLocal then Private | Deferred else EmptyFlags)
69+
val nameParam = newSymbol(owner, nameParamName, flags, defn.StringType, coord = owner.span)
70+
val ordinalParam = newSymbol(owner, ordinalParamName, flags, defn.IntType, coord = owner.span)
6871
List(ValDef(nameParam), ValDef(ordinalParam))
6972
}
7073

@@ -85,7 +88,7 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
8588
val sym = tree.symbol
8689
if (sym.isConstructor && sym.owner.derivesFromJavaEnum)
8790
val tree1 = cpy.DefDef(tree)(
88-
vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, Param)))
91+
vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, isLocal=false, Param)))
8992
sym.setParamssFromDefs(tree1.tparams, tree1.vparamss)
9093
tree1
9194
else tree
@@ -107,47 +110,68 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
107110
}
108111
}
109112

113+
private def isJavaEnumValueImpl(cls: Symbol)(using Context): Boolean =
114+
cls.isAnonymousClass
115+
&& (((cls.owner.name eq nme.DOLLAR_NEW) && cls.owner.isAllOf(Private|Synthetic)) || cls.owner.isAllOf(EnumCase))
116+
&& cls.owner.owner.linkedClass.derivesFromJavaEnum
117+
118+
private val enumCaseOrdinals: MutableSymbolMap[Int] = newMutableSymbolMap
119+
120+
private def registerEnumClass(cls: Symbol)(using Context): Unit =
121+
cls.children.zipWithIndex.foreach(enumCaseOrdinals.put)
122+
123+
private def ordinalFor(enumCase: Symbol): Int =
124+
enumCaseOrdinals.remove(enumCase).get
125+
110126
/** 1. If this is an enum class, add $name and $ordinal parameters to its
111127
* parameter accessors and pass them on to the java.lang.Enum constructor.
112128
*
113-
* 2. If this is an anonymous class that implement a value enum case,
129+
* 2. If this is an anonymous class that implement a singleton enum case,
114130
* pass $name and $ordinal parameters to the enum superclass. The class
115131
* looks like this:
116132
*
117133
* class $anon extends E(...) {
118134
* ...
119-
* def ordinal = N
120-
* def toString = S
121-
* ...
122135
* }
123136
*
124137
* After the transform it is expanded to
125138
*
126-
* class $anon extends E(..., N, S) {
127-
* "same as before"
139+
* class $anon extends E(..., $name, _$ordinal) { // if class implements a simple enum case
140+
* "same as before"
141+
* }
142+
*
143+
* class $anon extends E(..., "A", 0) { // if class implements a value enum case `A` with ordinal 0
144+
* "same as before"
128145
* }
129146
*/
130-
override def transformTemplate(templ: Template)(using Context): Template = {
147+
override def transformTemplate(templ: Template)(using Context): Tree = {
131148
val cls = templ.symbol.owner
132-
if (cls.derivesFromJavaEnum) {
149+
if cls.derivesFromJavaEnum then
150+
registerEnumClass(cls) // invariant: class is visited before cases: see tests/pos/enum-companion-first.scala
133151
val (params, rest) = decomposeTemplateBody(templ.body)
134-
val addedDefs = addedParams(cls, ParamAccessor)
152+
val addedDefs = addedParams(cls, isLocal=true, ParamAccessor)
135153
val addedSyms = addedDefs.map(_.symbol.entered)
136154
val addedForwarders = addedEnumForwarders(cls)
137155
cpy.Template(templ)(
138156
parents = addEnumConstrArgs(defn.JavaEnumClass, templ.parents, addedSyms.map(ref)),
139157
body = params ++ addedDefs ++ addedForwarders ++ rest)
140-
}
141-
else if (cls.isAnonymousClass && ((cls.owner.name eq nme.DOLLAR_NEW) || cls.owner.isAllOf(EnumCase)) &&
142-
cls.owner.owner.linkedClass.derivesFromJavaEnum) {
143-
def rhsOf(name: TermName) =
144-
templ.body.collect {
145-
case mdef: DefDef if mdef.name == name => mdef.rhs
146-
}.head
147-
val args = List(rhsOf(nme.toString_), rhsOf(nme.ordinalDollar))
158+
else if isJavaEnumValueImpl(cls) then
159+
def creatorParamRef(name: TermName) =
160+
ref(cls.owner.paramSymss.head.find(_.name == name).get)
161+
val args =
162+
if cls.owner.isAllOf(EnumCase) then
163+
List(Literal(Constant(cls.owner.name.toString)), Literal(Constant(ordinalFor(cls.owner))))
164+
else
165+
List(creatorParamRef(nme.nameDollar), creatorParamRef(nme.ordinalDollar_))
148166
cpy.Template(templ)(
149-
parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args))
150-
}
167+
parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args),
168+
)
169+
else if cls.linkedClass.derivesFromJavaEnum then
170+
enumCaseOrdinals.clear() // remove simple cases // invariant: companion is visited after cases
171+
templ
151172
else templ
152173
}
174+
175+
override def checkPostCondition(tree: Tree)(using Context): Unit =
176+
assert(enumCaseOrdinals.isEmpty, "Java based enum ordinal cache was not cleared")
153177
}

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
5757
private var myValueSymbols: List[Symbol] = Nil
5858
private var myCaseSymbols: List[Symbol] = Nil
5959
private var myCaseModuleSymbols: List[Symbol] = Nil
60-
private var myEnumCaseSymbols: List[Symbol] = Nil
6160

6261
private def initSymbols(using Context) =
6362
if (myValueSymbols.isEmpty) {
@@ -66,13 +65,11 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
6665
defn.Product_productArity, defn.Product_productPrefix, defn.Product_productElement,
6766
defn.Product_productElementName)
6867
myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals)
69-
myEnumCaseSymbols = List(defn.Enum_ordinal)
7068
}
7169

7270
def valueSymbols(using Context): List[Symbol] = { initSymbols; myValueSymbols }
7371
def caseSymbols(using Context): List[Symbol] = { initSymbols; myCaseSymbols }
7472
def caseModuleSymbols(using Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
75-
def enumCaseSymbols(using Context): List[Symbol] = { initSymbols; myEnumCaseSymbols }
7673

7774
private def existingDef(sym: Symbol, clazz: ClassSymbol)(using Context): Symbol = {
7875
val existing = sym.matchingMember(clazz.thisType)
@@ -96,9 +93,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
9693
val symbolsToSynthesize: List[Symbol] =
9794
if (clazz.is(Case))
9895
if (clazz.is(Module)) caseModuleSymbols
99-
else if (isEnumCase) caseSymbols ++ enumCaseSymbols
10096
else caseSymbols
101-
else if (isEnumCase) enumCaseSymbols
10297
else if (isDerivedValueClass(clazz)) valueSymbols
10398
else Nil
10499

@@ -128,7 +123,6 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
128123
case nme.productPrefix => ownName
129124
case nme.productElement => productElementBody(accessors.length, vrefss.head.head)
130125
case nme.productElementName => productElementNameBody(accessors.length, vrefss.head.head)
131-
case nme.ordinal => Select(This(clazz), nme.ordinalDollar)
132126
}
133127
report.log(s"adding $synthetic to $clazz at ${ctx.phase}")
134128
synthesizeDef(synthetic, syntheticRHS)

docs/docs/reference/enums/desugarEnums.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ map into `case class`es or `val`s.
3636
```
3737
expands to a `sealed abstract` class that extends the `scala.Enum` trait and
3838
an associated companion object that contains the defined cases, expanded according
39-
to rules (2 - 8). The enum trait starts with a compiler-generated import that imports
40-
the names `<caseIds>` of all cases so that they can be used without prefix in the trait.
39+
to rules (2 - 8). The enum class starts with a compiler-generated import that imports
40+
the names `<caseIds>` of all cases so that they can be used without prefix in the class.
4141
```scala
4242
sealed abstract class E ... extends <parents> with scala.Enum {
4343
import E.{ <caseIds> }
@@ -174,13 +174,15 @@ If `E` contains at least one simple case, its companion object will define in ad
174174
follows.
175175
```scala
176176
private def $new(_$ordinal: Int, $name: String) = new E with runtime.EnumValue {
177-
def $ordinal = $_ordinal
178-
override def toString = $name
177+
def ordinal = _$ordinal // if `E` does not have `java.lang.Enum` as a parent
178+
override def toString = $name // if `E` does not have `java.lang.Enum` as a parent
179179
$values.register(this) // register enum value so that `valueOf` and `values` can return it.
180180
}
181181
```
182182

183-
The `$ordinal` method above is used to generate the `ordinal` method if the enum does not extend a `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it.
183+
The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`.
184+
The `ordinal` method is only generated if the enum does not extend from `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it. Similarly there is no need to override `toString` as that is defined in terms of `name` in
185+
`java.lang.Enum`.
184186

185187
### Scopes for Enum Cases
186188

library/src-bootstrapped/scala/Enum.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,3 @@ trait Enum extends Product, Serializable:
55

66
/** A number uniquely identifying a case of an enum */
77
def ordinal: Int
8-
protected def $ordinal: Int
9-

library/src-non-bootstrapped/scala/Enum.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package scala
22

33
/** A base trait of all enum classes */
4-
trait Enum:
4+
trait Enum extends Product, Serializable:
55

66
/** A number uniquely identifying a case of an enum */
77
def ordinal: Int

tests/pos/enum-List-control.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
abstract sealed class List[T] extends Enum
22
object List {
33
final class Cons[T](x: T, xs: List[T]) extends List[T] {
4-
def $ordinal = 0
4+
def ordinal = 0
55
def canEqual(that: Any): Boolean = that.isInstanceOf[Cons[_]]
66
def productArity: Int = 2
77
def productElement(n: Int): Any = n match
@@ -12,7 +12,7 @@ object List {
1212
def apply[T](x: T, xs: List[T]): List[T] = new Cons(x, xs)
1313
}
1414
final class Nil[T]() extends List[T], runtime.EnumValue {
15-
def $ordinal = 1
15+
def ordinal = 1
1616
}
1717
object Nil {
1818
def apply[T](): List[T] = new Nil()

tests/pos/enum-companion-first.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
object Planet:
2+
final val G = 6.67300E-11
3+
4+
enum Planet(mass: Double, radius: Double) extends java.lang.Enum[Planet]:
5+
def surfaceGravity = Planet.G * mass / (radius * radius)
6+
def surfaceWeight(otherMass: Double) = otherMass * surfaceGravity
7+
8+
case Mercury extends Planet(3.303e+23, 2.4397e6)
9+
case Venus extends Planet(4.869e+24, 6.0518e6)

tests/run/enum-ordinal-java/Lib.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
object Lib1:
2+
trait MyJavaEnum[E <: java.lang.Enum[E]] extends java.lang.Enum[E]
3+
4+
object Lib2:
5+
type JavaEnumAlias[E <: java.lang.Enum[E]] = java.lang.Enum[E]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
enum Color1 extends Lib1.MyJavaEnum[Color1]:
2+
case Red, Green, Blue
3+
4+
enum Color2 extends Lib2.JavaEnumAlias[Color2]:
5+
case Red, Green, Blue
6+
7+
@main def Test =
8+
assert(Color1.Green.ordinal == 1)
9+
assert(Color2.Blue.ordinal == 2)

tests/run/enum-values-order.scala

Lines changed: 75 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,80 @@
11
/** immutable hashmaps (as of 2.13 collections) only store up to 4 entries in insertion order */
22
enum LatinAlphabet { case A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z }
33

4+
enum LatinAlphabet2 extends java.lang.Enum[LatinAlphabet2] { case A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z }
5+
6+
enum LatinAlphabet3[+T] extends java.lang.Enum[LatinAlphabet3[_]] { case A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z }
7+
8+
object Color:
9+
trait Pretty
10+
enum Color extends java.lang.Enum[Color]:
11+
case Red, Green, Blue
12+
case Aqua extends Color with Color.Pretty
13+
case Grey, Black, White
14+
case Emerald extends Color with Color.Pretty
15+
case Brown
16+
417
@main def Test =
5-
import LatinAlphabet._
6-
val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z)
718

8-
assert(ordered sameElements LatinAlphabet.values)
19+
20+
def testLatin() =
21+
22+
val ordinals = Seq(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25)
23+
val labels = Seq("A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z")
24+
25+
def testLatin1() =
26+
import LatinAlphabet._
27+
val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z)
28+
29+
assert(ordered sameElements LatinAlphabet.values)
30+
assert(ordinals == ordered.map(_.ordinal))
31+
assert(labels == ordered.map(_.productPrefix))
32+
33+
def testLatin2() =
34+
import LatinAlphabet2._
35+
val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z)
36+
37+
assert(ordered sameElements LatinAlphabet2.values)
38+
assert(ordinals == ordered.map(_.ordinal))
39+
assert(labels == ordered.map(_.name))
40+
41+
def testLatin3() =
42+
import LatinAlphabet3._
43+
val ordered = Seq(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V, W, X, Y, Z)
44+
45+
assert(ordered sameElements LatinAlphabet3.values)
46+
assert(ordinals == ordered.map(_.ordinal))
47+
assert(labels == ordered.map(_.name))
48+
49+
testLatin1()
50+
testLatin2()
51+
testLatin3()
52+
53+
end testLatin
54+
55+
def testColor() =
56+
import Color._
57+
val ordered = Seq(Red, Green, Blue, Aqua, Grey, Black, White, Emerald, Brown)
58+
val ordinals = Seq(0, 1, 2, 3, 4, 5, 6, 7, 8)
59+
val labels = Seq("Red", "Green", "Blue", "Aqua", "Grey", "Black", "White", "Emerald", "Brown")
60+
61+
assert(ordered sameElements Color.values)
62+
assert(ordinals == ordered.map(_.ordinal))
63+
assert(labels == ordered.map(_.name))
64+
65+
def isPretty(c: Color): Boolean = c match
66+
case _: Pretty => true
67+
case _ => false
68+
69+
assert(!isPretty(Brown))
70+
assert(!isPretty(Grey))
71+
assert(isPretty(Aqua))
72+
assert(isPretty(Emerald))
73+
assert(Emerald.getClass != Aqua.getClass)
74+
assert(Aqua.getClass != Grey.getClass)
75+
assert(Grey.getClass == Brown.getClass)
76+
77+
end testColor
78+
79+
testLatin()
80+
testColor()

0 commit comments

Comments
 (0)