Skip to content

fix #9873: no longer use scala.Enum as parents of enums #9877

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ object desugar {
val isCaseObject = mods.is(Case) && isObject
val isEnum = mods.isEnumClass && !mods.is(Module)
def isEnumCase = mods.isEnumCase
def isNonEnumCase = !isEnumCase && (isCaseClass || isCaseObject)
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.

Expand Down Expand Up @@ -483,7 +484,8 @@ object desugar {
val enumCompanionRef = TermRefTree()
val enumImport =
Import(enumCompanionRef, enumCases.flatMap(caseIds).map(ImportSelector(_)))
(enumImport :: enumStats, enumCases, enumCompanionRef)
val enumSpecMethods = EnumGetters()
(enumImport :: enumSpecMethods :: enumStats, enumCases, enumCompanionRef)
}
else (stats, Nil, EmptyTree)
}
Expand Down Expand Up @@ -621,10 +623,8 @@ object desugar {
var parents1 = parents
if (isEnumCase && parents.isEmpty)
parents1 = enumClassTypeRef :: Nil
if (isCaseClass | isCaseObject)
if (isNonEnumCase || isEnum)
parents1 = parents1 :+ scalaDot(str.Product.toTypeName) :+ scalaDot(nme.Serializable.toTypeName)
if (isEnum)
parents1 = parents1 :+ ref(defn.EnumClass.typeRef)

// derived type classes of non-module classes go to their companions
val (clsDerived, companionDerived) =
Expand Down Expand Up @@ -890,6 +890,9 @@ object desugar {
}
}

def enumGetters(getters: EnumGetters)(using Context): Tree =
flatTree(DesugarEnums.enumBaseMeths).withSpan(getters.span)

/** Transform extension construct to list of extension methods */
def extMethods(ext: ExtMethods)(using Context): Tree = flatTree {
for mdef <- ext.methods yield
Expand Down
7 changes: 7 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,13 @@ object DesugarEnums {
(ordinal, Nil)
}

def enumBaseMeths(using Context): List[Tree] =
val ordinalDef = DefDef(nme.ordinal, Nil, Nil, ref(defn.IntType), EmptyTree)
val enumLabelDef = DefDef(nme.enumLabel, Nil, Nil, ref(defn.StringClass.typeRef), EmptyTree)
val base = enumLabelDef :: Nil
if isJavaEnum then base
else ordinalDef :: base

def param(name: TermName, typ: Type)(using Context): ValDef = param(name, TypeTree(typ))
def param(name: TermName, tpt: Tree)(using Context): ValDef = ValDef(name, tpt, EmptyTree).withFlags(Param)

Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
case class Export(expr: Tree, selectors: List[ImportSelector])(implicit @constructorOnly src: SourceFile) extends Tree
case class ExtMethods(tparams: List[TypeDef], vparamss: List[List[ValDef]], methods: List[DefDef])(implicit @constructorOnly src: SourceFile) extends Tree
case class MacroTree(expr: Tree)(implicit @constructorOnly src: SourceFile) extends Tree
case class EnumGetters()(implicit @constructorOnly src: SourceFile) extends Tree

case class ImportSelector(imported: Ident, renamed: Tree = EmptyTree, bound: Tree = EmptyTree)(implicit @constructorOnly src: SourceFile) extends Tree {
// TODO: Make bound a typed tree?
Expand Down Expand Up @@ -700,6 +701,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
cpy.Export(tree)(transform(expr), selectors)
case ExtMethods(tparams, vparamss, methods) =>
cpy.ExtMethods(tree)(transformSub(tparams), vparamss.mapConserve(transformSub(_)), transformSub(methods))
case enums: EnumGetters => enums
case ImportSelector(imported, renamed, bound) =>
cpy.ImportSelector(tree)(transformSub(imported), transform(renamed), transform(bound))
case Number(_, _) | TypedSplice(_) =>
Expand Down Expand Up @@ -761,6 +763,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
this(x, expr)
case ExtMethods(tparams, vparamss, methods) =>
this(vparamss.foldLeft(this(x, tparams))(apply), methods)
case EnumGetters() =>
x
case ImportSelector(imported, renamed, bound) =>
this(this(this(x, imported), renamed), bound)
case Number(_, _) =>
Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/reporting/ErrorMessageID.scala
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ enum ErrorMessageID extends java.lang.Enum[ErrorMessageID] {
ModifierNotAllowedForDefinitionID,
CannotExtendJavaEnumID,
InvalidReferenceInImplicitNotFoundAnnotationID,
TraitMayNotDefineNativeMethodID
TraitMayNotDefineNativeMethodID,
EnumGettersRedefinitionID

def errorNumber = ordinal - 2
}
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/reporting/messages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2031,6 +2031,10 @@ import ast.tpd
def explain = ""
}

class EnumGettersRedefinition(decl: Symbol)(using Context) extends NamingMsg(EnumGettersRedefinitionID):
def msg = em"redefinition of $decl: ${decl.info} in an ${hl("enum")}"
def explain = em"users may not may not supply their own definition for $decl when inside an ${hl("enum")}"

class DoubleDefinition(decl: Symbol, previousDecl: Symbol, base: Symbol)(using Context) extends NamingMsg(DoubleDefinitionID) {
def msg = {
def nameAnd = if (decl.name != previousDecl.name) " name and" else ""
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ object SymUtils {
self
}

def isScalaEnum(using Context): Boolean = self.is(Enum, butNot=JavaDefined)

/** Does this symbol refer to anonymous classes synthesized by enum desugaring? */
def isEnumAnonymClass(using Context): Boolean =
self.isAnonymousClass && (self.owner.name.eq(nme.DOLLAR_NEW) || self.owner.is(CaseVal))
Expand Down
32 changes: 31 additions & 1 deletion compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
lazy val accessors =
if (isDerivedValueClass(clazz)) clazz.paramAccessors.take(1) // Tail parameters can only be `erased`
else clazz.caseAccessors
val isEnumCase = clazz.derivesFrom(defn.EnumClass) && clazz != defn.EnumClass
val isEnumCase = clazz.classParents.exists(_.typeSymbol.isScalaEnum)
val isEnumValue = isEnumCase && clazz.isAnonymousClass && clazz.classParents.head.classSymbol.is(Enum)
val isNonJavaEnumValue = isEnumValue && !clazz.derivesFrom(defn.JavaEnumClass)

Expand Down Expand Up @@ -513,6 +513,34 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
Match(param, cases)
}

/** For an enum T:
*
* def enumLabel(x: MirroredMonoType) = x.enumLabel
*
* For sealed trait with children of normalized types C_1, ..., C_n:
*
* def enumLabel(x: MirroredMonoType) = x match {
* case _: C_1 => "C_1"
* ...
* case _: C_n => "C_n"
* }
*
* Here, the normalized type of a class C is C[?, ...., ?] with
* a wildcard for each type parameter. The normalized type of an object
* O is O.type.
*/
def enumLabelBody(cls: Symbol, param: Tree)(using Context): Tree =
if (cls.is(Enum)) param.select(nme.enumLabel).ensureApplied
else {
val cases =
for ((child, idx) <- cls.children.zipWithIndex) yield {
val patType = if (child.isTerm) child.termRef else child.rawTypeRef
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
CaseDef(pat, EmptyTree, Literal(Constant(child.name.toString)))
}
Match(param, cases)
}

/** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent
* and `MirroredMonoType` and `ordinal` members.
* - If `impl` is the companion of a generic product, add `deriving.Mirror.Product` parent
Expand Down Expand Up @@ -564,6 +592,8 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
addParent(defn.Mirror_SumClass.typeRef)
addMethod(nme.ordinal, MethodType(monoType.typeRef :: Nil, defn.IntType), cls,
ordinalBody(_, _))
addMethod(nme.enumLabel, MethodType(monoType.typeRef :: Nil, defn.StringType), cls,
enumLabelBody(_, _))
}

if (clazz.is(Module)) {
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -889,7 +889,12 @@ trait Checking {
if (decl.matches(other) && !javaFieldMethodPair) {
def doubleDefError(decl: Symbol, other: Symbol): Unit =
if (!decl.info.isErroneous && !other.info.isErroneous)
report.error(DoubleDefinition(decl, other, cls), decl.srcPos)
if decl.owner.is(Enum, butNot=JavaDefined|Case) && decl.span.isSynthetic && (
decl.name == nme.ordinal || decl.name == nme.enumLabel)
then
report.error(EnumGettersRedefinition(decl), other.srcPos)
else
report.error(DoubleDefinition(decl, other, cls), decl.srcPos)
if (decl is Synthetic) doubleDefError(other, decl)
else doubleDefError(decl, other)
}
Expand Down
24 changes: 17 additions & 7 deletions compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -343,17 +343,20 @@ class Namer { typer: Typer =>
tree.pushAttachment(ExpandedTree, expanded)
}
tree match {
case tree: DefTree => record(desugar.defTree(tree))
case tree: PackageDef => record(desugar.packageDef(tree))
case tree: ExtMethods => record(desugar.extMethods(tree))
case _ =>
case tree: DefTree => record(desugar.defTree(tree))
case tree: PackageDef => record(desugar.packageDef(tree))
case tree: ExtMethods => record(desugar.extMethods(tree))
case tree: EnumGetters => record(desugar.enumGetters(tree))
case _ =>
}
}

/** The expanded version of this tree, or tree itself if not expanded */
def expanded(tree: Tree)(using Context): Tree = tree match {
case _: DefTree | _: PackageDef | _: ExtMethods => tree.attachmentOrElse(ExpandedTree, tree)
case _ => tree
case _: DefTree | _: PackageDef | _: ExtMethods | _: EnumGetters =>
tree.attachmentOrElse(ExpandedTree, tree)
case _ =>
tree
}

/** For all class definitions `stat` in `xstats`: If the companion class is
Expand Down Expand Up @@ -925,11 +928,17 @@ class Namer { typer: Typer =>

val TypeDef(name, impl @ Template(constr, _, self, _)) = original

private val (params, rest): (List[Tree], List[Tree]) = impl.body.span {
private val (params, restOfBody): (List[Tree], List[Tree]) = impl.body.span {
case td: TypeDef => td.mods.is(Param)
case vd: ValDef => vd.mods.is(ParamAccessor)
case _ => false
}
private val (restAfterParents, rest): (List[Tree], List[Tree]) =
if original.mods.isEnumClass then
val (imports :: getters :: Nil, stats): @unchecked = restOfBody.splitAt(2)
(getters :: Nil, imports :: stats) // enum getters desugaring needs to test if a parent is java.lang.Enum
else
(Nil, restOfBody)

def init(): Context = index(params)

Expand Down Expand Up @@ -1196,6 +1205,7 @@ class Namer { typer: Typer =>
cls.setNoInitsFlags(parentsKind(parents), untpd.bodyKind(rest))
if (cls.isNoInitsClass) cls.primaryConstructor.setFlag(StableRealizable)
processExports(using localCtx)
index(restAfterParents)(using localCtx)
}
}

Expand Down
5 changes: 4 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2122,7 +2122,7 @@ class Typer extends Namer
.withType(dummy.termRef)
if (!cls.isOneOf(AbstractOrTrait) && !ctx.isAfterTyper)
checkRealizableBounds(cls, cdef.sourcePos.withSpan(cdef.nameSpan))
if cls.derivesFrom(defn.EnumClass) then
if cls.isScalaEnum || firstParent.isScalaEnum then
checkEnum(cdef, cls, firstParent)
val cdef1 = assignType(cpy.TypeDef(cdef)(name, impl1), cls)

Expand Down Expand Up @@ -2635,6 +2635,9 @@ class Typer extends Namer
case (stat: untpd.ExtMethods) :: rest =>
val xtree = stat.removeAttachment(ExpandedTree).get
traverse(xtree :: rest)
case (stat: untpd.EnumGetters) :: rest =>
val xtree = stat.removeAttachment(ExpandedTree).get
traverse(xtree :: rest)
case stat :: rest =>
val stat1 = typed(stat)(using ctx.exprContext(stat, exprOwner))
checkStatementPurity(stat1)(stat, exprOwner)
Expand Down
3 changes: 2 additions & 1 deletion library/src-bootstrapped/scala/Enum.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scala

/** A base trait of all enum classes */
/** A Product that also describes a label and ordinal */
@deprecated("scala.Enum is no longer supported", "3.0.0-M1")
trait Enum extends Product, Serializable:

/** A string uniquely identifying a case of an enum */
Expand Down
76 changes: 76 additions & 0 deletions library/src-bootstrapped/scala/deriving.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package scala

import quoted._

object deriving {

/** Mirrors allows typelevel access to enums, case classes and objects, and their sealed parents.
*/
sealed trait Mirror {

/** The mirrored *-type */
type MirroredMonoType

/** The name of the type */
type MirroredLabel <: String

/** The names of the product elements */
type MirroredElemLabels <: Tuple
}

object Mirror {

/** The Mirror for a sum type */
trait Sum extends Mirror { self =>
/** The ordinal number of the case class of `x`. For enums, `ordinal(x) == x.ordinal` */
def ordinal(x: MirroredMonoType): Int
/** The case label of the case class of `x`. For enums, `enumLabel(x) == x.enumLabel` */
def enumLabel(x: MirroredMonoType): String
}

/** The Mirror for a product type */
trait Product extends Mirror {

/** Create a new instance of type `T` with elements taken from product `p`. */
def fromProduct(p: scala.Product): MirroredMonoType
}

trait Singleton extends Product {
type MirroredMonoType = this.type
type MirroredType = this.type
type MirroredElemTypes = EmptyTuple
type MirroredElemLabels = EmptyTuple
def fromProduct(p: scala.Product) = this
}

/** A proxy for Scala 2 singletons, which do not inherit `Singleton` directly */
class SingletonProxy(val value: AnyRef) extends Product {
type MirroredMonoType = value.type
type MirroredType = value.type
type MirroredElemTypes = EmptyTuple
type MirroredElemLabels = EmptyTuple
def fromProduct(p: scala.Product) = value
}

type Of[T] = Mirror { type MirroredType = T; type MirroredMonoType = T ; type MirroredElemTypes <: Tuple }
type ProductOf[T] = Mirror.Product { type MirroredType = T; type MirroredMonoType = T ; type MirroredElemTypes <: Tuple }
type SumOf[T] = Mirror.Sum { type MirroredType = T; type MirroredMonoType = T; type MirroredElemTypes <: Tuple }
}

/** Helper class to turn arrays into products */
class ArrayProduct(val elems: Array[AnyRef]) extends Product {
def this(size: Int) = this(new Array[AnyRef](size))
def canEqual(that: Any): Boolean = true
def productElement(n: Int) = elems(n)
def productArity = elems.length
override def productIterator: Iterator[Any] = elems.iterator
def update(n: Int, x: Any) = elems(n) = x.asInstanceOf[AnyRef]
}

/** The empty product */
object EmptyProduct extends ArrayProduct(Array.emptyObjectArray)

/** Helper method to select a product element */
def productElement[T](x: Any, idx: Int) =
x.asInstanceOf[Product].productElement(idx).asInstanceOf[T]
}
3 changes: 3 additions & 0 deletions library/src-non-bootstrapped/scala/Enum.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,8 @@ package scala
/** A base trait of all enum classes */
trait Enum extends Product, Serializable:

/** A string uniquely identifying a case of an enum */
def enumLabel: String

/** A number uniquely identifying a case of an enum */
def ordinal: Int

This file was deleted.

Loading