Skip to content

Commit 4611bdf

Browse files
oderskyDarkDimius
authored andcommitted
Appromiximate union types by intersections.
Appromiximate union types by intersections of their common base classes. Controlled by option -Xkeep-unions. If option is set, no approximation is done. Motivations for approximating: There are two. First, union types are departure from Scala 2. From time to time they lead to failure of inference. One example experiences in Dotty was in a foldLeft, where the accumulator type was inferred to be Tree before and was now a union of two tree specific kinds. Tree was the correct type, whereas the union type was too specific. These failures are not common (in the Dotty codebase there were 3, I believe), but they cause considerable difficulty to diagnose. So it seems safer to have a compatibility mode with Scala 2. The second motivation is that union types can become large and unwieldy. A function like TreeCopier has a result type consisting of ~ 40 alternatives, where the alternative type would be just Tree. Once we gain more experience with union types, we might consider flipping the option, and making union types the default. But for now it is safer this way, I believe.
1 parent 021c251 commit 4611bdf

File tree

5 files changed

+175
-2
lines changed

5 files changed

+175
-2
lines changed

src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,37 @@ trait TypeOps { this: Context =>
7777
def apply(tp: Type) = simplify(tp, this)
7878
}
7979

80+
/** Approximate union type by intersection of its dominators.
81+
* See Type#approximateUnion for an explanation.
82+
*/
83+
def approximateUnion(tp: Type): Type = {
84+
/** a faster version of cs1 intersect cs2 */
85+
def intersect(cs1: List[ClassSymbol], cs2: List[ClassSymbol]): List[ClassSymbol] = {
86+
val cs2AsSet = new util.HashSet[ClassSymbol](100)
87+
cs2.foreach(cs2AsSet.addEntry)
88+
cs1.filter(cs2AsSet.contains)
89+
}
90+
/** The minimal set of classes in `cs` which derive all other classes in `cs` */
91+
def dominators(cs: List[ClassSymbol], accu: List[ClassSymbol]): List[ClassSymbol] = (cs: @unchecked) match {
92+
case c :: rest =>
93+
val accu1 = if (accu exists (_ derivesFrom c)) accu else c :: accu
94+
if (cs == c.baseClasses) accu1 else dominators(rest, accu1)
95+
}
96+
if (ctx.featureEnabled(defn.LanguageModuleClass, nme.keepUnions)) tp
97+
else tp match {
98+
case tp: OrType =>
99+
val commonBaseClasses = tp.mapReduceOr(_.baseClasses)(intersect)
100+
val doms = dominators(commonBaseClasses, Nil)
101+
doms.map(tp.baseTypeWithArgs).reduceLeft(AndType.apply)
102+
case tp @ AndType(tp1, tp2) =>
103+
tp derived_& (approximateUnion(tp1), approximateUnion(tp2))
104+
case tp: RefinedType =>
105+
tp.derivedRefinedType(approximateUnion(tp.parent), tp.refinedName, tp.refinedInfo)
106+
case _ =>
107+
tp
108+
}
109+
}
110+
80111
/** A type is volatile if its DNF contains an alternative of the form
81112
* {P1, ..., Pn}, {N1, ..., Nk}, where the Pi are parent typerefs and the
82113
* Nj are refinement names, and one the 4 following conditions is met:

src/dotty/tools/dotc/core/Types.scala

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -845,6 +845,19 @@ object Types {
845845
*/
846846
def simplified(implicit ctx: Context) = ctx.simplify(this, null)
847847

848+
/** Approximations of union types: We replace a union type Tn | ... | Tn
849+
* by the smallest intersection type of baseclass instances of T1,...,Tn.
850+
* Example: Given
851+
*
852+
* trait C[+T]
853+
* trait D
854+
* class A extends C[A] with D
855+
* class B extends C[B] with D with E
856+
*
857+
* we approximate `A | B` by `C[A | B] with D`
858+
*/
859+
def approximateUnion(implicit ctx: Context) = ctx.approximateUnion(this)
860+
848861
/** customized hash code of this type.
849862
* NotCached for uncached types. Cached types
850863
* compute hash and use it as the type's hashCode.
@@ -1355,6 +1368,10 @@ object Types {
13551368
if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this
13561369
else AndType.make(tp1, tp2)
13571370

1371+
def derived_& (tp1: Type, tp2: Type)(implicit ctx: Context): Type =
1372+
if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this
1373+
else tp1 & tp2
1374+
13581375
def derivedAndOrType(tp1: Type, tp2: Type)(implicit ctx: Context): Type =
13591376
derivedAndType(tp1, tp2)
13601377

@@ -1735,10 +1752,38 @@ object Types {
17351752
case OrType(tp1, tp2) => isSingleton(tp1) & isSingleton(tp2)
17361753
case _ => false
17371754
}
1755+
def isFullyDefined(tp: Type): Boolean = tp match {
1756+
case tp: TypeVar => tp.isInstantiated && isFullyDefined(tp.instanceOpt)
1757+
case tp: TypeProxy => isFullyDefined(tp.underlying)
1758+
case tp: AndOrType => isFullyDefined(tp.tp1) && isFullyDefined(tp.tp2)
1759+
case _ => true
1760+
}
1761+
def isOrType(tp: Type): Boolean = tp.stripTypeVar.dealias match {
1762+
case tp: OrType => true
1763+
case AndType(tp1, tp2) => isOrType(tp1) | isOrType(tp2)
1764+
case RefinedType(parent, _) => isOrType(parent)
1765+
case WildcardType(bounds: TypeBounds) => isOrType(bounds.hi)
1766+
case _ => false
1767+
}
1768+
1769+
// First, solve the constraint.
17381770
var inst = ctx.typeComparer.approximation(origin, fromBelow)
1771+
1772+
// Then, approximate by (1.) and (2.) and simplify as follows.
1773+
// 1. If instance is from below and is a singleton type, yet
1774+
// upper bound is not a singleton type, widen the instance.
17391775
if (fromBelow && isSingleton(inst) && !isSingleton(upperBound))
17401776
inst = inst.widen
1741-
instantiateWith(inst.simplified)
1777+
1778+
inst = inst.simplified
1779+
1780+
// 2. If instance is from below and is a fully-defined union type, yet upper bound
1781+
// is not a union type, approximate the union type from above by an intersection
1782+
// of all common base types.
1783+
if (fromBelow && isOrType(inst) && isFullyDefined(inst) && !isOrType(upperBound))
1784+
inst = inst.approximateUnion
1785+
1786+
instantiateWith(inst)
17421787
}
17431788

17441789
/** Unwrap to instance (if instantiated) or origin (if not), until result

src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ class Namer { typer: Typer =>
576576
// println(s"final inherited for $sym: ${inherited.toString}") !!!
577577
// println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}")
578578
val rhsCtx = ctx.fresh addMode Mode.InferringReturnType
579-
def rhsType = typedAheadExpr(mdef.rhs, rhsProto)(rhsCtx).tpe.widen
579+
def rhsType = typedAheadExpr(mdef.rhs, rhsProto)(rhsCtx).tpe.widen.approximateUnion
580580
def lhsType = fullyDefinedType(rhsType, "right-hand side", mdef.pos)
581581
inherited orElse lhsType
582582
}

test/dotc/tests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class tests extends CompilerTest {
4343
@Test def pos_structural() = compileFile(posDir, "structural", twice)
4444
@Test def pos_i39 = compileFile(posDir, "i39", twice)
4545
@Test def pos_overloadedAccess = compileFile(posDir, "overloadedAccess", twice)
46+
@Test def pos_approximateUnion = compileFile(posDir, "approximateUnion", twice)
4647
*/
4748
@Test def pos_all = compileFiles(posDir, twice)
4849

tests/pos/approximateUnion.scala

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
object approximateUnion {
2+
3+
trait C[+T]
4+
trait D
5+
trait E
6+
trait X[-T]
7+
8+
{
9+
trait A extends C[A] with D
10+
trait B extends C[B] with D
11+
12+
val coin = true
13+
val x = if (coin) new A else new B
14+
val y = Some(if (coin) new A else new B)
15+
16+
val xtest: C[A | B] & D = x
17+
val ytest: Some[C[A | B] & D] = y
18+
}
19+
20+
{
21+
trait A extends C[X[A]] with D
22+
trait B extends C[X[B]] with D with E
23+
24+
val coin = true
25+
val x = if (coin) new A else new B
26+
val y = Some(if (coin) new A else new B)
27+
28+
val xtest: C[X[A & B]] & D = x
29+
val ytest: Some[C[X[A & B]] & D] = y
30+
}
31+
}
32+
33+
object approximateUnion2 {
34+
35+
trait C[T]
36+
trait D
37+
trait E
38+
trait X[-T]
39+
40+
{
41+
trait A extends C[A] with D
42+
trait B extends C[B] with D
43+
44+
val coin = true
45+
val x = if (coin) new A else new B
46+
val y = Some(if (coin) new A else new B)
47+
48+
val xtest: C[_ >: A & B <: A | B] & D = x
49+
val ytest: Some[C[_ >: A & B <: A | B] & D] = y
50+
}
51+
52+
{
53+
trait A extends C[X[A]] with D
54+
trait B extends C[X[B]] with D with E
55+
56+
val coin = true
57+
val x = if (coin) new A else new B
58+
val y = Some(if (coin) new A else new B)
59+
60+
val xtest: C[_ >: X[A | B] <: X[A & B]] & D = x
61+
val ytest: Some[C[_ >: X[A | B] <: X[A & B]]] = y
62+
}
63+
}
64+
65+
object approximateUnion3 {
66+
67+
trait C[-T]
68+
trait D
69+
trait E
70+
trait X[-T]
71+
72+
{
73+
trait A extends C[A] with D
74+
trait B extends C[B] with D
75+
76+
val coin = true
77+
val x = if (coin) new A else new B
78+
val y = Some(if (coin) new A else new B)
79+
80+
val xtest: C[A & B] & D = x
81+
val ytest: Some[C[A & B] & D] = y
82+
}
83+
84+
{
85+
trait A extends C[X[A]] with D
86+
trait B extends C[X[B]] with D with E
87+
88+
val coin = true
89+
val x = if (coin) new A else new B
90+
val y = Some(if (coin) new A else new B)
91+
92+
val xtest: C[X[A | B]] & D = x
93+
val ytest2: Some[C[X[A | B]] & D] = y
94+
}
95+
}
96+

0 commit comments

Comments
 (0)