Skip to content

Commit 5ad4de1

Browse files
committed
Make enums implement Eq
1 parent 6e4c3b7 commit 5ad4de1

File tree

4 files changed

+66
-16
lines changed

4 files changed

+66
-16
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ object desugar {
7272
val defctx = ctx.outersIterator.dropWhile(_.scope eq ctx.scope).next
7373
var local = defctx.denotNamed(tp.name).suchThat(_ is ParamOrAccessor).symbol
7474
if (local.exists) (defctx.owner.thisType select local).dealias
75-
else throw new java.lang.Error(
76-
s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope}"
77-
)
75+
else {
76+
def msg =
77+
s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope}"
78+
if (ctx.reporter.errorsReported) new ErrorType(msg)
79+
else throw new java.lang.Error(msg)
80+
}
7881
case _ =>
7982
mapOver(tp)
8083
}
@@ -124,7 +127,7 @@ object desugar {
124127
else vdef
125128
}
126129

127-
def makeImplicitParameters(tpts: List[Tree], forPrimaryConstructor: Boolean)(implicit ctx: Context) =
130+
def makeImplicitParameters(tpts: List[Tree], forPrimaryConstructor: Boolean = false)(implicit ctx: Context) =
128131
for (tpt <- tpts) yield {
129132
val paramFlags: FlagSet = if (forPrimaryConstructor) PrivateLocalParamAccessor else Param
130133
val epname = EvidenceParamName.fresh()
@@ -265,7 +268,7 @@ object desugar {
265268
val mods = cdef.mods
266269
val companionMods = mods
267270
.withFlags((mods.flags & AccessFlags).toCommonFlags)
268-
.withMods(mods.mods.filter(!_.isInstanceOf[Mod.EnumCase]))
271+
.withMods(Nil)
269272

270273
val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true) match {
271274
case meth: DefDef => (meth, Nil)
@@ -291,7 +294,7 @@ object desugar {
291294

292295
val isCaseClass = mods.is(Case) && !mods.is(Module)
293296
val isCaseObject = mods.is(Case) && mods.is(Module)
294-
val isEnum = mods.hasMod[Mod.Enum]
297+
val isEnum = mods.hasMod[Mod.Enum] && !mods.is(Module)
295298
val isEnumCase = isLegalEnumCase(cdef)
296299
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
297300
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
@@ -326,10 +329,12 @@ object desugar {
326329

327330
val classTycon: Tree = new TypeRefTree // watching is set at end of method
328331

329-
def appliedRef(tycon: Tree) =
330-
(if (constrTparams.isEmpty) tycon
331-
else AppliedTypeTree(tycon, constrTparams map refOfDef))
332-
.withPos(cdef.pos.startPos)
332+
def appliedTypeTree(tycon: Tree, args: List[Tree]) =
333+
(if (args.isEmpty) tycon else AppliedTypeTree(tycon, args))
334+
.withPos(cdef.pos.startPos)
335+
336+
def appliedRef(tycon: Tree, tparams: List[TypeDef] = constrTparams) =
337+
appliedTypeTree(tycon, tparams map refOfDef)
333338

334339
// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
335340
val classTypeRef = appliedRef(classTycon)
@@ -344,8 +349,7 @@ object desugar {
344349
else {
345350
ctx.error(i"explicit extends clause needed because type parameters of case and enum class differ"
346351
, cdef.pos.startPos)
347-
AppliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
348-
.withPos(cdef.pos.startPos)
352+
appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
349353
}
350354
case _ =>
351355
enumClassRef
@@ -411,6 +415,31 @@ object desugar {
411415
if (isEnum)
412416
parents1 = parents1 :+ ref(defn.EnumType)
413417

418+
// The Eq instance for an Enum class. For an enum class
419+
//
420+
// enum class C[T1, ..., Tn]
421+
//
422+
// we generate:
423+
//
424+
// implicit def eqInstance[T1$1, ..., Tn$1, T1$2, ..., Tn$2](implicit
425+
// ev1: Eq[T1$1, T1$2], ..., evn: Eq[Tn$1, Tn$2]])
426+
// : Eq[C[T1$1, ..., Tn$1], C[T1$2, ..., Tn$2]] = Eq
427+
def eqInstance = {
428+
def append(tdef: TypeDef, str: String) = cpy.TypeDef(tdef)(name = tdef.name ++ str)
429+
val leftParams = derivedTparams.map(append(_, "$1"))
430+
val rightParams = derivedTparams.map(append(_, "$2"))
431+
val subInstances = (leftParams, rightParams).zipped.map((param1, param2) =>
432+
appliedRef(ref(defn.EqType), List(param1, param2)))
433+
DefDef(
434+
name = nme.eqInstance,
435+
tparams = leftParams ++ rightParams,
436+
vparamss = List(makeImplicitParameters(subInstances)),
437+
tpt = appliedTypeTree(ref(defn.EqType),
438+
appliedRef(classTycon, leftParams) :: appliedRef(classTycon, rightParams) :: Nil),
439+
rhs = ref(defn.EqModule.termRef)).withFlags(Synthetic | Implicit)
440+
}
441+
def eqInstances = if (isEnum) eqInstance :: Nil else Nil
442+
414443
// The thicket which is the desugared version of the companion object
415444
// synthetic object C extends parentTpt { defs }
416445
def companionDefs(parentTpt: Tree, defs: List[Tree]) =
@@ -420,6 +449,8 @@ object desugar {
420449
.withMods(companionMods | Synthetic))
421450
.withPos(cdef.pos).toList
422451

452+
val companionMeths = defaultGetters ::: eqInstances
453+
423454
// The companion object definitions, if a companion is needed, Nil otherwise.
424455
// companion definitions include:
425456
// 1. If class is a case class case class C[Ts](p1: T1, ..., pN: TN)(moreParams):
@@ -465,10 +496,10 @@ object desugar {
465496
DefDef(nme.unapply, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
466497
.withMods(synthetic)
467498
}
468-
companionDefs(parent, applyMeths ::: unapplyMeth :: defaultGetters)
499+
companionDefs(parent, applyMeths ::: unapplyMeth :: companionMeths)
469500
}
470-
else if (defaultGetters.nonEmpty)
471-
companionDefs(anyRef, defaultGetters)
501+
else if (companionMeths.nonEmpty)
502+
companionDefs(anyRef, companionMeths)
472503
else if (isValueClass) {
473504
constr0.vparamss match {
474505
case List(_ :: Nil) => companionDefs(anyRef, Nil)
@@ -739,7 +770,7 @@ object desugar {
739770
}
740771

741772
def makeImplicitFunction(formals: List[Type], body: Tree)(implicit ctx: Context): Tree = {
742-
val params = makeImplicitParameters(formals.map(TypeTree), forPrimaryConstructor = false)
773+
val params = makeImplicitParameters(formals.map(TypeTree))
743774
new ImplicitFunction(params, body)
744775
}
745776

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ object StdNames {
400400
val ensureAccessible : N = "ensureAccessible"
401401
val enumTag: N = "enumTag"
402402
val eq: N = "eq"
403+
val eqInstance: N = "eqInstance"
403404
val equalsNumChar : N = "equalsNumChar"
404405
val equalsNumNum : N = "equalsNumNum"
405406
val equalsNumObject : N = "equalsNumObject"

tests/neg/enums.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
package enums
2+
13
enum List[+T] {
24
case Cons[T](x: T, xs: List[T]) // ok
35
case Snoc[U](xs: List[U], x: U) // error: different type parameters
@@ -18,3 +20,18 @@ enum E2[+T, +U >: T] {
1820
enum E3[-T <: Ordered[T]] {
1921
case C // error: cannot determine type argument
2022
}
23+
24+
enum Option[+T] {
25+
case Some[T](x: T)
26+
case None
27+
}
28+
29+
object Test {
30+
31+
class Unrelated
32+
33+
val x: Option[Int] = Option.Some(1)
34+
x == new Unrelated // error: cannot compare
35+
36+
}
37+

tests/run/enum-Option.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@ object Test {
1616
def main(args: Array[String]) = {
1717
assert(Some(None).isDefined)
1818
Option("22") match { case Option.Some(x) => assert(x == "22") }
19+
assert(Some(None) != None)
1920
}
2021
}

0 commit comments

Comments
 (0)