Skip to content

Commit 29ac3a6

Browse files
committed
Strip EnumValue parent from inferred types
We now strip EnumValue parents from inferred types, unless they are required by the bound. This is analogous to widen unions and singletons. It should be generalized to more types, not just EnumValue.
1 parent 323d228 commit 29ac3a6

File tree

3 files changed

+38
-7
lines changed

3 files changed

+38
-7
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ trait ConstraintHandling[AbstractContext] {
300300
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
301301
* 2. If `inst` is a union type, approximate the union type from above by an intersection
302302
* of all common base types, provided the result is a subtype of `bound`.
303+
* 3. If `inst` is an intersection with some protected base types, drop
304+
* the protected base types from the intersection, provided the result is a subtype of `bound`.
303305
*
304306
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
305307
* Also, if the result of these widenings is a TypeRef to a module class,
@@ -309,26 +311,43 @@ trait ConstraintHandling[AbstractContext] {
309311
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
310312
* as those could leak the annotation to users (see run/inferred-repeated-result).
311313
*/
312-
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type = {
313-
def widenOr(tp: Type) = {
314+
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type =
315+
316+
def isProtected(tp: Type) = tp.typeSymbol == defn.EnumValueClass // for now, to be generalized later
317+
318+
def dropProtected(tp: Type): Type = tp.dealias match
319+
case tp @ AndType(tp1, tp2) =>
320+
if isProtected(tp1) then tp2
321+
else if isProtected(tp2) then tp1
322+
else tp.derivedAndType(dropProtected(tp1), dropProtected(tp2))
323+
case _ =>
324+
tp
325+
326+
def widenProtected(tp: Type) =
327+
val tpw = dropProtected(tp)
328+
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
329+
330+
def widenOr(tp: Type) =
314331
val tpw = tp.widenUnion
315332
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
316-
}
317-
def widenSingle(tp: Type) = {
333+
334+
def widenSingle(tp: Type) =
318335
val tpw = tp.widenSingletons
319336
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
320-
}
337+
321338
def isSingleton(tp: Type): Boolean = tp match
322339
case WildcardType(optBounds) => optBounds.exists && isSingleton(optBounds.bounds.hi)
323340
case _ => isSubTypeWhenFrozen(tp, defn.SingletonType)
341+
324342
val wideInst =
325-
if isSingleton(bound) then inst else widenOr(widenSingle(inst))
343+
if isSingleton(bound) then inst
344+
else widenProtected(widenOr(widenSingle(inst)))
326345
wideInst match
327346
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
328347
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)
329348
case _ =>
330349
wideInst.dropRepeatedAnnot
331-
}
350+
end widenInferred
332351

333352
/** The instance type of `param` in the current constraint (which contains `param`).
334353
* If `fromBelow` is true, the instance type is the lub of the parameter's

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,7 @@ class Definitions {
639639
@tu lazy val EnumClass: ClassSymbol = ctx.requiredClass("scala.Enum")
640640
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
641641

642+
@tu lazy val EnumValueClass: ClassSymbol = ctx.requiredClass("scala.EnumValue")
642643
@tu lazy val EnumValuesClass: ClassSymbol = ctx.requiredClass("scala.runtime.EnumValues")
643644
@tu lazy val ProductClass: ClassSymbol = ctx.requiredClass("scala.Product")
644645
@tu lazy val Product_canEqual : Symbol = ProductClass.requiredMethod(nme.canEqual_)

tests/neg/enumvalues.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
enum Color:
2+
case Red, Green, Blue
3+
4+
@main def Test(c: Boolean) =
5+
// These currently give errors. But maybe we should make the actual
6+
// enum values carry the `EnumValue` trait, and only strip it from
7+
// user-defined vals and defs?
8+
val x: EnumValue = if c then Color.Red else Color.Blue // error // error
9+
val y: EnumValue = Color.Green // error
10+
11+

0 commit comments

Comments
 (0)