Skip to content

Precise apply for enum companion objects #9728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
20 changes: 4 additions & 16 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ object desugar {
cpy.ValDef(vparam)(rhs = copyDefault(vparam)))
val copyRestParamss = derivedVparamss.tail.nestedMap(vparam =>
cpy.ValDef(vparam)(rhs = EmptyTree))
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, classTypeRef, creatorExpr)
.withMods(Modifiers(Synthetic | constr1.mods.flags & copiedAccessFlags, constr1.mods.privateWithin)) :: Nil
}
}
Expand Down Expand Up @@ -656,15 +656,6 @@ object desugar {
// For all other classes, the parent is AnyRef.
val companions =
if (isCaseClass) {
// The return type of the `apply` method, and an (empty or singleton) list
// of widening coercions
val (applyResultTpt, widenDefs) =
if (!isEnumCase)
(TypeTree(), Nil)
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
(enumClassTypeRef, Nil)
else
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))

// true if access to the apply method has to be restricted
// i.e. if the case class constructor is either private or qualified private
Expand Down Expand Up @@ -695,8 +686,6 @@ object desugar {
then anyRef
else
constrVparamss.foldRight(classTypeRef)((vparams, restpe) => Function(vparams map (_.tpt), restpe))
def widenedCreatorExpr =
widenDefs.foldLeft(creatorExpr)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
val applyMeths =
if (mods.is(Abstract)) Nil
else {
Expand All @@ -709,9 +698,8 @@ object desugar {
val appParamss =
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
.withMods(appMods)
app :: widenDefs
DefDef(nme.apply, derivedTparams, appParamss, classTypeRef, creatorExpr)
.withMods(appMods) :: Nil
}
val unapplyMeth = {
val hasRepeatedParam = constrVparamss.head.exists {
Expand All @@ -720,7 +708,7 @@ object desugar {
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()
val unapplyResTp = if arity == 0 then Literal(Constant(true)) else classTypeRef
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, unapplyResTp, unapplyRHS)
.withMods(synthetic)
}
Expand Down
12 changes: 10 additions & 2 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ trait ConstraintHandling {
val tpw = tp.widenUnion
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def widenEnum(tp: Type) =
val tpw = tp.widenEnumClass
if (tpw ne tp) && (tpw <:< bound) then tpw else tp

def widenSingle(tp: Type) =
val tpw = tp.widenSingletons
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
Expand All @@ -354,9 +358,13 @@ trait ConstraintHandling {
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)

def isEnum(tp: Type): Boolean = tp match
case WildcardType(optBounds) => optBounds.exists && isEnum(optBounds.bounds.hi)
case _ => tp.typeSymbol.is(Enum, butNot=JavaDefined)

val wideInst =
if isSingleton(bound) then inst
else dropSuperTraits(widenOr(widenSingle(inst)))
if isSingleton(bound) || isEnum(bound) then inst
else dropSuperTraits(widenOr(widenEnum(widenSingle(inst))))
Copy link
Member

@smarter smarter Sep 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's still not quite right: if the upper bound is an enum, a singleton should still be widened unless it's the type of an enum case.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean remove singletons except the term ref of a singleton enum case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to get a lot trickier when unions of singletons are involved

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new steps are widenSingle > widenOr > [widenEnumCase] > dropSuperTraits, where widenOr is intercepted so that singletons of module or enum value do not widen

wideInst match
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
Expand Down
9 changes: 9 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1101,8 +1101,10 @@ object Types {

/** Widen from TermRef to its underlying non-termref
* base type, while also skipping Expr types.
* Preserves references to modules or singleton enum values
*/
final def widenTermRefExpr(using Context): Type = stripTypeVar match {
case tp: TermRef if tp.termSymbol.isAllOf(EnumCase) || tp.termSymbol.is(Module) => tp
case tp: TermRef if !tp.isOverloaded => tp.underlying.widenExpr.widenTermRefExpr
case _ => this
}
Expand Down Expand Up @@ -1173,6 +1175,13 @@ object Types {
tp
}

def widenEnumClass(using Context): Type = dealias match {
case tp: (TypeRef | AppliedType) if tp.typeSymbol.isAllOf(EnumCase) =>
tp.parents.head
case _ =>
this
}

/** Widen all top-level singletons reachable by dealiasing
* and going to the operands of & and |.
* Overridden and cached in OrType.
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Scanners.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1399,8 +1399,8 @@ object Scanners {

object IndentWidth {
private inline val MaxCached = 40
private val spaces = Array.tabulate(MaxCached + 1)(new Run(' ', _))
private val tabs = Array.tabulate(MaxCached + 1)(new Run('\t', _))
private val spaces = Array.tabulate[Run](MaxCached + 1)(new Run(' ', _)) // TODO: remove new after bootstrap
private val tabs = Array.tabulate[Run](MaxCached + 1)(new Run('\t', _)) // TODO: remove new after bootstrap

def Run(ch: Char, n: Int): Run =
if (n <= MaxCached && ch == ' ') spaces(n)
Expand Down
3 changes: 1 addition & 2 deletions compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
toTextRHS(tp)
case tp: TermRef
if !tp.denotationIsCurrent && !homogenizedView || // always print underlying when testing picklers
tp.symbol.is(Module) || tp.symbol.name == nme.IMPORT =>
tp.symbol.is(Module) || tp.symbol.isAllOf(EnumCase) || tp.symbol.name == nme.IMPORT =>
toTextRef(tp) ~ ".type"
case tp: TermRef if tp.denot.isOverloaded =>
"<overloaded " ~ toTextRef(tp) ~ ">"
Expand Down Expand Up @@ -598,4 +598,3 @@ class PlainPrinter(_ctx: Context) extends Printer {
protected def coloredText(text: Text, color: String): Text =
if (ctx.useColors) color ~ text ~ SyntaxHighlighting.NoColor else text
}

10 changes: 10 additions & 0 deletions tests/pos/i3935.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
enum Foo3[T](x: T) {
case Bar[S, T](y: T) extends Foo3[y.type](y)
}

val foo: Foo3.Bar[Nothing, 3] = Foo3.Bar(3)
val bar = foo

def baz[T](f: Foo3[T]): f.type = f

val qux = baz(bar) // existentials are back in Dotty?
4 changes: 4 additions & 0 deletions tests/run-macros/i8007.check
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,9 @@ true

true

true

true

false

3 changes: 1 addition & 2 deletions tests/run-macros/i8007/Macro_3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ object Eq {
$ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y)
}
}

'{
eqSum((x: T, y: T) => ${eqSumBody('x, 'y)})
}
Expand All @@ -76,4 +75,4 @@ object Macro3 {
extension [T](x: =>T) inline def === (y: =>T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)

implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] }
}
}
30 changes: 25 additions & 5 deletions tests/run-macros/i8007/Test_4.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,22 @@ import Macro3.eqGen
case class Person(name: String, age: Int)

enum Opt[+T] {
case Sm(t: T)
case Sm[U](t: U) extends Opt[U]
case Nn
}

enum OptInfer[+T] {
case Sm[+U](t: U) extends OptInfer[U]
case Nn
}

// simulation of Opt using case class hierarchy
sealed abstract class OptCase[+T]
object OptCase {
final case class Sm[T](t: T) extends OptCase[T]
case object Nn extends OptCase[Nothing]
}

@main def Test() = {
import Opt._
import Eq.{given _, _}
Expand All @@ -30,15 +42,23 @@ enum Opt[+T] {
println(t4) // false
println

val t5 = Sm(23) === Sm(23)
val t5 = Opt.Sm[Int](23) === Opt.Sm(23) // same behaviour as case class when using apply
println(t5) // true
println

val t6 = Sm(Person("Test", 23)) === Sm(Person("Test", 23))
val t5_2 = OptCase.Sm[Int](23) === OptCase.Sm(23)
println(t5_2) // true
println

val t5_3 = OptInfer.Sm(23) === OptInfer.Sm(23) // covariant `Sm` case means we can avoid explicit type parameter
println(t5_3) // true
println

val t6 = Sm[Person](Person("Test", 23)) === Sm(Person("Test", 23))
println(t6) // true
println

val t7 = Sm(Person("Test", 23)) === Sm(Person("Test", 24))
val t7 = Sm[Person](Person("Test", 23)) === Sm(Person("Test", 24))
println(t7) // false
println
}
}
34 changes: 34 additions & 0 deletions tests/run/enum-nat.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import Nat._
import compiletime._

enum Nat:
case Zero
case Succ[N <: Nat](n: N)

inline def toIntTypeLevel[N <: Nat]: Int = inline erasedValue[N] match
case _: Zero.type => 0
case _: Succ[n] => toIntTypeLevel[n] + 1

inline def toInt(inline nat: Nat): Int = inline nat match
case nat: Zero.type => 0
case nat: Succ[n] => toInt(nat.n) + 1

inline def toIntUnapply(inline nat: Nat): Int = inline nat match
case Zero => 0
case Succ(n) => toIntUnapply(n) + 1

inline def toIntTypeTailRec[N <: Nat, Acc <: Int]: Int = inline erasedValue[N] match
case _: Zero.type => constValue[Acc]
case _: Succ[n] => toIntTypeTailRec[n, S[Acc]]

inline def toIntErased[N <: Nat](inline nat: N): Int = toIntTypeTailRec[N, 0]

@main def Test: Unit =
println("erased value:")
assert(toIntTypeLevel[Succ[Succ[Succ[Zero.type]]]] == 3)
println("type test:")
assert(toInt(Succ(Succ(Succ(Zero)))) == 3)
println("unapply:")
assert(toIntUnapply(Succ(Succ(Succ(Zero)))) == 3)
println("infer erased:")
assert(toIntErased(Succ(Succ(Succ(Zero)))) == 3)
31 changes: 31 additions & 0 deletions tests/run/enum-precise.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
enum NonEmptyList[+T]:
case Many[+U](head: U, tail: NonEmptyList[U]) extends NonEmptyList[U]
case One [+U](value: U) extends NonEmptyList[U]

enum Ast:
case Binding(name: String, tpe: String)
case Lambda(args: NonEmptyList[Binding], rhs: Ast) // reference to another case of the enum
case Ident(name: String)
case Apply(fn: Ast, args: NonEmptyList[Ast])

import NonEmptyList._
import Ast._

// This example showcases the widening when inferring enum case types.
// With scala 2 case class hierarchies, if One.apply(1) returns One[Int] and Many.apply(2, One(3)) returns Many[Int]
// then the `foldRight` expression below would complain that Many[Binding] is not One[Binding]. With Scala 3 enums,
// .apply on the companion returns the precise class, but type inference will widen to NonEmptyList[Binding] unless
// the precise class is expected.
def Bindings(arg: (String, String), args: (String, String)*): NonEmptyList[Binding] =
def Bind(arg: (String, String)): Binding =
val (name, tpe) = arg
Binding(name, tpe)

args.foldRight(One[Binding](Bind(arg)))((arg, acc) => Many(Bind(arg), acc))

@main def Test: Unit =
val OneOfOne: One[1] = One[1](1)
val True = Lambda(Bindings("x" -> "T", "y" -> "T"), Ident("x"))
val Const = Lambda(One(Binding("x", "T")), Lambda(One(Binding("y", "U")), Ident("x"))) // precise type is forwarded

assert(OneOfOne.value == 1)