diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index d6971027682b..406ec58b9846 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -309,8 +309,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds, val r1 = recur(tp.tp1, fromBelow) val r2 = recur(tp.tp2, fromBelow) if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp - else if tp.isAnd then r1 & r2 - else r1 | r2 + else tp.match + case tp: OrType => + TypeComparer.lub(r1, r2, isSoft = tp.isSoft) + case _ => + r1 & r2 case tp: TypeParamRef => if tp eq param then if fromBelow then defn.NothingType else defn.AnyType diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 47f36255de81..66305a306daa 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2018,12 +2018,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling val tp2a = dropIfSuper(tp2, tp1) if tp2a ne tp2 then glb(tp1, tp2a) else tp2 match // normalize to disjunctive normal form if possible. - case OrType(tp21, tp22) => - tp1 & tp21 | tp1 & tp22 + case tp2 @ OrType(tp21, tp22) => + lub(tp1 & tp21, tp1 & tp22, isSoft = tp2.isSoft) case _ => tp1 match - case OrType(tp11, tp12) => - tp11 & tp2 | tp12 & tp2 + case tp1 @ OrType(tp11, tp12) => + lub(tp11 & tp2, tp12 & tp2, isSoft = tp1.isSoft) case tp1: ConstantType => tp2 match case tp2: ConstantType => @@ -2045,9 +2045,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling /** The least upper bound of two types * @param canConstrain If true, new constraints might be added to simplify the lub. + * @param isSoft If the lub is a union, this determines whether it's a soft union. * @note We do not admit singleton types in or-types as lubs. */ - def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain)", subtyping, show = true) /*<|<*/ { + def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain, isSoft=$isSoft)", subtyping, show = true) /*<|<*/ { if (tp1 eq tp2) tp1 else if (!tp1.exists) tp1 else if (!tp2.exists) tp2 @@ -2073,8 +2074,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def widen(tp: Type) = if (widenInUnions) tp.widen else tp.widenIfUnstable val tp1w = widen(tp1) val tp2w = widen(tp2) - if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain) - else orType(tp1w, tp2w) // no need to check subtypes again + if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain = canConstrain, isSoft = isSoft) + else orType(tp1w, tp2w, isSoft = isSoft) // no need to check subtypes again } mergedLub(tp1.stripLazyRef, tp2.stripLazyRef) } @@ -2183,11 +2184,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case tp2 @ OrType(tp21, tp22) => val higher1 = mergeIfSuper(tp1, tp21, canConstrain) if (higher1 eq tp21) tp2 - else if (higher1.exists) higher1 | tp22 + else if (higher1.exists) lub(higher1, tp22, isSoft = tp2.isSoft) else { val higher2 = mergeIfSuper(tp1, tp22, canConstrain) if (higher2 eq tp22) tp2 - else if (higher2.exists) tp21 | higher2 + else if (higher2.exists) lub(tp21, higher2, isSoft = tp2.isSoft) else NoType } case _ => @@ -2235,17 +2236,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling * ExprType, LambdaType). Also, when forming an `|`, * instantiated TypeVars are dereferenced and annotations are stripped. * + * @param isSoft If the result is a union, this determines whether it's a soft union. * @param isErased Apply erasure semantics. If erased is true, instead of creating * an OrType, the lub will be computed using TypeCreator#erasedLub. */ - final def orType(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type = { - val t1 = distributeOr(tp1, tp2) + final def orType(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type = { + val t1 = distributeOr(tp1, tp2, isSoft) if (t1.exists) t1 else { - val t2 = distributeOr(tp2, tp1) + val t2 = distributeOr(tp2, tp1, isSoft) if (t2.exists) t2 else if (isErased) erasedLub(tp1, tp2) - else liftIfHK(tp1, tp2, OrType(_, _, soft = true), _ | _, _ & _) + else liftIfHK(tp1, tp2, OrType(_, _, soft = isSoft), _ | _, _ & _) } } @@ -2333,18 +2335,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling * * The rhs is a proper supertype of the lhs. */ - private def distributeOr(tp1: Type, tp2: Type): Type = tp1 match { + private def distributeOr(tp1: Type, tp2: Type, isSoft: Boolean = true): Type = tp1 match { case ExprType(rt1) => tp2 match { case ExprType(rt2) => - ExprType(rt1 | rt2) + ExprType(lub(rt1, rt2, isSoft = isSoft)) case _ => NoType } case tp1: TypeVar if tp1.isInstantiated => - tp1.underlying | tp2 + lub(tp1.underlying, tp2, isSoft = isSoft) case tp1: AnnotatedType if !tp1.isRefining => - tp1.underlying | tp2 + lub(tp1.underlying, tp2, isSoft = isSoft) case _ => NoType } @@ -2699,8 +2701,8 @@ object TypeComparer { def matchingMethodParams(tp1: MethodType, tp2: MethodType)(using Context): Boolean = comparing(_.matchingMethodParams(tp1, tp2)) - def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false)(using Context): Type = - comparing(_.lub(tp1, tp2, canConstrain)) + def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true)(using Context): Type = + comparing(_.lub(tp1, tp2, canConstrain = canConstrain, isSoft = isSoft)) /** The least upper bound of a list of types */ final def lub(tps: List[Type])(using Context): Type = @@ -2716,8 +2718,8 @@ object TypeComparer { def glb(tps: List[Type])(using Context): Type = tps.foldLeft(defn.AnyType: Type)(glb) - def orType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type = - comparing(_.orType(tp1, tp2, isErased)) + def orType(using Context)(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type = + comparing(_.orType(tp1, tp2, isSoft = isSoft, isErased = isErased)) def andType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type = comparing(_.andType(tp1, tp2, isErased)) @@ -2946,9 +2948,9 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) { super.hasMatchingMember(name, tp1, tp2) } - override def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type = - traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain)") { - super.lub(tp1, tp2, canConstrain) + override def lub(tp1: Type, tp2: Type, canConstrain: Boolean, isSoft: Boolean): Type = + traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain, isSoft=$isSoft)") { + super.lub(tp1, tp2, canConstrain, isSoft) } override def glb(tp1: Type, tp2: Type): Type = diff --git a/tests/pos/preserve-union.scala b/tests/pos/preserve-union.scala new file mode 100644 index 000000000000..a56fcbef1b84 --- /dev/null +++ b/tests/pos/preserve-union.scala @@ -0,0 +1,7 @@ +class A { + val a: Int | String = 1 + val b: AnyVal = 2 + + val c = List(a, b) + val c1: List[AnyVal | String] = c +}