Skip to content

Commit 03a3115

Browse files
committed
Synthesise Mirror.Sum for nested hierarchies
1 parent 5bb0e92 commit 03a3115

File tree

5 files changed

+83
-8
lines changed

5 files changed

+83
-8
lines changed

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,6 +1612,9 @@ object SymDenotations {
16121612

16131613
annotations.collect { case Annotation.Child(child) => child }.reverse
16141614
end children
1615+
1616+
def subclasses(using Context): List[Symbol] =
1617+
children.flatMap(c => if c.is(Sealed) then c.children else List(c)).sortBy(_.span.start)
16151618
}
16161619

16171620
/** The contents of a class definition during a period

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ object SymUtils:
9595
* - none of its children are anonymous classes
9696
* - all of its children are addressable through a path from the parent class
9797
* and also the location of the generated mirror.
98-
* - all of its children are generic products or singletons
98+
* - all of its children are generic products, singletons, or generic sums
9999
*/
100100
def whyNotGenericSum(declScope: Symbol)(using Context): String =
101101
if (!self.is(Sealed))
@@ -113,7 +113,12 @@ object SymUtils:
113113
if (child == self) "it has anonymous or inaccessible subclasses"
114114
else if (!isAccessible(child.owner)) i"its child $child is not accessible"
115115
else if (!child.isClass) ""
116-
else {
116+
else if (child.isGenericProduct) ""
117+
else if (child.is(Sealed)) {
118+
val s = child.whyNotGenericSum(declScope)
119+
if (s.isEmpty) s
120+
else i"its child $child is not a generic sum because $s"
121+
} else {
117122
val s = child.whyNotGenericProduct
118123
if (s.isEmpty) s
119124
else i"its child $child is not a generic product because $s"

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
520520
if (cls.is(Enum)) param.select(nme.ordinal).ensureApplied
521521
else {
522522
val cases =
523-
for ((child, idx) <- cls.children.zipWithIndex) yield {
523+
for ((child, idx) <- cls.subclasses.zipWithIndex) yield {
524524
val patType = if (child.isTerm) child.reachableTermRef else child.reachableRawTypeRef
525525
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
526526
CaseDef(pat, EmptyTree, Literal(Constant(idx)))

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,11 +275,10 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
275275
val useCompanion = cls.useCompanionAsMirror
276276

277277
if cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
278-
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
278+
val elemLabels = cls.subclasses.map(c => ConstantType(Constant(c.name.toString)))
279279

280280
def solve(sym: Symbol): Type = sym match
281-
case caseClass: ClassSymbol =>
282-
assert(caseClass.is(Case))
281+
case caseClass: ClassSymbol if caseClass.is(Case) =>
283282
if caseClass.is(Module) then
284283
caseClass.sourceModule.termRef
285284
else
@@ -313,11 +312,11 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
313312
val (monoType, elemsType) = mirroredType match
314313
case mirroredType: HKTypeLambda =>
315314
val elems = mirroredType.derivedLambdaType(
316-
resType = TypeOps.nestedPairs(cls.children.map(solve))
315+
resType = TypeOps.nestedPairs(cls.subclasses.map(solve))
317316
)
318317
(mkMirroredMonoType(mirroredType), elems)
319318
case _ =>
320-
val elems = TypeOps.nestedPairs(cls.children.map(solve))
319+
val elems = TypeOps.nestedPairs(cls.subclasses.map(solve))
321320
(mirroredType, elems)
322321

323322
val mirrorType =

tests/run/i11050.scala

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import scala.compiletime.*
2+
import scala.deriving.*
3+
4+
sealed trait TreeValue
5+
6+
sealed trait SubLevel extends TreeValue
7+
8+
case class Leaf1(value: String) extends TreeValue
9+
case class Leaf2(value: Int) extends SubLevel
10+
case class Leaf3(value: Char) extends TreeValue
11+
12+
object Test:
13+
val m = summon[Mirror.SumOf[TreeValue]]
14+
given Show[TreeValue] = Show.derived[TreeValue]
15+
16+
def main(args: Array[String]) =
17+
val leaf1 = Leaf1("1")
18+
val leaf2 = Leaf2(2)
19+
val leaf3 = Leaf3('3')
20+
21+
assertEq(List(leaf1, leaf2, leaf3).map(m.ordinal), List(0, 1, 2))
22+
assertShow[TreeValue](leaf1, "[0] Leaf1(value = \"1\")")
23+
assertShow[TreeValue](leaf2, "[1] Leaf2(value = 2)")
24+
assertShow[TreeValue](leaf3, "[2] Leaf3(value = '3')")
25+
end main
26+
27+
def assertEq[A](obt: A, exp: A) = assert(obt == exp, s"Expected $obt == $exp")
28+
def assertShow[A: Show](x: A, s: String) = assertEq(Show.show(x), s)
29+
end Test
30+
31+
trait Show[-T]:
32+
def show(x: T): String
33+
34+
object Show:
35+
given Show[Int] with { def show(x: Int) = s"$x" }
36+
given Show[Char] with { def show(x: Char) = s"'$x'" }
37+
given Show[String] with { def show(x: String) = s"$"$x$"" }
38+
39+
inline def show[T](x: T): String = summonInline[Show[T]].show(x)
40+
41+
transparent inline def derived[T](implicit ev: Mirror.Of[T]): Show[T] = new {
42+
def show(x: T): String = inline ev match {
43+
case m: Mirror.ProductOf[T] => showProduct(x.asInstanceOf[Product], m)
44+
case m: Mirror.SumOf[T] => showCases[m.MirroredElemTypes](0)(x, m.ordinal(x))
45+
}
46+
}
47+
48+
inline def showProduct[T](x: Product, m: Mirror.ProductOf[T]): String =
49+
constValue[m.MirroredLabel] + showElems[m.MirroredElemTypes, m.MirroredElemLabels](0, Nil)(x)
50+
51+
inline def showElems[Elems <: Tuple, Labels <: Tuple](n: Int, elems: List[String])(x: Product): String =
52+
inline (erasedValue[Labels], erasedValue[Elems]) match {
53+
case _: (label *: labels, elem *: elems) =>
54+
val value = show(x.productElement(n).asInstanceOf[elem])
55+
showElems[elems, labels](n + 1, s"${constValue[label]} = $value" :: elems)(x)
56+
case _: (EmptyTuple, EmptyTuple) =>
57+
if elems.isEmpty then "" else elems.mkString(s"(", ", ", ")")
58+
}
59+
60+
transparent inline def showCases[Alts <: Tuple](n: Int)(x: Any, ord: Int): String =
61+
inline erasedValue[Alts] match {
62+
case _: (alt *: alts) =>
63+
if (ord == n) summonFrom {
64+
case m: Mirror.Of[`alt`] => s"[$ord] " + derived[alt](using m).show(x.asInstanceOf[alt])
65+
} else showCases[alts](n + 1)(x, ord)
66+
case _: EmptyTuple => throw new MatchError(x)
67+
}
68+
end Show

0 commit comments

Comments
 (0)