Skip to content

Commit 54929c2

Browse files
authored
Merge pull request #12654 from dotty-staging/soft-unions-2
Preserve hard unions in more situations
2 parents b955f9a + ace235c commit 54929c2

File tree

3 files changed

+38
-26
lines changed

3 files changed

+38
-26
lines changed

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
309309
val r1 = recur(tp.tp1, fromBelow)
310310
val r2 = recur(tp.tp2, fromBelow)
311311
if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp
312-
else if tp.isAnd then r1 & r2
313-
else r1 | r2
312+
else tp.match
313+
case tp: OrType =>
314+
TypeComparer.lub(r1, r2, isSoft = tp.isSoft)
315+
case _ =>
316+
r1 & r2
314317
case tp: TypeParamRef =>
315318
if tp eq param then
316319
if fromBelow then defn.NothingType else defn.AnyType

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,12 +2018,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20182018
val tp2a = dropIfSuper(tp2, tp1)
20192019
if tp2a ne tp2 then glb(tp1, tp2a)
20202020
else tp2 match // normalize to disjunctive normal form if possible.
2021-
case OrType(tp21, tp22) =>
2022-
tp1 & tp21 | tp1 & tp22
2021+
case tp2 @ OrType(tp21, tp22) =>
2022+
lub(tp1 & tp21, tp1 & tp22, isSoft = tp2.isSoft)
20232023
case _ =>
20242024
tp1 match
2025-
case OrType(tp11, tp12) =>
2026-
tp11 & tp2 | tp12 & tp2
2025+
case tp1 @ OrType(tp11, tp12) =>
2026+
lub(tp11 & tp2, tp12 & tp2, isSoft = tp1.isSoft)
20272027
case tp1: ConstantType =>
20282028
tp2 match
20292029
case tp2: ConstantType =>
@@ -2045,9 +2045,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20452045

20462046
/** The least upper bound of two types
20472047
* @param canConstrain If true, new constraints might be added to simplify the lub.
2048+
* @param isSoft If the lub is a union, this determines whether it's a soft union.
20482049
* @note We do not admit singleton types in or-types as lubs.
20492050
*/
2050-
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain)", subtyping, show = true) /*<|<*/ {
2051+
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain, isSoft=$isSoft)", subtyping, show = true) /*<|<*/ {
20512052
if (tp1 eq tp2) tp1
20522053
else if (!tp1.exists) tp1
20532054
else if (!tp2.exists) tp2
@@ -2073,8 +2074,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20732074
def widen(tp: Type) = if (widenInUnions) tp.widen else tp.widenIfUnstable
20742075
val tp1w = widen(tp1)
20752076
val tp2w = widen(tp2)
2076-
if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain)
2077-
else orType(tp1w, tp2w) // no need to check subtypes again
2077+
if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain = canConstrain, isSoft = isSoft)
2078+
else orType(tp1w, tp2w, isSoft = isSoft) // no need to check subtypes again
20782079
}
20792080
mergedLub(tp1.stripLazyRef, tp2.stripLazyRef)
20802081
}
@@ -2183,11 +2184,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
21832184
case tp2 @ OrType(tp21, tp22) =>
21842185
val higher1 = mergeIfSuper(tp1, tp21, canConstrain)
21852186
if (higher1 eq tp21) tp2
2186-
else if (higher1.exists) higher1 | tp22
2187+
else if (higher1.exists) lub(higher1, tp22, isSoft = tp2.isSoft)
21872188
else {
21882189
val higher2 = mergeIfSuper(tp1, tp22, canConstrain)
21892190
if (higher2 eq tp22) tp2
2190-
else if (higher2.exists) tp21 | higher2
2191+
else if (higher2.exists) lub(tp21, higher2, isSoft = tp2.isSoft)
21912192
else NoType
21922193
}
21932194
case _ =>
@@ -2235,17 +2236,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
22352236
* ExprType, LambdaType). Also, when forming an `|`,
22362237
* instantiated TypeVars are dereferenced and annotations are stripped.
22372238
*
2239+
* @param isSoft If the result is a union, this determines whether it's a soft union.
22382240
* @param isErased Apply erasure semantics. If erased is true, instead of creating
22392241
* an OrType, the lub will be computed using TypeCreator#erasedLub.
22402242
*/
2241-
final def orType(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type = {
2242-
val t1 = distributeOr(tp1, tp2)
2243+
final def orType(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type = {
2244+
val t1 = distributeOr(tp1, tp2, isSoft)
22432245
if (t1.exists) t1
22442246
else {
2245-
val t2 = distributeOr(tp2, tp1)
2247+
val t2 = distributeOr(tp2, tp1, isSoft)
22462248
if (t2.exists) t2
22472249
else if (isErased) erasedLub(tp1, tp2)
2248-
else liftIfHK(tp1, tp2, OrType(_, _, soft = true), _ | _, _ & _)
2250+
else liftIfHK(tp1, tp2, OrType(_, _, soft = isSoft), _ | _, _ & _)
22492251
}
22502252
}
22512253

@@ -2333,18 +2335,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
23332335
*
23342336
* The rhs is a proper supertype of the lhs.
23352337
*/
2336-
private def distributeOr(tp1: Type, tp2: Type): Type = tp1 match {
2338+
private def distributeOr(tp1: Type, tp2: Type, isSoft: Boolean = true): Type = tp1 match {
23372339
case ExprType(rt1) =>
23382340
tp2 match {
23392341
case ExprType(rt2) =>
2340-
ExprType(rt1 | rt2)
2342+
ExprType(lub(rt1, rt2, isSoft = isSoft))
23412343
case _ =>
23422344
NoType
23432345
}
23442346
case tp1: TypeVar if tp1.isInstantiated =>
2345-
tp1.underlying | tp2
2347+
lub(tp1.underlying, tp2, isSoft = isSoft)
23462348
case tp1: AnnotatedType if !tp1.isRefining =>
2347-
tp1.underlying | tp2
2349+
lub(tp1.underlying, tp2, isSoft = isSoft)
23482350
case _ =>
23492351
NoType
23502352
}
@@ -2699,8 +2701,8 @@ object TypeComparer {
26992701
def matchingMethodParams(tp1: MethodType, tp2: MethodType)(using Context): Boolean =
27002702
comparing(_.matchingMethodParams(tp1, tp2))
27012703

2702-
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false)(using Context): Type =
2703-
comparing(_.lub(tp1, tp2, canConstrain))
2704+
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true)(using Context): Type =
2705+
comparing(_.lub(tp1, tp2, canConstrain = canConstrain, isSoft = isSoft))
27042706

27052707
/** The least upper bound of a list of types */
27062708
final def lub(tps: List[Type])(using Context): Type =
@@ -2716,8 +2718,8 @@ object TypeComparer {
27162718
def glb(tps: List[Type])(using Context): Type =
27172719
tps.foldLeft(defn.AnyType: Type)(glb)
27182720

2719-
def orType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type =
2720-
comparing(_.orType(tp1, tp2, isErased))
2721+
def orType(using Context)(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type =
2722+
comparing(_.orType(tp1, tp2, isSoft = isSoft, isErased = isErased))
27212723

27222724
def andType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type =
27232725
comparing(_.andType(tp1, tp2, isErased))
@@ -2946,9 +2948,9 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
29462948
super.hasMatchingMember(name, tp1, tp2)
29472949
}
29482950

2949-
override def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type =
2950-
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain)") {
2951-
super.lub(tp1, tp2, canConstrain)
2951+
override def lub(tp1: Type, tp2: Type, canConstrain: Boolean, isSoft: Boolean): Type =
2952+
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain, isSoft=$isSoft)") {
2953+
super.lub(tp1, tp2, canConstrain, isSoft)
29522954
}
29532955

29542956
override def glb(tp1: Type, tp2: Type): Type =

tests/pos/preserve-union.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class A {
2+
val a: Int | String = 1
3+
val b: AnyVal = 2
4+
5+
val c = List(a, b)
6+
val c1: List[AnyVal | String] = c
7+
}

0 commit comments

Comments
 (0)