diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index 154f6ad90c41..20c1c5c8e64f 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -1407,32 +1407,13 @@ object SymDenotations { baseData._2 def computeBaseData(implicit onBehalf: BaseData, ctx: Context): (List[ClassSymbol], BaseClassSet) = { - val seen = new BaseClassSetBuilder - def addBaseClasses(bcs: List[ClassSymbol], to: List[ClassSymbol]) - : List[ClassSymbol] = bcs match { - case bc :: bcs1 => - val bcs1added = addBaseClasses(bcs1, to) - if (seen contains bc) bcs1added - else { - seen.add(bc) - bc :: bcs1added - } - case nil => - to - } - def addParentBaseClasses(ps: List[TypeRef], to: List[ClassSymbol]): List[ClassSymbol] = ps match { - case p :: ps1 => - addParentBaseClasses(ps1, - addBaseClasses(p.symbol.asClass.baseClasses, to)) - case nil => - to - } def emptyParentsExpected = is(Package) || (symbol == defn.AnyClass) || ctx.erasedTypes && (symbol == defn.ObjectClass) if (classParents.isEmpty && !emptyParentsExpected) onBehalf.signalProvisional() - (classSymbol :: addParentBaseClasses(classParents, Nil), - seen.result) + val builder = new BaseDataBuilder + for (p <- classParents) builder.addAll(p.symbol.asClass.baseClasses) + (classSymbol :: builder.baseClasses, builder.baseClassSet) } final override def derivesFrom(base: Symbol)(implicit ctx: Context): Boolean = @@ -2096,7 +2077,14 @@ object SymDenotations { def contains(sym: Symbol): Boolean = contains(sym, classIds.length) } - private class BaseClassSetBuilder { + object BaseClassSet { + def apply(bcs: List[ClassSymbol]): BaseClassSet = + new BaseClassSet(bcs.toArray.map(_.id)) + } + + /** A class to combine base data from parent types */ + class BaseDataBuilder { + private var classes: List[ClassSymbol] = Nil private var classIds = new Array[Int](32) private var length = 0 @@ -2106,19 +2094,32 @@ object SymDenotations { classIds = classIds1 } - def contains(sym: Symbol): Boolean = - new BaseClassSet(classIds).contains(sym, length) - - def add(sym: Symbol): Unit = { + private def add(sym: Symbol): Unit = { if (length == classIds.length) resize(length * 2) classIds(length) = sym.id length += 1 } - def result = { + def addAll(bcs: List[ClassSymbol]): this.type = { + val len = length + bcs match { + case bc :: bcs1 => + addAll(bcs1) + if (!new BaseClassSet(classIds).contains(bc, len)) { + add(bc) + classes = bc :: classes + } + case nil => + } + this + } + + def baseClassSet = { if (length != classIds.length) resize(length) new BaseClassSet(classIds) } + + def baseClasses: List[ClassSymbol] = classes } @sharable private var indent = 0 // for completions printing diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index fc2502bb70b3..7bdec511655c 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1143,11 +1143,11 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling { case OrType(tp11, tp12) => tp11 & tp2 | tp12 & tp2 case _ => - val t1 = mergeIfSub(tp1, tp2) - if (t1.exists) t1 + val tp1a = dropIfSuper(tp1, tp2) + if (tp1a ne tp1) glb(tp1a, tp2) else { - val t2 = mergeIfSub(tp2, tp1) - if (t2.exists) t2 + val tp2a = dropIfSuper(tp2, tp1) + if (tp2a ne tp2) glb(tp1, tp2a) else tp1 match { case tp1: ConstantType => tp2 match { @@ -1204,6 +1204,22 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling { final def lub(tps: List[Type]): Type = ((defn.NothingType: Type) /: tps)(lub(_,_, canConstrain = false)) + private def recombineAndOr(tp: AndOrType, tp1: Type, tp2: Type) = + if (!tp1.exists) tp2 + else if (!tp2.exists) tp1 + else tp.derivedAndOrType(tp1, tp2) + + /** If some (&-operand of) this type is a supertype of `sub` replace it with `NoType`. + */ + private def dropIfSuper(tp: Type, sub: Type): Type = + if (isSubTypeWhenFrozen(sub, tp)) NoType + else tp match { + case tp @ AndType(tp1, tp2) => + recombineAndOr(tp, dropIfSuper(tp1, sub), dropIfSuper(tp2, sub)) + case _ => + tp + } + /** Merge `t1` into `tp2` if t1 is a subtype of some &-summand of tp2. */ private def mergeIfSub(tp1: Type, tp2: Type): Type = diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 6d66cacdd067..43a3da922c0a 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -384,19 +384,15 @@ object Types { /** The base classes of this type as determined by ClassDenotation * in linearization order, with the class itself as first element. - * For AndTypes/OrTypes, the union/intersection of the operands' baseclasses. - * Inherited by all type proxies. `Nil` for all other types. + * Inherited by all type proxies. Overridden for And and Or types. + * `Nil` for all other types. */ - final def baseClasses(implicit ctx: Context): List[ClassSymbol] = track("baseClasses") { + def baseClasses(implicit ctx: Context): List[ClassSymbol] = track("baseClasses") { this match { case tp: TypeProxy => tp.underlying.baseClasses case tp: ClassInfo => tp.cls.baseClasses - case AndType(tp1, tp2) => - tp1.baseClasses union tp2.baseClasses - case OrType(tp1, tp2) => - tp1.baseClasses intersect tp2.baseClasses case _ => Nil } } @@ -472,22 +468,24 @@ object Types { */ final def findMember(name: Name, pre: Type, excluded: FlagSet)(implicit ctx: Context): Denotation = { @tailrec def go(tp: Type): Denotation = tp match { - case tp: RefinedType => - if (name eq tp.refinedName) goRefined(tp) else go(tp.parent) - case tp: ThisType => - goThis(tp) - case tp: TypeRef => - tp.denot.findMember(name, pre, excluded) case tp: TermRef => go (tp.underlying match { case mt: MethodType if mt.paramInfos.isEmpty && (tp.symbol is Stable) => mt.resultType case tp1 => tp1 }) - case tp: TypeParamRef => - goParam(tp) + case tp: TypeRef => + tp.denot.findMember(name, pre, excluded) + case tp: ThisType => + goThis(tp) + case tp: RefinedType => + if (name eq tp.refinedName) goRefined(tp) else go(tp.parent) case tp: RecType => goRec(tp) + case tp: TypeParamRef => + goParam(tp) + case tp: SuperType => + goSuper(tp) case tp: HKApply => goApply(tp) case tp: TypeProxy => @@ -611,6 +609,12 @@ object Types { go(next) } } + def goSuper(tp: SuperType) = go(tp.underlying) match { + case d: JointRefDenotation => + typr.println(i"redirecting super.$name from $tp to ${d.symbol.showLocated}") + new UniqueRefDenotation(d.symbol, tp.memberInfo(d.symbol), d.validFor) + case d => d + } def goAnd(l: Type, r: Type) = { go(l) & (go(r), pre, safeIntersection = ctx.pendingMemberSearches.contains(name)) } @@ -2296,6 +2300,37 @@ object Types { def tp2: Type def isAnd: Boolean def derivedAndOrType(tp1: Type, tp2: Type)(implicit ctx: Context): Type // needed? + + private[this] var myBaseClassesPeriod: Period = Nowhere + private[this] var myBaseClasses: List[ClassSymbol] = _ + + /** Base classes of And are the merge of the operand base classes + * For OrTypes, it's the intersection. + */ + override final def baseClasses(implicit ctx: Context) = { + if (myBaseClassesPeriod != ctx.period) { + val bcs1 = tp1.baseClasses + val bcs1set = BaseClassSet(bcs1) + def recur(bcs2: List[ClassSymbol]): List[ClassSymbol] = bcs2 match { + case bc2 :: bcs2rest => + if (isAnd) + if (bcs1set contains bc2) + if (bc2.is(Trait)) recur(bcs2rest) + else bcs1 // common class, therefore rest is the same in both sequences + else bc2 :: recur(bcs2rest) + else + if (bcs1set contains bc2) + if (bc2.is(Trait)) bc2 :: recur(bcs2rest) + else bcs2 + else recur(bcs2rest) + case nil => + if (isAnd) bcs1 else bcs2 + } + myBaseClasses = recur(tp2.baseClasses) + myBaseClassesPeriod = ctx.period + } + myBaseClasses + } } abstract case class AndType(tp1: Type, tp2: Type) extends CachedGroundType with AndOrType { diff --git a/tests/neg/i2677.scala b/tests/neg/i2677.scala new file mode 100644 index 000000000000..72940bf02628 --- /dev/null +++ b/tests/neg/i2677.scala @@ -0,0 +1,6 @@ +trait A { def x = "foo" } +trait B { def x = 42 } +object Test { + val AB = new A with B { override def x = super.x } // error: wrong override + AB.x +} \ No newline at end of file diff --git a/tests/neg/overrides.scala b/tests/neg/overrides.scala index 149220bd560d..89e20e94302f 100644 --- a/tests/neg/overrides.scala +++ b/tests/neg/overrides.scala @@ -79,27 +79,6 @@ class X3 { override type T = A1 // error: overrides nothing } -package p3 { - -// Dotty change of rules: Toverrider#f does not -// override TCommon#f, hence the accidental override rule -// applies. -trait TCommon { - def f: String -} - -class C1 extends TCommon { - def f = "in C1" -} - -trait TOverrider { this: TCommon => - override def f = "in TOverrider" // The overridden self-type member... -} - -class C2 extends C1 with TOverrider // ... fails to override, here. // error: accidental override - -} - package p4 { abstract class C[T] { def head: T } diff --git a/tests/pos/Orderings.scala b/tests/pos/Orderings.scala new file mode 100644 index 000000000000..24c89cd7b147 --- /dev/null +++ b/tests/pos/Orderings.scala @@ -0,0 +1,20 @@ +object Orderings { + + // A type class: + trait Ord[T] { def less(x: T, y: T): Boolean } + + implicit val intOrd: Ord[Int] = new { + def less(x: Int, y: Int) = x < y + } + + implicit def listOrd[T](implicit ev: Ord[T]): Ord[List[T]] = new { + def less(xs: List[T], ys: List[T]): Boolean = + if ys.isEmpty then false + else if xs.isEmpty then true + else if xs.head == ys.head then less(xs.tail, ys.tail) + else ev.less(xs.head, ys.head) + } + + def isLess[T]: T => T => implicit Ord[T] => Boolean = + x => y => implicitly[Ord[T]].less(x, y) +} diff --git a/tests/pos/override-via-self.scala b/tests/pos/override-via-self.scala new file mode 100644 index 000000000000..0a8147b81f42 --- /dev/null +++ b/tests/pos/override-via-self.scala @@ -0,0 +1,18 @@ +// Question: Does TOverrider#f override TCommon#f? +// If not, the accidental override rule applies. +// Dotty used to say no, but with the change to baseClasses in AndTypes says +// yes. Not sure what the right answer is. But in any case we should +// keep the test to notice if there's a difference in behavior. +trait TCommon { + def f: String +} +class C1 extends TCommon { + def f = "in C1" +} + +trait TOverrider { this: TCommon => + override def f = "in TOverrider" // The overridden self-type member... +} + +class C2 extends C1 with TOverrider // ... failed to override, here. But now it is OK. + diff --git a/tests/run/supercalls-traits.check b/tests/run/supercalls-traits.check new file mode 100644 index 000000000000..f559999a8f5d --- /dev/null +++ b/tests/run/supercalls-traits.check @@ -0,0 +1,5 @@ +C1A1A2B1B2C2 +C1B3B4C3 +IT +AT +ER diff --git a/tests/run/supercalls-traits.scala b/tests/run/supercalls-traits.scala index 241419314e68..2a29f1e1a7ba 100644 --- a/tests/run/supercalls-traits.scala +++ b/tests/run/supercalls-traits.scala @@ -14,9 +14,42 @@ class Base[A](exp: => Option[A]) object Empty extends Base[Nothing](None) + +trait B1 extends C1 { override def f() = { super.f(); print("B1") }} +trait B2 extends B1 { override def f() = { super.f(); print("B2") }} +trait A1 extends C1 { override def f() = { super.f(); print("A1") }} +trait A2 extends A1 { override def f() = { super.f(); print("A2") }} +class C1 { def f() = print("C1") } +class C2 extends A2 with B2 { override def f() = { super.f(); print("C2") }} + + +trait B3 extends C1 { override def f() = { super.f(); print("B3") }} +trait B4 extends C1 { this: B3 => override def f() = { super.f(); print("B4") }} +class C3 extends C1 with B3 with B4 { override def f() = { super.f(); print("C3") }} + +trait DT { + def f(): Unit +} +trait IT extends DT { + def f() = { println("IT") } +} +abstract class MPT { +} +trait AT extends MPT with DT { + abstract override def f() = { super.f(); println("AT") } +} +class ER extends MPT with IT with AT { + override def f() = { super.f(); println("ER") } +} + object Test { def main(args: Array[String]): Unit = { assert(new C().foo == 3) + new C2().f() + println() + new C3().f() + println() + new ER().f() Empty } }