Skip to content

Commit ace235c

Browse files
committed
Preserve hard unions in more situations
There's multiple places where we take apart a union, apply some transformation to its parts, then either return the original union if nothing changed or recombine the transformed parts into a new union. The problem is that to check if nothing changed we use referential equality when `=:=` would be more correct, so we sometimes end up returning a union equivalent to the original one except the new one is a soft union. I've seen this lead to subtle type inference differences both when experimenting with changes to constraint solving and when trying to replace our Uniques hashmaps by weak hashmaps. We could fix this by replacing these `eq` calls by `=:=` but instead this commit takes the approach of preserving the softness of a union even when its components are modified, see the test case.
1 parent d099811 commit ace235c

File tree

3 files changed

+38
-26
lines changed

3 files changed

+38
-26
lines changed

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
309309
val r1 = recur(tp.tp1, fromBelow)
310310
val r2 = recur(tp.tp2, fromBelow)
311311
if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp
312-
else if tp.isAnd then r1 & r2
313-
else r1 | r2
312+
else tp.match
313+
case tp: OrType =>
314+
TypeComparer.lub(r1, r2, isSoft = tp.isSoft)
315+
case _ =>
316+
r1 & r2
314317
case tp: TypeParamRef =>
315318
if tp eq param then
316319
if fromBelow then defn.NothingType else defn.AnyType

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,12 +2018,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20182018
val tp2a = dropIfSuper(tp2, tp1)
20192019
if tp2a ne tp2 then glb(tp1, tp2a)
20202020
else tp2 match // normalize to disjunctive normal form if possible.
2021-
case OrType(tp21, tp22) =>
2022-
tp1 & tp21 | tp1 & tp22
2021+
case tp2 @ OrType(tp21, tp22) =>
2022+
lub(tp1 & tp21, tp1 & tp22, isSoft = tp2.isSoft)
20232023
case _ =>
20242024
tp1 match
2025-
case OrType(tp11, tp12) =>
2026-
tp11 & tp2 | tp12 & tp2
2025+
case tp1 @ OrType(tp11, tp12) =>
2026+
lub(tp11 & tp2, tp12 & tp2, isSoft = tp1.isSoft)
20272027
case tp1: ConstantType =>
20282028
tp2 match
20292029
case tp2: ConstantType =>
@@ -2045,9 +2045,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20452045

20462046
/** The least upper bound of two types
20472047
* @param canConstrain If true, new constraints might be added to simplify the lub.
2048+
* @param isSoft If the lub is a union, this determines whether it's a soft union.
20482049
* @note We do not admit singleton types in or-types as lubs.
20492050
*/
2050-
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain)", subtyping, show = true) /*<|<*/ {
2051+
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain, isSoft=$isSoft)", subtyping, show = true) /*<|<*/ {
20512052
if (tp1 eq tp2) tp1
20522053
else if (!tp1.exists) tp1
20532054
else if (!tp2.exists) tp2
@@ -2073,8 +2074,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20732074
def widen(tp: Type) = if (widenInUnions) tp.widen else tp.widenIfUnstable
20742075
val tp1w = widen(tp1)
20752076
val tp2w = widen(tp2)
2076-
if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain)
2077-
else orType(tp1w, tp2w) // no need to check subtypes again
2077+
if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain = canConstrain, isSoft = isSoft)
2078+
else orType(tp1w, tp2w, isSoft = isSoft) // no need to check subtypes again
20782079
}
20792080
mergedLub(tp1.stripLazyRef, tp2.stripLazyRef)
20802081
}
@@ -2183,11 +2184,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
21832184
case tp2 @ OrType(tp21, tp22) =>
21842185
val higher1 = mergeIfSuper(tp1, tp21, canConstrain)
21852186
if (higher1 eq tp21) tp2
2186-
else if (higher1.exists) higher1 | tp22
2187+
else if (higher1.exists) lub(higher1, tp22, isSoft = tp2.isSoft)
21872188
else {
21882189
val higher2 = mergeIfSuper(tp1, tp22, canConstrain)
21892190
if (higher2 eq tp22) tp2
2190-
else if (higher2.exists) tp21 | higher2
2191+
else if (higher2.exists) lub(tp21, higher2, isSoft = tp2.isSoft)
21912192
else NoType
21922193
}
21932194
case _ =>
@@ -2235,17 +2236,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
22352236
* ExprType, LambdaType). Also, when forming an `|`,
22362237
* instantiated TypeVars are dereferenced and annotations are stripped.
22372238
*
2239+
* @param isSoft If the result is a union, this determines whether it's a soft union.
22382240
* @param isErased Apply erasure semantics. If erased is true, instead of creating
22392241
* an OrType, the lub will be computed using TypeCreator#erasedLub.
22402242
*/
2241-
final def orType(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type = {
2242-
val t1 = distributeOr(tp1, tp2)
2243+
final def orType(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type = {
2244+
val t1 = distributeOr(tp1, tp2, isSoft)
22432245
if (t1.exists) t1
22442246
else {
2245-
val t2 = distributeOr(tp2, tp1)
2247+
val t2 = distributeOr(tp2, tp1, isSoft)
22462248
if (t2.exists) t2
22472249
else if (isErased) erasedLub(tp1, tp2)
2248-
else liftIfHK(tp1, tp2, OrType(_, _, soft = true), _ | _, _ & _)
2250+
else liftIfHK(tp1, tp2, OrType(_, _, soft = isSoft), _ | _, _ & _)
22492251
}
22502252
}
22512253

@@ -2333,18 +2335,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
23332335
*
23342336
* The rhs is a proper supertype of the lhs.
23352337
*/
2336-
private def distributeOr(tp1: Type, tp2: Type): Type = tp1 match {
2338+
private def distributeOr(tp1: Type, tp2: Type, isSoft: Boolean = true): Type = tp1 match {
23372339
case ExprType(rt1) =>
23382340
tp2 match {
23392341
case ExprType(rt2) =>
2340-
ExprType(rt1 | rt2)
2342+
ExprType(lub(rt1, rt2, isSoft = isSoft))
23412343
case _ =>
23422344
NoType
23432345
}
23442346
case tp1: TypeVar if tp1.isInstantiated =>
2345-
tp1.underlying | tp2
2347+
lub(tp1.underlying, tp2, isSoft = isSoft)
23462348
case tp1: AnnotatedType if !tp1.isRefining =>
2347-
tp1.underlying | tp2
2349+
lub(tp1.underlying, tp2, isSoft = isSoft)
23482350
case _ =>
23492351
NoType
23502352
}
@@ -2699,8 +2701,8 @@ object TypeComparer {
26992701
def matchingMethodParams(tp1: MethodType, tp2: MethodType)(using Context): Boolean =
27002702
comparing(_.matchingMethodParams(tp1, tp2))
27012703

2702-
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false)(using Context): Type =
2703-
comparing(_.lub(tp1, tp2, canConstrain))
2704+
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true)(using Context): Type =
2705+
comparing(_.lub(tp1, tp2, canConstrain = canConstrain, isSoft = isSoft))
27042706

27052707
/** The least upper bound of a list of types */
27062708
final def lub(tps: List[Type])(using Context): Type =
@@ -2716,8 +2718,8 @@ object TypeComparer {
27162718
def glb(tps: List[Type])(using Context): Type =
27172719
tps.foldLeft(defn.AnyType: Type)(glb)
27182720

2719-
def orType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type =
2720-
comparing(_.orType(tp1, tp2, isErased))
2721+
def orType(using Context)(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type =
2722+
comparing(_.orType(tp1, tp2, isSoft = isSoft, isErased = isErased))
27212723

27222724
def andType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type =
27232725
comparing(_.andType(tp1, tp2, isErased))
@@ -2946,9 +2948,9 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
29462948
super.hasMatchingMember(name, tp1, tp2)
29472949
}
29482950

2949-
override def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type =
2950-
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain)") {
2951-
super.lub(tp1, tp2, canConstrain)
2951+
override def lub(tp1: Type, tp2: Type, canConstrain: Boolean, isSoft: Boolean): Type =
2952+
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain, isSoft=$isSoft)") {
2953+
super.lub(tp1, tp2, canConstrain, isSoft)
29522954
}
29532955

29542956
override def glb(tp1: Type, tp2: Type): Type =

tests/pos/preserve-union.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class A {
2+
val a: Int | String = 1
3+
val b: AnyVal = 2
4+
5+
val c = List(a, b)
6+
val c1: List[AnyVal | String] = c
7+
}

0 commit comments

Comments
 (0)