Skip to content

Commit 1f835f5

Browse files
committed
Implement enum desugaring
1 parent d691126 commit 1f835f5

File tree

6 files changed

+214
-35
lines changed

6 files changed

+214
-35
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import reporting.diagnostic.messages._
1414

1515
object desugar {
1616
import untpd._
17+
import DesugarEnums._
1718

1819
/** Tags a .withFilter call generated by desugaring a for expression.
1920
* Such calls can alternatively be rewritten to use filter.
@@ -263,7 +264,9 @@ object desugar {
263264
val className = checkNotReservedName(cdef).asTypeName
264265
val impl @ Template(constr0, parents, self, _) = cdef.rhs
265266
val mods = cdef.mods
266-
val companionMods = mods.withFlags((mods.flags & AccessFlags).toCommonFlags)
267+
val companionMods = mods
268+
.withFlags((mods.flags & AccessFlags).toCommonFlags)
269+
.withMods(mods.mods.filter(!_.isInstanceOf[Mod.EnumCase]))
267270

268271
val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true) match {
269272
case meth: DefDef => (meth, Nil)
@@ -288,17 +291,22 @@ object desugar {
288291
}
289292

290293
val isCaseClass = mods.is(Case) && !mods.is(Module)
294+
val isEnum = mods.hasMod[Mod.Enum]
295+
val isEnumCase = isLegalEnumCase(cdef)
291296
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
292297
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
293298

294-
val constrTparams = constr1.tparams map toDefParam
299+
val originalTparams =
300+
if (isEnumCase && parents.isEmpty) reconstitutedEnumTypeParams(cdef.pos.startPos)
301+
else constr1.tparams
302+
val originalVparamss = constr1.vparamss
303+
val constrTparams = originalTparams.map(toDefParam)
295304
val constrVparamss =
296-
if (constr1.vparamss.isEmpty) { // ensure parameter list is non-empty
297-
if (isCaseClass)
298-
ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
305+
if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
306+
if (isCaseClass) ctx.error(CaseClassMissingParamList(cdef), cdef.namePos)
299307
ListOfNil
300308
}
301-
else constr1.vparamss.nestedMap(toDefParam)
309+
else originalVparamss.nestedMap(toDefParam)
302310
val constr = cpy.DefDef(constr1)(tparams = constrTparams, vparamss = constrVparamss)
303311

304312
// Add constructor type parameters and evidence implicit parameters
@@ -312,21 +320,22 @@ object desugar {
312320
stat
313321
}
314322

315-
val derivedTparams = constrTparams map derivedTypeParam
323+
val derivedTparams =
324+
if (isEnumCase) constrTparams else constrTparams map derivedTypeParam
316325
val derivedVparamss = constrVparamss nestedMap derivedTermParam
317326
val arity = constrVparamss.head.length
318327

319-
var classTycon: Tree = EmptyTree
328+
val classTycon: Tree = new TypeRefTree // watching is set at end of method
320329

321-
// a reference to the class type, with all parameters given.
322-
val classTypeRef/*: Tree*/ = {
323-
// -language:keepUnions difference: classTypeRef needs type annotation, otherwise
324-
// infers Ident | AppliedTypeTree, which
325-
// renders the :\ in companions below untypable.
326-
classTycon = (new TypeRefTree) withPos cdef.pos.startPos // watching is set at end of method
327-
val tparams = impl.constr.tparams
328-
if (tparams.isEmpty) classTycon else AppliedTypeTree(classTycon, tparams map refOfDef)
329-
}
330+
def appliedRef(tycon: Tree) =
331+
(if (constrTparams.isEmpty) tycon
332+
else AppliedTypeTree(tycon, constrTparams map refOfDef))
333+
.withPos(cdef.pos.startPos)
334+
335+
// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
336+
val classTypeRef = appliedRef(classTycon)
337+
// a refereence to `enumClass`, with type parameters coming from the constructor
338+
lazy val enumClassTypeRef = appliedRef(enumClassRef)
330339

331340
// new C[Ts](paramss)
332341
lazy val creatorExpr = New(classTypeRef, constrVparamss nestedMap refOfDef)
@@ -374,7 +383,9 @@ object desugar {
374383
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
375384
.withMods(synthetic) :: Nil
376385
}
377-
copyMeths ::: productElemMeths.toList
386+
387+
val enumTagMeths = if (isEnumCase) enumTagMeth :: Nil else Nil
388+
copyMeths ::: enumTagMeths ::: productElemMeths.toList
378389
}
379390
else Nil
380391

@@ -387,8 +398,12 @@ object desugar {
387398

388399
// Case classes and case objects get a ProductN parent
389400
var parents1 = parents
401+
if (isEnumCase && parents.isEmpty)
402+
parents1 = enumClassTypeRef :: Nil
390403
if (mods.is(Case) && arity <= Definitions.MaxTupleArity)
391-
parents1 = parents1 :+ productConstr(arity)
404+
parents1 = parents1 :+ productConstr(arity) // TODO: This also adds Product0 to caes objects. Do we want that?
405+
if (isEnum)
406+
parents1 = parents1 :+ ref(defn.EnumType)
392407

393408
// The thicket which is the desugared version of the companion object
394409
// synthetic object C extends parentTpt { defs }
@@ -419,9 +434,11 @@ object desugar {
419434
else (constrVparamss :\ classTypeRef) ((vparams, restpe) => Function(vparams map (_.tpt), restpe))
420435
val applyMeths =
421436
if (mods is Abstract) Nil
422-
else
423-
DefDef(nme.apply, derivedTparams, derivedVparamss, TypeTree(), creatorExpr)
437+
else {
438+
val restpe = if (isEnumCase) enumClassTypeRef else TypeTree()
439+
DefDef(nme.apply, derivedTparams, derivedVparamss, restpe, creatorExpr)
424440
.withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized)) :: Nil
441+
}
425442
val unapplyMeth = {
426443
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
427444
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
@@ -464,12 +481,12 @@ object desugar {
464481
else cpy.ValDef(self)(tpt = selfType).withMods(self.mods | SelfName)
465482
}
466483

467-
val cdef1 = {
468-
val originalTparams = constr1.tparams.toIterator
469-
val originalVparams = constr1.vparamss.toIterator.flatten
470-
val tparamAccessors = derivedTparams.map(_.withMods(originalTparams.next.mods))
484+
val cdef1 = addEnumFlags {
485+
val originalTparamsIt = originalTparams.toIterator
486+
val originalVparamsIt = originalVparamss.toIterator.flatten
487+
val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt.next.mods))
471488
val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
472-
val vparamAccessors = derivedVparamss.flatten.map(_.withMods(originalVparams.next.mods | caseAccessor))
489+
val vparamAccessors = derivedVparamss.flatten.map(_.withMods(originalVparamsIt.next.mods | caseAccessor))
473490
cpy.TypeDef(cdef)(
474491
name = className,
475492
rhs = cpy.Template(impl)(constr, parents1, self1,
@@ -497,23 +514,26 @@ object desugar {
497514
*/
498515
def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = {
499516
val moduleName = checkNotReservedName(mdef).asTermName
500-
val tmpl = mdef.impl
517+
val impl = mdef.impl
501518
val mods = mdef.mods
519+
lazy val isEnumCase = isLegalEnumCase(mdef)
502520
if (mods is Package)
503-
PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, tmpl).withMods(mods &~ Package) :: Nil)
521+
PackageDef(Ident(moduleName), cpy.ModuleDef(mdef)(nme.PACKAGE, impl).withMods(mods &~ Package) :: Nil)
522+
else if (isEnumCase)
523+
expandEnumModule(moduleName, impl, mods, mdef.pos)
504524
else {
505525
val clsName = moduleName.moduleClassName
506526
val clsRef = Ident(clsName)
507527
val modul = ValDef(moduleName, clsRef, New(clsRef, Nil))
508528
.withMods(mods | ModuleCreationFlags | mods.flags & AccessFlags)
509529
.withPos(mdef.pos)
510-
val ValDef(selfName, selfTpt, _) = tmpl.self
511-
val selfMods = tmpl.self.mods
512-
if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), tmpl.self.pos)
513-
val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), tmpl.self.rhs)
530+
val ValDef(selfName, selfTpt, _) = impl.self
531+
val selfMods = impl.self.mods
532+
if (!selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType(mdef), impl.self.pos)
533+
val clsSelf = ValDef(selfName, SingletonTypeTree(Ident(moduleName)), impl.self.rhs)
514534
.withMods(selfMods)
515-
.withPos(tmpl.self.pos orElse tmpl.pos.startPos)
516-
val clsTmpl = cpy.Template(tmpl)(self = clsSelf, body = tmpl.body)
535+
.withPos(impl.self.pos orElse impl.pos.startPos)
536+
val clsTmpl = cpy.Template(impl)(self = clsSelf, body = impl.body)
517537
val cls = TypeDef(clsName, clsTmpl)
518538
.withMods(mods.toTypeFlags & RetainedModuleClassFlags | ModuleClassCreationFlags)
519539
Thicket(modul, classDef(cls).withPos(mdef.pos))
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package dotty.tools
2+
package dotc
3+
package ast
4+
5+
import core._
6+
import util.Positions._, Types._, Contexts._, Constants._, Names._, NameOps._, Flags._
7+
import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._
8+
import Decorators._
9+
import collection.mutable.ListBuffer
10+
import util.Property
11+
import reporting.diagnostic.messages._
12+
13+
object DesugarEnums {
14+
import untpd._
15+
import desugar.DerivedFromParamTree
16+
17+
val EnumCaseCount = new Property.Key[Int]
18+
19+
def enumClass(implicit ctx: Context) = ctx.owner.linkedClass
20+
21+
def nextEnumTag(implicit ctx: Context): Int = {
22+
val result = ctx.tree.removeAttachment(EnumCaseCount).getOrElse(0)
23+
ctx.tree.pushAttachment(EnumCaseCount, result + 1)
24+
result
25+
}
26+
27+
def isLegalEnumCase(tree: MemberDef)(implicit ctx: Context): Boolean = {
28+
tree.mods.hasMod[Mod.EnumCase] &&
29+
( ctx.owner.is(ModuleClass) && enumClass.derivesFrom(defn.EnumClass)
30+
|| { ctx.error(em"case not allowed here, since owner ${ctx.owner} is not an `enum' object", tree.pos)
31+
false
32+
}
33+
)
34+
}
35+
36+
/** Type parameters reconstituted from the constructor
37+
* of the `enum' class corresponding to an enum case
38+
*/
39+
def reconstitutedEnumTypeParams(pos: Position)(implicit ctx: Context) = {
40+
val tparams = enumClass.primaryConstructor.info match {
41+
case info: PolyType =>
42+
ctx.newTypeParams(ctx.newLocalDummy(enumClass), info.paramNames, EmptyFlags, info.instantiateBounds)
43+
case _ =>
44+
Nil
45+
}
46+
for (tparam <- tparams) yield {
47+
val tbounds = new DerivedFromParamTree
48+
tbounds.pushAttachment(OriginalSymbol, tparam)
49+
TypeDef(tparam.name, tbounds)
50+
.withFlags(Param | PrivateLocal).withPos(pos)
51+
}
52+
}
53+
54+
def enumTagMeth(implicit ctx: Context) =
55+
DefDef(nme.enumTag, Nil, Nil, TypeTree(), Literal(Constant(nextEnumTag)))
56+
57+
def enumClassRef(implicit ctx: Context) = TypeTree(enumClass.typeRef)
58+
59+
def addEnumFlags(cdef: TypeDef)(implicit ctx: Context) =
60+
if (cdef.mods.hasMod[Mod.Enum]) cdef.withFlags(cdef.mods.flags | Abstract | Sealed)
61+
else if (isLegalEnumCase(cdef)) cdef.withFlags(cdef.mods.flags | Final)
62+
else cdef
63+
64+
/** The following lists of definitions for an enum type E:
65+
*
66+
* private val $values = new EnumValues[E]
67+
* def valueOf: Int => E = $values
68+
* def values = $values.values
69+
*
70+
* private def $new(tag: Int, name: String) = new E {
71+
* def enumTag = tag
72+
* override def toString = name
73+
* $values.register(this)
74+
* }
75+
*/
76+
private def enumScaffolding(implicit ctx: Context): List[Tree] = {
77+
val valsRef = Ident(nme.DOLLAR_VALUES)
78+
def param(name: TermName, typ: Type) =
79+
ValDef(name, TypeTree(typ), EmptyTree).withFlags(Param)
80+
val privateValuesDef =
81+
ValDef(nme.DOLLAR_VALUES, TypeTree(),
82+
New(TypeTree(defn.EnumValuesType.appliedTo(enumClass.typeRef :: Nil)), ListOfNil))
83+
.withFlags(Private)
84+
val valueOfDef =
85+
DefDef(nme.valueOf, Nil, Nil,
86+
TypeTree(defn.FunctionOf(defn.IntType :: Nil, enumClass.typeRef)), valsRef)
87+
val valuesDef =
88+
DefDef(nme.values, Nil, Nil, TypeTree(), Select(valsRef, nme.values))
89+
val enumTagDef =
90+
DefDef(nme.enumTag, Nil, Nil, TypeTree(), Ident(nme.tag))
91+
val toStringDef =
92+
DefDef(nme.toString_, Nil, Nil, TypeTree(), Ident(nme.name))
93+
.withFlags(Override)
94+
val registerStat =
95+
Apply(Select(valsRef, nme.register), This(EmptyTypeIdent) :: Nil)
96+
def creator = New(Template(emptyConstructor, enumClassRef :: Nil, EmptyValDef,
97+
List(enumTagDef, toStringDef, registerStat)))
98+
val newDef =
99+
DefDef(nme.DOLLAR_NEW, Nil,
100+
List(List(param(nme.tag, defn.IntType), param(nme.name, defn.StringType))),
101+
TypeTree(), creator)
102+
List(privateValuesDef, valueOfDef, valuesDef, newDef)
103+
}
104+
105+
def expandEnumModule(name: TermName, impl: Template, mods: Modifiers, pos: Position)(implicit ctx: Context): Tree = {
106+
def nameLit = Literal(Constant(name.toString))
107+
if (impl.parents.isEmpty) {
108+
if (reconstitutedEnumTypeParams(pos).nonEmpty)
109+
ctx.error(i"illegal enum value of generic $enumClass: an explicit `extends' clause is needed", pos)
110+
val tag = nextEnumTag
111+
val prefix = if (tag == 0) enumScaffolding else Nil
112+
val creator = Apply(Ident(nme.DOLLAR_NEW), List(Literal(Constant(tag)), nameLit))
113+
val vdef = ValDef(name, enumClassRef, creator).withMods(mods | Final).withPos(pos)
114+
flatTree(prefix ::: vdef :: Nil).withPos(pos.startPos)
115+
} else {
116+
def toStringMeth =
117+
DefDef(nme.toString_, Nil, Nil, TypeTree(defn.StringType), nameLit)
118+
.withFlags(Override)
119+
val impl1 = cpy.Template(impl)(body =
120+
impl.body ++ List(enumTagMeth, toStringMeth))
121+
ValDef(name, TypeTree(), New(impl1)).withMods(mods | Final).withPos(pos)
122+
}
123+
}
124+
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,10 @@ class Definitions {
507507
def DynamicClass(implicit ctx: Context) = DynamicType.symbol.asClass
508508
lazy val OptionType: TypeRef = ctx.requiredClassRef("scala.Option")
509509
def OptionClass(implicit ctx: Context) = OptionType.symbol.asClass
510+
lazy val EnumType: TypeRef = ctx.requiredClassRef("scala.Enum")
511+
def EnumClass(implicit ctx: Context) = EnumType.symbol.asClass
512+
lazy val EnumValuesType: TypeRef = ctx.requiredClassRef("scala.runtime.EnumValues")
513+
def EnumValuesClass(implicit ctx: Context) = EnumValuesType.symbol.asClass
510514
lazy val ProductType: TypeRef = ctx.requiredClassRef("scala.Product")
511515
def ProductClass(implicit ctx: Context) = ProductType.symbol.asClass
512516
lazy val Product_canEqualR = ProductClass.requiredMethodRef(nme.canEqual_)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ object StdNames {
130130
val COMPANION_CLASS_METHOD: N = "companion$class"
131131
val TRAIT_SETTER_SEPARATOR: N = "$_setter_$"
132132
val DIRECT_SUFFIX: N = "$direct"
133+
val DOLLAR_VALUES: N = "$values"
134+
val DOLLAR_NEW: N = "$new"
133135

134136
// value types (and AnyRef) are all used as terms as well
135137
// as (at least) arguments to the @specialize annotation.
@@ -393,6 +395,7 @@ object StdNames {
393395
val elem: N = "elem"
394396
val emptyValDef: N = "emptyValDef"
395397
val ensureAccessible : N = "ensureAccessible"
398+
val enumTag: N = "enumTag"
396399
val eq: N = "eq"
397400
val equalsNumChar : N = "equalsNumChar"
398401
val equalsNumNum : N = "equalsNumNum"
@@ -472,6 +475,7 @@ object StdNames {
472475
val productPrefix: N = "productPrefix"
473476
val readResolve: N = "readResolve"
474477
val reflect : N = "reflect"
478+
val register: N = "register"
475479
val reify : N = "reify"
476480
val rootMirror : N = "rootMirror"
477481
val runOrElse: N = "runOrElse"
@@ -497,6 +501,7 @@ object StdNames {
497501
val staticModule : N = "staticModule"
498502
val staticPackage : N = "staticPackage"
499503
val synchronized_ : N = "synchronized"
504+
val tag: N = "tag"
500505
val tail: N = "tail"
501506
val `then` : N = "then"
502507
val this_ : N = "this"
@@ -521,7 +526,7 @@ object StdNames {
521526
val updateDynamic: N = "updateDynamic"
522527
val value: N = "value"
523528
val valueOf : N = "valueOf"
524-
val values : N = "values"
529+
val values: N = "values"
525530
val view_ : N = "view"
526531
val wait_ : N = "wait"
527532
val withFilter: N = "withFilter"

library/src/scala/Enum.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package scala
2+
3+
/** A base trait of all enum classes */
4+
trait Enum {
5+
6+
/** A number uniquely identifying a case of an enum */
7+
def enumTag: Int
8+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
package scala.runtime
2+
3+
import scala.collection.immutable.Seq
4+
import scala.collection.mutable.ResizableArray
5+
6+
class EnumValues[E <: Enum] extends ResizableArray[E] {
7+
private var valuesCache: List[E] = Nil
8+
def register(v: E) = {
9+
ensureSize(v.enumTag + 1)
10+
size0 = size0 max (v.enumTag + 1)
11+
array(v.enumTag) = v
12+
valuesCache = null
13+
}
14+
def values: Seq[E] = {
15+
if (valuesCache == null) valuesCache = array.filter(_ != null).toList.asInstanceOf[scala.List[E]]
16+
valuesCache
17+
}
18+
}

0 commit comments

Comments
 (0)