From ace235cc7d63a15abad725cff7d82a84f8b7273a Mon Sep 17 00:00:00 2001 From: Guillaume Martres Date: Sat, 29 May 2021 17:12:45 +0200 Subject: [PATCH] Preserve hard unions in more situations There's multiple places where we take apart a union, apply some transformation to its parts, then either return the original union if nothing changed or recombine the transformed parts into a new union. The problem is that to check if nothing changed we use referential equality when `=:=` would be more correct, so we sometimes end up returning a union equivalent to the original one except the new one is a soft union. I've seen this lead to subtle type inference differences both when experimenting with changes to constraint solving and when trying to replace our Uniques hashmaps by weak hashmaps. We could fix this by replacing these `eq` calls by `=:=` but instead this commit takes the approach of preserving the softness of a union even when its components are modified, see the test case. --- .../tools/dotc/core/OrderingConstraint.scala | 7 ++- .../dotty/tools/dotc/core/TypeComparer.scala | 50 ++++++++++--------- tests/pos/preserve-union.scala | 7 +++ 3 files changed, 38 insertions(+), 26 deletions(-) create mode 100644 tests/pos/preserve-union.scala diff --git a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala index d6971027682b..406ec58b9846 100644 --- a/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala @@ -309,8 +309,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds, val r1 = recur(tp.tp1, fromBelow) val r2 = recur(tp.tp2, fromBelow) if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp - else if tp.isAnd then r1 & r2 - else r1 | r2 + else tp.match + case tp: OrType => + TypeComparer.lub(r1, r2, isSoft = tp.isSoft) + case _ => + r1 & r2 case tp: TypeParamRef => if tp eq param then if fromBelow then defn.NothingType else defn.AnyType diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 47f36255de81..66305a306daa 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -2018,12 +2018,12 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling val tp2a = dropIfSuper(tp2, tp1) if tp2a ne tp2 then glb(tp1, tp2a) else tp2 match // normalize to disjunctive normal form if possible. - case OrType(tp21, tp22) => - tp1 & tp21 | tp1 & tp22 + case tp2 @ OrType(tp21, tp22) => + lub(tp1 & tp21, tp1 & tp22, isSoft = tp2.isSoft) case _ => tp1 match - case OrType(tp11, tp12) => - tp11 & tp2 | tp12 & tp2 + case tp1 @ OrType(tp11, tp12) => + lub(tp11 & tp2, tp12 & tp2, isSoft = tp1.isSoft) case tp1: ConstantType => tp2 match case tp2: ConstantType => @@ -2045,9 +2045,10 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling /** The least upper bound of two types * @param canConstrain If true, new constraints might be added to simplify the lub. + * @param isSoft If the lub is a union, this determines whether it's a soft union. * @note We do not admit singleton types in or-types as lubs. */ - def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain)", subtyping, show = true) /*<|<*/ { + def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true): Type = /*>|>*/ trace(s"lub(${tp1.show}, ${tp2.show}, canConstrain=$canConstrain, isSoft=$isSoft)", subtyping, show = true) /*<|<*/ { if (tp1 eq tp2) tp1 else if (!tp1.exists) tp1 else if (!tp2.exists) tp2 @@ -2073,8 +2074,8 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling def widen(tp: Type) = if (widenInUnions) tp.widen else tp.widenIfUnstable val tp1w = widen(tp1) val tp2w = widen(tp2) - if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain) - else orType(tp1w, tp2w) // no need to check subtypes again + if ((tp1 ne tp1w) || (tp2 ne tp2w)) lub(tp1w, tp2w, canConstrain = canConstrain, isSoft = isSoft) + else orType(tp1w, tp2w, isSoft = isSoft) // no need to check subtypes again } mergedLub(tp1.stripLazyRef, tp2.stripLazyRef) } @@ -2183,11 +2184,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling case tp2 @ OrType(tp21, tp22) => val higher1 = mergeIfSuper(tp1, tp21, canConstrain) if (higher1 eq tp21) tp2 - else if (higher1.exists) higher1 | tp22 + else if (higher1.exists) lub(higher1, tp22, isSoft = tp2.isSoft) else { val higher2 = mergeIfSuper(tp1, tp22, canConstrain) if (higher2 eq tp22) tp2 - else if (higher2.exists) tp21 | higher2 + else if (higher2.exists) lub(tp21, higher2, isSoft = tp2.isSoft) else NoType } case _ => @@ -2235,17 +2236,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling * ExprType, LambdaType). Also, when forming an `|`, * instantiated TypeVars are dereferenced and annotations are stripped. * + * @param isSoft If the result is a union, this determines whether it's a soft union. * @param isErased Apply erasure semantics. If erased is true, instead of creating * an OrType, the lub will be computed using TypeCreator#erasedLub. */ - final def orType(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type = { - val t1 = distributeOr(tp1, tp2) + final def orType(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type = { + val t1 = distributeOr(tp1, tp2, isSoft) if (t1.exists) t1 else { - val t2 = distributeOr(tp2, tp1) + val t2 = distributeOr(tp2, tp1, isSoft) if (t2.exists) t2 else if (isErased) erasedLub(tp1, tp2) - else liftIfHK(tp1, tp2, OrType(_, _, soft = true), _ | _, _ & _) + else liftIfHK(tp1, tp2, OrType(_, _, soft = isSoft), _ | _, _ & _) } } @@ -2333,18 +2335,18 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling * * The rhs is a proper supertype of the lhs. */ - private def distributeOr(tp1: Type, tp2: Type): Type = tp1 match { + private def distributeOr(tp1: Type, tp2: Type, isSoft: Boolean = true): Type = tp1 match { case ExprType(rt1) => tp2 match { case ExprType(rt2) => - ExprType(rt1 | rt2) + ExprType(lub(rt1, rt2, isSoft = isSoft)) case _ => NoType } case tp1: TypeVar if tp1.isInstantiated => - tp1.underlying | tp2 + lub(tp1.underlying, tp2, isSoft = isSoft) case tp1: AnnotatedType if !tp1.isRefining => - tp1.underlying | tp2 + lub(tp1.underlying, tp2, isSoft = isSoft) case _ => NoType } @@ -2699,8 +2701,8 @@ object TypeComparer { def matchingMethodParams(tp1: MethodType, tp2: MethodType)(using Context): Boolean = comparing(_.matchingMethodParams(tp1, tp2)) - def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false)(using Context): Type = - comparing(_.lub(tp1, tp2, canConstrain)) + def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false, isSoft: Boolean = true)(using Context): Type = + comparing(_.lub(tp1, tp2, canConstrain = canConstrain, isSoft = isSoft)) /** The least upper bound of a list of types */ final def lub(tps: List[Type])(using Context): Type = @@ -2716,8 +2718,8 @@ object TypeComparer { def glb(tps: List[Type])(using Context): Type = tps.foldLeft(defn.AnyType: Type)(glb) - def orType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type = - comparing(_.orType(tp1, tp2, isErased)) + def orType(using Context)(tp1: Type, tp2: Type, isSoft: Boolean = true, isErased: Boolean = ctx.erasedTypes): Type = + comparing(_.orType(tp1, tp2, isSoft = isSoft, isErased = isErased)) def andType(using Context)(tp1: Type, tp2: Type, isErased: Boolean = ctx.erasedTypes): Type = comparing(_.andType(tp1, tp2, isErased)) @@ -2946,9 +2948,9 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) { super.hasMatchingMember(name, tp1, tp2) } - override def lub(tp1: Type, tp2: Type, canConstrain: Boolean = false): Type = - traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain)") { - super.lub(tp1, tp2, canConstrain) + override def lub(tp1: Type, tp2: Type, canConstrain: Boolean, isSoft: Boolean): Type = + traceIndented(s"lub(${show(tp1)}, ${show(tp2)}, canConstrain=$canConstrain, isSoft=$isSoft)") { + super.lub(tp1, tp2, canConstrain, isSoft) } override def glb(tp1: Type, tp2: Type): Type = diff --git a/tests/pos/preserve-union.scala b/tests/pos/preserve-union.scala new file mode 100644 index 000000000000..a56fcbef1b84 --- /dev/null +++ b/tests/pos/preserve-union.scala @@ -0,0 +1,7 @@ +class A { + val a: Int | String = 1 + val b: AnyVal = 2 + + val c = List(a, b) + val c1: List[AnyVal | String] = c +}