Skip to content

Commit 01f3094

Browse files
authored
Merge pull request #4003 from dotty-staging/simplify-enums
Simplify Enums
2 parents 02725c6 + f07d9ca commit 01f3094

39 files changed

+667
-395
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 75 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@ object desugar {
3333
// ----- DerivedTypeTrees -----------------------------------
3434

3535
class SetterParamTree extends DerivedTypeTree {
36-
def derivedType(sym: Symbol)(implicit ctx: Context) = sym.info.resultType
36+
def derivedTree(sym: Symbol)(implicit ctx: Context) = tpd.TypeTree(sym.info.resultType)
3737
}
3838

3939
class TypeRefTree extends DerivedTypeTree {
40-
def derivedType(sym: Symbol)(implicit ctx: Context) = sym.typeRef
40+
def derivedTree(sym: Symbol)(implicit ctx: Context) = tpd.TypeTree(sym.typeRef)
41+
}
42+
43+
class TermRefTree extends DerivedTypeTree {
44+
def derivedTree(sym: Symbol)(implicit ctx: Context) = tpd.ref(sym)
4145
}
4246

4347
/** A type tree that computes its type from an existing parameter.
@@ -73,7 +77,7 @@ object desugar {
7377
*
7478
* parameter name == reference name ++ suffix
7579
*/
76-
def derivedType(sym: Symbol)(implicit ctx: Context) = {
80+
def derivedTree(sym: Symbol)(implicit ctx: Context) = {
7781
val relocate = new TypeMap {
7882
val originalOwner = sym.owner
7983
def apply(tp: Type) = tp match {
@@ -91,7 +95,7 @@ object desugar {
9195
mapOver(tp)
9296
}
9397
}
94-
relocate(sym.info)
98+
tpd.TypeTree(relocate(sym.info))
9599
}
96100
}
97101

@@ -301,34 +305,56 @@ object desugar {
301305
val isCaseObject = mods.is(Case) && mods.is(Module)
302306
val isImplicit = mods.is(Implicit)
303307
val isEnum = mods.hasMod[Mod.Enum] && !mods.is(Module)
304-
val isEnumCase = isLegalEnumCase(cdef)
308+
val isEnumCase = mods.hasMod[Mod.EnumCase]
305309
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
306-
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
307-
310+
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
308311

309312
val originalTparams = constr1.tparams
310313
val originalVparamss = constr1.vparamss
311-
val constrTparams = originalTparams.map(toDefParam)
314+
lazy val derivedEnumParams = enumClass.typeParams.map(derivedTypeParam)
315+
val impliedTparams =
316+
if (isEnumCase && originalTparams.isEmpty)
317+
derivedEnumParams.map(tdef => tdef.withFlags(tdef.mods.flags | PrivateLocal))
318+
else
319+
originalTparams
320+
val constrTparams = impliedTparams.map(toDefParam)
312321
val constrVparamss =
313322
if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
314-
if (isCaseClass) ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
323+
if (isCaseClass && originalTparams.isEmpty)
324+
ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
315325
ListOfNil
316326
}
317327
else originalVparamss.nestedMap(toDefParam)
318328
val constr = cpy.DefDef(constr1)(tparams = constrTparams, vparamss = constrVparamss)
319329

320-
// Add constructor type parameters and evidence implicit parameters
321-
// to auxiliary constructors
322-
val normalizedBody = impl.body map {
323-
case ddef: DefDef if ddef.name.isConstructorName =>
324-
decompose(
325-
defDef(
326-
addEvidenceParams(
327-
cpy.DefDef(ddef)(tparams = constrTparams),
328-
evidenceParams(constr1).map(toDefParam))))
329-
case stat =>
330-
stat
330+
val (normalizedBody, enumCases, enumCompanionRef) = {
331+
// Add constructor type parameters and evidence implicit parameters
332+
// to auxiliary constructors; set defaultGetters as a side effect.
333+
def expandConstructor(tree: Tree) = tree match {
334+
case ddef: DefDef if ddef.name.isConstructorName =>
335+
decompose(
336+
defDef(
337+
addEvidenceParams(
338+
cpy.DefDef(ddef)(tparams = constrTparams),
339+
evidenceParams(constr1).map(toDefParam))))
340+
case stat =>
341+
stat
342+
}
343+
// The Identifiers defined by a case
344+
def caseIds(tree: Tree) = tree match {
345+
case tree: MemberDef => Ident(tree.name.toTermName) :: Nil
346+
case PatDef(_, ids, _, _) => ids
347+
}
348+
val stats = impl.body.map(expandConstructor)
349+
if (isEnum) {
350+
val (enumCases, enumStats) = stats.partition(DesugarEnums.isEnumCase)
351+
val enumCompanionRef = new TermRefTree()
352+
val enumImport = Import(enumCompanionRef, enumCases.flatMap(caseIds))
353+
(enumImport :: enumStats, enumCases, enumCompanionRef)
354+
}
355+
else (stats, Nil, EmptyTree)
331356
}
357+
332358
def anyRef = ref(defn.AnyRefAlias.typeRef)
333359

334360
val derivedTparams = constrTparams.map(derivedTypeParam(_))
@@ -361,20 +387,16 @@ object desugar {
361387
val classTypeRef = appliedRef(classTycon)
362388

363389
// a reference to `enumClass`, with type parameters coming from the case constructor
364-
lazy val enumClassTypeRef = enumClass.primaryConstructor.info match {
365-
case info: PolyType =>
366-
if (constrTparams.isEmpty)
367-
interpolatedEnumParent(cdef.pos.startPos)
368-
else if ((constrTparams.corresponds(info.paramNames))((param, name) => param.name == name))
369-
appliedRef(enumClassRef)
370-
else {
371-
ctx.error(i"explicit extends clause needed because type parameters of case and enum class differ"
372-
, cdef.pos.startPos)
373-
appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
374-
}
375-
case _ =>
390+
lazy val enumClassTypeRef =
391+
if (enumClass.typeParams.isEmpty)
376392
enumClassRef
377-
}
393+
else if (originalTparams.isEmpty)
394+
appliedRef(enumClassRef)
395+
else {
396+
ctx.error(i"explicit extends clause needed because both enum case and enum class have type parameters"
397+
, cdef.pos.startPos)
398+
appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
399+
}
378400

379401
// new C[Ts](paramss)
380402
lazy val creatorExpr = New(classTypeRef, constrVparamss nestedMap refOfDef)
@@ -428,6 +450,7 @@ object desugar {
428450
}
429451

430452
// Case classes and case objects get Product parents
453+
// Enum cases get an inferred parent if no parents are given
431454
var parents1 = parents
432455
if (isEnumCase && parents.isEmpty)
433456
parents1 = enumClassTypeRef :: Nil
@@ -473,7 +496,7 @@ object desugar {
473496
.withMods(companionMods | Synthetic))
474497
.withPos(cdef.pos).toList
475498

476-
val companionMeths = defaultGetters ::: eqInstances
499+
val companionMembers = defaultGetters ::: eqInstances ::: enumCases
477500

478501
// The companion object definitions, if a companion is needed, Nil otherwise.
479502
// companion definitions include:
@@ -486,18 +509,17 @@ object desugar {
486509
// For all other classes, the parent is AnyRef.
487510
val companions =
488511
if (isCaseClass) {
489-
// The return type of the `apply` method
512+
// The return type of the `apply` method, and an (empty or singleton) list
513+
// of widening coercions
490514
val (applyResultTpt, widenDefs) =
491515
if (!isEnumCase)
492516
(TypeTree(), Nil)
493517
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
494518
(enumClassTypeRef, Nil)
495-
else {
496-
val tparams = enumClass.typeParams.map(derivedTypeParam)
497-
enumApplyResult(cdef, parents, tparams, appliedRef(enumClassRef, tparams))
498-
}
519+
else
520+
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))
499521

500-
val parent =
522+
val companionParent =
501523
if (constrTparams.nonEmpty ||
502524
constrVparamss.length > 1 ||
503525
mods.is(Abstract) ||
@@ -519,10 +541,10 @@ object desugar {
519541
DefDef(nme.unapply, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
520542
.withMods(synthetic)
521543
}
522-
companionDefs(parent, applyMeths ::: unapplyMeth :: companionMeths)
544+
companionDefs(companionParent, applyMeths ::: unapplyMeth :: companionMembers)
523545
}
524-
else if (companionMeths.nonEmpty)
525-
companionDefs(anyRef, companionMeths)
546+
else if (companionMembers.nonEmpty)
547+
companionDefs(anyRef, companionMembers)
526548
else if (isValueClass) {
527549
constr0.vparamss match {
528550
case (_ :: Nil) :: _ => companionDefs(anyRef, Nil)
@@ -531,6 +553,13 @@ object desugar {
531553
}
532554
else Nil
533555

556+
enumCompanionRef match {
557+
case ref: TermRefTree => // have the enum import watch the companion object
558+
val (modVal: ValDef) :: _ = companions
559+
ref.watching(modVal)
560+
case _ =>
561+
}
562+
534563
// For an implicit class C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, .., pMN: TMN), the method
535564
// synthetic implicit C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, ..., pMN: TMN): C[Ts] =
536565
// new C[Ts](p11, ..., p1N) ... (pM1, ..., pMN) =
@@ -563,7 +592,7 @@ object desugar {
563592
}
564593

565594
val cdef1 = addEnumFlags {
566-
val originalTparamsIt = originalTparams.toIterator
595+
val originalTparamsIt = impliedTparams.toIterator
567596
val originalVparamsIt = originalVparamss.toIterator.flatten
568597
val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt.next().mods))
569598
val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
@@ -603,7 +632,7 @@ object desugar {
603632
val moduleName = checkNotReservedName(mdef).asTermName
604633
val impl = mdef.impl
605634
val mods = mdef.mods
606-
lazy val isEnumCase = isLegalEnumCase(mdef)
635+
lazy val isEnumCase = mods.hasMod[Mod.EnumCase]
607636
if (mods is Package)
608637
PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, impl).withMods(mods &~ Package) :: Nil)
609638
else if (isEnumCase)
@@ -650,7 +679,7 @@ object desugar {
650679
*/
651680
def patDef(pdef: PatDef)(implicit ctx: Context): Tree = flatTree {
652681
val PatDef(mods, pats, tpt, rhs) = pdef
653-
if (mods.hasMod[Mod.EnumCase] && enumCaseIsLegal(pdef))
682+
if (mods.hasMod[Mod.EnumCase])
654683
pats map {
655684
case id: Ident =>
656685
expandSimpleEnumCase(id.name.asTermName, mods,

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import core._
66
import util.Positions._, Types._, Contexts._, Constants._, Names._, NameOps._, Flags._
77
import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._
88
import Decorators._
9-
import reporting.diagnostic.messages.EnumCaseDefinitionInNonEnumOwner
109
import collection.mutable.ListBuffer
1110
import util.Property
1211
import typer.ErrorReporting._
@@ -23,20 +22,21 @@ object DesugarEnums {
2322
/** Attachment containing the number of enum cases and the smallest kind that was seen so far. */
2423
val EnumCaseCount = new Property.Key[(Int, CaseKind.Value)]
2524

26-
/** the enumeration class that is a companion of the current object */
27-
def enumClass(implicit ctx: Context) = ctx.owner.linkedClass
28-
29-
/** Is this an enum case that's situated in a companion object of an enum class? */
30-
def isLegalEnumCase(tree: MemberDef)(implicit ctx: Context): Boolean =
31-
tree.mods.hasMod[Mod.EnumCase] && enumCaseIsLegal(tree)
25+
/** The enumeration class that belongs to an enum case. This works no matter
26+
* whether the case is still in the enum class or it has been transferred to the
27+
* companion object.
28+
*/
29+
def enumClass(implicit ctx: Context): Symbol = {
30+
val cls = ctx.owner
31+
if (cls.is(Module)) cls.linkedClass else cls
32+
}
3233

33-
/** Is enum case `tree` situated in a companion object of an enum class? */
34-
def enumCaseIsLegal(tree: Tree)(implicit ctx: Context): Boolean = (
35-
ctx.owner.is(ModuleClass) && enumClass.derivesFrom(defn.EnumClass)
36-
|| { ctx.error(EnumCaseDefinitionInNonEnumOwner(ctx.owner), tree.pos)
37-
false
38-
}
39-
)
34+
/** Is `tree` an (untyped) enum case? */
35+
def isEnumCase(tree: Tree)(implicit ctx: Context): Boolean = tree match {
36+
case tree: MemberDef => tree.mods.hasMod[Mod.EnumCase]
37+
case PatDef(mods, _, _, _) => mods.hasMod[Mod.EnumCase]
38+
case _ => false
39+
}
4040

4141
/** A reference to the enum class `E`, possibly followed by type arguments.
4242
* Each covariant type parameter is approximated by its lower bound.
@@ -68,8 +68,8 @@ object DesugarEnums {
6868

6969
/** Add implied flags to an enum class or an enum case */
7070
def addEnumFlags(cdef: TypeDef)(implicit ctx: Context) =
71-
if (cdef.mods.hasMod[Mod.Enum]) cdef.withFlags(cdef.mods.flags | Abstract | Sealed)
72-
else if (isLegalEnumCase(cdef)) cdef.withFlags(cdef.mods.flags | Final)
71+
if (cdef.mods.hasMod[Mod.Enum]) cdef.withMods(cdef.mods.withFlags(cdef.mods.flags | Abstract | Sealed))
72+
else if (isEnumCase(cdef)) cdef.withMods(cdef.mods.withFlags(cdef.mods.flags | Final))
7373
else cdef
7474

7575
private def valuesDot(name: String) = Select(Ident(nme.DOLLAR_VALUES), name.toTermName)
@@ -193,24 +193,20 @@ object DesugarEnums {
193193
}
194194

195195
/** Expand a module definition representing a parameterless enum case */
196-
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, pos: Position)(implicit ctx: Context): Tree =
196+
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, pos: Position)(implicit ctx: Context): Tree = {
197+
assert(impl.body.isEmpty)
197198
if (impl.parents.isEmpty)
198-
if (impl.body.isEmpty)
199-
expandSimpleEnumCase(name, mods, pos)
200-
else {
201-
val parent = interpolatedEnumParent(pos)
202-
expandEnumModule(name, cpy.Template(impl)(parents = parent :: Nil), mods, pos)
203-
}
199+
expandSimpleEnumCase(name, mods, pos)
204200
else {
205201
def toStringMeth =
206202
DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), Literal(Constant(name.toString)))
207203
.withFlags(Override)
208204
val (tagMeth, scaffolding) = enumTagMeth(CaseKind.Object)
209-
val impl1 = cpy.Template(impl)(body =
210-
impl.body ++ List(tagMeth, toStringMeth) ++ registerCall)
205+
val impl1 = cpy.Template(impl)(body = List(tagMeth, toStringMeth) ++ registerCall)
211206
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods | Final)
212207
flatTree(scaffolding ::: vdef :: Nil).withPos(pos)
213208
}
209+
}
214210

215211
/** Expand a simple enum case */
216212
def expandSimpleEnumCase(name: TermName, mods: Modifiers, pos: Position)(implicit ctx: Context): Tree =

compiler/src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,7 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
308308
def prefixIsElidable(tp: NamedType)(implicit ctx: Context) = {
309309
val typeIsElidable = tp.prefix match {
310310
case pre: ThisType =>
311+
tp.isType ||
311312
pre.cls.isStaticOwner ||
312313
tp.symbol.isParamOrAccessor && !pre.cls.is(Trait) && ctx.owner.enclosingClass == pre.cls
313314
// was ctx.owner.enclosingClass.derivesFrom(pre.cls) which was not tight enough

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
229229
*/
230230
def ensureCompletions(implicit ctx: Context): Unit = ()
231231

232-
/** The method that computes the type of this tree */
233-
def derivedType(originalSym: Symbol)(implicit ctx: Context): Type
232+
/** The method that computes the tree with the derived type */
233+
def derivedTree(originalSym: Symbol)(implicit ctx: Context): tpd.Tree
234234
}
235235

236236
/** Property key containing TypeTrees whose type is computed

0 commit comments

Comments
 (0)