diff --git a/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala index 6a2fa80d6e66..acea4610fdff 100644 --- a/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala +++ b/compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala @@ -20,10 +20,17 @@ object DesugarEnums { val Simple, Object, Class: Value = Value } + final case class EnumConstraints(minKind: CaseKind.Value, maxKind: CaseKind.Value, enumCases: List[(Int, RefTree)]): + require(minKind <= maxKind && !(cached && enumCases.isEmpty)) + def requiresCreator = minKind == CaseKind.Simple + def isEnumeration = maxKind < CaseKind.Class + def cached = minKind < CaseKind.Class + end EnumConstraints + /** Attachment containing the number of enum cases, the smallest kind that was seen so far, * and a list of all the value cases with their ordinals. */ - val EnumCaseCount: Property.Key[(Int, CaseKind.Value, List[(Int, TermName)])] = Property.Key() + val EnumCaseCount: Property.Key[(Int, CaseKind.Value, CaseKind.Value, List[(Int, TermName)])] = Property.Key() /** Attachment signalling that when this definition is desugared, it should add any additional * lookup methods for enums. @@ -39,6 +46,11 @@ object DesugarEnums { if (cls.is(Module)) cls.linkedClass else cls } + def enumCompanion(using Context): Symbol = { + val cls = ctx.owner + if (cls.is(Module)) cls.sourceModule else cls.linkedClass.sourceModule + } + /** Is `tree` an (untyped) enum case? */ def isEnumCase(tree: Tree)(using Context): Boolean = tree match { case tree: MemberDef => tree.mods.isEnumCase @@ -84,65 +96,73 @@ object DesugarEnums { private def valuesDot(name: PreName)(implicit src: SourceFile) = Select(Ident(nme.DOLLAR_VALUES), name.toTermName) - private def registerCall(using Context): Tree = - Apply(valuesDot("register"), This(EmptyTypeIdent) :: Nil) + private def ArrayLiteral(values: List[Tree], tpt: Tree)(using Context): Tree = + val clazzOf = TypeApply(ref(defn.Predef_classOf.termRef), tpt :: Nil) + val ctag = Apply(TypeApply(ref(defn.ClassTagModule_apply.termRef), tpt :: Nil), clazzOf :: Nil) + val apply = Select(ref(defn.ArrayModule.termRef), nme.apply) + Apply(Apply(TypeApply(apply, tpt :: Nil), values), ctag :: Nil) - /** The following lists of definitions for an enum type E: + /** The following lists of definitions for an enum type E and known value cases e_0, ..., e_n: * - * private val $values = new EnumValues[E] - * def values = $values.values.toArray - * def valueOf($name: String) = - * try $values.fromName($name) catch - * { - * case ex$:NoSuchElementException => - * throw new IllegalArgumentException("key not found: ".concat(name)) - * } + * private val $values = Array[E](e_0,...,e_n)(ClassTag[E](classOf[E])) + * def values = $values.clone + * def valueOf($name: String) = $name match { + * case "e_0" => e_0 + * ... + * case "e_n" => e_n + * case _ => throw new IllegalArgumentException("case not found: " + $name) + * } */ - private def enumScaffolding(using Context): List[Tree] = { + private def enumScaffolding(enumValues: List[RefTree])(using Context): List[Tree] = { val rawEnumClassRef = rawRef(enumClass.typeRef) extension (tpe: NamedType) def ofRawEnum = AppliedTypeTree(ref(tpe), rawEnumClassRef) + + val lazyFlagOpt = if enumCompanion.owner.isStatic then EmptyFlags else Lazy + val privateValuesDef = ValDef(nme.DOLLAR_VALUES, TypeTree(), ArrayLiteral(enumValues, rawEnumClassRef)) + .withFlags(Private | Synthetic | lazyFlagOpt) + val valuesDef = - DefDef(nme.values, Nil, Nil, defn.ArrayType.ofRawEnum, Select(valuesDot(nme.values), nme.toArray)) + DefDef(nme.values, Nil, Nil, defn.ArrayType.ofRawEnum, valuesDot(nme.clone_)) .withFlags(Synthetic) - val privateValuesDef = - ValDef(nme.DOLLAR_VALUES, TypeTree(), New(defn.EnumValuesClass.typeRef.ofRawEnum, ListOfNil)) - .withFlags(Private | Synthetic) - - val valuesOfExnMessage = Apply( - Select(Literal(Constant("key not found: ")), "concat".toTermName), - Ident(nme.nameDollar) :: Nil) - val valuesOfBody = Try( - expr = Apply(valuesDot("fromName"), Ident(nme.nameDollar) :: Nil), - cases = CaseDef( - pat = Typed(Ident(nme.DEFAULT_EXCEPTION_NAME), TypeTree(defn.NoSuchElementExceptionType)), - guard = EmptyTree, - body = Throw(New(TypeTree(defn.IllegalArgumentExceptionType), List(valuesOfExnMessage :: Nil))) - ) :: Nil, - finalizer = EmptyTree - ) + + val valuesOfBody: Tree = + val defaultCase = + val msg = Apply(Select(Literal(Constant("enum case not found: ")), nme.PLUS), Ident(nme.nameDollar)) + CaseDef(Ident(nme.WILDCARD), EmptyTree, + Throw(New(TypeTree(defn.IllegalArgumentExceptionType), List(msg :: Nil)))) + val stringCases = enumValues.map(enumValue => + CaseDef(Literal(Constant(enumValue.name.toString)), EmptyTree, enumValue) + ) ::: defaultCase :: Nil + Match(Ident(nme.nameDollar), stringCases) val valueOfDef = DefDef(nme.valueOf, Nil, List(param(nme.nameDollar, defn.StringType) :: Nil), TypeTree(), valuesOfBody) .withFlags(Synthetic) - valuesDef :: privateValuesDef :: + valuesDef :: valueOfDef :: Nil } - private def enumLookupMethods(cases: List[(Int, TermName)])(using Context): List[Tree] = - if isJavaEnum || cases.isEmpty then Nil - else - val defaultCase = - val ord = Ident(nme.ordinal) - val err = Throw(New(TypeTree(defn.IndexOutOfBoundsException.typeRef), List(Select(ord, nme.toString_) :: Nil))) - CaseDef(ord, EmptyTree, err) - val valueCases = cases.map((i, name) => - CaseDef(Literal(Constant(i)), EmptyTree, Ident(name)) - ) ::: defaultCase :: Nil - val fromOrdinalDef = DefDef(nme.fromOrdinalDollar, Nil, List(param(nme.ordinalDollar_, defn.IntType) :: Nil), - rawRef(enumClass.typeRef), Match(Ident(nme.ordinalDollar_), valueCases)) - .withFlags(Synthetic | Private) - fromOrdinalDef :: Nil + private def enumLookupMethods(constraints: EnumConstraints)(using Context): List[Tree] = + def scaffolding: List[Tree] = if constraints.cached then enumScaffolding(constraints.enumCases.map(_._2)) else Nil + def valueCtor: List[Tree] = if constraints.requiresCreator then enumValueCreator :: Nil else Nil + def byOrdinal: List[Tree] = + if isJavaEnum || !constraints.cached then Nil + else + val defaultCase = + val ord = Ident(nme.ordinal) + val err = Throw(New(TypeTree(defn.IndexOutOfBoundsException.typeRef), List(Select(ord, nme.toString_) :: Nil))) + CaseDef(ord, EmptyTree, err) + val valueCases = constraints.enumCases.map((i, enumValue) => + CaseDef(Literal(Constant(i)), EmptyTree, enumValue) + ) ::: defaultCase :: Nil + val fromOrdinalDef = DefDef(nme.fromOrdinalDollar, Nil, List(param(nme.ordinalDollar_, defn.IntType) :: Nil), + rawRef(enumClass.typeRef), Match(Ident(nme.ordinalDollar_), valueCases)) + .withFlags(Synthetic | Private) + fromOrdinalDef :: Nil + + scaffolding ::: valueCtor ::: byOrdinal + end enumLookupMethods /** A creation method for a value of enum type `E`, which is defined as follows: * @@ -167,7 +187,7 @@ object DesugarEnums { parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil, derived = Nil, self = EmptyValDef, - body = fieldMethods ::: registerCall :: Nil + body = fieldMethods ).withAttachment(ExtendsSingletonMirror, ())) DefDef(nme.DOLLAR_NEW, Nil, List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))), @@ -279,27 +299,26 @@ object DesugarEnums { * unless that scaffolding was already generated by a previous call to `nextEnumKind`. */ def nextOrdinal(name: Name, kind: CaseKind.Value, definesLookups: Boolean)(using Context): (Int, List[Tree]) = { - val (ordinal, seenKind, seenCases) = ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class, Nil)) - val minKind = if kind < seenKind then kind else seenKind + val (ordinal, seenMinKind, seenMaxKind, seenCases) = + ctx.tree.removeAttachment(EnumCaseCount).getOrElse((0, CaseKind.Class, CaseKind.Simple, Nil)) + val minKind = if kind < seenMinKind then kind else seenMinKind + val maxKind = if kind > seenMaxKind then kind else seenMaxKind val cases = name match case name: TermName => (ordinal, name) :: seenCases case _ => seenCases - ctx.tree.pushAttachment(EnumCaseCount, (ordinal + 1, minKind, cases)) - val scaffolding0 = - if (kind >= seenKind) Nil - else if (kind == CaseKind.Object) enumScaffolding - else if (seenKind == CaseKind.Object) enumValueCreator :: Nil - else enumScaffolding :+ enumValueCreator - val scaffolding = - if definesLookups then scaffolding0 ::: enumLookupMethods(cases.reverse) - else scaffolding0 - (ordinal, scaffolding) + if definesLookups then + val companionRef = ref(enumCompanion.termRef) + val cachedValues = cases.reverse.map((i, name) => (i, Select(companionRef, name))) + (ordinal, enumLookupMethods(EnumConstraints(minKind, maxKind, cachedValues))) + else + ctx.tree.pushAttachment(EnumCaseCount, (ordinal + 1, minKind, maxKind, cases)) + (ordinal, Nil) } - def param(name: TermName, typ: Type)(using Context) = - ValDef(name, TypeTree(typ), EmptyTree).withFlags(Param) + def param(name: TermName, typ: Type)(using Context): ValDef = param(name, TypeTree(typ)) + def param(name: TermName, tpt: Tree)(using Context): ValDef = ValDef(name, tpt, EmptyTree).withFlags(Param) - private def isJavaEnum(using Context): Boolean = ctx.owner.linkedClass.derivesFrom(defn.JavaEnumClass) + private def isJavaEnum(using Context): Boolean = enumClass.derivesFrom(defn.JavaEnumClass) def ordinalMeth(body: Tree)(using Context): DefDef = DefDef(nme.ordinal, Nil, Nil, TypeTree(defn.IntType), body) @@ -325,10 +344,10 @@ object DesugarEnums { val enumLabelDef = enumLabelLit(name.toString) val impl1 = cpy.Template(impl)( parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue), - body = ordinalDef ::: enumLabelDef :: registerCall :: Nil + body = ordinalDef ::: enumLabelDef :: Nil ).withAttachment(ExtendsSingletonMirror, ()) val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span)) - flatTree(scaffolding ::: vdef :: Nil).withSpan(span) + flatTree(vdef :: scaffolding).withSpan(span) } } @@ -344,6 +363,6 @@ object DesugarEnums { val (tag, scaffolding) = nextOrdinal(name, CaseKind.Simple, definesLookups) val creator = Apply(Ident(nme.DOLLAR_NEW), List(Literal(Constant(tag)), Literal(Constant(name.toString)))) val vdef = ValDef(name, enumClassRef, creator).withMods(mods.withAddedFlags(EnumValue, span)) - flatTree(scaffolding ::: vdef :: Nil).withSpan(span) + flatTree(vdef :: scaffolding).withSpan(span) } } diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 03d7fd3c6e98..69f9d19ea0a5 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -749,8 +749,6 @@ class Definitions { @tu lazy val EnumClass: ClassSymbol = requiredClass("scala.Enum") - @tu lazy val EnumValuesClass: ClassSymbol = requiredClass("scala.runtime.EnumValues") - @tu lazy val EnumValueSerializationProxyClass: ClassSymbol = requiredClass("scala.runtime.EnumValueSerializationProxy") @tu lazy val EnumValueSerializationProxyConstructor: TermSymbol = EnumValueSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty), IntType)) diff --git a/compiler/src/dotty/tools/dotc/transform/CheckReentrant.scala b/compiler/src/dotty/tools/dotc/transform/CheckReentrant.scala index 172e611f260f..c0bd8861d217 100644 --- a/compiler/src/dotty/tools/dotc/transform/CheckReentrant.scala +++ b/compiler/src/dotty/tools/dotc/transform/CheckReentrant.scala @@ -46,12 +46,9 @@ class CheckReentrant extends MiniPhase { def isIgnored(sym: Symbol)(using Context): Boolean = sym.hasAnnotation(sharableAnnot()) || sym.hasAnnotation(unsharedAnnot()) || - sym.topLevelClass.owner == scalaJSIRPackageClass() || + sym.topLevelClass.owner == scalaJSIRPackageClass() // We would add @sharable annotations on ScalaJSVersions and // VersionChecks but we do not have control over that code - sym.owner == defn.EnumValuesClass - // enum values are initialized eagerly before use - // in the long run, we should make them vals def scanning(sym: Symbol)(op: => Unit)(using Context): Unit = { report.log(i"${" " * indent}scanning $sym") diff --git a/compiler/src/dotty/tools/dotc/transform/init/Env.scala b/compiler/src/dotty/tools/dotc/transform/init/Env.scala index 65b3684bd7f9..cd36dc1d867d 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Env.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Env.scala @@ -25,7 +25,6 @@ case class Env(ctx: Context) { // Methods that should be ignored in the checking lazy val ignoredMethods: Set[Symbol] = Set( - requiredClass("scala.runtime.EnumValues").requiredMethod("register"), defn.Any_getClass, defn.Any_isInstanceOf, defn.Object_eq, diff --git a/docs/docs/reference/enums/desugarEnums.md b/docs/docs/reference/enums/desugarEnums.md index 3ce196f6dc33..9de02d5f4079 100644 --- a/docs/docs/reference/enums/desugarEnums.md +++ b/docs/docs/reference/enums/desugarEnums.md @@ -121,12 +121,10 @@ map into `case class`es or `val`s. ``` expands to a value definition in `E`'s companion object: ```scala - val C = new { ; def ordinal = n; $values.register(this) } + val C = new { ; def ordinal = n } ``` where `n` is the ordinal number of the case in the companion object, - starting from 0. The statement `$values.register(this)` registers the value - as one of the `values` of the enumeration (see below). `$values` is a - compiler-defined private value in the companion object. The anonymous class also + starting from 0. The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`. @@ -162,8 +160,7 @@ An enum `E` (possibly generic) that defines one or more singleton cases will define the following additional synthetic members in its companion object (where `E'` denotes `E` with any type parameters replaced by wildcards): - - A method `valueOf(name: String): E'`. It returns the singleton case value whose - `toString` representation is `name`. + - A method `valueOf(name: String): E'`. It returns the singleton case value whose `enumLabel` is `name`. - A method `values` which returns an `Array[E']` of all singleton case values defined by `E`, in the order of their definitions. @@ -178,7 +175,6 @@ If `E` contains at least one simple case, its companion object will define in ad def enumLabel = $name override def productPrefix = enumLabel // if not overridden in `E` override def toString = enumLabel // if not overridden in `E` - $values.register(this) // register enum value so that `valueOf` and `values` can return it. } ``` diff --git a/docs/docs/reference/enums/enums.md b/docs/docs/reference/enums/enums.md index 37af6760b224..86350110b03e 100644 --- a/docs/docs/reference/enums/enums.md +++ b/docs/docs/reference/enums/enums.md @@ -136,7 +136,6 @@ val Venus: Planet = def enumLabel: String = "Venus" override def productPrefix: String = enumLabel override def toString: String = enumLabel - // internal code to register value } ``` diff --git a/library/src-bootstrapped/scala/runtime/EnumValues.scala b/library/src-bootstrapped/scala/runtime/EnumValues.scala deleted file mode 100644 index c34f5cd974c0..000000000000 --- a/library/src-bootstrapped/scala/runtime/EnumValues.scala +++ /dev/null @@ -1,21 +0,0 @@ -package scala.runtime - -import scala.collection.immutable.TreeMap - -class EnumValues[E <: Enum] { - private[this] var myMap: Map[Int, E] = TreeMap.empty - private[this] var fromNameCache: Map[String, E] = null - - def register(v: E) = { - require(!myMap.contains(v.ordinal)) - myMap = myMap.updated(v.ordinal, v) - fromNameCache = null - } - - def fromInt: Map[Int, E] = myMap - def fromName: Map[String, E] = { - if (fromNameCache == null) fromNameCache = myMap.values.map(v => v.enumLabel -> v).toMap - fromNameCache - } - def values: Iterable[E] = myMap.values -} diff --git a/tests/neg/enumsLabelDef.scala b/tests/neg/enumsLabelDef.scala index 99ecffc73df1..e7bc10108bb6 100644 --- a/tests/neg/enumsLabelDef.scala +++ b/tests/neg/enumsLabelDef.scala @@ -1,7 +1,6 @@ enum Labelled { case A // error overriding method enumLabel in class Labelled of type => String; - case B(arg: Int) // error overriding method enumLabel in class Labelled of type => String; def enumLabel: String = "nolabel" } @@ -9,7 +8,11 @@ enum Labelled { trait Mixin { def enumLabel: String = "mixin" } enum Mixed extends Mixin { - case C // error overriding method enumLabel in trait Mixin of type => String; + case B // error overriding method enumLabel in trait Mixin of type => String; +} + +enum MixedAlso { + case C extends MixedAlso with Mixin // error overriding method enumLabel in trait Mixin of type => String; } trait HasEnumLabel { def enumLabel: String } diff --git a/tests/run/enum-java.check b/tests/run/enum-java.check index 3cd188c5d5f0..667e6a4df37c 100644 --- a/tests/run/enum-java.check +++ b/tests/run/enum-java.check @@ -26,7 +26,7 @@ MONDAY : 0 TUESDAY : 1 SATURDAY : 2 By-name value: MONDAY -Correctly failed to retrieve illegal name, message: key not found: stuff +Correctly failed to retrieve illegal name, message: enum case not found: stuff Collections Test Retrieving Monday: workday diff --git a/tests/run/enums-java-compat.check b/tests/run/enums-java-compat.check index 04468ba9b44d..0acf636fe377 100644 --- a/tests/run/enums-java-compat.check +++ b/tests/run/enums-java-compat.check @@ -6,4 +6,4 @@ TUESDAY : 1 SATURDAY : 2 Stuff : 3 By-name value: MONDAY -Correctly failed to retrieve illegal name, message: key not found: stuff +Correctly failed to retrieve illegal name, message: enum case not found: stuff diff --git a/tests/run/enums-thunk.scala b/tests/run/enums-thunk.scala new file mode 100644 index 000000000000..8ca9dde31e79 --- /dev/null +++ b/tests/run/enums-thunk.scala @@ -0,0 +1,27 @@ +class Outer { + val thunk = { () => + enum E { case A1 } + E.A1 + } + val thunk2 = { () => + enum E { case A2 } + E.values + } +} + +object Outer2 { + val thunk = { () => + enum E { case B1 } + E.B1 + } + val thunk2 = { () => + enum E { case B2 } + E.values + } +} + +@main def Test = + assert(Outer().thunk().toString == "A1") + assert(Outer().thunk2()(0).toString == "A2") + assert(Outer2.thunk().toString == "B1") + assert(Outer2.thunk2()(0).toString == "B2")