Skip to content

Commit 5c924ca

Browse files
oderskymilessabin
authored andcommitted
Mirror infrastructure for generic sum types
1 parent c6c0f71 commit 5c924ca

File tree

4 files changed

+107
-34
lines changed

4 files changed

+107
-34
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -697,6 +697,9 @@ class Definitions {
697697
lazy val Mirror_Product_fromProductR: TermRef = Mirror_ProductClass.requiredMethodRef(nme.fromProduct)
698698
def Mirror_Product_fromProduct(implicit ctx: Context): Symbol = Mirror_Product_fromProductR.symbol
699699

700+
lazy val Mirror_SumType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Sum")
701+
def Mirror_SumClass(implicit ctx: Context): ClassSymbol = Mirror_SumType.symbol.asClass
702+
700703
lazy val Mirror_SingletonType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Singleton")
701704
def Mirror_SingletonClass(implicit ctx: Context): ClassSymbol = Mirror_SingletonType.symbol.asClass
702705

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,36 @@ class SymUtils(val self: Symbol) extends AnyVal {
6868
* Excluded are value classes, abstract classes and case classes with more than one
6969
* parameter section.
7070
*/
71-
def isGenericProduct(implicit ctx: Context): Boolean =
72-
self.is(CaseClass, butNot = Abstract) &&
73-
self.primaryConstructor.info.paramInfoss.length == 1 &&
74-
!isDerivedValueClass(self)
71+
def whyNotGenericProduct(implicit ctx: Context): String =
72+
if (!self.is(CaseClass)) "it is not a case class"
73+
else if (self.is(Abstract)) "it is an abstract class"
74+
else if (self.primaryConstructor.info.paramInfoss.length != 1) "it takes more than one parameter list"
75+
else if (isDerivedValueClass(self)) "it is a value class"
76+
else ""
77+
78+
def isGenericProduct(implicit ctx: Context): Boolean = whyNotGenericProduct.isEmpty
79+
80+
/** Is this a sealed class or trait for which a sum mirror is generated?
81+
* Excluded are
82+
*/
83+
def whyNotGenericSum(implicit ctx: Context): String =
84+
if (!self.is(Sealed))
85+
s"it is not a sealed ${if (self.is(Trait)) "trait" else "class"}"
86+
else {
87+
val children = self.children
88+
def problem(child: Symbol) =
89+
if (child == self) "it has anonymous or inaccessible subclasses"
90+
else if (!child.isClass) ""
91+
else {
92+
val s = child.whyNotGenericProduct
93+
if (s.isEmpty) s
94+
else "its child $child is not a generic product because $s"
95+
}
96+
if (children.isEmpty) "it does not have subclasses"
97+
else children.filter(_.isClass).map(problem).find(!_.isEmpty).getOrElse("")
98+
}
99+
100+
def isGenericSum(implicit ctx: Context): Boolean = whyNotGenericSum.isEmpty
75101

76102
/** If this is a constructor, its owner: otherwise this. */
77103
final def skipConstructor(implicit ctx: Context): Symbol =
@@ -161,6 +187,10 @@ class SymUtils(val self: Symbol) extends AnyVal {
161187
else owner.isLocal
162188
}
163189

190+
/** The typeRef with wildcard arguments for each type parameter */
191+
def rawTypeRef(implicit ctx: Context) =
192+
self.typeRef.appliedTo(self.typeParams.map(_ => TypeBounds.empty))
193+
164194
/** Is symbol a quote operation? */
165195
def isQuote(implicit ctx: Context): Boolean =
166196
self == defn.InternalQuoted_exprQuote || self == defn.InternalQuoted_typeQuote

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,10 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
5454
def caseSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseSymbols }
5555
def caseModuleSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
5656

57-
private def alreadyDefined(sym: Symbol, clazz: ClassSymbol)(implicit ctx: Context): Boolean = {
57+
private def existingDef(sym: Symbol, clazz: ClassSymbol)(implicit ctx: Context): Symbol = {
5858
val existing = sym.matchingMember(clazz.thisType)
59-
existing.exists && !(existing == sym || existing.is(Deferred))
59+
if (existing != sym && !existing.is(Deferred)) existing
60+
else NoSymbol
6061
}
6162

6263
private def synthesizeDef(sym: TermSymbol, rhsFn: List[List[Tree]] => Context => Tree)(implicit ctx: Context): Tree =
@@ -80,7 +81,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
8081
else Nil
8182

8283
def syntheticDefIfMissing(sym: Symbol): List[Tree] =
83-
if (alreadyDefined(sym, clazz)) Nil else syntheticDef(sym) :: Nil
84+
if (existingDef(sym, clazz).exists) Nil else syntheticDef(sym) :: Nil
8485

8586
def syntheticDef(sym: Symbol): Tree = {
8687
val synthetic = sym.copy(
@@ -344,42 +345,79 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
344345
}
345346
}
346347

348+
/** For an enum T:
349+
*
350+
* def ordinal(x: MonoType) = x.enumTag
351+
*
352+
* For sealed trait with children of normalized types C_1, ..., C_n:
353+
*
354+
* def ordinal(x: MonoType) = x match {
355+
* case _: C_1 => 0
356+
* ...
357+
* case _: C_n => n - 1
358+
*
359+
* Here, the normalized type of a class C is C[_, ...., _] with
360+
* a wildcard for each type parameter. The normalized type of an object
361+
* O is O.type.
362+
*/
363+
def ordinalBody(cls: Symbol, param: Tree)(implicit ctx: Context): Tree =
364+
if (cls.is(Enum)) param.select(nme.enumTag)
365+
else {
366+
val cases =
367+
for ((child, idx) <- cls.children.zipWithIndex) yield {
368+
val patType = if (child.isTerm) child.termRef else child.rawTypeRef
369+
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
370+
CaseDef(pat, EmptyTree, Literal(Constant(idx)))
371+
}
372+
Match(param, cases)
373+
}
374+
375+
/** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent
376+
* and `MonoType` and `ordinal` members.
377+
* - If `impl` is the companion of a generic product, add `deriving.Mirror.Product` parent
378+
* and `MonoType` and `fromProduct` members.
379+
*/
347380
def addMirrorSupport(impl: Template)(implicit ctx: Context): Template = {
348381
val clazz = ctx.owner.asClass
349-
var newBody = serializableObjectMethod(clazz) ::: caseAndValueMethods(clazz) ::: impl.body
382+
val linked = clazz.linkedClass
383+
384+
var newBody = impl.body
350385
var newParents = impl.parents
351-
def addParent(parent: Type) = {
386+
def addParent(parent: Type): Unit = {
352387
newParents = newParents :+ TypeTree(parent)
353388
val oldClassInfo = clazz.classInfo
354389
val newClassInfo = oldClassInfo.derivedClassInfo(
355390
classParents = oldClassInfo.classParents :+ parent)
356391
clazz.copySymDenotation(info = newClassInfo).installAfter(thisPhase)
357392
}
393+
def addMethod(name: TermName, info: Type, body: (Symbol, Tree, Context) => Tree): Unit = {
394+
val meth = ctx.newSymbol(clazz, name, Synthetic | Method, info, coord = clazz.coord)
395+
if (!existingDef(meth, clazz).exists) {
396+
meth.entered
397+
newBody = newBody :+
398+
synthesizeDef(meth, vrefss => ctx => body(linked, vrefss.head.head, ctx))
399+
}
400+
}
401+
lazy val monoType = {
402+
val monoType =
403+
ctx.newSymbol(clazz, tpnme.MonoType, Synthetic, TypeAlias(linked.rawTypeRef), coord = clazz.coord)
404+
existingDef(monoType, clazz).orElse {
405+
newBody = newBody :+ TypeDef(monoType).withSpan(ctx.owner.span.focus)
406+
monoType.entered
407+
}
408+
}
358409
if (clazz.is(Module)) {
359-
if (clazz.is(Case)) addParent(defn.Mirror_SingletonType)
360-
else {
361-
val linked = clazz.linkedClass
362-
if (linked.isGenericProduct) {
363-
addParent(defn.Mirror_ProductType)
364-
val rawClassType =
365-
linked.typeRef.appliedTo(linked.typeParams.map(_ => TypeBounds.empty))
366-
val monoType =
367-
ctx.newSymbol(clazz, tpnme.MonoType, Synthetic, TypeAlias(rawClassType), coord = clazz.coord)
368-
if (!alreadyDefined(monoType, clazz)) {
369-
monoType.entered
370-
newBody = newBody :+ TypeDef(monoType).withSpan(ctx.owner.span.focus)
371-
}
372-
val fromProduct =
373-
ctx.newSymbol(clazz, nme.fromProduct, Synthetic | Method,
374-
info = MethodType(defn.ProductType :: Nil, monoType.typeRef), coord = clazz.coord)
375-
if (!alreadyDefined(fromProduct, clazz)) {
376-
fromProduct.entered
377-
newBody = newBody :+
378-
synthesizeDef(fromProduct, vrefss => ctx =>
379-
fromProductBody(linked, vrefss.head.head)(ctx)
380-
.ensureConforms(rawClassType)) // t4758.scala or i3381.scala are examples where a cast is needed
381-
}
382-
}
410+
if (clazz.is(Case))
411+
addParent(defn.Mirror_SingletonType)
412+
else if (linked.isGenericProduct) {
413+
addParent(defn.Mirror_ProductType)
414+
addMethod(nme.fromProduct, MethodType(defn.ProductType :: Nil, monoType.typeRef),
415+
fromProductBody(_, _)(_).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
416+
}
417+
else if (linked.isGenericSum) {
418+
addParent(defn.Mirror_SumType)
419+
addMethod(nme.ordinal, MethodType(monoType.typeRef :: Nil, defn.IntType),
420+
ordinalBody(_, _)(_))
383421
}
384422
}
385423

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -868,8 +868,10 @@ class Namer { typer: Typer =>
868868
val child = if (denot.is(Module)) denot.sourceModule else denot.symbol
869869
register(child, parent)
870870
}
871-
else if (denot.is(CaseVal, butNot = Method | Module))
871+
else if (denot.is(CaseVal, butNot = Method | Module)) {
872+
assert(denot.is(Enum), denot)
872873
register(denot.symbol, denot.info)
874+
}
873875
}
874876

875877
/** Intentionally left without `implicit ctx` parameter. We need

0 commit comments

Comments
 (0)