diff --git a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala index 0342e8291495..a969bce99268 100644 --- a/compiler/src/dotty/tools/dotc/core/SymDenotations.scala +++ b/compiler/src/dotty/tools/dotc/core/SymDenotations.scala @@ -1614,6 +1614,66 @@ object SymDenotations { annotations.collect { case Annotation.Child(child) => child }.reverse end children + + /** Recursively assemble all children of this symbol, Preserves order of insertion. + */ + final def sealedStrictDescendants(using Context): List[Symbol] = + + @tailrec + def findLvlN( + explore: mutable.ArrayDeque[Symbol], + seen: util.HashSet[Symbol], + acc: mutable.ListBuffer[Symbol] + ): List[Symbol] = + if explore.isEmpty then + acc.toList + else + val sym = explore.head + val explore1 = explore.dropInPlace(1) + val lvlN = sym.children + val notSeen = lvlN.filterConserve(!seen.contains(_)) + if notSeen.isEmpty then + findLvlN(explore1, seen, acc) + else + findLvlN(explore1 ++= notSeen, {seen ++= notSeen; seen}, acc ++= notSeen) + end findLvlN + + /** Scans through `explore` to see if there are recursive children. + * If a symbol in `explore` has children that are not contained in + * `lvl1`, fallback to `findLvlN`, or else return `lvl1`. + */ + @tailrec + def findLvl2( + lvl1: List[Symbol], explore: List[Symbol], seenOrNull: util.HashSet[Symbol] | Null + ): List[Symbol] = explore match + case sym :: explore1 => + val lvl2 = sym.children + if lvl2.isEmpty then // no children, scan rest of explore1 + findLvl2(lvl1, explore1, seenOrNull) + else // check if we have seen the children before + val seen = // initialise the seen set if not already + if seenOrNull != null then seenOrNull + else util.HashSet.from(lvl1) + val notSeen = lvl2.filterConserve(!seen.contains(_)) + if notSeen.isEmpty then // we found children, but we had already seen them, scan the rest of explore1 + findLvl2(lvl1, explore1, seen) + else // found unseen recursive children, we should fallback to the loop + findLvlN( + explore = mutable.ArrayDeque.from(explore1).appendAll(notSeen), + seen = {seen ++= notSeen; seen}, + acc = mutable.ListBuffer.from(lvl1).appendAll(notSeen) + ) + case nil => + lvl1 + end findLvl2 + + val lvl1 = children + findLvl2(lvl1, lvl1, seenOrNull = null) + end sealedStrictDescendants + + /** Same as `sealedStrictDescendants` but prepends this symbol as well. + */ + final def sealedDescendants(using Context): List[Symbol] = this.symbol :: sealedStrictDescendants } /** The contents of a class definition during a period diff --git a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala index cf8b3e01822a..b2559b5ccac4 100644 --- a/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala +++ b/compiler/src/dotty/tools/dotc/sbt/ExtractAPI.scala @@ -221,7 +221,7 @@ private class ExtractAPICollector(using Context) extends ThunkHolder { val modifiers = apiModifiers(sym) val anns = apiAnnotations(sym).toArray val topLevel = sym.isTopLevelClass - val childrenOfSealedClass = sym.children.sorted(classFirstSort).map(c => + val childrenOfSealedClass = sym.sealedDescendants.sorted(classFirstSort).map(c => if (c.isClass) apiType(c.typeRef) else diff --git a/compiler/src/dotty/tools/dotc/util/HashSet.scala b/compiler/src/dotty/tools/dotc/util/HashSet.scala index e7406f9ab094..e99754c7267b 100644 --- a/compiler/src/dotty/tools/dotc/util/HashSet.scala +++ b/compiler/src/dotty/tools/dotc/util/HashSet.scala @@ -7,6 +7,11 @@ object HashSet: */ inline val DenseLimit = 8 + def from[T](xs: IterableOnce[T]): HashSet[T] = + val set = new HashSet[T]() + set ++= xs + set + /** A hash set that allows some privileged protected access to its internals * @param initialCapacity Indicates the initial number of slots in the hash table. * The actual number of slots is always a power of 2, so the diff --git a/compiler/test/dotty/tools/dotc/core/SealedDescendantsTest.scala b/compiler/test/dotty/tools/dotc/core/SealedDescendantsTest.scala new file mode 100644 index 000000000000..7d90d0ed8870 --- /dev/null +++ b/compiler/test/dotty/tools/dotc/core/SealedDescendantsTest.scala @@ -0,0 +1,108 @@ +package dotty.tools.dotc.core + +import dotty.tools.dotc.core.Contexts.{Context, ctx} +import dotty.tools.dotc.core.Symbols.* + +import org.junit.Assert._ +import org.junit.Test + +import dotty.tools.DottyTest + +class SealedDescendantsTest extends DottyTest { + + @Test + def zincIssue979: Unit = + val source = + """ + sealed trait Z + sealed trait A extends Z + class B extends A + class C extends A + class D extends A + """ + + expectedDescendents(source, "Z", + "Z" :: + "A" :: + "B" :: + "C" :: + "D" :: Nil + ) + end zincIssue979 + + @Test + def enumOpt: Unit = + val source = + """ + enum Opt[+T] { + case Some(t: T) + case None + } + """ + + expectedDescendents(source, "Opt", + "Opt" :: + "Some" :: + "None.type" :: Nil + ) + end enumOpt + + @Test + def hierarchicalSharedChildren: Unit = + // Q is a child of both Z and A and should appear once + // X is a child of both A and Q and should appear once + val source = + """ + sealed trait Z + sealed trait A extends Z + sealed trait Q extends A with Z + trait X extends A with Q + case object Y extends Q + """ + + expectedDescendents(source, "Z", + "Z" :: + "A" :: + "Q" :: + "X" :: + "Y.type" :: Nil + ) + end hierarchicalSharedChildren + + @Test + def hierarchicalSharedChildrenB: Unit = + val source = + """ + sealed trait Z + case object A extends Z with D with E + sealed trait B extends Z + trait C extends B + sealed trait D extends B + sealed trait E extends D + """ + + expectedDescendents(source, "Z", + "Z" :: + "A.type" :: + "B" :: + "C" :: + "D" :: + "E" :: Nil + ) + end hierarchicalSharedChildrenB + + def expectedDescendents(source: String, root: String, expected: List[String]) = + exploreRoot(source, root) { rootCls => + val descendents = rootCls.sealedDescendants.map(sym => s"${sym.name}${if (sym.isTerm) ".type" else ""}") + assertEquals(expected.toString, descendents.toString) + } + + def exploreRoot(source: String, root: String)(op: Context ?=> ClassSymbol => Unit) = + val source0 = source.linesIterator.map(_.trim).mkString("\n|") + val source1 = s"""package testsealeddescendants + |$source0""".stripMargin + checkCompile("typer", source1) { (_, context) => + given Context = context + op(requiredClass(s"testsealeddescendants.$root")) + } +} diff --git a/sbt-test/source-dependencies/sealed-extends-sealed/A.scala b/sbt-test/source-dependencies/sealed-extends-sealed/A.scala new file mode 100644 index 000000000000..680c28adc94a --- /dev/null +++ b/sbt-test/source-dependencies/sealed-extends-sealed/A.scala @@ -0,0 +1,4 @@ +sealed trait Z +sealed trait A extends Z +class B extends A +class C extends A diff --git a/sbt-test/source-dependencies/sealed-extends-sealed/App.scala b/sbt-test/source-dependencies/sealed-extends-sealed/App.scala new file mode 100644 index 000000000000..e41c9149d06e --- /dev/null +++ b/sbt-test/source-dependencies/sealed-extends-sealed/App.scala @@ -0,0 +1,6 @@ +object App { + def foo(z: Z) = z match { + case _: B => + case _: C => + } +} diff --git a/sbt-test/source-dependencies/sealed-extends-sealed/changes/A.scala b/sbt-test/source-dependencies/sealed-extends-sealed/changes/A.scala new file mode 100644 index 000000000000..c8eb7651f412 --- /dev/null +++ b/sbt-test/source-dependencies/sealed-extends-sealed/changes/A.scala @@ -0,0 +1,5 @@ +sealed trait Z +sealed trait A extends Z +class B extends A +class C extends A +class D extends A diff --git a/sbt-test/source-dependencies/sealed-extends-sealed/project/DottyInjectedPlugin.scala b/sbt-test/source-dependencies/sealed-extends-sealed/project/DottyInjectedPlugin.scala new file mode 100644 index 000000000000..fc8fd26fc29e --- /dev/null +++ b/sbt-test/source-dependencies/sealed-extends-sealed/project/DottyInjectedPlugin.scala @@ -0,0 +1,12 @@ +import sbt._ +import Keys._ + +object DottyInjectedPlugin extends AutoPlugin { + override def requires = plugins.JvmPlugin + override def trigger = allRequirements + + override val projectSettings = Seq( + scalaVersion := sys.props("plugin.scalaVersion"), + scalacOptions ++= Seq("-source:3.0-migration", "-Xfatal-warnings") + ) +} diff --git a/sbt-test/source-dependencies/sealed-extends-sealed/test b/sbt-test/source-dependencies/sealed-extends-sealed/test new file mode 100644 index 000000000000..6cae243fdcd5 --- /dev/null +++ b/sbt-test/source-dependencies/sealed-extends-sealed/test @@ -0,0 +1,8 @@ +> compile + +# Introduce a new class C that also extends A +$ copy-file changes/A.scala A.scala + +# App.scala needs recompiling because the pattern match in it +# is no longer exhaustive, which emits a warning +-> compile