Skip to content

Commit 24ebd19

Browse files
committed
Port validateBaseTypes from Scala 2
In order to fix a soundness issue with parametrized classes / traits (#11018), we port the validation of base types from Scala 2 to Scala 3. The original implementation of the check can be found here: https://github.com/scala/scala/blob/9bb659e62a9239c01aec14c171f8598bb1a576fe/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala#L841-L882
1 parent 8effbc4 commit 24ebd19

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

compiler/src/dotty/tools/dotc/typer/RefChecks.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,53 @@ object RefChecks {
782782
report.error(problem(), clazz.srcPos)
783783
}
784784

785+
// Checks base types of the current clazz against each other
786+
//
787+
// In particular, it checks that there are no two base classes with
788+
// different type instantiations.
789+
//
790+
// ported from Scala 2:
791+
// https://github.com/scala/scala/blob/9bb659e62a9239c01aec14c171f8598bb1a576fe/src/compiler/scala/tools/nsc/typechecker/RefChecks.scala#L834-L883
792+
def validateBaseTypes(): Unit = {
793+
val tpe = clazz.thisType // in Scala 2 this was clazz.tpe
794+
val seenParents = mutable.HashSet[Type]()
795+
val baseClasses: List[ClassSymbol] = clazz.info.baseClasses
796+
797+
// tracks types that we have seen for a particular base class in baseClasses
798+
val seenTypes = mutable.Map.empty[Symbol, List[Type]]
799+
800+
// validate all base types of a class in reverse linear order.
801+
def register(tp: Type): Unit = {
802+
val baseClass = tp.typeSymbol
803+
if (baseClasses contains baseClass) {
804+
val alreadySeen = seenTypes.getOrElse(baseClass, Nil)
805+
if (alreadySeen.forall { tp1 => !(tp1 <:< tp) })
806+
seenTypes.update(baseClass, tp :: alreadySeen.filter { tp1 => !(tp <:< tp1) })
807+
}
808+
val remaining = tp.parents filterNot seenParents
809+
seenParents ++= remaining
810+
remaining foreach register
811+
}
812+
register(tpe)
813+
814+
seenTypes.foreach {
815+
case (cls, Nil) =>
816+
assert(false) // this case should not be reachable
817+
case (cls, _ :: Nil) =>
818+
() // Ok
819+
case (cls, tp1 :: tp2 :: _) =>
820+
val msg =
821+
em"""illegal inheritance;
822+
|
823+
| $clazz inherits different type instances of $cls:
824+
| $tp1 and $tp2"""
825+
826+
report.error(msg, clazz.srcPos)
827+
}
828+
}
829+
785830
checkParameterizedTraitsOK()
831+
validateBaseTypes()
786832
}
787833

788834
/** Check that `site` does not inherit conflicting generic instances of `baseCls`,

0 commit comments

Comments
 (0)