Skip to content

Fix #9011: Make single enum values inherit from Product #9018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 22, 2020
Merged
8 changes: 5 additions & 3 deletions compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ object DesugarEnums {

/** A creation method for a value of enum type `E`, which is defined as follows:
*
* private def $new(_$ordinal: Int, $name: String) = new E {
* private def $new(_$ordinal: Int, $name: String) = new E with scala.runtime.EnumValue {
* def $ordinal = $tag
* override def toString = $name
* $values.register(this)
Expand All @@ -135,7 +135,7 @@ object DesugarEnums {
val toStringDef = toStringMeth(Ident(nme.nameDollar))
val creator = New(Template(
constr = emptyConstructor,
parents = enumClassRef :: Nil,
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
derived = Nil,
self = EmptyValDef,
body = List(ordinalDef, toStringDef) ++ registerCall
Expand Down Expand Up @@ -286,7 +286,9 @@ object DesugarEnums {
val (tag, scaffolding) = nextOrdinal(CaseKind.Object)
val ordinalDef = ordinalMethLit(tag)
val toStringDef = toStringMethLit(name.toString)
val impl1 = cpy.Template(impl)(body = List(ordinalDef, toStringDef) ++ registerCall)
val impl1 = cpy.Template(impl)(
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
body = List(ordinalDef, toStringDef) ++ registerCall)
.withAttachment(ExtendsSingletonMirror, ())
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
flatTree(scaffolding ::: vdef :: Nil).withSpan(span)
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def rootDot(name: Name)(implicit src: SourceFile): Select = Select(Ident(nme.ROOTPKG), name)
def scalaDot(name: Name)(implicit src: SourceFile): Select = Select(rootDot(nme.scala), name)
def scalaAnnotationDot(name: Name)(using SourceFile): Select = Select(scalaDot(nme.annotation), name)
def scalaRuntimeDot(name: Name)(using SourceFile): Select = Select(scalaDot(nme.runtime), name)
def scalaUnit(implicit src: SourceFile): Select = scalaDot(tpnme.Unit)
def scalaAny(implicit src: SourceFile): Select = scalaDot(tpnme.Any)
def javaDotLangDot(name: Name)(implicit src: SourceFile): Select = Select(Select(Ident(nme.java), nme.lang), name)
Expand Down
38 changes: 31 additions & 7 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ trait ConstraintHandling[AbstractContext] {
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
* 2. If `inst` is a union type, approximate the union type from above by an intersection
* of all common base types, provided the result is a subtype of `bound`.
* 3. (currently not enabled, see #9028) If `inst` is an intersection with some restricted base types, drop
* the restricted base types from the intersection, provided the result is a subtype of `bound`.
*
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
* Also, if the result of these widenings is a TypeRef to a module class,
Expand All @@ -309,26 +311,48 @@ trait ConstraintHandling[AbstractContext] {
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
* as those could leak the annotation to users (see run/inferred-repeated-result).
*/
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type = {
def widenOr(tp: Type) = {
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type =

def isRestricted(tp: Type) = tp.typeSymbol == defn.EnumValueClass // for now, to be generalized later

def dropRestricted(tp: Type): Type = tp.dealias match
case tpd @ AndType(tp1, tp2) =>
if isRestricted(tp1) then tp2
else if isRestricted(tp2) then tp1
else
val tpw = tpd.derivedAndType(dropRestricted(tp1), dropRestricted(tp2))
if tpw ne tpd then tpw else tp
case _ =>
tp

def widenRestricted(tp: Type) =
val tpw = dropRestricted(tp)
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def widenOr(tp: Type) =
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
}
def widenSingle(tp: Type) = {

def widenSingle(tp: Type) =
val tpw = tp.widenSingletons
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
}

def isSingleton(tp: Type): Boolean = tp match
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)

val wideInst =
if isSingleton(bound) then inst else widenOr(widenSingle(inst))
if isSingleton(bound) then inst
else /*widenRestricted*/(widenOr(widenSingle(inst)))
// widenRestricted is currently not called since it's special cased in `dropEnumValue`
// in `Namer`. It's left in here in case we want to generalize the scheme to other
// "protected inheritance" classes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should leave dead code like this around without a clear idea of what we want to do with it. Is this something you're planning to experiment with more? I think it would make sense to use this mechanism for getting rid of Serializable and Product so that they don't show up when lubbing case classes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #9028.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks! Can you leave a reference to this issue number in the comment?

wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
case _ =>
wideInst.dropRepeatedAnnot
}
end widenInferred

/** The instance type of `param` in the current constraint (which contains `param`).
* If `fromBelow` is true, the instance type is the lub of the parameter's
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ class Definitions {
@tu lazy val EnumClass: ClassSymbol = ctx.requiredClass("scala.Enum")
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)

@tu lazy val EnumValueClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValue")
@tu lazy val EnumValuesClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValues")
@tu lazy val ProductClass: ClassSymbol = ctx.requiredClass("scala.Product")
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ object Flags {
* TODO: Should check that FromStartFlags do not change in completion
*/
val FromStartFlags: FlagSet = commonFlags(
Module, Package, Deferred, Method, Case,
Module, Package, Deferred, Method, Case, Enum,
HigherKinded, Param, ParamAccessor,
Scala2SpecialFlags, MutableOrOpen, Opaque, Touched, JavaStatic,
OuterOrCovariant, LabelOrContravariant, CaseAccessor,
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -357,15 +357,14 @@ object StdNames {
val CAP: N = "CAP"
val Constant: N = "Constant"
val ConstantType: N = "ConstantType"
val doubleHash: N = "doubleHash"
val EnumValue: N = "EnumValue"
val ExistentialTypeTree: N = "ExistentialTypeTree"
val Flag : N = "Flag"
val floatHash: N = "floatHash"
val Ident: N = "Ident"
val Import: N = "Import"
val Literal: N = "Literal"
val LiteralAnnotArg: N = "LiteralAnnotArg"
val longHash: N = "longHash"
val MatchCase: N = "MatchCase"
val MirroredElemTypes: N = "MirroredElemTypes"
val MirroredElemLabels: N = "MirroredElemLabels"
Expand Down Expand Up @@ -443,6 +442,7 @@ object StdNames {
val delayedInitArg: N = "delayedInit$body"
val derived: N = "derived"
val derives: N = "derives"
val doubleHash: N = "doubleHash"
val drop: N = "drop"
val dynamics: N = "dynamics"
val elem: N = "elem"
Expand Down Expand Up @@ -505,6 +505,7 @@ object StdNames {
val language: N = "language"
val length: N = "length"
val lengthCompare: N = "lengthCompare"
val longHash: N = "longHash"
val macroThis : N = "_this"
val macroContext : N = "c"
val main: N = "main"
Expand Down
17 changes: 16 additions & 1 deletion compiler/src/dotty/tools/dotc/typer/Namer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1439,6 +1439,19 @@ class Namer { typer: Typer =>
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
def isInlineVal = sym.isOneOf(FinalOrInline, butNot = Method | Mutable)

def isEnumValue(tp: Type) = tp.typeSymbol == defn.EnumValueClass

// Drop EnumValue parents from inferred types of enum constants
def dropEnumValue(tp: Type): Type = tp.dealias match
case tpd @ AndType(tp1, tp2) =>
if isEnumValue(tp1) then tp2
else if isEnumValue(tp2) then tp1
else
val tpw = tpd.derivedAndType(dropEnumValue(tp1), dropEnumValue(tp2))
if tpw ne tpd then tpw else tp
case _ =>
tp

// Widen rhs type and eliminate `|' but keep ConstantTypes if
// definition is inline (i.e. final in Scala2) and keep module singleton types
// instead of widening to the underlying module class types.
Expand All @@ -1447,7 +1460,9 @@ class Namer { typer: Typer =>
def widenRhs(tp: Type): Type =
tp.widenTermRefExpr.simplified match
case ctp: ConstantType if isInlineVal => ctp
case tp => ctx.typeComparer.widenInferred(tp, rhsProto)
case tp =>
val tp1 = ctx.typeComparer.widenInferred(tp, rhsProto)
if sym.is(Enum) then dropEnumValue(tp1) else tp1

// Replace aliases to Unit by Unit itself. If we leave the alias in
// it would be erased to BoxedUnit.
Expand Down
5 changes: 4 additions & 1 deletion docs/docs/reference/enums/desugarEnums.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ map into `case class`es or `val`s.
where `n` is the ordinal number of the case in the companion object,
starting from 0. The statement `$values.register(this)` registers the value
as one of the `values` of the enumeration (see below). `$values` is a
compiler-defined private value in the companion object.
compiler-defined private value in the companion object. The anonymous class also
implements the abstract `Product` methods that it inherits from `Enum`.


It is an error if a value case refers to a type parameter of the enclosing `enum`
in a type argument of `<parents>`.
Expand Down Expand Up @@ -178,6 +180,7 @@ Companion objects of enumerations that contain at least one simple case define i
}
```

The anonymous class also implements the abstract `Product` methods that it inherits from `Enum`.
The `$ordinal` method above is used to generate the `ordinal` method if the enum does not extend a `java.lang.Enum` (as Scala enums do not extend `java.lang.Enum`s unless explicitly specified). In case it does, there is no need to generate `ordinal` as `java.lang.Enum` defines it.

### Scopes for Enum Cases
Expand Down
4 changes: 2 additions & 2 deletions docs/docs/reference/enums/enums.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ If you want to use the Scala-defined enums as Java enums, you can do so by exten
enum Color extends java.lang.Enum[Color] { case Red, Green, Blue }
```

The type parameter comes from the Java enum [definition](https://docs.oracle.com/javase/8/docs/api/index.html?java/lang/Enum.html) and should be the same as the type of the enum.
The type parameter comes from the Java enum [definition](https://docs.oracle.com/javase/8/docs/api/index.html?java/lang/Enum.html) and should be the same as the type of the enum.
There is no need to provide constructor arguments (as defined in the Java API docs) to `java.lang.Enum` when extending it – the compiler will generate them automatically.

After defining `Color` like that, you can use it like you would a Java enum:
Expand All @@ -116,7 +116,7 @@ This trait defines a single public method, `ordinal`:
package scala

/** A base trait of all enum classes */
trait Enum {
trait Enum extends Product with Serializable {

/** A number uniquely identifying a case of an enum */
def ordinal: Int
Expand Down
9 changes: 9 additions & 0 deletions library/src-bootstrapped/scala/Enum.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package scala

/** A base trait of all enum classes */
trait Enum extends Product, Serializable:

/** A number uniquely identifying a case of an enum */
def ordinal: Int
protected def $ordinal: Int

Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
package scala

/** A base trait of all enum classes */
trait Enum {
trait Enum:

/** A number uniquely identifying a case of an enum */
def ordinal: Int
protected def $ordinal: Int
}
10 changes: 10 additions & 0 deletions library/src/scala/runtime/EnumValue.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package scala.runtime

trait EnumValue extends Product, Serializable:
override def canEqual(that: Any) = this eq that.asInstanceOf[AnyRef]
override def productArity: Int = 0
override def productPrefix: String = toString
override def productElement(n: Int): Any =
throw IndexOutOfBoundsException(n.toString)
override def productElementName(n: Int): String =
throw IndexOutOfBoundsException(n.toString)
2 changes: 1 addition & 1 deletion tests/fuzzy/b82054893e0db44e31ae82d696c19c1fbc7be55c.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
object main {
object _main {
def i0 = {
class i1 {
private[i0] var i2: _ > 0 private def i3: List[Int]
Expand Down
15 changes: 15 additions & 0 deletions tests/neg/enumvalues.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
enum Color:
case Red, Green, Blue

enum Option[+T]:
case None extends Option[Nothing]

import scala.runtime.EnumValue

@main def Test(c: Boolean) =
// Verify that enum constants don't leak the scala.runtime.EnumValue trait
val x: EnumValue = if c then Color.Red else Color.Blue // error // error
val y: EnumValue = Color.Green // error
val z: EnumValue = Option.None // error


7 changes: 6 additions & 1 deletion tests/pos/enum-List-control.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@ abstract sealed class List[T] extends Enum
object List {
final class Cons[T](x: T, xs: List[T]) extends List[T] {
def $ordinal = 0
def canEqual(that: Any): Boolean = that.isInstanceOf[Cons[_]]
def productArity: Int = 2
def productElement(n: Int): Any = n match
case 0 => x
case 1 => xs
}
object Cons {
def apply[T](x: T, xs: List[T]): List[T] = new Cons(x, xs)
}
final class Nil[T]() extends List[T] {
final class Nil[T]() extends List[T], runtime.EnumValue {
def $ordinal = 1
}
object Nil {
Expand Down
2 changes: 1 addition & 1 deletion tests/pos/localmodules.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package test;

object main {
object _main {

class a {

Expand Down
2 changes: 1 addition & 1 deletion tests/pos/t0002.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
object main {
object _main {
def main(args: Array[String]) = {
val b = true;
while (b == true) { }
Expand Down
2 changes: 1 addition & 1 deletion tests/pos/t789.scala
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
object main { // don't do this at home
object _main { // don't do this at home

trait Impl

Expand Down
2 changes: 1 addition & 1 deletion tests/pos/typealiases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ trait Test[T] {
def check2[S](xs: Array[S], c: Check[S]) = c(xs)
}

object main extends Test[Int] {
object _main extends Test[Int] {
val pair1 = (1,1)

implicit def topair(x: Int): Tuple2[Int, Int] = (x,x)
Expand Down
57 changes: 57 additions & 0 deletions tests/run/i9011.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
enum Opt[+T] derives Eq:
case Sm(t: T)
case Nn

import scala.deriving._
import scala.compiletime.{erasedValue, summonInline}

trait Eq[T] {
def eqv(x: T, y: T): Boolean
}

object Eq {
given Eq[Int] {
def eqv(x: Int, y: Int) = x == y
}

inline def summonAll[T <: Tuple]: List[Eq[_]] = inline erasedValue[T] match {
case _: Unit => Nil
case _: (t *: ts) => summonInline[Eq[t]] :: summonAll[ts]
}

def check(elem: Eq[_])(x: Any, y: Any): Boolean =
elem.asInstanceOf[Eq[Any]].eqv(x, y)

def iterator[T](p: T) = p.asInstanceOf[Product].productIterator
Copy link
Member

@smarter smarter May 22, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going back to the original issue (#9011), is this cast actually safe?

The example in the docs suggests that if we have a Mirror.ProductOf[T] then we can safely cast an instance of T to Product (in order to access productIterator etc.). But this is not the case.

If this isn't actually guaranteed, then the documentation needs to be updated to not rely on this cast in an example.


def eqSum[T](s: Mirror.SumOf[T], elems: List[Eq[_]]): Eq[T] =
new Eq[T] {
def eqv(x: T, y: T): Boolean = {
val ordx = s.ordinal(x)
(s.ordinal(y) == ordx) && check(elems(ordx))(x, y)
}
}

def eqProduct[T](p: Mirror.ProductOf[T], elems: List[Eq[_]]): Eq[T] =
new Eq[T] {
def eqv(x: T, y: T): Boolean =
iterator(x).zip(iterator(y)).zip(elems.iterator).forall {
case ((x, y), elem) => check(elem)(x, y)
}
}

inline given derived[T](using m: Mirror.Of[T]) as Eq[T] = {
val elemInstances = summonAll[m.MirroredElemTypes]
inline m match {
case s: Mirror.SumOf[T] => eqSum(s, elemInstances)
case p: Mirror.ProductOf[T] => eqProduct(p, elemInstances)
}
}
}

object Test extends App {
import Opt._
val eqoi = summon[Eq[Opt[Int]]]
assert(eqoi.eqv(Sm(23), Sm(23)))
assert(eqoi.eqv(Nn, Nn))
}
Loading