Skip to content

Commit a220938

Browse files
committed
Synthesise Mirror.Sum for nested hierarchies
1 parent 8f3fdf5 commit a220938

File tree

3 files changed

+146
-13
lines changed

3 files changed

+146
-13
lines changed

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ object SymUtils:
9595
* - none of its children are anonymous classes
9696
* - all of its children are addressable through a path from the parent class
9797
* and also the location of the generated mirror.
98-
* - all of its children are generic products or singletons
98+
* - all of its children are generic products, singletons, or generic sums themselves.
9999
*/
100100
def whyNotGenericSum(declScope: Symbol)(using Context): String =
101101
if (!self.is(Sealed))
@@ -116,7 +116,11 @@ object SymUtils:
116116
else {
117117
val s = child.whyNotGenericProduct
118118
if (s.isEmpty) s
119-
else i"its child $child is not a generic product because $s"
119+
else if (child.is(Sealed)) {
120+
val s = child.whyNotGenericSum(child.linkedClass)
121+
if (s.isEmpty) s
122+
else i"its child $child is not a generic sum because $s"
123+
} else i"its child $child is not a generic product because $s"
120124
}
121125
}
122126
if (children.isEmpty) "it does not have subclasses"

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
288288
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
289289

290290
def solve(sym: Symbol): Type = sym match
291-
case caseClass: ClassSymbol =>
292-
assert(caseClass.is(Case))
293-
if caseClass.is(Module) then
294-
caseClass.sourceModule.termRef
291+
case childClass: ClassSymbol =>
292+
assert(childClass.isOneOf(Case | Sealed))
293+
if childClass.is(Module) then
294+
childClass.sourceModule.termRef
295295
else
296-
caseClass.primaryConstructor.info match
296+
childClass.primaryConstructor.info match
297297
case info: PolyType =>
298298
// Compute the the full child type by solving the subtype constraint
299299
// `C[X1, ..., Xn] <: P`, where
@@ -310,13 +310,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
310310
case tp => tp
311311
resType <:< target
312312
val tparams = poly.paramRefs
313-
val variances = caseClass.typeParams.map(_.paramVarianceSign)
313+
val variances = childClass.typeParams.map(_.paramVarianceSign)
314314
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
315315
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
316316
resType.substParams(poly, instanceTypes)
317-
instantiate(using ctx.fresh.setExploreTyperState().setOwner(caseClass))
317+
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
318318
case _ =>
319-
caseClass.typeRef
319+
childClass.typeRef
320320
case child => child.termRef
321321
end solve
322322

@@ -331,9 +331,9 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
331331
(mirroredType, elems)
332332

333333
val mirrorType =
334-
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
335-
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
336-
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
334+
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
335+
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
336+
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
337337
val mirrorRef =
338338
if useCompanion then companionPath(mirroredType, span)
339339
else anonymousMirror(monoType, ExtendsSumMirror, span)

tests/run/i11050.scala

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import scala.compiletime.*
2+
import scala.deriving.*
3+
4+
object OriginalReport:
5+
sealed trait TreeValue
6+
sealed trait SubLevel extends TreeValue
7+
case class Leaf1(value: String) extends TreeValue
8+
case class Leaf2(value: Int) extends SubLevel
9+
10+
// Variants from the initial failure in akka.event.LogEvent
11+
object FromAkkaCB:
12+
sealed trait A
13+
sealed trait B extends A
14+
sealed trait C extends A
15+
case class D() extends B, C
16+
case class E() extends C, B
17+
18+
object FromAkkaCB2:
19+
sealed trait A
20+
sealed trait N extends A
21+
case class B() extends A
22+
case class C() extends A, N
23+
24+
object FromAkkaCB3:
25+
sealed trait A
26+
case class B() extends A
27+
case class C() extends A
28+
class D extends C // ignored pattern: class extending a case class
29+
30+
object FromAkkaCB4:
31+
sealed trait A
32+
sealed trait N extends A
33+
case class B() extends A
34+
case class C() extends A
35+
class D extends C, N // ignored
36+
37+
object FromAkkaCB5:
38+
sealed trait A
39+
sealed trait N extends A
40+
case class B() extends A
41+
case class C() extends A
42+
class D extends N // ignored
43+
44+
object Test:
45+
def main(args: Array[String]): Unit =
46+
testOriginalReport()
47+
testFromAkkaCB()
48+
testFromAkkaCB2()
49+
end main
50+
51+
def testOriginalReport() =
52+
import OriginalReport._
53+
val m = summon[Mirror.SumOf[TreeValue]]
54+
given Show[TreeValue] = Show.derived[TreeValue]
55+
val leaf1 = Leaf1("1")
56+
val leaf2 = Leaf2(2)
57+
58+
assertEq(List(leaf1, leaf2).map(m.ordinal), List(1, 0))
59+
assertShow[TreeValue](leaf1, "[1] Leaf1(value = \"1\")")
60+
assertShow[TreeValue](leaf2, "[0] [0] Leaf2(value = 2)")
61+
end testOriginalReport
62+
63+
def testFromAkkaCB() =
64+
import FromAkkaCB._
65+
val m = summon[Mirror.SumOf[A]]
66+
given Show[A] = Show.derived[A]
67+
val d = D()
68+
val e = E()
69+
70+
assertEq(List(d, e).map(m.ordinal), List(0, 0))
71+
assertShow[A](d, "[0] [0] D")
72+
assertShow[A](e, "[0] [1] E")
73+
end testFromAkkaCB
74+
75+
def testFromAkkaCB2() =
76+
import FromAkkaCB2._
77+
val m = summon[Mirror.SumOf[A]]
78+
val n = summon[Mirror.SumOf[N]]
79+
given Show[A] = Show.derived[A]
80+
val b = B()
81+
val c = C()
82+
83+
assertEq(List(b, c).map(m.ordinal), List(1, 0))
84+
assertShow[A](b, "[1] B")
85+
assertShow[A](c, "[0] [0] C")
86+
end testFromAkkaCB2
87+
88+
def assertEq[A](obt: A, exp: A) = assert(obt == exp, s"$obt != $exp (obtained != expected)")
89+
def assertShow[A: Show](x: A, s: String) = assertEq(Show.show(x), s)
90+
end Test
91+
92+
trait Show[-T]:
93+
def show(x: T): String
94+
95+
object Show:
96+
given Show[Int] with { def show(x: Int) = s"$x" }
97+
given Show[Char] with { def show(x: Char) = s"'$x'" }
98+
given Show[String] with { def show(x: String) = s"$"$x$"" }
99+
100+
inline def show[T](x: T): String = summonInline[Show[T]].show(x)
101+
102+
transparent inline def derived[T](implicit ev: Mirror.Of[T]): Show[T] = new {
103+
def show(x: T): String = inline ev match {
104+
case m: Mirror.ProductOf[T] => showProduct(x.asInstanceOf[Product], m)
105+
case m: Mirror.SumOf[T] => showCases[m.MirroredElemTypes](0)(x, m.ordinal(x))
106+
}
107+
}
108+
109+
transparent inline def showProduct[T](x: Product, m: Mirror.ProductOf[T]): String =
110+
constValue[m.MirroredLabel] + showElems[m.MirroredElemTypes, m.MirroredElemLabels](0, Nil)(x)
111+
112+
transparent inline def showElems[Elems <: Tuple, Labels <: Tuple](n: Int, elems: List[String])(x: Product): String =
113+
inline (erasedValue[Labels], erasedValue[Elems]) match {
114+
case _: (label *: labels, elem *: elems) =>
115+
val value = show(x.productElement(n).asInstanceOf[elem])
116+
showElems[elems, labels](n + 1, s"${constValue[label]} = $value" :: elems)(x)
117+
case _: (EmptyTuple, EmptyTuple) =>
118+
if elems.isEmpty then "" else elems.mkString(s"(", ", ", ")")
119+
}
120+
121+
transparent inline def showCases[Alts <: Tuple](n: Int)(x: Any, ord: Int): String =
122+
inline erasedValue[Alts] match {
123+
case _: (alt *: alts) =>
124+
if (ord == n) summonFrom {
125+
case m: Mirror.Of[`alt`] => s"[$ord] " + derived[alt](using m).show(x.asInstanceOf[alt])
126+
} else showCases[alts](n + 1)(x, ord)
127+
case _: EmptyTuple => throw new MatchError(x)
128+
}
129+
end Show

0 commit comments

Comments
 (0)