Skip to content

Commit 0dd4b22

Browse files
committed
fix #7227: allow custom toString on enum
productPrefix is now overriden using the enum constant's name and used in the by-name lookup in EnumValues. java based enum values are optimised so that productPrefix will forward to .name in the simple enum case, avoiding an extra field
1 parent a2e7b73 commit 0dd4b22

File tree

6 files changed

+93
-15
lines changed

6 files changed

+93
-15
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,19 +126,19 @@ object DesugarEnums {
126126
*
127127
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
128128
* def $ordinal = $tag
129-
* override def toString = $name
129+
* override def productPrefix = $name // if does not derive from `java.lang.Enum`
130+
* override def productPrefix = this.name // if derives from `java.lang.Enum`
130131
* $values.register(this)
131132
* }
132133
*/
133134
private def enumValueCreator(using Context) = {
134135
val ordinalDef = ordinalMeth(Ident(nme.ordinalDollar_))
135-
val toStringDef = toStringMeth(Ident(nme.nameDollar))
136136
val creator = New(Template(
137137
constr = emptyConstructor,
138138
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
139139
derived = Nil,
140140
self = EmptyValDef,
141-
body = List(ordinalDef, toStringDef) ++ registerCall
141+
body = List(ordinalDef, productPrefixDynamic) ++ registerCall
142142
).withAttachment(ExtendsSingletonMirror, ()))
143143
DefDef(nme.DOLLAR_NEW, Nil,
144144
List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))),
@@ -267,14 +267,22 @@ object DesugarEnums {
267267
def ordinalMeth(body: Tree)(using Context): DefDef =
268268
DefDef(nme.ordinalDollar, Nil, Nil, TypeTree(defn.IntType), body)
269269

270-
def toStringMeth(body: Tree)(using Context): DefDef =
271-
DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), body).withFlags(Override)
270+
def productPrefixMeth(body: Tree)(using Context): DefDef =
271+
DefDef(nme.productPrefix, Nil, Nil, TypeTree(defn.StringType), body).withFlags(Override)
272272

273273
def ordinalMethLit(ord: Int)(using Context): DefDef =
274274
ordinalMeth(Literal(Constant(ord)))
275275

276-
def toStringMethLit(name: String)(using Context): DefDef =
277-
toStringMeth(Literal(Constant(name)))
276+
def productPrefixLit(name: String)(using Context): DefDef =
277+
productPrefixMeth(Literal(Constant(name)))
278+
279+
def productPrefixDynamic(using Context): DefDef =
280+
val body =
281+
if ctx.owner.linkedClass.derivesFrom(defn.JavaEnumClass) then
282+
Select(This(Ident(tpnme.EMPTY)), nme.name)
283+
else
284+
Ident(nme.nameDollar)
285+
productPrefixMeth(body)
278286

279287
/** Expand a module definition representing a parameterless enum case */
280288
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, span: Span)(using Context): Tree = {
@@ -285,10 +293,10 @@ object DesugarEnums {
285293
else {
286294
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
287295
val ordinalDef = ordinalMethLit(tag)
288-
val toStringDef = toStringMethLit(name.toString)
296+
val productPrefixDef = productPrefixLit(name.toString)
289297
val impl1 = cpy.Template(impl)(
290298
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
291-
body = List(ordinalDef, toStringDef) ++ registerCall)
299+
body = List(ordinalDef, productPrefixDef) ++ registerCall)
292300
.withAttachment(ExtendsSingletonMirror, ())
293301
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
294302
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,7 @@ class Definitions {
651651
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
652652

653653
@tu lazy val EnumValuesClass: ClassSymbol = requiredClass("scala.runtime.EnumValues")
654+
@tu lazy val EnumValueClass: ClassSymbol = requiredClass("scala.runtime.EnumValue")
654655
@tu lazy val ProductClass: ClassSymbol = requiredClass("scala.Product")
655656
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)
656657
@tu lazy val Product_productArity : Symbol = ProductClass.requiredMethod(nme.productArity)

compiler/src/dotty/tools/dotc/transform/CompleteJavaEnums.scala

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,18 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
117117
* class $anon extends E(...) {
118118
* ...
119119
* def ordinal = N
120-
* def toString = S
120+
* def productPrefix = this.name // if a simple enum case
121+
* def productPrefix = "S" // if a value enum case
121122
* ...
122123
* }
123124
*
124125
* After the transform it is expanded to
125126
*
126-
* class $anon extends E(..., N, S) {
127+
* class $anon extends E(..., N, "S") { // for value enum cases
128+
* "same as before"
129+
* }
130+
*
131+
* class $anon extends E(..., N, $name) { // for simple enum cases, where `$name` comes from the `$new` method
127132
* "same as before"
128133
* }
129134
*/
@@ -144,7 +149,12 @@ class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
144149
templ.body.collect {
145150
case mdef: DefDef if mdef.name == name => mdef.rhs
146151
}.head
147-
val args = List(rhsOf(nme.toString_), rhsOf(nme.ordinalDollar))
152+
def nameArg =
153+
if cls.owner.name eq nme.DOLLAR_NEW then // productPrefix calls .name so we need to find the argument from $new
154+
ref(cls.owner.paramSymss.flatten.find(p => p.isTerm && p.name == nme.nameDollar).map(_.termRef).get)
155+
else
156+
rhsOf(nme.productPrefix)
157+
val args = List(nameArg, rhsOf(nme.ordinalDollar))
148158
cpy.Template(templ)(
149159
parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args))
150160
}

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
5858
private var myCaseSymbols: List[Symbol] = Nil
5959
private var myCaseModuleSymbols: List[Symbol] = Nil
6060
private var myEnumCaseSymbols: List[Symbol] = Nil
61+
private var myEnumValueSymbols: List[Symbol] = Nil
6162

6263
private def initSymbols(using Context) =
6364
if (myValueSymbols.isEmpty) {
@@ -67,12 +68,14 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
6768
defn.Product_productElementName)
6869
myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals)
6970
myEnumCaseSymbols = List(defn.Enum_ordinal)
71+
myEnumValueSymbols = List(defn.Any_toString)
7072
}
7173

7274
def valueSymbols(using Context): List[Symbol] = { initSymbols; myValueSymbols }
7375
def caseSymbols(using Context): List[Symbol] = { initSymbols; myCaseSymbols }
7476
def caseModuleSymbols(using Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
7577
def enumCaseSymbols(using Context): List[Symbol] = { initSymbols; myEnumCaseSymbols }
78+
def enumValueSymbols(using Context): List[Symbol] = { initSymbols; myEnumValueSymbols }
7679

7780
private def existingDef(sym: Symbol, clazz: ClassSymbol)(using Context): Symbol = {
7881
val existing = sym.matchingMember(clazz.thisType)
@@ -92,13 +95,15 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
9295
if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased`
9396
else clazz.caseAccessors
9497
val isEnumCase = clazz.derivesFrom(defn.EnumClass) && clazz != defn.EnumClass
98+
val isEnumValue = isEnumCase && clazz.derivesFrom(defn.EnumValueClass)
99+
val isNonJavaEnumValue = isEnumValue && !clazz.derivesFrom(defn.JavaEnumClass)
95100

96101
val symbolsToSynthesize: List[Symbol] =
97102
if (clazz.is(Case))
98103
if (clazz.is(Module)) caseModuleSymbols
99104
else if (isEnumCase) caseSymbols ++ enumCaseSymbols
100105
else caseSymbols
101-
else if (isEnumCase) enumCaseSymbols
106+
else if (isEnumCase) enumCaseSymbols ++ (if isNonJavaEnumValue then enumValueSymbols else Nil)
102107
else if (isDerivedValueClass(clazz)) valueSymbols
103108
else Nil
104109

@@ -118,10 +123,18 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
118123
def ownName: Tree =
119124
Literal(Constant(clazz.name.stripModuleClassSuffix.toString))
120125

126+
def callProductPrefix: Tree =
127+
Select(This(clazz), nme.productPrefix).ensureApplied
128+
129+
def toStringBody(vrefss: List[List[Tree]]): Tree =
130+
if (clazz.is(ModuleClass)) ownName
131+
else if (isNonJavaEnumValue) callProductPrefix
132+
else forwardToRuntime(vrefss.head)
133+
121134
def syntheticRHS(vrefss: List[List[Tree]])(using Context): Tree = synthetic.name match {
122135
case nme.hashCode_ if isDerivedValueClass(clazz) => valueHashCodeBody
123136
case nme.hashCode_ => chooseHashcode
124-
case nme.toString_ => if (clazz.is(ModuleClass)) ownName else forwardToRuntime(vrefss.head)
137+
case nme.toString_ => toStringBody(vrefss)
125138
case nme.equals_ => equalsBody(vrefss.head.head)
126139
case nme.canEqual_ => canEqualBody(vrefss.head.head)
127140
case nme.productArity => Literal(Constant(accessors.length))

library/src/scala/runtime/EnumValues.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ class EnumValues[E <: Enum] {
1414

1515
def fromInt: Map[Int, E] = myMap
1616
def fromName: Map[String, E] = {
17-
if (fromNameCache == null) fromNameCache = myMap.values.map(v => v.toString -> v).toMap
17+
// TODO remove cast when scala.Enum is bootstrapped
18+
if (fromNameCache == null) fromNameCache = myMap.values.map(v => v.asInstanceOf[Product].productPrefix -> v).toMap
1819
fromNameCache
1920
}
2021
def values: Iterable[E] = myMap.values

tests/run/enum-custom-toString.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
enum ES:
2+
case A
3+
override def toString: String = "overridden"
4+
5+
enum EJ extends java.lang.Enum[EJ]:
6+
case B
7+
override def toString: String = "overridden"
8+
9+
trait Mixin:
10+
override def toString: String = "overridden"
11+
12+
enum EM extends Mixin:
13+
case C
14+
15+
enum ET[T] extends java.lang.Enum[ET[_]]:
16+
case D extends ET[Unit]
17+
override def toString: String = "overridden"
18+
19+
enum EZ:
20+
case E(arg: Int)
21+
override def toString: String = "overridden"
22+
23+
enum EC: // control case
24+
case F
25+
case G(arg: Int)
26+
27+
@main def Test =
28+
assert(ES.A.toString == "overridden", s"ES.A.toString = ${ES.A.toString}")
29+
assert(ES.A.productPrefix == "A", s"ES.A.productPrefix = ${ES.A.productPrefix}")
30+
assert(ES.valueOf("A") == ES.A, s"ES.valueOf(A) = ${ES.valueOf("A")}")
31+
assert(EJ.B.toString == "overridden", s"EJ.B.toString = ${EJ.B.toString}")
32+
assert(EJ.B.productPrefix == "B", s"EJ.B.productPrefix = ${EJ.B.productPrefix}")
33+
assert(EJ.valueOf("B") == EJ.B, s"EJ.valueOf(B) = ${EJ.valueOf("B")}")
34+
assert(EM.C.toString == "overridden", s"EM.C.toString = ${EM.C.toString}")
35+
assert(EM.C.productPrefix == "C", s"EM.C.productPrefix = ${EM.C.productPrefix}")
36+
assert(EM.valueOf("C") == EM.C, s"EM.valueOf(C) = ${EM.valueOf("C")}")
37+
assert(ET.D.toString == "overridden", s"ET.D.toString = ${ET.D.toString}")
38+
assert(ET.D.productPrefix == "D", s"ET.D.productPrefix = ${ET.D.productPrefix}")
39+
assert(EZ.E(0).toString == "overridden", s"EZ.E(0).toString = ${EZ.E(0).toString}")
40+
assert(EZ.E(0).productPrefix == "E", s"EZ.E(0).productPrefix = ${EZ.E(0).productPrefix}")
41+
assert(EC.F.toString == "F", s"EC.F.toString = ${EC.F.toString}")
42+
assert(EC.F.productPrefix == "F", s"EC.F.productPrefix = ${EC.F.productPrefix}")
43+
assert(EC.valueOf("F") == EC.F, s"EC.valueOf(F) = ${EC.valueOf("F")}")
44+
assert(EC.G(0).toString == "G(0)", s"EC.G(0).toString = ${EC.G(0).toString}")
45+
assert(EC.G(0).productPrefix == "G", s"EC.G(0).productPrefix = ${EC.G(0).productPrefix}")

0 commit comments

Comments
 (0)