Skip to content

Preserve hard unions in more situations #12654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
val r1 = recur(tp.tp1, fromBelow)
val r2 = recur(tp.tp2, fromBelow)
if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp
else if tp.isAnd then r1 & r2
else r1 | r2
else tp.match
case tp: OrType =>
TypeComparer.lub(r1, r2, isSoft = tp.isSoft)
case _ =>
r1 & r2
case tp: TypeParamRef =>
if tp eq param then
if fromBelow then defn.NothingType else defn.AnyType
Expand Down
50 changes: 26 additions & 24 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2018,12 +2018,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
val tp2a = dropIfSuper(tp2, tp1)
if tp2a ne tp2 then glb(tp1, tp2a)
else tp2 match // normalize to disjunctive normal form if possible.
case OrType(tp21, tp22) =>
tp1 & tp21 | tp1 & tp22
case tp2 @ OrType(tp21, tp22) =>
lub(tp1 & tp21, tp1 & tp22, isSoft = tp2.isSoft)
case _ =>
tp1 match
case OrType(tp11, tp12) =>
tp11 & tp2 | tp12 & tp2
case tp1 @ OrType(tp11, tp12) =>
lub(tp11 & tp2, tp12 & tp2, isSoft = tp1.isSoft)
case tp1: ConstantType =>
tp2 match
case tp2: ConstantType =>
Expand All @@ -2045,9 +2045,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling

/** The least upper bound of two types
* @param canConstrain If true, new constraints might be added to simplify the lub.
* @param isSoft If the lub is a union, this determines whether it's a soft union.
* @note We do not admit singleton types in or-types as lubs.
*/
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain)", subtyping, show = true) /*<|<*/ {
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain, isSoft=$isSoft)", subtyping, show = true) /*<|<*/ {
if (tp1 eq tp2) tp1
else if (!tp1.exists) tp1
else if (!tp2.exists) tp2
Expand All @@ -2073,8 +2074,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
def widen(tp: Type) = if (widenInUnions) tp.widen else tp.widenIfUnstable
val tp1w = widen(tp1)
val tp2w = widen(tp2)
if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain)
else orType(tp1w, tp2w) // no need to check subtypes again
if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain = canConstrain, isSoft = isSoft)
else orType(tp1w, tp2w, isSoft = isSoft) // no need to check subtypes again
}
mergedLub(tp1.stripLazyRef, tp2.stripLazyRef)
}
Expand Down Expand Up @@ -2183,11 +2184,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
case tp2 @ OrType(tp21, tp22) =>
val higher1 = mergeIfSuper(tp1, tp21, canConstrain)
if (higher1 eq tp21) tp2
else if (higher1.exists) higher1 | tp22
else if (higher1.exists) lub(higher1, tp22, isSoft = tp2.isSoft)
else {
val higher2 = mergeIfSuper(tp1, tp22, canConstrain)
if (higher2 eq tp22) tp2
else if (higher2.exists) tp21 | higher2
else if (higher2.exists) lub(tp21, higher2, isSoft = tp2.isSoft)
else NoType
}
case _ =>
Expand Down Expand Up @@ -2235,17 +2236,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
* ExprType, LambdaType). Also, when forming an `|`,
* instantiated TypeVars are dereferenced and annotations are stripped.
*
* @param isSoft If the result is a union, this determines whether it's a soft union.
* @param isErased Apply erasure semantics. If erased is true, instead of creating
* an OrType, the lub will be computed using TypeCreator#erasedLub.
*/
final def orType(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type = {
val t1 = distributeOr(tp1, tp2)
final def orType(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type = {
val t1 = distributeOr(tp1, tp2, isSoft)
if (t1.exists) t1
else {
val t2 = distributeOr(tp2, tp1)
val t2 = distributeOr(tp2, tp1, isSoft)
if (t2.exists) t2
else if (isErased) erasedLub(tp1, tp2)
else liftIfHK(tp1, tp2, OrType(_, _, soft = true), _ | _, _ & _)
else liftIfHK(tp1, tp2, OrType(_, _, soft = isSoft), _ | _, _ & _)
}
}

Expand Down Expand Up @@ -2333,18 +2335,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
*
* The rhs is a proper supertype of the lhs.
*/
private def distributeOr(tp1: Type, tp2: Type): Type = tp1 match {
private def distributeOr(tp1: Type, tp2: Type, isSoft: Boolean = true): Type = tp1 match {
case ExprType(rt1) =>
tp2 match {
case ExprType(rt2) =>
ExprType(rt1 | rt2)
ExprType(lub(rt1, rt2, isSoft = isSoft))
case _ =>
NoType
}
case tp1: TypeVar if tp1.isInstantiated =>
tp1.underlying | tp2
lub(tp1.underlying, tp2, isSoft = isSoft)
case tp1: AnnotatedType if !tp1.isRefining =>
tp1.underlying | tp2
lub(tp1.underlying, tp2, isSoft = isSoft)
case _ =>
NoType
}
Expand Down Expand Up @@ -2699,8 +2701,8 @@ object TypeComparer {
def matchingMethodParams(tp1: MethodType, tp2: MethodType)(using Context): Boolean =
comparing(_.matchingMethodParams(tp1, tp2))

def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false)(using Context): Type =
comparing(_.lub(tp1, tp2, canConstrain))
def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true)(using Context): Type =
comparing(_.lub(tp1, tp2, canConstrain = canConstrain, isSoft = isSoft))

/** The least upper bound of a list of types */
final def lub(tps: List[Type])(using Context): Type =
Expand All @@ -2716,8 +2718,8 @@ object TypeComparer {
def glb(tps: List[Type])(using Context): Type =
tps.foldLeft(defn.AnyType: Type)(glb)

def orType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type =
comparing(_.orType(tp1, tp2, isErased))
def orType(using Context)(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type =
comparing(_.orType(tp1, tp2, isSoft = isSoft, isErased = isErased))

def andType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type =
comparing(_.andType(tp1, tp2, isErased))
Expand Down Expand Up @@ -2946,9 +2948,9 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
super.hasMatchingMember(name, tp1, tp2)
}

override def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type =
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain)") {
super.lub(tp1, tp2, canConstrain)
override def lub(tp1: Type, tp2: Type, canConstrain: Boolean, isSoft: Boolean): Type =
traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain, isSoft=$isSoft)") {
super.lub(tp1, tp2, canConstrain, isSoft)
}

override def glb(tp1: Type, tp2: Type): Type =
Expand Down
7 changes: 7 additions & 0 deletions tests/pos/preserve-union.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class A {
val a: Int | String = 1
val b: AnyVal = 2

val c = List(a, b)
val c1: List[AnyVal | String] = c
}