Skip to content

Commit cde7d7d

Browse files
committed
Treat soft and hard unions differently when widening
1 parent c5b491f commit cde7d7d

File tree

6 files changed

+150
-63
lines changed

6 files changed

+150
-63
lines changed

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
428428
case Some(b) => return b
429429
case None =>
430430

431+
def widenOK =
432+
(tp2.widenSingletons eq tp2)
433+
&& (tp1.widenSingletons ne tp1)
434+
&& recur(tp1.widenSingletons, tp2)
435+
431436
def joinOK = tp2.dealiasKeepRefiningAnnots match {
432437
case tp2: AppliedType if !tp2.tycon.typeSymbol.isClass =>
433438
// If we apply the default algorithm for `A[X] | B[Y] <: C[Z]` where `C` is a
@@ -439,25 +444,30 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
439444
false
440445
}
441446

447+
// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
448+
// before splitting the LHS into its constituents. That way, the RHS variables are
449+
// constraint by the hard union and can be instantiated to it. If we just split and add
450+
// the two parts of the LHS separately to the constraint, the lower bound would become
451+
// a soft union.
452+
def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match
453+
case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
454+
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
455+
case _ => true
456+
457+
// An & on the left side loses information. We compensate by also trying the join.
458+
// This is less ad-hoc than it looks since we produce joins in type inference,
459+
// and then need to check that they are indeed supertypes of the original types
460+
// under -Ycheck. Test case is i7965.scala.
442461
def containsAnd(tp: Type): Boolean = tp.dealiasKeepRefiningAnnots match
443462
case tp: AndType => true
444463
case OrType(tp1, tp2) => containsAnd(tp1) || containsAnd(tp2)
445464
case _ => false
446465

447-
def widenOK =
448-
(tp2.widenSingletons eq tp2) &&
449-
(tp1.widenSingletons ne tp1) &&
450-
recur(tp1.widenSingletons, tp2)
451-
452466
widenOK
453467
|| joinOK
454-
|| recur(tp11, tp2) && recur(tp12, tp2)
468+
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
455469
|| containsAnd(tp1) && recur(tp1.join, tp2)
456-
// An & on the left side loses information. Compensate by also trying the join.
457-
// This is less ad-hoc than it looks since we produce joins in type inference,
458-
// and then need to check that they are indeed supertypes of the original types
459-
// under -Ycheck. Test case is i7965.scala.
460-
case tp1: MatchType =>
470+
case tp1: MatchType =>
461471
val reduced = tp1.reduced
462472
if (reduced.exists) recur(reduced, tp2) else thirdTry
463473
case _: FlexType =>
@@ -511,35 +521,36 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
511521
fourthTry
512522
}
513523

524+
def compareTypeParamRef(tp2: TypeParamRef): Boolean =
525+
assumedTrue(tp2) || {
526+
val alwaysTrue =
527+
// The following condition is carefully formulated to catch all cases
528+
// where the subtype relation is true without needing to add a constraint
529+
// It's tricky because we might need to either approximate tp2 by its
530+
// lower bound or else widen tp1 and check that the result is a subtype of tp2.
531+
// So if the constraint is not yet frozen, we do the same comparison again
532+
// with a frozen constraint, which means that we get a chance to do the
533+
// widening in `fourthTry` before adding to the constraint.
534+
if (frozenConstraint) recur(tp1, bounds(tp2).lo)
535+
else isSubTypeWhenFrozen(tp1, tp2)
536+
alwaysTrue ||
537+
frozenConstraint && (tp1 match {
538+
case tp1: TypeParamRef => constraint.isLess(tp1, tp2)
539+
case _ => false
540+
}) || {
541+
if (canConstrain(tp2) && !approx.low)
542+
addConstraint(tp2, tp1.widenExpr, fromBelow = true)
543+
else fourthTry
544+
}
545+
}
546+
514547
def thirdTry: Boolean = tp2 match {
515548
case tp2 @ AppliedType(tycon2, args2) =>
516549
compareAppliedType2(tp2, tycon2, args2)
517550
case tp2: NamedType =>
518551
thirdTryNamed(tp2)
519552
case tp2: TypeParamRef =>
520-
def compareTypeParamRef =
521-
assumedTrue(tp2) || {
522-
val alwaysTrue =
523-
// The following condition is carefully formulated to catch all cases
524-
// where the subtype relation is true without needing to add a constraint
525-
// It's tricky because we might need to either approximate tp2 by its
526-
// lower bound or else widen tp1 and check that the result is a subtype of tp2.
527-
// So if the constraint is not yet frozen, we do the same comparison again
528-
// with a frozen constraint, which means that we get a chance to do the
529-
// widening in `fourthTry` before adding to the constraint.
530-
if (frozenConstraint) recur(tp1, bounds(tp2).lo)
531-
else isSubTypeWhenFrozen(tp1, tp2)
532-
alwaysTrue ||
533-
frozenConstraint && (tp1 match {
534-
case tp1: TypeParamRef => constraint.isLess(tp1, tp2)
535-
case _ => false
536-
}) || {
537-
if (canConstrain(tp2) && !approx.low)
538-
addConstraint(tp2, tp1.widenExpr, fromBelow = true)
539-
else fourthTry
540-
}
541-
}
542-
compareTypeParamRef
553+
compareTypeParamRef(tp2)
543554
case tp2: RefinedType =>
544555
def compareRefinedSlow: Boolean = {
545556
val name2 = tp2.refinedName

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,14 @@ object TypeOps:
150150
tp.derivedAlias(simplify(tp.alias, theMap))
151151
case AndType(l, r) if !ctx.mode.is(Mode.Type) =>
152152
simplify(l, theMap) & simplify(r, theMap)
153-
case OrType(l, r) if !ctx.mode.is(Mode.Type) =>
153+
case tp as OrType(l, r)
154+
if !ctx.mode.is(Mode.Type)
155+
&& (tp.isSoft || defn.isBottomType(l) || defn.isBottomType(r)) =>
156+
// Normalize A | Null and Null | A to A even if the union is hard (i.e.
157+
// explicitly declared), but not if -Yexplicit-nulls is set. The reason is
158+
// that in this case the normal asSeenFrom machinery is not prepared to deal
159+
// with Nulls (which have no base classes). Under -Yexplicit-nulls, we take
160+
// corrective steps, so no widening is wanted.
154161
simplify(l, theMap) | simplify(r, theMap)
155162
case AnnotatedType(parent, annot)
156163
if !ctx.mode.is(Mode.Type) && annot.symbol == defn.UncheckedVarianceAnnot =>

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 68 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,10 +1151,11 @@ object Types {
11511151
case _ => this
11521152
}
11531153

1154-
/** Widen this type and if the result contains embedded union types, replace
1154+
/** Widen this type and if the result contains embedded soft union types, replace
11551155
* them by their joins.
1156-
* "Embedded" means: inside type lambdas, intersections or recursive types, or in prefixes of refined types.
1157-
* If an embedded union is found, we first try to simplify or eliminate it by
1156+
* "Embedded" means: inside type lambdas, intersections or recursive types,
1157+
* in prefixes of refined types, or in hard union types.
1158+
* If an embedded soft union is found, we first try to simplify or eliminate it by
11581159
* re-lubbing it while allowing type parameters to be constrained further.
11591160
* Any remaining union types are replaced by their joins.
11601161
*
@@ -1165,36 +1166,78 @@ object Types {
11651166
* is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]`
11661167
* instead of `ArrayBuffer[? >: Int | A <: Int & A]`
11671168
*
1169+
* Hard unions inside soft ones are treated specially. For illustration assume we
1170+
* want to widen the type `(A | C) \/ (B | C)` where `\/` means soft union and `|`
1171+
* means hard union. In that case, the hard unions `A | C` and `B | C` are treated
1172+
* in an asymmetric way. Only the first parts `A` and `B` are joined and the rest
1173+
* is added again with a hard union to the result. So
1174+
*
1175+
* widenUnion[ (A | C) \/ (B | C) ]
1176+
* = widenUnion[ A \/ B ] | C | C
1177+
* = D | C | C
1178+
* = D | C
1179+
*
1180+
* In general, If a hard union A | B_1 | ... | B_n is part of of a soft union,
1181+
* only A forms part of the join, and B_1, ..., B_n are pushed out, just `C` is
1182+
* pushed out above. All types that are pushed out are recombined with the result
1183+
* of the join with a lub, but that lub yields again a hard union, not a soft one.
1184+
*
11681185
* Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
11691186
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
11701187
*/
1171-
def widenUnion(using Context): Type = widen match {
1188+
def widenUnion(using Context): Type = widen.match {
11721189
case tp @ OrNull(tp1): OrType =>
11731190
// Don't widen `T|Null`, since otherwise we wouldn't be able to infer nullable unions.
11741191
val tp1Widen = tp1.widenUnionWithoutNull
11751192
if (tp1Widen.isRef(defn.AnyClass)) tp1Widen
11761193
else tp.derivedOrType(tp1Widen, defn.NullType)
11771194
case tp =>
11781195
tp.widenUnionWithoutNull
1179-
}
1196+
}.reporting(i"widenUnion($this) = $result")
11801197

1181-
def widenUnionWithoutNull(using Context): Type = widen match {
1182-
case tp @ OrType(lhs, rhs) =>
1183-
TypeComparer.lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true) match {
1184-
case union: OrType => union.join
1185-
case res => res
1186-
}
1187-
case tp @ AndType(tp1, tp2) =>
1188-
tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1189-
case tp: RefinedType =>
1190-
tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
1191-
case tp: RecType =>
1192-
tp.rebind(tp.parent.widenUnion)
1193-
case tp: HKTypeLambda =>
1194-
tp.derivedLambdaType(resType = tp.resType.widenUnion)
1195-
case tp =>
1196-
tp
1197-
}
1198+
def widenUnionWithoutNull(using Context): Type =
1199+
1200+
// Split hard union `A | B1 | ... | Bn` into leftmost part `A` and list of
1201+
// pushed out parts `B1, ..., Bn`.
1202+
def splitAlts(tp: Type, follow: List[Type]): (Type, List[Type]) = tp match
1203+
case tp as OrType(lhs, rhs) if !tp.isSoft =>
1204+
splitAlts(lhs, rhs :: follow)
1205+
case _ =>
1206+
(tp, follow)
1207+
1208+
// Convert any soft unions in result of lub to hard ones */
1209+
def harden(tp: Type): Type = tp match
1210+
case tp as OrType(tp1, tp2) if tp.isSoft =>
1211+
OrType(harden(tp1), harden(tp2), soft = false)
1212+
case _ =>
1213+
tp
1214+
1215+
def recombine(tp1: Type, tp2: Type) = harden(TypeComparer.lub(tp1, tp2))
1216+
1217+
widen match
1218+
case tp @ OrType(lhs, rhs) =>
1219+
if tp.isSoft then
1220+
val (lhsCore, lhsExtras) = splitAlts(lhs.widenUnionWithoutNull, Nil)
1221+
val (rhsCore, rhsExtras) = splitAlts(rhs.widenUnionWithoutNull, Nil)
1222+
val core = TypeComparer.lub(lhsCore, rhsCore, canConstrain = true) match
1223+
case union: OrType => union.join
1224+
case res => res
1225+
rhsExtras.foldLeft(lhsExtras.foldLeft(core)(recombine))(recombine)
1226+
else
1227+
val lhs1 = lhs.widenUnionWithoutNull
1228+
val rhs1 = rhs.widenUnionWithoutNull
1229+
if (lhs1 eq lhs) && (rhs1 eq rhs) then tp else recombine(lhs1, rhs1)
1230+
case tp @ AndType(tp1, tp2) =>
1231+
tp derived_& (tp1.widenUnionWithoutNull, tp2.widenUnionWithoutNull)
1232+
case tp: RefinedType =>
1233+
tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
1234+
case tp: RecType =>
1235+
tp.rebind(tp.parent.widenUnion)
1236+
case tp: HKTypeLambda =>
1237+
tp.derivedLambdaType(resType = tp.resType.widenUnion)
1238+
case tp =>
1239+
tp
1240+
end widenUnionWithoutNull
11981241

11991242
/** Widen all top-level singletons reachable by dealiasing
12001243
* and going to the operands of & and |.
@@ -3054,9 +3097,9 @@ object Types {
30543097
myWidened
30553098
}
30563099

3057-
def derivedOrType(tp1: Type, tp2: Type)(using Context): Type =
3058-
if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this
3059-
else OrType.make(tp1, tp2, isSoft)
3100+
def derivedOrType(tp1: Type, tp2: Type, soft: Boolean = isSoft)(using Context): Type =
3101+
if ((tp1 eq this.tp1) && (tp2 eq this.tp2) && soft == isSoft) this
3102+
else OrType.make(tp1, tp2, soft)
30603103

30613104
override def computeHash(bs: Binders): Int =
30623105
doHash(bs, if isSoft then 0 else 1, tp1, tp2)

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
214214
// of AndType and OrType to account for associativity
215215
case AndType(tp1, tp2) =>
216216
toTextInfixType(tpnme.raw.AMP, tp1, tp2) { toText(tpnme.raw.AMP) }
217-
case OrType(tp1, tp2) =>
218-
toTextInfixType(tpnme.raw.BAR, tp1, tp2) { toText(tpnme.raw.BAR) }
217+
case tp as OrType(tp1, tp2) =>
218+
toTextInfixType(tpnme.raw.BAR, tp1, tp2) {
219+
if tp.isSoft && printDebug then toText(tpnme.ZOR) else toText(tpnme.raw.BAR)
220+
}
219221
case tp @ EtaExpansion(tycon)
220222
if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty =>
221223
// don't eta contract if the application would be printed specially

tests/pos/widen-union.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
object Test1:
3+
val x: Int | String = 1
4+
val y = x
5+
val z: Int | String = y
6+
7+
object Test2:
8+
type Sig = Int | String
9+
def consistent(x: Sig, y: Sig): Boolean = ???// x == y
10+
11+
def consistentLists(xs: List[Sig], ys: List[Sig]): Boolean =
12+
xs.corresponds(ys)(consistent) // OK
13+
|| xs.corresponds(ys)(consistent(_, _)) // error, found: Any, required: Int | String
14+
15+
object Test3:
16+
17+
def g[X](x: X | String): Int = ???
18+
def y: Boolean | String = ???
19+
g[Boolean](y)
20+
g(y)
21+
g[Boolean](identity(y))
22+
g(identity(y))
23+
24+

tests/run/i8726.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
case class A(a: Int)
2-
object C { def unapply(a: A): true | true = true }
2+
object C { def unapply(a: A): true = true }
33

44
@main
55
def Test = (A(1): A | A) match { case C() => "OK" }

0 commit comments

Comments
 (0)