Skip to content

Commit ee5a82f

Browse files
committed
Synthesise Mirror.Sum for nested hierarchies
1 parent 38b983c commit ee5a82f

File tree

5 files changed

+160
-14
lines changed

5 files changed

+160
-14
lines changed

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ object SymUtils:
9595
* - none of its children are anonymous classes
9696
* - all of its children are addressable through a path from the parent class
9797
* and also the location of the generated mirror.
98-
* - all of its children are generic products or singletons
98+
* - all of its children are generic products, singletons, or generic sums themselves.
9999
*/
100100
def whyNotGenericSum(declScope: Symbol)(using Context): String =
101101
if (!self.is(Sealed))
@@ -118,7 +118,11 @@ object SymUtils:
118118
else {
119119
val s = child.whyNotGenericProduct
120120
if (s.isEmpty) s
121-
else i"its child $child is not a generic product because $s"
121+
else if (child.is(Sealed)) {
122+
val s = child.whyNotGenericSum(if child.useCompanionAsMirror then child.linkedClass else ctx.owner)
123+
if (s.isEmpty) s
124+
else i"its child $child is not a generic sum because $s"
125+
} else i"its child $child is not a generic product because $s"
122126
}
123127
}
124128
if (children.isEmpty) "it does not have subclasses"

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
525525
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
526526
CaseDef(pat, EmptyTree, Literal(Constant(idx)))
527527
}
528-
Match(param, cases)
528+
Match(param.annotated(New(defn.UncheckedAnnot.typeRef, Nil)), cases)
529529
}
530530

531531
/** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -285,12 +285,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
285285
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
286286

287287
def solve(sym: Symbol): Type = sym match
288-
case caseClass: ClassSymbol =>
289-
assert(caseClass.is(Case))
290-
if caseClass.is(Module) then
291-
caseClass.sourceModule.termRef
288+
case childClass: ClassSymbol =>
289+
assert(childClass.isOneOf(Case | Sealed))
290+
if childClass.is(Module) then
291+
childClass.sourceModule.termRef
292292
else
293-
caseClass.primaryConstructor.info match
293+
childClass.primaryConstructor.info match
294294
case info: PolyType =>
295295
// Compute the the full child type by solving the subtype constraint
296296
// `C[X1, ..., Xn] <: P`, where
@@ -307,13 +307,13 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
307307
case tp => tp
308308
resType <:< target
309309
val tparams = poly.paramRefs
310-
val variances = caseClass.typeParams.map(_.paramVarianceSign)
310+
val variances = childClass.typeParams.map(_.paramVarianceSign)
311311
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
312312
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
313313
resType.substParams(poly, instanceTypes)
314-
instantiate(using ctx.fresh.setExploreTyperState().setOwner(caseClass))
314+
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
315315
case _ =>
316-
caseClass.typeRef
316+
childClass.typeRef
317317
case child => child.termRef
318318
end solve
319319

@@ -328,9 +328,9 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
328328
(mirroredType, elems)
329329

330330
val mirrorType =
331-
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
332-
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
333-
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
331+
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
332+
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
333+
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
334334
val mirrorRef =
335335
if useCompanion then companionPath(mirroredType, span)
336336
else anonymousMirror(monoType, ExtendsSumMirror, span)

compiler/test/dotty/tools/dotc/CompilationTests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ class CompilationTests {
197197
compileFile("tests/run-custom-args/no-useless-forwarders.scala", defaultOptions and "-Xmixin-force-forwarders:false"),
198198
compileFile("tests/run-custom-args/defaults-serizaliable-no-forwarders.scala", defaultOptions and "-Xmixin-force-forwarders:false"),
199199
compileFilesInDir("tests/run-custom-args/erased", defaultOptions.and("-language:experimental.erasedDefinitions")),
200+
compileFilesInDir("tests/run-custom-args/fatal-warnings", defaultOptions.and("-Xfatal-warnings")),
200201
compileFilesInDir("tests/run-deep-subtype", allowDeepSubtypes),
201202
compileFilesInDir("tests/run", defaultOptions.and("-Ysafe-init"))
202203
).checkRuns()
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import scala.compiletime.*
2+
import scala.deriving.*
3+
4+
object OriginalReport:
5+
sealed trait TreeValue
6+
sealed trait SubLevel extends TreeValue
7+
case class Leaf1(value: String) extends TreeValue
8+
case class Leaf2(value: Int) extends SubLevel
9+
10+
// Variants from the initial failure in akka.event.LogEvent
11+
object FromAkkaCB:
12+
sealed trait A
13+
sealed trait B extends A
14+
sealed trait C extends A
15+
case class D() extends B, C
16+
case class E() extends C, B
17+
18+
object FromAkkaCB2:
19+
sealed trait A
20+
sealed trait N extends A
21+
case class B() extends A
22+
case class C() extends A, N
23+
24+
object FromAkkaCB3:
25+
sealed trait A
26+
case class B() extends A
27+
case class C() extends A
28+
class D extends C // ignored pattern: class extending a case class
29+
30+
object NoUnreachableWarnings:
31+
sealed trait Top
32+
object Top
33+
34+
final case class MiddleA() extends Top with Bottom
35+
final case class MiddleB() extends Top with Bottom
36+
final case class MiddleC() extends Top with Bottom
37+
38+
sealed trait Bottom extends Top
39+
40+
object FromAkkaCB4:
41+
sealed trait LogEvent
42+
object LogEvent
43+
case class Error() extends LogEvent
44+
class Error2() extends Error() with LogEventWithMarker // ignored pattern
45+
case class Warning() extends LogEvent
46+
sealed trait LogEventWithMarker extends LogEvent // must be defined late
47+
48+
object FromAkkaCB4simpler:
49+
sealed trait LogEvent
50+
object LogEvent
51+
case class Error() extends LogEvent
52+
class Error2() extends LogEventWithMarker // not a case class
53+
case class Warning() extends LogEvent
54+
sealed trait LogEventWithMarker extends LogEvent
55+
56+
object Test:
57+
def main(args: Array[String]): Unit =
58+
testOriginalReport()
59+
testFromAkkaCB()
60+
testFromAkkaCB2()
61+
end main
62+
63+
def testOriginalReport() =
64+
import OriginalReport._
65+
val m = summon[Mirror.SumOf[TreeValue]]
66+
given Show[TreeValue] = Show.derived[TreeValue]
67+
val leaf1 = Leaf1("1")
68+
val leaf2 = Leaf2(2)
69+
70+
assertEq(List(leaf1, leaf2).map(m.ordinal), List(1, 0))
71+
assertShow[TreeValue](leaf1, "[1] Leaf1(value = \"1\")")
72+
assertShow[TreeValue](leaf2, "[0] [0] Leaf2(value = 2)")
73+
end testOriginalReport
74+
75+
def testFromAkkaCB() =
76+
import FromAkkaCB._
77+
val m = summon[Mirror.SumOf[A]]
78+
given Show[A] = Show.derived[A]
79+
val d = D()
80+
val e = E()
81+
82+
assertEq(List(d, e).map(m.ordinal), List(0, 0))
83+
assertShow[A](d, "[0] [0] D")
84+
assertShow[A](e, "[0] [1] E")
85+
end testFromAkkaCB
86+
87+
def testFromAkkaCB2() =
88+
import FromAkkaCB2._
89+
val m = summon[Mirror.SumOf[A]]
90+
val n = summon[Mirror.SumOf[N]]
91+
given Show[A] = Show.derived[A]
92+
val b = B()
93+
val c = C()
94+
95+
assertEq(List(b, c).map(m.ordinal), List(1, 0))
96+
assertShow[A](b, "[1] B")
97+
assertShow[A](c, "[0] [0] C")
98+
end testFromAkkaCB2
99+
100+
def assertEq[A](obt: A, exp: A) = assert(obt == exp, s"$obt != $exp (obtained != expected)")
101+
def assertShow[A: Show](x: A, s: String) = assertEq(Show.show(x), s)
102+
end Test
103+
104+
trait Show[-T]:
105+
def show(x: T): String
106+
107+
object Show:
108+
given Show[Int] with { def show(x: Int) = s"$x" }
109+
given Show[Char] with { def show(x: Char) = s"'$x'" }
110+
given Show[String] with { def show(x: String) = s"$"$x$"" }
111+
112+
inline def show[T](x: T): String = summonInline[Show[T]].show(x)
113+
114+
transparent inline def derived[T](implicit ev: Mirror.Of[T]): Show[T] = new {
115+
def show(x: T): String = inline ev match {
116+
case m: Mirror.ProductOf[T] => showProduct(x.asInstanceOf[Product], m)
117+
case m: Mirror.SumOf[T] => showCases[m.MirroredElemTypes](0)(x, m.ordinal(x))
118+
}
119+
}
120+
121+
transparent inline def showProduct[T](x: Product, m: Mirror.ProductOf[T]): String =
122+
constValue[m.MirroredLabel] + showElems[m.MirroredElemTypes, m.MirroredElemLabels](0, Nil)(x)
123+
124+
transparent inline def showElems[Elems <: Tuple, Labels <: Tuple](n: Int, elems: List[String])(x: Product): String =
125+
inline (erasedValue[Labels], erasedValue[Elems]) match {
126+
case _: (label *: labels, elem *: elems) =>
127+
val value = show(x.productElement(n).asInstanceOf[elem])
128+
showElems[elems, labels](n + 1, s"${constValue[label]} = $value" :: elems)(x)
129+
case _: (EmptyTuple, EmptyTuple) =>
130+
if elems.isEmpty then "" else elems.mkString(s"(", ", ", ")")
131+
}
132+
133+
transparent inline def showCases[Alts <: Tuple](n: Int)(x: Any, ord: Int): String =
134+
inline erasedValue[Alts] match {
135+
case _: (alt *: alts) =>
136+
if (ord == n) summonFrom {
137+
case m: Mirror.Of[`alt`] => s"[$ord] " + derived[alt](using m).show(x.asInstanceOf[alt])
138+
} else showCases[alts](n + 1)(x, ord)
139+
case _: EmptyTuple => throw new MatchError(x)
140+
}
141+
end Show

0 commit comments

Comments
 (0)