Skip to content

Commit a9dd109

Browse files
committed
Fix-scala#4867 Keep Unions when explicit
Improves the inference of union types by preserving the union when there is a type ascription with an union. If the term has not been explicity ascribed with an union then the existing semantics of joining the orType is maintained. To determine whether an explicit union ascription exists, a lexical check on the untyped tree is performed.
1 parent 4bbddd7 commit a9dd109

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import util.SimpleIdentitySet
2323
import reporting.diagnostic.Message
2424
import ast.tpd._
2525
import ast.TreeTypeMap
26+
import ast.untpd
2627
import printing.Texts._
2728
import printing.Printer
2829
import Hashable._
@@ -38,6 +39,7 @@ import java.lang.ref.WeakReference
3839

3940
import scala.annotation.internal.sharable
4041
import scala.annotation.threadUnsafe
42+
import dotty.tools.dotc.ast.Trees.Untyped
4143

4244
import dotty.tools.dotc.transform.SymUtils._
4345

@@ -1099,7 +1101,7 @@ object Types {
10991101
* re-lubbing it while allowing type parameters to be constrained further.
11001102
* Any remaining union types are replaced by their joins.
11011103
*
1102-
* For instance, if `A` is an unconstrained type variable, then
1104+
* For instance, if `A` is an unconstrained type variable, then
11031105
*
11041106
* ArrayBuffer[Int] | ArrayBuffer[A]
11051107
*
@@ -1108,6 +1110,8 @@ object Types {
11081110
*
11091111
* Exception (if `-YexplicitNulls` is set): if this type is a nullable union (i.e. of the form `T | Null`),
11101112
* then the top-level union isn't widened. This is needed so that type inference can infer nullable types.
1113+
*
1114+
* Another exception is when there is an explicit type ascription, then the union isn't widened.
11111115
*/
11121116
def widenUnion(implicit ctx: Context): Type = widen match {
11131117
case tp @ OrNull(tp1): OrType =>
@@ -1122,7 +1126,14 @@ object Types {
11221126
def widenUnionWithoutNull(implicit ctx: Context): Type = widen match {
11231127
case tp @ OrType(lhs, rhs) =>
11241128
ctx.typeComparer.lub(lhs.widenUnionWithoutNull, rhs.widenUnionWithoutNull, canConstrain = true) match {
1125-
case union: OrType => union.join
1129+
case union: OrType =>
1130+
val keepUnion = ctx.tree match {
1131+
case DefDef(_, _, _, untpd.TypedSplice(_), _) => true
1132+
case ValDef(name, untpd.InfixOp(_, op, _), _) => op.symbol == ctx.definitions.orType
1133+
case _ => false
1134+
}
1135+
if (keepUnion) union else union.join
1136+
11261137
case res => res
11271138
}
11281139
case tp @ AndType(tp1, tp2) =>

0 commit comments

Comments
 (0)