Skip to content

Commit 9116ecb

Browse files
Merge pull request #6602 from dotty-staging/add-enum-constrs
Add mini-phase to fix constructors for enums extending java.lang.Enum
2 parents fa62051 + a8474fd commit 9116ecb

File tree

14 files changed

+274
-71
lines changed

14 files changed

+274
-71
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ class Compiler {
5656
List(new FirstTransform, // Some transformations to put trees into a canonical form
5757
new CheckReentrant, // Internal use only: Check that compiled program has no data races involving global vars
5858
new ElimPackagePrefixes, // Eliminate references to package prefixes in Select nodes
59-
new CookComments) :: // Cook the comments: expand variables, doc, etc.
59+
new CookComments, // Cook the comments: expand variables, doc, etc.
60+
new CompleteJavaEnums) :: // Fill in constructors for Java enums
6061
List(new CheckStatic, // Check restrictions that apply to @static members
6162
new ElimRepeated, // Rewrite vararg parameters and arguments
6263
new ExpandSAMs, // Expand single abstract method closures to anonymous classes

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,15 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
596596
loop(tree, Nil, Nil)
597597
}
598598

599+
/** Decompose a template body into parameters and other statements */
600+
def decomposeTemplateBody(body: List[Tree])(implicit ctx: Context): (List[Tree], List[Tree]) =
601+
body.partition {
602+
case stat: TypeDef => stat.symbol is Flags.Param
603+
case stat: ValOrDefDef =>
604+
stat.symbol.is(Flags.ParamAccessor) && !stat.symbol.isSetter
605+
case _ => false
606+
}
607+
599608
/** An extractor for closures, either contained in a block or standalone.
600609
*/
601610
object closure {

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,6 @@ class Definitions {
299299
val companion = JavaLangPackageVal.info.decl(nme.Object).symbol
300300
companion.moduleClass.info = NoType // to indicate that it does not really exist
301301
companion.info = NoType // to indicate that it does not really exist
302-
303302
completeClass(cls)
304303
}
305304
def ObjectType: TypeRef = ObjectClass.typeRef
@@ -674,6 +673,8 @@ class Definitions {
674673
def NoneClass(implicit ctx: Context): ClassSymbol = NoneModuleRef.symbol.moduleClass.asClass
675674
lazy val EnumType: TypeRef = ctx.requiredClassRef("scala.Enum")
676675
def EnumClass(implicit ctx: Context): ClassSymbol = EnumType.symbol.asClass
676+
lazy val JEnumType: TypeRef = ctx.requiredClassRef("scala.compat.JEnum")
677+
def JEnumClass(implicit ctx: Context): ClassSymbol = JEnumType.symbol.asClass
677678
lazy val EnumValuesType: TypeRef = ctx.requiredClassRef("scala.runtime.EnumValues")
678679
def EnumValuesClass(implicit ctx: Context): ClassSymbol = EnumValuesType.symbol.asClass
679680
lazy val ProductType: TypeRef = ctx.requiredClassRef("scala.Product")

compiler/src/dotty/tools/dotc/core/Denotations.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,8 +814,11 @@ object Denotations {
814814
def invalidateInheritedInfo(): Unit = ()
815815

816816
private def updateValidity()(implicit ctx: Context): this.type = {
817-
assert(ctx.runId >= validFor.runId || ctx.settings.YtestPickler.value, // mixing test pickler with debug printing can travel back in time
818-
s"denotation $this invalid in run ${ctx.runId}. ValidFor: $validFor")
817+
assert(
818+
ctx.runId >= validFor.runId ||
819+
ctx.settings.YtestPickler.value || // mixing test pickler with debug printing can travel back in time
820+
symbol.is(Permanent), // Permanent symbols are valid in all runIds
821+
s"denotation $this invalid in run ${ctx.runId}. ValidFor: $validFor")
819822
var d: SingleDenotation = this
820823
do {
821824
d.validFor = Period(ctx.period.runId, d.validFor.firstPhaseId, d.validFor.lastPhaseId)

compiler/src/dotty/tools/dotc/core/tasty/TastyFormat.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,7 @@ object TastyFormat {
490490
| STATIC
491491
| OBJECT
492492
| TRAIT
493+
| ENUM
493494
| LOCAL
494495
| SYNTHETIC
495496
| ARTIFACT

compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala

Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -509,12 +509,7 @@ class TreePickler(pickler: TastyPickler) {
509509
case tree: Template =>
510510
registerDef(tree.symbol)
511511
writeByte(TEMPLATE)
512-
val (params, rest) = tree.body partition {
513-
case stat: TypeDef => stat.symbol is Flags.Param
514-
case stat: ValOrDefDef =>
515-
stat.symbol.is(Flags.ParamAccessor) && !stat.symbol.isSetter
516-
case _ => false
517-
}
512+
val (params, rest) = decomposeTemplateBody(tree.body)
518513
withLength {
519514
pickleParams(params)
520515
tree.parents.foreach(pickleTree)
@@ -635,44 +630,48 @@ class TreePickler(pickler: TastyPickler) {
635630

636631
def pickleFlags(flags: FlagSet, isTerm: Boolean)(implicit ctx: Context): Unit = {
637632
import Flags._
638-
if (flags is Private) writeByte(PRIVATE)
639-
if (flags is Protected) writeByte(PROTECTED)
640-
if (flags.is(Final, butNot = Module)) writeByte(FINAL)
641-
if (flags is Case) writeByte(CASE)
642-
if (flags is Override) writeByte(OVERRIDE)
643-
if (flags is Inline) writeByte(INLINE)
644-
if (flags is InlineProxy) writeByte(INLINEPROXY)
645-
if (flags is Macro) writeByte(MACRO)
646-
if (flags is JavaStatic) writeByte(STATIC)
647-
if (flags is Module) writeByte(OBJECT)
648-
if (flags is Enum) writeByte(ENUM)
649-
if (flags is Local) writeByte(LOCAL)
650-
if (flags is Synthetic) writeByte(SYNTHETIC)
651-
if (flags is Artifact) writeByte(ARTIFACT)
652-
if (flags is Scala2x) writeByte(SCALA2X)
633+
def writeModTag(tag: Int) = {
634+
assert(isModifierTag(tag))
635+
writeByte(tag)
636+
}
637+
if (flags is Private) writeModTag(PRIVATE)
638+
if (flags is Protected) writeModTag(PROTECTED)
639+
if (flags.is(Final, butNot = Module)) writeModTag(FINAL)
640+
if (flags is Case) writeModTag(CASE)
641+
if (flags is Override) writeModTag(OVERRIDE)
642+
if (flags is Inline) writeModTag(INLINE)
643+
if (flags is InlineProxy) writeModTag(INLINEPROXY)
644+
if (flags is Macro) writeModTag(MACRO)
645+
if (flags is JavaStatic) writeModTag(STATIC)
646+
if (flags is Module) writeModTag(OBJECT)
647+
if (flags is Enum) writeModTag(ENUM)
648+
if (flags is Local) writeModTag(LOCAL)
649+
if (flags is Synthetic) writeModTag(SYNTHETIC)
650+
if (flags is Artifact) writeModTag(ARTIFACT)
651+
if (flags is Scala2x) writeModTag(SCALA2X)
653652
if (isTerm) {
654-
if (flags is Implicit) writeByte(IMPLICIT)
655-
if (flags is Implied) writeByte(IMPLIED)
656-
if (flags is Erased) writeByte(ERASED)
657-
if (flags.is(Lazy, butNot = Module)) writeByte(LAZY)
658-
if (flags is AbsOverride) { writeByte(ABSTRACT); writeByte(OVERRIDE) }
659-
if (flags is Mutable) writeByte(MUTABLE)
660-
if (flags is Accessor) writeByte(FIELDaccessor)
661-
if (flags is CaseAccessor) writeByte(CASEaccessor)
662-
if (flags is DefaultParameterized) writeByte(DEFAULTparameterized)
663-
if (flags is StableRealizable) writeByte(STABLE)
664-
if (flags is Extension) writeByte(EXTENSION)
665-
if (flags is Given) writeByte(GIVEN)
666-
if (flags is ParamAccessor) writeByte(PARAMsetter)
667-
if (flags is Exported) writeByte(EXPORTED)
653+
if (flags is Implicit) writeModTag(IMPLICIT)
654+
if (flags is Implied) writeModTag(IMPLIED)
655+
if (flags is Erased) writeModTag(ERASED)
656+
if (flags.is(Lazy, butNot = Module)) writeModTag(LAZY)
657+
if (flags is AbsOverride) { writeModTag(ABSTRACT); writeModTag(OVERRIDE) }
658+
if (flags is Mutable) writeModTag(MUTABLE)
659+
if (flags is Accessor) writeModTag(FIELDaccessor)
660+
if (flags is CaseAccessor) writeModTag(CASEaccessor)
661+
if (flags is DefaultParameterized) writeModTag(DEFAULTparameterized)
662+
if (flags is StableRealizable) writeModTag(STABLE)
663+
if (flags is Extension) writeModTag(EXTENSION)
664+
if (flags is Given) writeModTag(GIVEN)
665+
if (flags is ParamAccessor) writeModTag(PARAMsetter)
666+
if (flags is Exported) writeModTag(EXPORTED)
668667
assert(!(flags is Label))
669668
} else {
670-
if (flags is Sealed) writeByte(SEALED)
671-
if (flags is Abstract) writeByte(ABSTRACT)
672-
if (flags is Trait) writeByte(TRAIT)
673-
if (flags is Covariant) writeByte(COVARIANT)
674-
if (flags is Contravariant) writeByte(CONTRAVARIANT)
675-
if (flags is Opaque) writeByte(OPAQUE)
669+
if (flags is Sealed) writeModTag(SEALED)
670+
if (flags is Abstract) writeModTag(ABSTRACT)
671+
if (flags is Trait) writeModTag(TRAIT)
672+
if (flags is Covariant) writeModTag(COVARIANT)
673+
if (flags is Contravariant) writeModTag(CONTRAVARIANT)
674+
if (flags is Opaque) writeModTag(OPAQUE)
676675
}
677676
}
678677

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
851851
else if (sym.is(ModuleClass))
852852
nameString(sym.name.stripModuleClassSuffix)
853853
else if (hasMeaninglessName(sym))
854-
simpleNameString(sym.owner)
854+
simpleNameString(sym.owner) + idString(sym)
855855
else
856856
nameString(sym)
857857
(keywordText(kindString(sym)) ~~ {
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
package dotty.tools.dotc
2+
package transform
3+
4+
import core._
5+
import Names._
6+
import StdNames.{nme, tpnme}
7+
import Types._
8+
import dotty.tools.dotc.transform.MegaPhase._
9+
import Flags._
10+
import Contexts.Context
11+
import Symbols._
12+
import Constants._
13+
import Decorators._
14+
import DenotTransformers._
15+
16+
object CompleteJavaEnums {
17+
val name: String = "completeJavaEnums"
18+
19+
private val nameParamName: TermName = "$name".toTermName
20+
private val ordinalParamName: TermName = "$ordinal".toTermName
21+
}
22+
23+
/** For Scala enums that inherit from java.lang.Enum:
24+
* Add constructor parameters for `name` and `ordinal` to pass from each
25+
* case to the java.lang.Enum class.
26+
*/
27+
class CompleteJavaEnums extends MiniPhase with InfoTransformer { thisPhase =>
28+
import CompleteJavaEnums._
29+
import ast.tpd._
30+
31+
override def phaseName: String = CompleteJavaEnums.name
32+
33+
override def relaxedTypingInGroup: Boolean = true
34+
// Because it adds additional parameters to some constructors
35+
36+
def transformInfo(tp: Type, sym: Symbol)(implicit ctx: Context): Type =
37+
if (sym.isConstructor && derivesFromJEnum(sym.owner)) addConstrParams(sym.info)
38+
else tp
39+
40+
/** Is `sym` a Scala enum class that derives (directly) from `java.lang.Enum`?
41+
*/
42+
private def derivesFromJEnum(sym: Symbol)(implicit ctx: Context) =
43+
sym.is(Enum, butNot = Case) &&
44+
sym.info.parents.exists(p => p.typeSymbol == defn.JEnumClass)
45+
46+
/** Add constructor parameters `$name: String` and `$ordinal: Int` to the end of
47+
* the last parameter list of (method- or poly-) type `tp`.
48+
*/
49+
private def addConstrParams(tp: Type)(implicit ctx: Context): Type = tp match {
50+
case tp: PolyType =>
51+
tp.derivedLambdaType(resType = addConstrParams(tp.resType))
52+
case tp: MethodType =>
53+
tp.resType match {
54+
case restpe: MethodType =>
55+
tp.derivedLambdaType(resType = addConstrParams(restpe))
56+
case _ =>
57+
tp.derivedLambdaType(
58+
paramNames = tp.paramNames ++ List(nameParamName, ordinalParamName),
59+
paramInfos = tp.paramInfos ++ List(defn.StringType, defn.IntType))
60+
}
61+
}
62+
63+
/** The list of parameter definitions `$name: String, $ordinal: Int`, in given `owner`
64+
* with given flags (either `Param` or `ParamAccessor`)
65+
*/
66+
private def addedParams(owner: Symbol, flag: FlagSet)(implicit ctx: Context): List[ValDef] = {
67+
val nameParam = ctx.newSymbol(owner, nameParamName, flag | Synthetic, defn.StringType, coord = owner.span)
68+
val ordinalParam = ctx.newSymbol(owner, ordinalParamName, flag | Synthetic, defn.IntType, coord = owner.span)
69+
List(ValDef(nameParam), ValDef(ordinalParam))
70+
}
71+
72+
/** Add arguments `args` to the parent constructor application in `parents` that invokes
73+
* a constructor of `targetCls`,
74+
*/
75+
private def addEnumConstrArgs(targetCls: Symbol, parents: List[Tree], args: List[Tree])(implicit ctx: Context): List[Tree] =
76+
parents.map {
77+
case app @ Apply(fn, args0) if fn.symbol.owner == targetCls => cpy.Apply(app)(fn, args0 ++ args)
78+
case p => p
79+
}
80+
81+
/** 1. If this is a constructor of a enum class that extends, add $name and $ordinal parameters to it.
82+
*
83+
* 2. If this is a $new method that creates simple cases, pass $name and $ordinal parameters
84+
* to the enum superclass. The $new method looks like this:
85+
*
86+
* def $new(..., enumTag: Int, name: String) = {
87+
* class $anon extends E(...) { ... }
88+
* new $anon
89+
* }
90+
*
91+
* After the transform it is expanded to
92+
*
93+
* def $new(..., enumTag: Int, name: String) = {
94+
* class $anon extends E(..., name, enumTag) { ... }
95+
* new $anon
96+
* }
97+
*/
98+
override def transformDefDef(tree: DefDef)(implicit ctx: Context): DefDef = {
99+
val sym = tree.symbol
100+
if (sym.isConstructor && derivesFromJEnum(sym.owner))
101+
cpy.DefDef(tree)(
102+
vparamss = tree.vparamss.init :+ (tree.vparamss.last ++ addedParams(sym, Param)))
103+
else if (sym.name == nme.DOLLAR_NEW && derivesFromJEnum(sym.owner.linkedClass)) {
104+
val Block((tdef @ TypeDef(tpnme.ANON_CLASS, templ: Template)) :: Nil, call) = tree.rhs
105+
val args = tree.vparamss.last.takeRight(2).map(param => ref(param.symbol)).reverse
106+
val templ1 = cpy.Template(templ)(
107+
parents = addEnumConstrArgs(sym.owner.linkedClass, templ.parents, args))
108+
cpy.DefDef(tree)(
109+
rhs = cpy.Block(tree.rhs)(cpy.TypeDef(tdef)(tdef.name, templ1) :: Nil, call))
110+
}
111+
else tree
112+
}
113+
114+
/** 1. If this is an enum class, add $name and $ordinal parameters to its
115+
* parameter accessors and pass them on to the java.lang.Enum constructor,
116+
* replacing the dummy arguments that were passed before.
117+
*
118+
* 2. If this is an anonymous class that implement a value enum case,
119+
* pass $name and $ordinal parameters to the enum superclass. The class
120+
* looks like this:
121+
*
122+
* class $anon extends E(...) {
123+
* ...
124+
* def enumTag = N
125+
* def toString = S
126+
* ...
127+
* }
128+
*
129+
* After the transform it is expanded to
130+
*
131+
* class $anon extends E(..., N, S) {
132+
* "same as before"
133+
* }
134+
*/
135+
override def transformTemplate(templ: Template)(implicit ctx: Context): Template = {
136+
val cls = templ.symbol.owner
137+
if (derivesFromJEnum(cls)) {
138+
val (params, rest) = decomposeTemplateBody(templ.body)
139+
val addedDefs = addedParams(cls, ParamAccessor)
140+
val addedSyms = addedDefs.map(_.symbol.entered)
141+
val parents1 = templ.parents.map {
142+
case app @ Apply(fn, _) if fn.symbol.owner == defn.JEnumClass =>
143+
cpy.Apply(app)(fn, addedSyms.map(ref))
144+
case p => p
145+
}
146+
cpy.Template(templ)(
147+
parents = parents1,
148+
body = params ++ addedDefs ++ rest)
149+
}
150+
else if (cls.isAnonymousClass && cls.owner.is(EnumCase) && derivesFromJEnum(cls.owner.owner.linkedClass)) {
151+
def rhsOf(name: TermName) =
152+
templ.body.collect {
153+
case mdef: DefDef if mdef.name == name => mdef.rhs
154+
}.head
155+
val args = List(rhsOf(nme.toString_), rhsOf(nme.enumTag))
156+
cpy.Template(templ)(
157+
parents = addEnumConstrArgs(cls.owner.owner.linkedClass, templ.parents, args))
158+
}
159+
else templ
160+
}
161+
}

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1046,20 +1046,26 @@ trait Checking {
10461046
ctx.error(em"$what $msg", posd.sourcePos)
10471047
}
10481048

1049-
/** Check that all case classes that extend `scala.Enum` are `enum` cases */
1049+
/** 1. Check that all case classes that extend `scala.Enum` are `enum` cases
1050+
* 2. Check that case class `enum` cases do not extend java.lang.Enum.
1051+
*/
10501052
def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(implicit ctx: Context): Unit = {
10511053
import untpd.modsDeco
10521054
def isEnumAnonCls =
10531055
cls.isAnonymousClass &&
10541056
cls.owner.isTerm &&
10551057
(cls.owner.flagsUNSAFE.is(Case) || cls.owner.name == nme.DOLLAR_NEW)
1056-
if (!cdef.mods.isEnumCase && !isEnumAnonCls) {
1057-
// Since enums are classes and Namer checks that classes don't extend multiple classes, we only check the class
1058-
// parent.
1059-
//
1060-
// Unlike firstParent.derivesFrom(defn.EnumClass), this test allows inheriting from `Enum` by hand;
1061-
// see enum-List-control.scala.
1062-
if (cls.is(Case) || firstParent.is(Enum))
1058+
if (!isEnumAnonCls) {
1059+
if (cdef.mods.isEnumCase) {
1060+
if (cls.derivesFrom(defn.JEnumClass))
1061+
ctx.error(em"parameterized case is not allowed in an enum that extends java.lang.Enum", cdef.sourcePos)
1062+
}
1063+
else if (cls.is(Case) || firstParent.is(Enum))
1064+
// Since enums are classes and Namer checks that classes don't extend multiple classes, we only check the class
1065+
// parent.
1066+
//
1067+
// Unlike firstParent.derivesFrom(defn.EnumClass), this test allows inheriting from `Enum` by hand;
1068+
// see enum-List-control.scala.
10631069
ctx.error(ClassCannotExtendEnum(cls, firstParent), cdef.sourcePos)
10641070
}
10651071
}

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -498,23 +498,9 @@ class Namer { typer: Typer =>
498498
recur(expanded(origStat))
499499
}
500500

501-
/** Determines whether this field holds an enum constant.
502-
* To qualify, the following conditions must be met:
503-
* - The field's class has the ENUM flag set
504-
* - The field's class extends java.lang.Enum
505-
* - The field has the ENUM flag set
506-
* - The field is static
507-
* - The field is stable
508-
*/
509-
def isEnumConstant(vd: ValDef)(implicit ctx: Context): Boolean = {
510-
// val ownerHasEnumFlag =
511-
// Necessary to check because scalac puts Java's static members into the companion object
512-
// while Scala's enum constants live directly in the class.
513-
// We don't check for clazz.superClass == JavaEnumClass, because this causes a illegal
514-
// cyclic reference error. See the commit message for details.
515-
// if (ctx.compilationUnit.isJava) ctx.owner.companionClass.is(Enum) else ctx.owner.is(Enum)
516-
vd.mods.is(JavaEnumValue) // && ownerHasEnumFlag
517-
}
501+
/** Determines whether this field holds an enum constant. */
502+
def isEnumConstant(vd: ValDef)(implicit ctx: Context): Boolean =
503+
vd.mods.is(JavaEnumValue)
518504

519505
/** Add child annotation for `child` to annotations of `cls`. The annotation
520506
* is added at the correct insertion point, so that Child annotations appear

0 commit comments

Comments
 (0)