diff --git a/src/dotty/tools/dotc/ast/Desugar.scala b/src/dotty/tools/dotc/ast/Desugar.scala index 027c3238d91c..23a85688068b 100644 --- a/src/dotty/tools/dotc/ast/Desugar.scala +++ b/src/dotty/tools/dotc/ast/Desugar.scala @@ -180,8 +180,9 @@ object desugar { constr1.mods, constr1.name, tparams, vparamss, constr1.tpt, constr1.rhs) // a reference to the class type, with all parameters given. - val classTypeRef: Tree = { - // Dotty deviation: Without type annotation infers Ident | AppliedTypeTree, which + val classTypeRef/*: Tree*/ = { + // -Xkeep-unions difference: classTypeRef needs type annotation, otherwise + // infers Ident | AppliedTypeTree, which // renders the :\ in companions below untypable. val tycon = Ident(cdef.name) withPos cdef.pos.startPos val tparams = impl.constr.tparams diff --git a/src/dotty/tools/dotc/config/ScalaSettings.scala b/src/dotty/tools/dotc/config/ScalaSettings.scala index aaf3b83858ff..c2cfda746a07 100644 --- a/src/dotty/tools/dotc/config/ScalaSettings.scala +++ b/src/dotty/tools/dotc/config/ScalaSettings.scala @@ -90,6 +90,7 @@ class ScalaSettings extends Settings.SettingGroup { val XoldPatmat = BooleanSetting("-Xoldpatmat", "Use the pre-2.10 pattern matcher. Otherwise, the 'virtualizing' pattern matcher is used in 2.10.") val XnoPatmatAnalysis = BooleanSetting("-Xno-patmat-analysis", "Don't perform exhaustivity/unreachability analysis. Also, ignore @switch annotation.") val XfullLubs = BooleanSetting("-Xfull-lubs", "Retains pre 2.10 behavior of less aggressive truncation of least upper bounds.") + val XkeepUnions = BooleanSetting("-Xkeep-unions", "Do not approximate union types by their common classes") /** -Y "Private" settings */ diff --git a/src/dotty/tools/dotc/core/Definitions.scala b/src/dotty/tools/dotc/core/Definitions.scala index ef16a970dfa4..cb1239f928d0 100644 --- a/src/dotty/tools/dotc/core/Definitions.scala +++ b/src/dotty/tools/dotc/core/Definitions.scala @@ -283,8 +283,9 @@ class Definitions { object FunctionType { def apply(args: List[Type], resultType: Type) = FunctionClass(args.length).typeRef.appliedTo(args ::: resultType :: Nil) - def unapply(ft: Type): Option[(List[Type], Type)] = { // Dotty deviation: Type annotation needed because inferred type - // is Some[(List[Type], Type)] | None, which is not a legal unapply type. + def unapply(ft: Type)/*: Option[(List[Type], Type)]*/ = { + // -Xkeep-unions difference: unapply needs result type because inferred type + // is Some[(List[Type], Type)] | None, which is not a legal unapply type. val tsym = ft.typeSymbol lazy val targs = ft.argInfos if ((FunctionClasses contains tsym) && diff --git a/src/dotty/tools/dotc/core/TypeApplications.scala b/src/dotty/tools/dotc/core/TypeApplications.scala index b4c30d902976..0abd28a716ab 100644 --- a/src/dotty/tools/dotc/core/TypeApplications.scala +++ b/src/dotty/tools/dotc/core/TypeApplications.scala @@ -170,17 +170,20 @@ class TypeApplications(val self: Type) extends AnyVal { /** The type arguments of this type's base type instance wrt.`base`. * Existential types in arguments are disallowed. */ - final def baseArgTypes(base: Symbol)(implicit ctx: Context): List[Type] = baseArgInfos(base) mapConserve noBounds + final def baseArgTypes(base: Symbol)(implicit ctx: Context): List[Type] = + baseArgInfos(base) mapConserve noBounds /** The type arguments of this type's base type instance wrt.`base`. * Existential types in arguments are approximanted by their lower bound. */ - final def baseArgTypesLo(base: Symbol)(implicit ctx: Context): List[Type] = baseArgInfos(base) mapConserve boundsToLo + final def baseArgTypesLo(base: Symbol)(implicit ctx: Context): List[Type] = + baseArgInfos(base) mapConserve boundsToLo /** The type arguments of this type's base type instance wrt.`base`. * Existential types in arguments are approximanted by their upper bound. */ - final def baseArgTypesHi(base: Symbol)(implicit ctx: Context): List[Type] = baseArgInfos(base) mapConserve boundsToHi + final def baseArgTypesHi(base: Symbol)(implicit ctx: Context): List[Type] = + baseArgInfos(base) mapConserve boundsToHi /** The first type argument of the base type instance wrt `base` of this type */ final def firstBaseArgInfo(base: Symbol)(implicit ctx: Context): Type = base.typeParams match { @@ -193,8 +196,11 @@ class TypeApplications(val self: Type) extends AnyVal { /** The base type including all type arguments of this type. * Existential types in arguments are returned as TypeBounds instances. */ - final def baseTypeWithArgs(base: Symbol)(implicit ctx: Context): Type = - self.baseTypeRef(base).appliedTo(baseArgInfos(base)) + final def baseTypeWithArgs(base: Symbol)(implicit ctx: Context): Type = self.dealias match { + case AndType(tp1, tp2) => tp1.baseTypeWithArgs(base) & tp2.baseTypeWithArgs(base) + case OrType(tp1, tp2) => tp1.baseTypeWithArgs(base) | tp2.baseTypeWithArgs(base) + case _ => self.baseTypeRef(base).appliedTo(baseArgInfos(base)) + } /** Translate a type of the form From[T] to To[T], keep other types as they are. * `from` and `to` must be static classes, both with one type parameter, and the same variance. @@ -205,7 +211,7 @@ class TypeApplications(val self: Type) extends AnyVal { else self /** If this is an encoding of a (partially) applied type, return its arguments, - * otherwise return Nil. + * otherwise return Nil. * Existential types in arguments are returned as TypeBounds instances. */ final def argInfos(implicit ctx: Context): List[Type] = { diff --git a/src/dotty/tools/dotc/core/TypeComparer.scala b/src/dotty/tools/dotc/core/TypeComparer.scala index da8263ac1a3f..348d22d4b356 100644 --- a/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/src/dotty/tools/dotc/core/TypeComparer.scala @@ -890,6 +890,16 @@ class TypeComparer(initctx: Context) extends DotClass { /** Try to distribute `&` inside type, detect and handle conflicts */ private def distributeAnd(tp1: Type, tp2: Type): Type = tp1 match { + case tp1: RefinedType => + tp2 match { + case tp2: RefinedType if tp1.refinedName == tp2.refinedName => + tp1.derivedRefinedType( + tp1.parent & tp2.parent, + tp1.refinedName, + tp1.refinedInfo & tp2.refinedInfo) + case _ => + NoType + } case tp1: TypeBounds => tp2 match { case tp2: TypeBounds => tp1 & tp2 @@ -935,18 +945,6 @@ class TypeComparer(initctx: Context) extends DotClass { case _ => rt1 & tp2 } - case tp1: RefinedType => - // opportunistically merge same-named refinements - // this does not change anything semantically (i.e. merging or not merging - // gives =:= types), but it keeps the type smaller. - tp2 match { - case tp2: RefinedType if tp1.refinedName == tp2.refinedName => - tp1.derivedRefinedType( - tp1.parent & tp2.parent, tp1.refinedName, - tp1.refinedInfo & tp2.refinedInfo) - case _ => - NoType - } case tp1: TypeVar if tp1.isInstantiated => tp1.underlying & tp2 case tp1: AnnotatedType => @@ -957,6 +955,16 @@ class TypeComparer(initctx: Context) extends DotClass { /** Try to distribute `|` inside type, detect and handle conflicts */ private def distributeOr(tp1: Type, tp2: Type): Type = tp1 match { + case tp1: RefinedType => + tp2 match { + case tp2: RefinedType if tp1.refinedName == tp2.refinedName => + tp1.derivedRefinedType( + tp1.parent | tp2.parent, + tp1.refinedName, + tp1.refinedInfo | tp2.refinedInfo) + case _ => + NoType + } case tp1: TypeBounds => tp2 match { case tp2: TypeBounds => tp1 | tp2 diff --git a/src/dotty/tools/dotc/core/TypeOps.scala b/src/dotty/tools/dotc/core/TypeOps.scala index ca69ab6156ff..93b732e04756 100644 --- a/src/dotty/tools/dotc/core/TypeOps.scala +++ b/src/dotty/tools/dotc/core/TypeOps.scala @@ -2,7 +2,7 @@ package dotty.tools.dotc package core import Contexts._, Types._, Symbols._, Names._, Flags._, Scopes._ -import SymDenotations._ +import SymDenotations._, Decorators._ import util.SimpleMap trait TypeOps { this: Context => @@ -74,6 +74,37 @@ trait TypeOps { this: Context => def apply(tp: Type) = simplify(tp, this) } + /** Approximate union type by intersection of its dominators. + * See Type#approximateUnion for an explanation. + */ + def approximateUnion(tp: Type): Type = { + /** a faster version of cs1 intersect cs2 */ + def intersect(cs1: List[ClassSymbol], cs2: List[ClassSymbol]): List[ClassSymbol] = { + val cs2AsSet = new util.HashSet[ClassSymbol](100) + cs2.foreach(cs2AsSet.addEntry) + cs1.filter(cs2AsSet.contains) + } + /** The minimal set of classes in `cs` which derive all other classes in `cs` */ + def dominators(cs: List[ClassSymbol], accu: List[ClassSymbol]): List[ClassSymbol] = (cs: @unchecked) match { + case c :: rest => + val accu1 = if (accu exists (_ derivesFrom c)) accu else c :: accu + if (cs == c.baseClasses) accu1 else dominators(rest, accu1) + } + if (ctx.settings.XkeepUnions.value) tp + else tp match { + case tp: OrType => + val commonBaseClasses = tp.mapReduceOr(_.baseClasses)(intersect) + val doms = dominators(commonBaseClasses, Nil) + doms.map(tp.baseTypeWithArgs).reduceLeft(AndType.apply) + case tp @ AndType(tp1, tp2) => + tp derived_& (approximateUnion(tp1), approximateUnion(tp2)) + case tp: RefinedType => + tp.derivedRefinedType(approximateUnion(tp.parent), tp.refinedName, tp.refinedInfo) + case _ => + tp + } + } + final def isVolatile(tp: Type): Boolean = { /** Pre-filter to avoid expensive DNF computation */ def needsChecking(tp: Type, isPart: Boolean): Boolean = tp match { diff --git a/src/dotty/tools/dotc/core/Types.scala b/src/dotty/tools/dotc/core/Types.scala index bb30d9a9c7a6..0792bd5bf0de 100644 --- a/src/dotty/tools/dotc/core/Types.scala +++ b/src/dotty/tools/dotc/core/Types.scala @@ -102,8 +102,19 @@ object Types { } /** Is this type an instance of a non-bottom subclass of the given class `cls`? */ - final def derivesFrom(cls: Symbol)(implicit ctx: Context): Boolean = - classSymbol.derivesFrom(cls) + final def derivesFrom(cls: Symbol)(implicit ctx: Context): Boolean = this match { + case tp: TypeRef => + val sym = tp.symbol + if (sym.isClass) sym.derivesFrom(cls) else tp.underlying.derivesFrom(cls) + case tp: TypeProxy => + tp.underlying.derivesFrom(cls) + case tp: AndType => + tp.tp1.derivesFrom(cls) || tp.tp2.derivesFrom(cls) + case tp: OrType => + tp.tp1.derivesFrom(cls) && tp.tp2.derivesFrom(cls) + case _ => + false + } /** A type T is a legal prefix in a type selection T#A if * T is stable or T contains no uninstantiated type variables. @@ -272,6 +283,7 @@ object Types { /** The base classes of this type as determined by ClassDenotation * in linearization order, with the class itself as first element. + * For AndTypes/OrTypes, the union/intersection of the operands' baseclasses. * Inherited by all type proxies. `Nil` for all other types. */ final def baseClasses(implicit ctx: Context): List[ClassSymbol] = track("baseClasses") { @@ -280,6 +292,10 @@ object Types { tp.underlying.baseClasses case tp: ClassInfo => tp.cls.baseClasses + case AndType(tp1, tp2) => + tp1.baseClasses union tp2.baseClasses + case OrType(tp1, tp2) => + tp1.baseClasses intersect tp2.baseClasses case _ => Nil } } @@ -865,6 +881,19 @@ object Types { */ def simplified(implicit ctx: Context) = ctx.simplify(this, null) + /** Approximations of union types: We replace a union type Tn | ... | Tn + * by the smallest intersection type of baseclass instances of T1,...,Tn. + * Example: Given + * + * trait C[+T] + * trait D + * class A extends C[A] with D + * class B extends C[B] with D with E + * + * we approximate `A | B` by `C[A | B] with D` + */ + def approximateUnion(implicit ctx: Context) = ctx.approximateUnion(this) + /** customized hash code of this type. * NotCached for uncached types. Cached types * compute hash and use it as the type's hashCode. @@ -1371,6 +1400,10 @@ object Types { if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this else AndType.make(tp1, tp2) + def derived_& (tp1: Type, tp2: Type)(implicit ctx: Context): Type = + if ((tp1 eq this.tp1) && (tp2 eq this.tp2)) this + else tp1 & tp2 + def derivedAndOrType(tp1: Type, tp2: Type)(implicit ctx: Context): Type = derivedAndType(tp1, tp2) @@ -1751,10 +1784,38 @@ object Types { case OrType(tp1, tp2) => isSingleton(tp1) & isSingleton(tp2) case _ => false } + def isFullyDefined(tp: Type): Boolean = tp match { + case tp: TypeVar => tp.isInstantiated && isFullyDefined(tp.instanceOpt) + case tp: TypeProxy => isFullyDefined(tp.underlying) + case tp: AndOrType => isFullyDefined(tp.tp1) && isFullyDefined(tp.tp2) + case _ => true + } + def isOrType(tp: Type): Boolean = tp.stripTypeVar.dealias match { + case tp: OrType => true + case AndType(tp1, tp2) => isOrType(tp1) | isOrType(tp2) + case RefinedType(parent, _) => isOrType(parent) + case WildcardType(bounds: TypeBounds) => isOrType(bounds.hi) + case _ => false + } + + // First, solve the constraint. var inst = ctx.typeComparer.approximation(origin, fromBelow) + + // Then, approximate by (1.) and (2.) and simplify as follows. + // 1. If instance is from below and is a singleton type, yet + // upper bound is not a singleton type, widen the instance. if (fromBelow && isSingleton(inst) && !isSingleton(upperBound)) inst = inst.widen - instantiateWith(inst.simplified) + + inst = inst.simplified + + // 2. If instance is from below and is a fully-defined union type, yet upper bound + // is not a union type, approximate the union type from above by an intersection + // of all common base types. + if (fromBelow && isOrType(inst) && isFullyDefined(inst) && !isOrType(upperBound)) + inst = inst.approximateUnion + + instantiateWith(inst) } /** Unwrap to instance (if instantiated) or origin (if not), until result @@ -1895,7 +1956,7 @@ object Types { def | (that: TypeBounds)(implicit ctx: Context): TypeBounds = { val v = this commonVariance that - if (v == 0 && (this.lo eq this.hi) && (that.lo eq that.hi)) + if (v != 0 && (this.lo eq this.hi) && (that.lo eq that.hi)) if (v > 0) derivedTypeAlias(this.hi | that.hi, v) else derivedTypeAlias(this.lo & that.lo, v) else derivedTypeBounds(this.lo & that.lo, this.hi | that.hi, v) diff --git a/src/dotty/tools/dotc/typer/Namer.scala b/src/dotty/tools/dotc/typer/Namer.scala index c24021936d57..3827450666cb 100644 --- a/src/dotty/tools/dotc/typer/Namer.scala +++ b/src/dotty/tools/dotc/typer/Namer.scala @@ -567,7 +567,7 @@ class Namer { typer: Typer => // println(s"final inherited for $sym: ${inherited.toString}") !!! // println(s"owner = ${sym.owner}, decls = ${sym.owner.info.decls.show}") val rhsCtx = ctx.fresh addMode Mode.InferringReturnType - def rhsType = typedAheadExpr(mdef.rhs, rhsProto)(rhsCtx).tpe.widen + def rhsType = typedAheadExpr(mdef.rhs, rhsProto)(rhsCtx).tpe.widen.approximateUnion def lhsType = fullyDefinedType(rhsType, "right-hand side", mdef.pos) inherited orElse lhsType } diff --git a/test/dotc/tests.scala b/test/dotc/tests.scala index 9914485ac3cb..722bb8ac7390 100644 --- a/test/dotc/tests.scala +++ b/test/dotc/tests.scala @@ -42,6 +42,7 @@ class tests extends CompilerTest { @Test def pos_structural() = compileFile(posDir, "structural") @Test def pos_i39 = compileFile(posDir, "i39") @Test def pos_overloadedAccess = compileFile(posDir, "overloadedAccess") + @Test def pos_approximateUnion = compileFile(posDir, "approximateUnion") @Test def neg_blockescapes() = compileFile(negDir, "blockescapesNeg", xerrors = 1) @Test def neg_typedapply() = compileFile(negDir, "typedapply", xerrors = 4) diff --git a/tests/pos/approximateUnion.scala b/tests/pos/approximateUnion.scala new file mode 100644 index 000000000000..c3fe0e1625f5 --- /dev/null +++ b/tests/pos/approximateUnion.scala @@ -0,0 +1,96 @@ +object approximateUnion { + + trait C[+T] + trait D + trait E + trait X[-T] + + { + trait A extends C[A] with D + trait B extends C[B] with D + + val coin = true + val x = if (coin) new A else new B + val y = Some(if (coin) new A else new B) + + val xtest: C[A | B] & D = x + val ytest: Some[C[A | B] & D] = y + } + + { + trait A extends C[X[A]] with D + trait B extends C[X[B]] with D with E + + val coin = true + val x = if (coin) new A else new B + val y = Some(if (coin) new A else new B) + + val xtest: C[X[A & B]] & D = x + val ytest: Some[C[X[A & B]] & D] = y + } +} + +object approximateUnion2 { + + trait C[T] + trait D + trait E + trait X[-T] + + { + trait A extends C[A] with D + trait B extends C[B] with D + + val coin = true + val x = if (coin) new A else new B + val y = Some(if (coin) new A else new B) + + val xtest: C[_ >: A & B <: A | B] & D = x + val ytest: Some[C[_ >: A & B <: A | B] & D] = y + } + + { + trait A extends C[X[A]] with D + trait B extends C[X[B]] with D with E + + val coin = true + val x = if (coin) new A else new B + val y = Some(if (coin) new A else new B) + + val xtest: C[_ >: X[A | B] <: X[A & B]] & D = x + val ytest: Some[C[_ >: X[A | B] <: X[A & B]]] = y + } +} + +object approximateUnion3 { + + trait C[-T] + trait D + trait E + trait X[-T] + + { + trait A extends C[A] with D + trait B extends C[B] with D + + val coin = true + val x = if (coin) new A else new B + val y = Some(if (coin) new A else new B) + + val xtest: C[A & B] & D = x + val ytest: Some[C[A & B] & D] = y + } + + { + trait A extends C[X[A]] with D + trait B extends C[X[B]] with D with E + + val coin = true + val x = if (coin) new A else new B + val y = Some(if (coin) new A else new B) + + val xtest: C[X[A | B]] & D = x + val ytest2: Some[C[X[A | B]] & D] = y + } +} +