@@ -14,6 +14,7 @@ import reporting.diagnostic.messages._
14
14
15
15
object desugar {
16
16
import untpd ._
17
+ import DesugarEnums ._
17
18
18
19
/** Tags a .withFilter call generated by desugaring a for expression.
19
20
* Such calls can alternatively be rewritten to use filter.
@@ -263,7 +264,9 @@ object desugar {
263
264
val className = checkNotReservedName(cdef).asTypeName
264
265
val impl @ Template (constr0, parents, self, _) = cdef.rhs
265
266
val mods = cdef.mods
266
- val companionMods = mods.withFlags((mods.flags & AccessFlags ).toCommonFlags)
267
+ val companionMods = mods
268
+ .withFlags((mods.flags & AccessFlags ).toCommonFlags)
269
+ .withMods(mods.mods.filter(! _.isInstanceOf [Mod .EnumCase ]))
267
270
268
271
val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true ) match {
269
272
case meth : DefDef => (meth, Nil )
@@ -288,17 +291,22 @@ object desugar {
288
291
}
289
292
290
293
val isCaseClass = mods.is(Case ) && ! mods.is(Module )
294
+ val isEnum = mods.hasMod[Mod .Enum ]
295
+ val isEnumCase = isLegalEnumCase(cdef)
291
296
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
292
297
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
293
298
294
- val constrTparams = constr1.tparams map toDefParam
299
+ val originalTparams =
300
+ if (isEnumCase && parents.isEmpty) reconstitutedEnumTypeParams(cdef.pos.startPos)
301
+ else constr1.tparams
302
+ val originalVparamss = constr1.vparamss
303
+ val constrTparams = originalTparams.map(toDefParam)
295
304
val constrVparamss =
296
- if (constr1.vparamss.isEmpty) { // ensure parameter list is non-empty
297
- if (isCaseClass)
298
- ctx.error(CaseClassMissingParamList (cdef), cdef.namePos)
305
+ if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
306
+ if (isCaseClass) ctx.error(CaseClassMissingParamList (cdef), cdef.namePos)
299
307
ListOfNil
300
308
}
301
- else constr1.vparamss .nestedMap(toDefParam)
309
+ else originalVparamss .nestedMap(toDefParam)
302
310
val constr = cpy.DefDef (constr1)(tparams = constrTparams, vparamss = constrVparamss)
303
311
304
312
// Add constructor type parameters and evidence implicit parameters
@@ -312,21 +320,22 @@ object desugar {
312
320
stat
313
321
}
314
322
315
- val derivedTparams = constrTparams map derivedTypeParam
323
+ val derivedTparams =
324
+ if (isEnumCase) constrTparams else constrTparams map derivedTypeParam
316
325
val derivedVparamss = constrVparamss nestedMap derivedTermParam
317
326
val arity = constrVparamss.head.length
318
327
319
- var classTycon : Tree = EmptyTree
328
+ val classTycon : Tree = new TypeRefTree // watching is set at end of method
320
329
321
- // a reference to the class type, with all parameters given.
322
- val classTypeRef /* : Tree */ = {
323
- // -language:keepUnions difference: classTypeRef needs type annotation, otherwise
324
- // infers Ident | AppliedTypeTree, which
325
- // renders the :\ in companions below untypable.
326
- classTycon = ( new TypeRefTree ) withPos cdef.pos.startPos // watching is set at end of method
327
- val tparams = impl.constr.tparams
328
- if (tparams.isEmpty) classTycon else AppliedTypeTree (classTycon, tparams map refOfDef)
329
- }
330
+ def appliedRef ( tycon : Tree ) =
331
+ ( if (constrTparams.isEmpty) tycon
332
+ else AppliedTypeTree (tycon, constrTparams map refOfDef))
333
+ .withPos(cdef.pos.startPos)
334
+
335
+ // a reference to the class type bound by `cdef`, with type parameters coming from the constructor
336
+ val classTypeRef = appliedRef(classTycon)
337
+ // a refereence to `enumClass`, with type parameters coming from the constructor
338
+ lazy val enumClassTypeRef = appliedRef(enumClassRef)
330
339
331
340
// new C[Ts](paramss)
332
341
lazy val creatorExpr = New (classTypeRef, constrVparamss nestedMap refOfDef)
@@ -374,7 +383,9 @@ object desugar {
374
383
DefDef (nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree (), creatorExpr)
375
384
.withMods(synthetic) :: Nil
376
385
}
377
- copyMeths ::: productElemMeths.toList
386
+
387
+ val enumTagMeths = if (isEnumCase) enumTagMeth :: Nil else Nil
388
+ copyMeths ::: enumTagMeths ::: productElemMeths.toList
378
389
}
379
390
else Nil
380
391
@@ -387,8 +398,12 @@ object desugar {
387
398
388
399
// Case classes and case objects get a ProductN parent
389
400
var parents1 = parents
401
+ if (isEnumCase && parents.isEmpty)
402
+ parents1 = enumClassTypeRef :: Nil
390
403
if (mods.is(Case ) && arity <= Definitions .MaxTupleArity )
391
- parents1 = parents1 :+ productConstr(arity)
404
+ parents1 = parents1 :+ productConstr(arity) // TODO: This also adds Product0 to caes objects. Do we want that?
405
+ if (isEnum)
406
+ parents1 = parents1 :+ ref(defn.EnumType )
392
407
393
408
// The thicket which is the desugared version of the companion object
394
409
// synthetic object C extends parentTpt { defs }
@@ -419,9 +434,11 @@ object desugar {
419
434
else (constrVparamss :\ classTypeRef) ((vparams, restpe) => Function (vparams map (_.tpt), restpe))
420
435
val applyMeths =
421
436
if (mods is Abstract ) Nil
422
- else
423
- DefDef (nme.apply, derivedTparams, derivedVparamss, TypeTree (), creatorExpr)
437
+ else {
438
+ val restpe = if (isEnumCase) enumClassTypeRef else TypeTree ()
439
+ DefDef (nme.apply, derivedTparams, derivedVparamss, restpe, creatorExpr)
424
440
.withFlags(Synthetic | (constr1.mods.flags & DefaultParameterized )) :: Nil
441
+ }
425
442
val unapplyMeth = {
426
443
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
427
444
val unapplyRHS = if (arity == 0 ) Literal (Constant (true )) else Ident (unapplyParam.name)
@@ -464,12 +481,12 @@ object desugar {
464
481
else cpy.ValDef (self)(tpt = selfType).withMods(self.mods | SelfName )
465
482
}
466
483
467
- val cdef1 = {
468
- val originalTparams = constr1.tparams .toIterator
469
- val originalVparams = constr1.vparamss .toIterator.flatten
470
- val tparamAccessors = derivedTparams.map(_.withMods(originalTparams .next.mods))
484
+ val cdef1 = addEnumFlags {
485
+ val originalTparamsIt = originalTparams .toIterator
486
+ val originalVparamsIt = originalVparamss .toIterator.flatten
487
+ val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt .next.mods))
471
488
val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
472
- val vparamAccessors = derivedVparamss.flatten.map(_.withMods(originalVparams .next.mods | caseAccessor))
489
+ val vparamAccessors = derivedVparamss.flatten.map(_.withMods(originalVparamsIt .next.mods | caseAccessor))
473
490
cpy.TypeDef (cdef)(
474
491
name = className,
475
492
rhs = cpy.Template (impl)(constr, parents1, self1,
@@ -497,23 +514,26 @@ object desugar {
497
514
*/
498
515
def moduleDef (mdef : ModuleDef )(implicit ctx : Context ): Tree = {
499
516
val moduleName = checkNotReservedName(mdef).asTermName
500
- val tmpl = mdef.impl
517
+ val impl = mdef.impl
501
518
val mods = mdef.mods
519
+ lazy val isEnumCase = isLegalEnumCase(mdef)
502
520
if (mods is Package )
503
- PackageDef (Ident (moduleName), cpy.ModuleDef (mdef)(nme.PACKAGE , tmpl).withMods(mods &~ Package ) :: Nil )
521
+ PackageDef (Ident (moduleName), cpy.ModuleDef (mdef)(nme.PACKAGE , impl).withMods(mods &~ Package ) :: Nil )
522
+ else if (isEnumCase)
523
+ expandEnumModule(moduleName, impl, mods, mdef.pos)
504
524
else {
505
525
val clsName = moduleName.moduleClassName
506
526
val clsRef = Ident (clsName)
507
527
val modul = ValDef (moduleName, clsRef, New (clsRef, Nil ))
508
528
.withMods(mods | ModuleCreationFlags | mods.flags & AccessFlags )
509
529
.withPos(mdef.pos)
510
- val ValDef (selfName, selfTpt, _) = tmpl .self
511
- val selfMods = tmpl .self.mods
512
- if (! selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType (mdef), tmpl .self.pos)
513
- val clsSelf = ValDef (selfName, SingletonTypeTree (Ident (moduleName)), tmpl .self.rhs)
530
+ val ValDef (selfName, selfTpt, _) = impl .self
531
+ val selfMods = impl .self.mods
532
+ if (! selfTpt.isEmpty) ctx.error(ObjectMayNotHaveSelfType (mdef), impl .self.pos)
533
+ val clsSelf = ValDef (selfName, SingletonTypeTree (Ident (moduleName)), impl .self.rhs)
514
534
.withMods(selfMods)
515
- .withPos(tmpl .self.pos orElse tmpl .pos.startPos)
516
- val clsTmpl = cpy.Template (tmpl )(self = clsSelf, body = tmpl .body)
535
+ .withPos(impl .self.pos orElse impl .pos.startPos)
536
+ val clsTmpl = cpy.Template (impl )(self = clsSelf, body = impl .body)
517
537
val cls = TypeDef (clsName, clsTmpl)
518
538
.withMods(mods.toTypeFlags & RetainedModuleClassFlags | ModuleClassCreationFlags )
519
539
Thicket (modul, classDef(cls).withPos(mdef.pos))
0 commit comments