diff --git a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala index de96f644a91c..fa41de5c1d8b 100644 --- a/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala +++ b/compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala @@ -258,7 +258,7 @@ trait ConstraintHandling { } // First, solve the constraint. - var inst = approximation(param, fromBelow) + var inst = approximation(param, fromBelow).simplified // Then, approximate by (1.) - (3.) and simplify as follows. // 1. If instance is from below and is a singleton type, yet @@ -266,13 +266,11 @@ trait ConstraintHandling { if (fromBelow && isSingleton(inst) && !isSingleton(upperBound)) inst = inst.widen - inst = inst.simplified - // 2. If instance is from below and is a fully-defined union type, yet upper bound // is not a union type, approximate the union type from above by an intersection // of all common base types. - if (fromBelow && isOrType(inst) && isFullyDefined(inst) && !isOrType(upperBound)) - inst = ctx.harmonizeUnion(inst) + if (fromBelow && isOrType(inst) && !isOrType(upperBound)) + inst = inst.widenUnion inst } diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 4abaf3bc788c..f46ea887a61f 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -357,7 +357,6 @@ class Definitions { enterCompleteClassSymbol( ScalaPackageClass, tpnme.Singleton, PureInterfaceCreationFlags | Final, List(AnyClass.typeRef), EmptyScope) - def SingletonType = SingletonClass.typeRef lazy val SeqType: TypeRef = ctx.requiredClassRef("scala.collection.Seq") def SeqClass(implicit ctx: Context) = SeqType.symbol.asClass diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index 54b96a2530b3..281d3cbca353 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -331,7 +331,7 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling { else thirdTry(tp1, tp2) case tp1 @ OrType(tp11, tp12) => def joinOK = tp2.dealias match { - case tp12: HKApply => + case _: HKApply => // If we apply the default algorithm for `A[X] | B[Y] <: C[Z]` where `C` is a // type parameter, we will instantiate `C` to `A` and then fail when comparing // with `B[Y]`. To do the right thing, we need to instantiate `C` to the @@ -1511,10 +1511,17 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) { override def compareHkApply2(tp1: Type, tp2: HKApply, tycon2: Type, args2: List[Type]): Boolean = { def addendum = "" - traceIndented(i"compareHkApply $tp1, $tp2$addendum") { + traceIndented(i"compareHkApply2 $tp1, $tp2$addendum") { super.compareHkApply2(tp1, tp2, tycon2, args2) } } + override def compareHkApply1(tp1: HKApply, tycon1: Type, args1: List[Type], tp2: Type): Boolean = { + def addendum = "" + traceIndented(i"compareHkApply1 $tp1, $tp2$addendum") { + super.compareHkApply1(tp1, tycon1, args1, tp2) + } + } + override def toString = "Subtype trace:" + { try b.toString finally b.clear() } } diff --git a/compiler/src/dotty/tools/dotc/core/TypeOps.scala b/compiler/src/dotty/tools/dotc/core/TypeOps.scala index 4a1c3d04469d..d4941026e9d4 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeOps.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeOps.scala @@ -273,37 +273,6 @@ trait TypeOps { this: Context => // TODO: Make standalone object. } } - /** Given a disjunction T1 | ... | Tn of types with potentially embedded - * type variables, constrain type variables further if this eliminates - * some of the branches of the disjunction. Do this also for disjunctions - * embedded in intersections, as parents in refinements, and in recursive types. - * - * For instance, if `A` is an unconstrained type variable, then - * - * ArrayBuffer[Int] | ArrayBuffer[A] - * - * is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]` - * instead of `ArrayBuffer[_ >: Int | A <: Int & A]` - */ - def harmonizeUnion(tp: Type): Type = tp match { - case tp: OrType => - joinIfScala2(ctx.typeComparer.lub(harmonizeUnion(tp.tp1), harmonizeUnion(tp.tp2), canConstrain = true)) - case tp @ AndType(tp1, tp2) => - tp derived_& (harmonizeUnion(tp1), harmonizeUnion(tp2)) - case tp: RefinedType => - tp.derivedRefinedType(harmonizeUnion(tp.parent), tp.refinedName, tp.refinedInfo) - case tp: RecType => - tp.rebind(harmonizeUnion(tp.parent)) - case _ => - tp - } - - /** Under -language:Scala2: Replace or-types with their joins */ - private def joinIfScala2(tp: Type) = tp match { - case tp: OrType if scala2Mode => tp.join - case _ => tp - } - /** Not currently needed: * def liftToRec(f: (Type, Type) => Type)(tp1: Type, tp2: Type)(implicit ctx: Context) = { diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 955a5a11c6f2..04a74b48291d 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -830,7 +830,7 @@ object Types { * def o: Outer * .widen = o.C */ - @tailrec final def widen(implicit ctx: Context): Type = widenSingleton match { + final def widen(implicit ctx: Context): Type = widenSingleton match { case tp: ExprType => tp.resultType.widen case tp => tp } @@ -838,7 +838,7 @@ object Types { /** Widen from singleton type to its underlying non-singleton * base type by applying one or more `underlying` dereferences. */ - @tailrec final def widenSingleton(implicit ctx: Context): Type = stripTypeVar match { + final def widenSingleton(implicit ctx: Context): Type = stripTypeVar match { case tp: SingletonType if !tp.isOverloaded => tp.underlying.widenSingleton case _ => this } @@ -846,7 +846,7 @@ object Types { /** Widen from TermRef to its underlying non-termref * base type, while also skipping Expr types. */ - @tailrec final def widenTermRefExpr(implicit ctx: Context): Type = stripTypeVar match { + final def widenTermRefExpr(implicit ctx: Context): Type = stripTypeVar match { case tp: TermRef if !tp.isOverloaded => tp.underlying.widenExpr.widenTermRefExpr case _ => this } @@ -860,7 +860,7 @@ object Types { } /** Widen type if it is unstable (i.e. an ExprType, or TermRef to unstable symbol */ - @tailrec final def widenIfUnstable(implicit ctx: Context): Type = stripTypeVar match { + final def widenIfUnstable(implicit ctx: Context): Type = stripTypeVar match { case tp: ExprType => tp.resultType.widenIfUnstable case tp: TermRef if !tp.symbol.isStable => tp.underlying.widenIfUnstable case _ => this @@ -872,6 +872,35 @@ object Types { case _ => this } + /** If this type contains embedded union types, replace them by their joins. + * "Embedded" means: inside intersectons or recursive types, or in prefixes of refined types. + * If an embedded union is found, we first try to simplify or eliminate it by + * re-lubbing it while allowing type parameters to be constrained further. + * Any remaining union types are replaced by their joins. + * + * For instance, if `A` is an unconstrained type variable, then + * + * ArrayBuffer[Int] | ArrayBuffer[A] + * + * is approximated by constraining `A` to be =:= to `Int` and returning `ArrayBuffer[Int]` + * instead of `ArrayBuffer[_ >: Int | A <: Int & A]` + */ + def widenUnion(implicit ctx: Context): Type = this match { + case OrType(tp1, tp2) => + ctx.typeComparer.lub(tp1.widenUnion, tp2.widenUnion, canConstrain = true) match { + case union: OrType => union.join + case res => res + } + case tp @ AndType(tp1, tp2) => + tp derived_& (tp1.widenUnion, tp2.widenUnion) + case tp: RefinedType => + tp.derivedRefinedType(tp.parent.widenUnion, tp.refinedName, tp.refinedInfo) + case tp: RecType => + tp.rebind(tp.parent.widenUnion) + case _ => + this + } + /** Eliminate anonymous classes */ final def deAnonymize(implicit ctx: Context): Type = this match { case tp:TypeRef if tp.symbol.isAnonymousClass => diff --git a/compiler/src/dotty/tools/dotc/typer/Namer.scala b/compiler/src/dotty/tools/dotc/typer/Namer.scala index 23de2a089b93..192ba04630f3 100644 --- a/compiler/src/dotty/tools/dotc/typer/Namer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Namer.scala @@ -1034,13 +1034,13 @@ class Namer { typer: Typer => // println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}") def isInline = sym.is(FinalOrInline, butNot = Method | Mutable) - // Widen rhs type and approximate `|' but keep ConstantTypes if + // Widen rhs type and eliminate `|' but keep ConstantTypes if // definition is inline (i.e. final in Scala2) and keep module singleton types // instead of widening to the underlying module class types. def widenRhs(tp: Type): Type = tp.widenTermRefExpr match { case ctp: ConstantType if isInline => ctp case ref: TypeRef if ref.symbol.is(ModuleClass) => tp - case _ => ctx.harmonizeUnion(tp.widen) + case _ => tp.widen.widenUnion } // Replace aliases to Unit by Unit itself. If we leave the alias in diff --git a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala index 28150eec5316..e0399ffec964 100644 --- a/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala +++ b/compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala @@ -404,7 +404,7 @@ object ProtoTypes { /** Create a new TypeVar that represents a dependent method parameter singleton */ def newDepTypeVar(tp: Type)(implicit ctx: Context): TypeVar = { val poly = PolyType(DepParamName.fresh().toTypeName :: Nil)( - pt => TypeBounds.upper(AndType(tp, defn.SingletonType)) :: Nil, + pt => TypeBounds.upper(AndType(tp, defn.SingletonClass.typeRef)) :: Nil, pt => defn.AnyType) constrained(poly, untpd.EmptyTree, alwaysAddTypeVars = true) ._2.head.tpe.asInstanceOf[TypeVar] diff --git a/compiler/test/dotc/tests.scala b/compiler/test/dotc/tests.scala index 0c902ee5d034..3c07df88a952 100644 --- a/compiler/test/dotc/tests.scala +++ b/compiler/test/dotc/tests.scala @@ -167,6 +167,7 @@ class tests extends CompilerTest { @Test def rewrites = compileFile(posScala2Dir, "rewrites", "-rewrite" :: scala2mode) @Test def pos_t8146a = compileFile(posSpecialDir, "t8146a")(allowDeepSubtypes) + @Test def pos_jon = compileFile(posSpecialDir, "jon")(allowDeepSubtypes) @Test def pos_t5545 = { // compile by hand in two batches, since junit lacks the infrastructure to diff --git a/tests/neg/union.scala b/tests/neg/union.scala new file mode 100644 index 000000000000..c594e83d74bc --- /dev/null +++ b/tests/neg/union.scala @@ -0,0 +1,28 @@ +object Test { + + class A + class B extends A + class C extends A + class D extends A + + val b = true + val x = if (b) new B else new C + val y: B | C = x // error +} + +object O { + class A + class B + def f[T](x: T, y: T): T = x + + val x: A = f(new A { }, new A) + + val y1: A | B = f(new A { }, new B) // error + val y2: A | B = f[A | B](new A { }, new B) // ok + + val z = if (???) new A{} else new B + + val z1: A | B = z // error + + val z2: A | B = if (???) new A else new B // ok +} diff --git a/tests/pos/jon.scala b/tests/pos-special/jon.scala similarity index 100% rename from tests/pos/jon.scala rename to tests/pos-special/jon.scala diff --git a/tests/pos/anonClassSubtyping.scala b/tests/pos/anonClassSubtyping.scala index b5591d826412..f89d69619ef2 100644 --- a/tests/pos/anonClassSubtyping.scala +++ b/tests/pos/anonClassSubtyping.scala @@ -5,5 +5,5 @@ object O { val x: A = f(new A { }, new A) - val y: A | B = f(new A { }, new B) + val z: A | B = if (???) new A{} else new A } diff --git a/tests/pos/constraining-lub.scala b/tests/pos/constraining-lub.scala index 80da2ec868b4..31a6e63ca3b6 100644 --- a/tests/pos/constraining-lub.scala +++ b/tests/pos/constraining-lub.scala @@ -17,7 +17,7 @@ object Test { val x: Inv[Int] = inv(true) - def inv2(cond: Boolean) = + def inv2(cond: Boolean): Inv[Int] | Inv2[Int] = if (cond) { if (cond) new Inv(1) diff --git a/tests/pos/intersection.scala b/tests/pos/intersection.scala index 48551920c318..d2e445dbafb0 100644 --- a/tests/pos/intersection.scala +++ b/tests/pos/intersection.scala @@ -9,7 +9,9 @@ object intersection { val z = if (???) x else y val a: A & B => Unit = z - val b: (A => Unit) | (B => Unit) = z + //val b: (A => Unit) | (B => Unit) = z // error under new or-type rules + + val c: (A => Unit) | (B => Unit) = if (???) x else y // ok type needsA = A => Nothing type needsB = B => Nothing diff --git a/tests/pos/union.scala b/tests/pos/union.scala deleted file mode 100644 index 8b20a8458a48..000000000000 --- a/tests/pos/union.scala +++ /dev/null @@ -1,11 +0,0 @@ -object Test { - - class A - class B extends A - class C extends A - class D extends A - - val b = true - val x = if (b) new B else new C - val y: B | C = x -}