Skip to content

Commit 9567fc4

Browse files
authored
Merge pull request #2330 from dotty-staging/change-minor-types
Simplify handling of union types
2 parents b19d1fb + cdc4eec commit 9567fc4

14 files changed

+82
-60
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -258,21 +258,19 @@ trait ConstraintHandling {
258258
}
259259

260260
// First, solve the constraint.
261-
var inst = approximation(param, fromBelow)
261+
var inst = approximation(param, fromBelow).simplified
262262

263263
// Then, approximate by (1.) - (3.) and simplify as follows.
264264
// 1. If instance is from below and is a singleton type, yet
265265
// upper bound is not a singleton type, widen the instance.
266266
if (fromBelow && isSingleton(inst) && !isSingleton(upperBound))
267267
inst = inst.widen
268268

269-
inst = inst.simplified
270-
271269
// 2. If instance is from below and is a fully-defined union type, yet upper bound
272270
// is not a union type, approximate the union type from above by an intersection
273271
// of all common base types.
274-
if (fromBelow && isOrType(inst) && isFullyDefined(inst) && !isOrType(upperBound))
275-
inst = ctx.harmonizeUnion(inst)
272+
if (fromBelow && isOrType(inst) && !isOrType(upperBound))
273+
inst = inst.widenUnion
276274

277275
inst
278276
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,6 @@ class Definitions {
357357
enterCompleteClassSymbol(
358358
ScalaPackageClass, tpnme.Singleton, PureInterfaceCreationFlags | Final,
359359
List(AnyClass.typeRef), EmptyScope)
360-
def SingletonType = SingletonClass.typeRef
361360

362361
lazy val SeqType: TypeRef = ctx.requiredClassRef("scala.collection.Seq")
363362
def SeqClass(implicit ctx: Context) = SeqType.symbol.asClass

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
331331
else thirdTry(tp1, tp2)
332332
case tp1 @ OrType(tp11, tp12) =>
333333
def joinOK = tp2.dealias match {
334-
case tp12: HKApply =>
334+
case _: HKApply =>
335335
// If we apply the default algorithm for `A[X] | B[Y] <: C[Z]` where `C` is a
336336
// type parameter, we will instantiate `C` to `A` and then fail when comparing
337337
// with `B[Y]`. To do the right thing, we need to instantiate `C` to the
@@ -1511,10 +1511,17 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
15111511

15121512
override def compareHkApply2(tp1: Type, tp2: HKApply, tycon2: Type, args2: List[Type]): Boolean = {
15131513
def addendum = ""
1514-
traceIndented(i"compareHkApply $tp1, $tp2$addendum") {
1514+
traceIndented(i"compareHkApply2 $tp1, $tp2$addendum") {
15151515
super.compareHkApply2(tp1, tp2, tycon2, args2)
15161516
}
15171517
}
15181518

1519+
override def compareHkApply1(tp1: HKApply, tycon1: Type, args1: List[Type], tp2: Type): Boolean = {
1520+
def addendum = ""
1521+
traceIndented(i"compareHkApply1 $tp1, $tp2$addendum") {
1522+
super.compareHkApply1(tp1, tycon1, args1, tp2)
1523+
}
1524+
}
1525+
15191526
override def toString = "Subtype trace:" + { try b.toString finally b.clear() }
15201527
}

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -273,37 +273,6 @@ trait TypeOps { this: Context => // TODO: Make standalone object.
273273
}
274274
}
275275

276-
/** Given a disjunction T1 | ... | Tn of types with potentially embedded
277-
* type variables, constrain type variables further if this eliminates
278-
* some of the branches of the disjunction. Do this also for disjunctions
279-
* embedded in intersections, as parents in refinements, and in recursive types.
280-
*
281-
* For instance, if `A` is an unconstrained type variable, then
282-
*
283-
* ArrayBuffer[Int] | ArrayBuffer[A]
284-
*
285-
* is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]`
286-
* instead of `ArrayBuffer[_ >: Int | A <: Int & A]`
287-
*/
288-
def harmonizeUnion(tp: Type): Type = tp match {
289-
case tp: OrType =>
290-
joinIfScala2(ctx.typeComparer.lub(harmonizeUnion(tp.tp1), harmonizeUnion(tp.tp2), canConstrain = true))
291-
case tp @ AndType(tp1, tp2) =>
292-
tp derived_& (harmonizeUnion(tp1), harmonizeUnion(tp2))
293-
case tp: RefinedType =>
294-
tp.derivedRefinedType(harmonizeUnion(tp.parent), tp.refinedName, tp.refinedInfo)
295-
case tp: RecType =>
296-
tp.rebind(harmonizeUnion(tp.parent))
297-
case _ =>
298-
tp
299-
}
300-
301-
/** Under -language:Scala2: Replace or-types with their joins */
302-
private def joinIfScala2(tp: Type) = tp match {
303-
case tp: OrType if scala2Mode => tp.join
304-
case _ => tp
305-
}
306-
307276
/** Not currently needed:
308277
*
309278
def liftToRec(f: (Type, Type) => Type)(tp1: Type, tp2: Type)(implicit ctx: Context) = {

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -830,23 +830,23 @@ object Types {
830830
* def o: Outer
831831
* <o.x.type>.widen = o.C
832832
*/
833-
@tailrec final def widen(implicit ctx: Context): Type = widenSingleton match {
833+
final def widen(implicit ctx: Context): Type = widenSingleton match {
834834
case tp: ExprType => tp.resultType.widen
835835
case tp => tp
836836
}
837837

838838
/** Widen from singleton type to its underlying non-singleton
839839
* base type by applying one or more `underlying` dereferences.
840840
*/
841-
@tailrec final def widenSingleton(implicit ctx: Context): Type = stripTypeVar match {
841+
final def widenSingleton(implicit ctx: Context): Type = stripTypeVar match {
842842
case tp: SingletonType if !tp.isOverloaded => tp.underlying.widenSingleton
843843
case _ => this
844844
}
845845

846846
/** Widen from TermRef to its underlying non-termref
847847
* base type, while also skipping Expr types.
848848
*/
849-
@tailrec final def widenTermRefExpr(implicit ctx: Context): Type = stripTypeVar match {
849+
final def widenTermRefExpr(implicit ctx: Context): Type = stripTypeVar match {
850850
case tp: TermRef if !tp.isOverloaded => tp.underlying.widenExpr.widenTermRefExpr
851851
case _ => this
852852
}
@@ -860,7 +860,7 @@ object Types {
860860
}
861861

862862
/** Widen type if it is unstable (i.e. an ExprType, or TermRef to unstable symbol */
863-
@tailrec final def widenIfUnstable(implicit ctx: Context): Type = stripTypeVar match {
863+
final def widenIfUnstable(implicit ctx: Context): Type = stripTypeVar match {
864864
case tp: ExprType => tp.resultType.widenIfUnstable
865865
case tp: TermRef if !tp.symbol.isStable => tp.underlying.widenIfUnstable
866866
case _ => this
@@ -872,6 +872,35 @@ object Types {
872872
case _ => this
873873
}
874874

875+
/** If this type contains embedded union types, replace them by their joins.
876+
* "Embedded" means: inside intersectons or recursive types, or in prefixes of refined types.
877+
* If an embedded union is found, we first try to simplify or eliminate it by
878+
* re-lubbing it while allowing type parameters to be constrained further.
879+
* Any remaining union types are replaced by their joins.
880+
*
881+
* For instance, if `A` is an unconstrained type variable, then
882+
*
883+
* ArrayBuffer[Int] | ArrayBuffer[A]
884+
*
885+
* is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]`
886+
* instead of `ArrayBuffer[_ >: Int | A <: Int & A]`
887+
*/
888+
def widenUnion(implicit ctx: Context): Type = this match {
889+
case OrType(tp1, tp2) =>
890+
ctx.typeComparer.lub(tp1.widenUnion, tp2.widenUnion, canConstrain = true) match {
891+
case union: OrType => union.join
892+
case res => res
893+
}
894+
case tp @ AndType(tp1, tp2) =>
895+
tp derived_& (tp1.widenUnion, tp2.widenUnion)
896+
case tp: RefinedType =>
897+
tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo)
898+
case tp: RecType =>
899+
tp.rebind(tp.parent.widenUnion)
900+
case _ =>
901+
this
902+
}
903+
875904
/** Eliminate anonymous classes */
876905
final def deAnonymize(implicit ctx: Context): Type = this match {
877906
case tp:TypeRef if tp.symbol.isAnonymousClass =>

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,13 +1034,13 @@ class Namer { typer: Typer =>
10341034
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
10351035
def isInline = sym.is(FinalOrInline, butNot = Method | Mutable)
10361036

1037-
// Widen rhs type and approximate `|' but keep ConstantTypes if
1037+
// Widen rhs type and eliminate `|' but keep ConstantTypes if
10381038
// definition is inline (i.e. final in Scala2) and keep module singleton types
10391039
// instead of widening to the underlying module class types.
10401040
def widenRhs(tp: Type): Type = tp.widenTermRefExpr match {
10411041
case ctp: ConstantType if isInline => ctp
10421042
case ref: TypeRef if ref.symbol.is(ModuleClass) => tp
1043-
case _ => ctx.harmonizeUnion(tp.widen)
1043+
case _ => tp.widen.widenUnion
10441044
}
10451045

10461046
// Replace aliases to Unit by Unit itself. If we leave the alias in

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ object ProtoTypes {
404404
/** Create a new TypeVar that represents a dependent method parameter singleton */
405405
def newDepTypeVar(tp: Type)(implicit ctx: Context): TypeVar = {
406406
val poly = PolyType(DepParamName.fresh().toTypeName :: Nil)(
407-
pt => TypeBounds.upper(AndType(tp, defn.SingletonType)) :: Nil,
407+
pt => TypeBounds.upper(AndType(tp, defn.SingletonClass.typeRef)) :: Nil,
408408
pt => defn.AnyType)
409409
constrained(poly, untpd.EmptyTree, alwaysAddTypeVars = true)
410410
._2.head.tpe.asInstanceOf[TypeVar]

compiler/test/dotc/tests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ class tests extends CompilerTest {
167167
@Test def rewrites = compileFile(posScala2Dir, "rewrites", "-rewrite" :: scala2mode)
168168

169169
@Test def pos_t8146a = compileFile(posSpecialDir, "t8146a")(allowDeepSubtypes)
170+
@Test def pos_jon = compileFile(posSpecialDir, "jon")(allowDeepSubtypes)
170171

171172
@Test def pos_t5545 = {
172173
// compile by hand in two batches, since junit lacks the infrastructure to

tests/neg/union.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
object Test {
2+
3+
class A
4+
class B extends A
5+
class C extends A
6+
class D extends A
7+
8+
val b = true
9+
val x = if (b) new B else new C
10+
val y: B | C = x // error
11+
}
12+
13+
object O {
14+
class A
15+
class B
16+
def f[T](x: T, y: T): T = x
17+
18+
val x: A = f(new A { }, new A)
19+
20+
val y1: A | B = f(new A { }, new B) // error
21+
val y2: A | B = f[A | B](new A { }, new B) // ok
22+
23+
val z = if (???) new A{} else new B
24+
25+
val z1: A | B = z // error
26+
27+
val z2: A | B = if (???) new A else new B // ok
28+
}
File renamed without changes.

tests/pos/anonClassSubtyping.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ object O {
55

66
val x: A = f(new A { }, new A)
77

8-
val y: A | B = f(new A { }, new B)
8+
val z: A | B = if (???) new A{} else new A
99
}

tests/pos/constraining-lub.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ object Test {
1717

1818
val x: Inv[Int] = inv(true)
1919

20-
def inv2(cond: Boolean) =
20+
def inv2(cond: Boolean): Inv[Int] | Inv2[Int] =
2121
if (cond) {
2222
if (cond)
2323
new Inv(1)

tests/pos/intersection.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ object intersection {
99
val z = if (???) x else y
1010

1111
val a: A & B => Unit = z
12-
val b: (A => Unit) | (B => Unit) = z
12+
//val b: (A => Unit) | (B => Unit) = z // error under new or-type rules
13+
14+
val c: (A => Unit) | (B => Unit) = if (???) x else y // ok
1315

1416
type needsA = A => Nothing
1517
type needsB = B => Nothing

tests/pos/union.scala

Lines changed: 0 additions & 11 deletions
This file was deleted.

0 commit comments

Comments
 (0)