Skip to content

Commit 371d8f5

Browse files
authored
Merge pull request #9932 from dotty-staging/change-enum-apply
An alternative scheme for precise apply methods of enums
2 parents 19b60fb + fe1355a commit 371d8f5

File tree

18 files changed

+211
-116
lines changed

18 files changed

+211
-116
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -658,15 +658,6 @@ object desugar {
658658
// For all other classes, the parent is AnyRef.
659659
val companions =
660660
if (isCaseClass) {
661-
// The return type of the `apply` method, and an (empty or singleton) list
662-
// of widening coercions
663-
val (applyResultTpt, widenDefs) =
664-
if (!isEnumCase)
665-
(TypeTree(), Nil)
666-
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
667-
(enumClassTypeRef, Nil)
668-
else
669-
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))
670661

671662
// true if access to the apply method has to be restricted
672663
// i.e. if the case class constructor is either private or qualified private
@@ -697,8 +688,6 @@ object desugar {
697688
then anyRef
698689
else
699690
constrVparamss.foldRight(classTypeRef)((vparams, restpe) => Function(vparams map (_.tpt), restpe))
700-
def widenedCreatorExpr =
701-
widenDefs.foldLeft(creatorExpr)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
702691
val applyMeths =
703692
if (mods.is(Abstract)) Nil
704693
else {
@@ -711,9 +700,8 @@ object desugar {
711700
val appParamss =
712701
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
713702
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
714-
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
715-
.withMods(appMods)
716-
app :: widenDefs
703+
DefDef(nme.apply, derivedTparams, appParamss, TypeTree(), creatorExpr)
704+
.withMods(appMods) :: Nil
717705
}
718706
val unapplyMeth = {
719707
val hasRepeatedParam = constrVparamss.head.exists {

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -201,52 +201,6 @@ object DesugarEnums {
201201
TypeTree(), creator).withFlags(Private | Synthetic)
202202
}
203203

204-
/** The return type of an enum case apply method and any widening methods in which
205-
* the apply's right hand side will be wrapped. For parents of the form
206-
*
207-
* extends E(args) with T1(args1) with ... TN(argsN)
208-
*
209-
* and type parameters `tparams` the generated widen method is
210-
*
211-
* def C$to$E[tparams](x$1: E[tparams] with T1 with ... TN) = x$1
212-
*
213-
* @param cdef The case definition
214-
* @param parents The declared parents of the enum case
215-
* @param tparams The type parameters of the enum case
216-
* @param appliedEnumRef The enum class applied to `tparams`.
217-
*/
218-
def enumApplyResult(
219-
cdef: TypeDef,
220-
parents: List[Tree],
221-
tparams: List[TypeDef],
222-
appliedEnumRef: Tree)(using Context): (Tree, List[DefDef]) = {
223-
224-
def extractType(t: Tree): Tree = t match {
225-
case Apply(t1, _) => extractType(t1)
226-
case TypeApply(t1, ts) => AppliedTypeTree(extractType(t1), ts)
227-
case Select(t1, nme.CONSTRUCTOR) => extractType(t1)
228-
case New(t1) => t1
229-
case t1 => t1
230-
}
231-
232-
val parentTypes = parents.map(extractType)
233-
parentTypes.head match {
234-
case parent: RefTree if parent.name == enumClass.name =>
235-
// need a widen method to compute correct type parameters for enum base class
236-
val widenParamType = parentTypes.tail.foldLeft(appliedEnumRef)(makeAndType)
237-
val widenParam = makeSyntheticParameter(tpt = widenParamType)
238-
val widenDef = DefDef(
239-
name = s"${cdef.name}$$to$$${enumClass.name}".toTermName,
240-
tparams = tparams,
241-
vparamss = (widenParam :: Nil) :: Nil,
242-
tpt = TypeTree(),
243-
rhs = Ident(widenParam.name))
244-
(TypeTree(), widenDef :: Nil)
245-
case _ =>
246-
(parentTypes.reduceLeft(makeAndType), Nil)
247-
}
248-
}
249-
250204
/** Is a type parameter in `enumTypeParams` referenced from an enum class case that has
251205
* given type parameters `caseTypeParams`, value parameters `vparamss` and parents `parents`?
252206
* Issues an error if that is the case but the reference is illegal.

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -286,18 +286,53 @@ trait ConstraintHandling {
286286
}
287287
}
288288

289+
/** If `tp` is an intersection such that some operands are super trait instances
290+
* and others are not, replace as many super trait instances as possible with Any
291+
* as long as the result is still a subtype of `bound`. But fall back to the
292+
* original type if the resulting widened type is a supertype of all dropped
293+
* types (since in this case the type was not a true intersection of super traits
294+
* and other types to start with).
295+
*/
296+
def dropSuperTraits(tp: Type, bound: Type)(using Context): Type =
297+
var kept: Set[Type] = Set() // types to keep since otherwise bound would not fit
298+
var dropped: List[Type] = List() // the types dropped so far, last one on top
299+
300+
def dropOneSuperTrait(tp: Type): Type =
301+
val tpd = tp.dealias
302+
if tpd.typeSymbol.isSuperTrait && !tpd.isLambdaSub && !kept.contains(tpd) then
303+
dropped = tpd :: dropped
304+
defn.AnyType
305+
else tpd match
306+
case AndType(tp1, tp2) =>
307+
val tp1w = dropOneSuperTrait(tp1)
308+
if tp1w ne tp1 then tp1w & tp2
309+
else
310+
val tp2w = dropOneSuperTrait(tp2)
311+
if tp2w ne tp2 then tp1 & tp2w
312+
else tpd
313+
case _ =>
314+
tp
315+
316+
def recur(tp: Type): Type =
317+
val tpw = dropOneSuperTrait(tp)
318+
if tpw eq tp then tp
319+
else if tpw <:< bound then recur(tpw)
320+
else
321+
kept += dropped.head
322+
dropped = dropped.tail
323+
recur(tp)
324+
325+
val tpw = recur(tp)
326+
if (tpw eq tp) || dropped.forall(_ frozen_<:< tpw) then tp else tpw
327+
end dropSuperTraits
328+
289329
/** Widen inferred type `inst` with upper `bound`, according to the following rules:
290330
* 1. If `inst` is a singleton type, or a union containing some singleton types,
291331
* widen (all) the singleton type(s), provided the result is a subtype of `bound`
292332
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
293333
* 2. If `inst` is a union type, approximate the union type from above by an intersection
294334
* of all common base types, provided the result is a subtype of `bound`.
295-
* 3. If `inst` is an intersection such that some operands are super trait instances
296-
* and others are not, replace as many super trait instances as possible with Any
297-
* as long as the result is still a subtype of `bound`. But fall back to the
298-
* original type if the resulting widened type is a supertype of all dropped
299-
* types (since in this case the type was not a true intersection of super traits
300-
* and other types to start with).
335+
* 3. drop super traits from intersections (see @dropSuperTraits)
301336
*
302337
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
303338
* Also, if the result of these widenings is a TypeRef to a module class,
@@ -308,40 +343,6 @@ trait ConstraintHandling {
308343
* as those could leak the annotation to users (see run/inferred-repeated-result).
309344
*/
310345
def widenInferred(inst: Type, bound: Type)(using Context): Type =
311-
312-
def dropSuperTraits(tp: Type): Type =
313-
var kept: Set[Type] = Set() // types to keep since otherwise bound would not fit
314-
var dropped: List[Type] = List() // the types dropped so far, last one on top
315-
316-
def dropOneSuperTrait(tp: Type): Type =
317-
val tpd = tp.dealias
318-
if tpd.typeSymbol.isSuperTrait && !tpd.isLambdaSub && !kept.contains(tpd) then
319-
dropped = tpd :: dropped
320-
defn.AnyType
321-
else tpd match
322-
case AndType(tp1, tp2) =>
323-
val tp1w = dropOneSuperTrait(tp1)
324-
if tp1w ne tp1 then tp1w & tp2
325-
else
326-
val tp2w = dropOneSuperTrait(tp2)
327-
if tp2w ne tp2 then tp1 & tp2w
328-
else tpd
329-
case _ =>
330-
tp
331-
332-
def recur(tp: Type): Type =
333-
val tpw = dropOneSuperTrait(tp)
334-
if tpw eq tp then tp
335-
else if tpw <:< bound then recur(tpw)
336-
else
337-
kept += dropped.head
338-
dropped = dropped.tail
339-
recur(tp)
340-
341-
val tpw = recur(tp)
342-
if (tpw eq tp) || dropped.forall(_ frozen_<:< tpw) then tp else tpw
343-
end dropSuperTraits
344-
345346
def widenOr(tp: Type) =
346347
val tpw = tp.widenUnion
347348
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
@@ -356,7 +357,7 @@ trait ConstraintHandling {
356357

357358
val wideInst =
358359
if isSingleton(bound) then inst
359-
else dropSuperTraits(widenOr(widenSingle(inst)))
360+
else dropSuperTraits(widenOr(widenSingle(inst)), bound)
360361
wideInst match
361362
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
362363
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2625,6 +2625,9 @@ object TypeComparer {
26252625
def widenInferred(inst: Type, bound: Type)(using Context): Type =
26262626
comparing(_.widenInferred(inst, bound))
26272627

2628+
def dropSuperTraits(tp: Type, bound: Type)(using Context): Type =
2629+
comparing(_.dropSuperTraits(tp, bound))
2630+
26282631
def constrainPatternType(pat: Type, scrut: Type)(using Context): Boolean =
26292632
comparing(_.constrainPatternType(pat, scrut))
26302633

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ object SymUtils {
160160
def isField(using Context): Boolean =
161161
self.isTerm && !self.is(Method)
162162

163+
def isEnumCase(using Context): Boolean =
164+
self.isAllOf(EnumCase, butNot = JavaDefined)
165+
163166
def annotationsCarrying(meta: ClassSymbol)(using Context): List[Annotation] =
164167
self.annotations.filter(_.symbol.hasAnnotation(meta))
165168

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import ProtoTypes._
2525
import Inferencing._
2626
import reporting._
2727
import transform.TypeUtils._
28+
import transform.SymUtils._
2829
import Nullables.{postProcessByNameArgs, given _}
2930
import config.Feature
3031

@@ -891,7 +892,9 @@ trait Applications extends Compatibility {
891892
case funRef: TermRef =>
892893
val app = ApplyTo(tree, fun1, funRef, proto, pt)
893894
convertNewGenericArray(
894-
postProcessByNameArgs(funRef, app).computeNullable())
895+
widenEnumCase(
896+
postProcessByNameArgs(funRef, app).computeNullable(),
897+
pt))
895898
case _ =>
896899
handleUnexpectedFunType(tree, fun1)
897900
}
@@ -1091,7 +1094,7 @@ trait Applications extends Compatibility {
10911094
* It is performed during typer as creation of generic arrays needs a classTag.
10921095
* we rely on implicit search to find one.
10931096
*/
1094-
def convertNewGenericArray(tree: Tree)(using Context): Tree = tree match {
1097+
def convertNewGenericArray(tree: Tree)(using Context): Tree = tree match {
10951098
case Apply(TypeApply(tycon, targs@(targ :: Nil)), args) if tycon.symbol == defn.ArrayConstructor =>
10961099
fullyDefinedType(tree.tpe, "array", tree.span)
10971100

@@ -1107,6 +1110,28 @@ trait Applications extends Compatibility {
11071110
tree
11081111
}
11091112

1113+
/** If `tree` is a complete application of a compiler-generated `apply`
1114+
* or `copy` method of an enum case, widen its type to the underlying
1115+
* type by means of a type ascription, as long as the widened type is
1116+
* still compatible with the expected type.
1117+
* The underlying type is the intersection of all class parents of the
1118+
* orginal type.
1119+
*/
1120+
def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree =
1121+
val sym = tree.symbol
1122+
def isEnumCopy = sym.name == nme.copy && sym.owner.isEnumCase
1123+
def isEnumApply = sym.name == nme.apply && sym.owner.linkedClass.isEnumCase
1124+
if sym.is(Synthetic) && (isEnumApply || isEnumCopy)
1125+
&& tree.tpe.classSymbol.isEnumCase
1126+
&& tree.tpe.widen.isValueType
1127+
then
1128+
val widened = TypeComparer.dropSuperTraits(
1129+
tree.tpe.parents.reduceLeft(TypeComparer.andType(_, _)),
1130+
pt)
1131+
if widened <:< pt then Typed(tree, TypeTree(widened))
1132+
else tree
1133+
else tree
1134+
11101135
/** Does `state` contain a "NotAMember" or "MissingIdent" message as
11111136
* first pending error message? That message would be
11121137
* `$memberName is not a member of ...` or `Not found: $memberName`.

compiler/src/dotty/tools/dotc/typer/ReTyper.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ class ReTyper extends Typer with ReChecking {
132132
override def inferView(from: Tree, to: Type)(using Context): Implicits.SearchResult =
133133
Implicits.NoMatchingImplicitsFailure
134134
override def checkCanEqual(ltp: Type, rtp: Type, span: Span)(using Context): Unit = ()
135+
136+
override def widenEnumCase(tree: Tree, pt: Type)(using Context): Tree = tree
137+
135138
override protected def addAccessorDefs(cls: Symbol, body: List[Tree])(using Context): List[Tree] = body
136139
override protected def checkEqualityEvidence(tree: tpd.Tree, pt: Type)(using Context): Unit = ()
137140
override protected def matchingApply(methType: MethodOrPoly, pt: FunProto)(using Context): Boolean = true

docs/docs/reference/enums/adts.md

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,13 @@ scala> Option.None
4848
val res2: t2.Option[Nothing] = None
4949
```
5050

51-
Note that the type of the expressions above is always `Option`. That
52-
is, the implementation case classes are not visible in the result
53-
types of their `apply` methods. This is a subtle difference with
54-
respect to normal case classes. The classes making up the cases do
55-
exist, and can be unveiled by constructing them directly with a `new`.
51+
Note that the type of the expressions above is always `Option`. Generally, the type of a enum case constructor application will be widened to the underlying enum type, unless a more specific type is expected. This is a subtle difference with respect to normal case classes. The classes making up the cases do exist, and can be unveiled, either by constructing them directly with a `new`, or by explicitly providing an expected type.
5652

5753
```scala
5854
scala> new Option.Some(2)
59-
val res3: t2.Option.Some[Int] = Some(2)
55+
val res3: Option.Some[Int] = Some(2)
56+
scala> val x: Option.Some[Int] = Option.Some(3)
57+
val res4: Option.Some[Int] = Some(3)
6058
```
6159

6260
As all other enums, ADTs can define methods. For instance, here is `Option` again, with an

docs/docs/reference/enums/desugarEnums.md

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,7 @@ map into `case class`es or `val`s.
139139
```scala
140140
final case class C <params> extends <parents>
141141
```
142-
However, unlike for a regular case class, the return type of the associated
143-
`apply` method is a fully parameterized type instance of the enum class `E`
144-
itself instead of `C`. Also the enum case defines an `ordinal` method of
145-
the form
142+
The enum case defines an `ordinal` method of the form
146143
```scala
147144
def ordinal = n
148145
```
@@ -153,6 +150,14 @@ map into `case class`es or `val`s.
153150
in a parameter type in `<params>` or in a type argument of `<parents>`, unless that parameter is already
154151
a type parameter of the case, i.e. the parameter name is defined in `<params>`.
155152

153+
The compiler-generated `apply` and `copy` methods of an enum case
154+
```scala
155+
case C(ps) extends P1, ..., Pn
156+
```
157+
are treated specially. A call `C(ts)` of the apply method is ascribed the underlying type
158+
`P1 & ... & Pn` (dropping any [super traits](../other-new-features/super-traits.html))
159+
as long as that type is still compatible with the expected type at the point of application.
160+
A call `t.copy(ts)` of `C`'s `copy` method is treated in the same way.
156161

157162
### Translation of Enums with Singleton Cases
158163

tests/pos/enum-widen.scala

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
object test:
2+
3+
enum Option[+T]:
4+
case Some[T](x: T) extends Option[T]
5+
case None
6+
7+
import Option._
8+
9+
var x = Some(1)
10+
val y: Some[Int] = Some(2)
11+
var xc = y.copy(3)
12+
val yc: Some[Int] = y.copy(3)
13+
x = None
14+
xc = None
15+
16+
enum Nat:
17+
case Z
18+
case S[N <: Z.type | S[_]](pred: N)
19+
import Nat._
20+
21+
val two = S(S(Z))

tests/pos/i3935.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
enum Foo3[T](x: T) {
2+
case Bar[S, T](y: T) extends Foo3[y.type](y)
3+
}
4+
5+
val foo: Foo3.Bar[Nothing, 3] = Foo3.Bar(3)
6+
val bar = foo
7+
8+
def baz[T](f: Foo3[T]): f.type = f
9+
10+
val qux = baz(bar) // existentials are back in Dotty?
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import Nat._
2+
3+
inline def toIntMacro(inline nat: Nat): Int = ${ Macros.toIntImpl('nat) }
4+
inline def ZeroMacro: Zero.type = ${ Macros.natZero }
5+
transparent inline def toNatMacro(inline int: Int): Nat = ${ Macros.toNatImpl('int) }
6+
7+
object Macros:
8+
import quoted._
9+
10+
def toIntImpl(nat: Expr[Nat])(using QuoteContext): Expr[Int] =
11+
12+
def inner(nat: Expr[Nat], acc: Int): Int = nat match
13+
case '{ Succ($nat) } => inner(nat, acc + 1)
14+
case '{ Zero } => acc
15+
16+
Expr(inner(nat, 0))
17+
18+
def natZero(using QuoteContext): Expr[Nat.Zero.type] = '{Zero}
19+
20+
def toNatImpl(int: Expr[Int])(using QuoteContext): Expr[Nat] =
21+
22+
// it seems even with the bound that the arg will always widen to Expr[Nat] unless explicit
23+
24+
def inner[N <: Nat: Type](int: Int, acc: Expr[N]): Expr[Nat] = int match
25+
case 0 => acc
26+
case n => inner[Succ[N]](n - 1, '{Succ($acc)})
27+
28+
val Const(i) = int
29+
require(i >= 0)
30+
inner[Zero.type](i, '{Zero})

0 commit comments

Comments
 (0)