@@ -305,14 +305,19 @@ object desugar {
305
305
val isCaseObject = mods.is(Case ) && mods.is(Module )
306
306
val isImplicit = mods.is(Implicit )
307
307
val isEnum = mods.hasMod[Mod .Enum ] && ! mods.is(Module )
308
- val isEnumCase = isLegalEnumCase(cdef)
308
+ val isEnumCase = mods.hasMod[ Mod . EnumCase ]
309
309
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
310
- // This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
311
-
310
+ // This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
312
311
313
312
val originalTparams = constr1.tparams
314
313
val originalVparamss = constr1.vparamss
315
- val constrTparams = originalTparams.map(toDefParam)
314
+ lazy val derivedEnumParams = enumClass.typeParams.map(derivedTypeParam)
315
+ val impliedTparams =
316
+ if (isEnumCase && originalTparams.isEmpty)
317
+ derivedEnumParams.map(tdef => tdef.withFlags(tdef.mods.flags | PrivateLocal ))
318
+ else
319
+ originalTparams
320
+ val constrTparams = impliedTparams.map(toDefParam)
316
321
val constrVparamss =
317
322
if (originalVparamss.isEmpty) { // ensure parameter list is non-empty
318
323
if (isCaseClass) ctx.error(CaseClassMissingParamList (cdef), cdef.namePos)
@@ -321,18 +326,34 @@ object desugar {
321
326
else originalVparamss.nestedMap(toDefParam)
322
327
val constr = cpy.DefDef (constr1)(tparams = constrTparams, vparamss = constrVparamss)
323
328
324
- // Add constructor type parameters and evidence implicit parameters
325
- // to auxiliary constructors
326
- val normalizedBody = impl.body map {
327
- case ddef : DefDef if ddef.name.isConstructorName =>
328
- decompose(
329
- defDef(
330
- addEvidenceParams(
331
- cpy.DefDef (ddef)(tparams = constrTparams),
332
- evidenceParams(constr1).map(toDefParam))))
333
- case stat =>
334
- stat
329
+ val (normalizedBody, enumCases, enumCompanionRef) = {
330
+ // Add constructor type parameters and evidence implicit parameters
331
+ // to auxiliary constructors; set defaultGetters as a side effect.
332
+ def expandConstructor (tree : Tree ) = tree match {
333
+ case ddef : DefDef if ddef.name.isConstructorName =>
334
+ decompose(
335
+ defDef(
336
+ addEvidenceParams(
337
+ cpy.DefDef (ddef)(tparams = constrTparams),
338
+ evidenceParams(constr1).map(toDefParam))))
339
+ case stat =>
340
+ stat
341
+ }
342
+ // The Identifiers defined by a case
343
+ def caseIds (tree : Tree ) = tree match {
344
+ case tree : MemberDef => Ident (tree.name.toTermName) :: Nil
345
+ case PatDef (_, ids, _, _) => ids
346
+ }
347
+ val stats = impl.body.map(expandConstructor)
348
+ if (isEnum) {
349
+ val (enumCases, enumStats) = stats.partition(DesugarEnums .isEnumCase)
350
+ val enumCompanionRef = new TermRefTree ()
351
+ val enumImport = Import (enumCompanionRef, enumCases.flatMap(caseIds))
352
+ (enumImport :: enumStats, enumCases, enumCompanionRef)
353
+ }
354
+ else (stats, Nil , EmptyTree )
335
355
}
356
+
336
357
def anyRef = ref(defn.AnyRefAlias .typeRef)
337
358
338
359
val derivedTparams = constrTparams.map(derivedTypeParam(_))
@@ -365,20 +386,16 @@ object desugar {
365
386
val classTypeRef = appliedRef(classTycon)
366
387
367
388
// a reference to `enumClass`, with type parameters coming from the case constructor
368
- lazy val enumClassTypeRef = enumClass.primaryConstructor.info match {
369
- case info : PolyType =>
370
- if (constrTparams.isEmpty)
371
- interpolatedEnumParent(cdef.pos.startPos)
372
- else if ((constrTparams.corresponds(info.paramNames))((param, name) => param.name == name))
373
- appliedRef(enumClassRef)
374
- else {
375
- ctx.error(i " explicit extends clause needed because type parameters of case and enum class differ "
376
- , cdef.pos.startPos)
377
- appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
378
- }
379
- case _ =>
389
+ lazy val enumClassTypeRef =
390
+ if (enumClass.typeParams.isEmpty)
380
391
enumClassRef
381
- }
392
+ else if (originalTparams.isEmpty)
393
+ appliedRef(enumClassRef)
394
+ else {
395
+ ctx.error(i " explicit extends clause needed because both enum case and enum class have type parameters "
396
+ , cdef.pos.startPos)
397
+ appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
398
+ }
382
399
383
400
// new C[Ts](paramss)
384
401
lazy val creatorExpr = New (classTypeRef, constrVparamss nestedMap refOfDef)
@@ -432,6 +449,7 @@ object desugar {
432
449
}
433
450
434
451
// Case classes and case objects get Product parents
452
+ // Enum cases get an inferred parent if no parents are given
435
453
var parents1 = parents
436
454
if (isEnumCase && parents.isEmpty)
437
455
parents1 = enumClassTypeRef :: Nil
@@ -477,7 +495,7 @@ object desugar {
477
495
.withMods(companionMods | Synthetic ))
478
496
.withPos(cdef.pos).toList
479
497
480
- val companionMeths = defaultGetters ::: eqInstances
498
+ val companionMembers = defaultGetters ::: eqInstances ::: enumCases
481
499
482
500
// The companion object definitions, if a companion is needed, Nil otherwise.
483
501
// companion definitions include:
@@ -490,18 +508,17 @@ object desugar {
490
508
// For all other classes, the parent is AnyRef.
491
509
val companions =
492
510
if (isCaseClass) {
493
- // The return type of the `apply` method
511
+ // The return type of the `apply` method, and an (empty or singleton) list
512
+ // of widening coercions
494
513
val (applyResultTpt, widenDefs) =
495
514
if (! isEnumCase)
496
515
(TypeTree (), Nil )
497
516
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
498
517
(enumClassTypeRef, Nil )
499
- else {
500
- val tparams = enumClass.typeParams.map(derivedTypeParam)
501
- enumApplyResult(cdef, parents, tparams, appliedRef(enumClassRef, tparams))
502
- }
518
+ else
519
+ enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))
503
520
504
- val parent =
521
+ val companionParent =
505
522
if (constrTparams.nonEmpty ||
506
523
constrVparamss.length > 1 ||
507
524
mods.is(Abstract ) ||
@@ -523,10 +540,10 @@ object desugar {
523
540
DefDef (nme.unapply, derivedTparams, (unapplyParam :: Nil ) :: Nil , TypeTree (), unapplyRHS)
524
541
.withMods(synthetic)
525
542
}
526
- companionDefs(parent , applyMeths ::: unapplyMeth :: companionMeths )
543
+ companionDefs(companionParent , applyMeths ::: unapplyMeth :: companionMembers )
527
544
}
528
- else if (companionMeths .nonEmpty)
529
- companionDefs(anyRef, companionMeths )
545
+ else if (companionMembers .nonEmpty)
546
+ companionDefs(anyRef, companionMembers )
530
547
else if (isValueClass) {
531
548
constr0.vparamss match {
532
549
case (_ :: Nil ) :: _ => companionDefs(anyRef, Nil )
@@ -535,6 +552,13 @@ object desugar {
535
552
}
536
553
else Nil
537
554
555
+ enumCompanionRef match {
556
+ case ref : TermRefTree => // have the enum import watch the companion object
557
+ val (modVal : ValDef ) :: _ = companions
558
+ ref.watching(modVal)
559
+ case _ =>
560
+ }
561
+
538
562
// For an implicit class C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, .., pMN: TMN), the method
539
563
// synthetic implicit C[Ts](p11: T11, ..., p1N: T1N) ... (pM1: TM1, ..., pMN: TMN): C[Ts] =
540
564
// new C[Ts](p11, ..., p1N) ... (pM1, ..., pMN) =
@@ -567,7 +591,7 @@ object desugar {
567
591
}
568
592
569
593
val cdef1 = addEnumFlags {
570
- val originalTparamsIt = originalTparams .toIterator
594
+ val originalTparamsIt = impliedTparams .toIterator
571
595
val originalVparamsIt = originalVparamss.toIterator.flatten
572
596
val tparamAccessors = derivedTparams.map(_.withMods(originalTparamsIt.next().mods))
573
597
val caseAccessor = if (isCaseClass) CaseAccessor else EmptyFlags
@@ -607,7 +631,7 @@ object desugar {
607
631
val moduleName = checkNotReservedName(mdef).asTermName
608
632
val impl = mdef.impl
609
633
val mods = mdef.mods
610
- lazy val isEnumCase = isLegalEnumCase(mdef)
634
+ lazy val isEnumCase = mods.hasMod[ Mod . EnumCase ]
611
635
if (mods is Package )
612
636
PackageDef (Ident (moduleName), cpy.ModuleDef (mdef)(nme.PACKAGE , impl).withMods(mods &~ Package ) :: Nil )
613
637
else if (isEnumCase)
@@ -654,7 +678,7 @@ object desugar {
654
678
*/
655
679
def patDef (pdef : PatDef )(implicit ctx : Context ): Tree = flatTree {
656
680
val PatDef (mods, pats, tpt, rhs) = pdef
657
- if (mods.hasMod[Mod .EnumCase ] && enumCaseIsLegal(pdef) )
681
+ if (mods.hasMod[Mod .EnumCase ])
658
682
pats map {
659
683
case id : Ident =>
660
684
expandSimpleEnumCase(id.name.asTermName, mods,
0 commit comments