Skip to content

Commit 802128e

Browse files
authored
Merge pull request #15847 from dotty-staging/mirror-for-inner-classes
2 parents 9ff4fae + ec9541f commit 802128e

32 files changed

+1779
-94
lines changed

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,4 +868,73 @@ object TypeOps:
868868
def stripTypeVars(tp: Type)(using Context): Type =
869869
new StripTypeVarsMap().apply(tp)
870870

871+
/** computes a prefix for `child`, derived from its common prefix with `pre`
872+
* - `pre` is assumed to be the prefix of `parent` at a given callsite.
873+
* - `child` is assumed to be the sealed child of `parent`, and reachable according to `whyNotGenericSum`.
874+
*/
875+
def childPrefix(pre: Type, parent: Symbol, child: Symbol)(using Context): Type =
876+
// Example, given this class hierarchy, we can see how this should work
877+
// when summoning a mirror for `wrapper.Color`:
878+
//
879+
// package example
880+
// object Outer3:
881+
// class Wrapper:
882+
// sealed trait Color
883+
// val wrapper = new Wrapper
884+
// object Inner:
885+
// case object Red extends wrapper.Color
886+
// case object Green extends wrapper.Color
887+
// case object Blue extends wrapper.Color
888+
//
889+
// summon[Mirror.SumOf[wrapper.Color]]
890+
// ^^^^^^^^^^^^^
891+
// > pre = example.Outer3.wrapper.type
892+
// > parent = sealed trait example.Outer3.Wrapper.Color
893+
// > child = module val example.Outer3.Innner.Red
894+
// > parentOwners = [example, Outer3, Wrapper] // computed from definition
895+
// > childOwners = [example, Outer3, Inner] // computed from definition
896+
// > parentRest = [Wrapper] // strip common owners from `childOwners`
897+
// > childRest = [Inner] // strip common owners from `parentOwners`
898+
// > commonPrefix = example.Outer3.type // i.e. parentRest has only 1 element, use 1st subprefix of `pre`.
899+
// > childPrefix = example.Outer3.Inner.type // select all symbols in `childRest` from `commonPrefix`
900+
901+
/** unwind the prefix into a sequence of sub-prefixes, selecting the one at `limit`
902+
* @return `NoType` if there is an unrecognised prefix type.
903+
*/
904+
def subPrefixAt(pre: Type, limit: Int): Type =
905+
def go(pre: Type, limit: Int): Type =
906+
if limit == 0 then pre // EXIT: No More prefix
907+
else pre match
908+
case pre: ThisType => go(pre.tref.prefix, limit - 1)
909+
case pre: TermRef => go(pre.prefix, limit - 1)
910+
case _:SuperType | NoPrefix => pre.ensuring(limit == 1) // EXIT: can't rewind further than this
911+
case _ => NoType // EXIT: unrecognized prefix
912+
go(pre, limit)
913+
end subPrefixAt
914+
915+
/** Successively select each symbol in the `suffix` from `pre`, such that they are reachable. */
916+
def selectAll(pre: Type, suffix: Seq[Symbol]): Type =
917+
suffix.foldLeft(pre)((pre, sym) =>
918+
pre.select(
919+
if sym.isType && sym.is(Module) then sym.sourceModule
920+
else sym
921+
)
922+
)
923+
924+
def stripCommonPrefix(xs: List[Symbol], ys: List[Symbol]): (List[Symbol], List[Symbol]) = (xs, ys) match
925+
case (x :: xs1, y :: ys1) if x eq y => stripCommonPrefix(xs1, ys1)
926+
case _ => (xs, ys)
927+
928+
val (parentRest, childRest) = stripCommonPrefix(
929+
parent.owner.ownersIterator.toList.reverse,
930+
child.owner.ownersIterator.toList.reverse
931+
)
932+
933+
val commonPrefix = subPrefixAt(pre, parentRest.size) // unwind parent owners up to common prefix
934+
935+
if commonPrefix.exists then selectAll(commonPrefix, childRest)
936+
else NoType
937+
938+
end childPrefix
939+
871940
end TypeOps

compiler/src/dotty/tools/dotc/transform/PostInlining.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@ class PostInlining extends MacroTransform, IdentityDenotTransformer:
2626
override def transform(tree: Tree)(using Context): Tree =
2727
super.transform(tree) match
2828
case tree1: Template
29-
if tree1.hasAttachment(ExtendsSingletonMirror)
30-
|| tree1.hasAttachment(ExtendsProductMirror)
31-
|| tree1.hasAttachment(ExtendsSumMirror) =>
29+
if tree1.hasAttachment(ExtendsSingletonMirror) || tree1.hasAttachment(ExtendsSumOrProductMirror) =>
3230
synthMbr.addMirrorSupport(tree1)
3331
case tree1 => tree1
3432

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,28 +163,42 @@ object SymUtils:
163163
* and also the location of the generated mirror.
164164
* - all of its children are generic products, singletons, or generic sums themselves.
165165
*/
166-
def whyNotGenericSum(using Context): String =
166+
def whyNotGenericSum(pre: Type)(using Context): String =
167167
if (!self.is(Sealed))
168168
s"it is not a sealed ${self.kindString}"
169169
else if (!self.isOneOf(AbstractOrTrait))
170170
"it is not an abstract class"
171171
else {
172172
val children = self.children
173173
val companionMirror = self.useCompanionAsSumMirror
174+
val ownerScope = if pre.isInstanceOf[SingletonType] then pre.classSymbol else NoSymbol
174175
def problem(child: Symbol) = {
175176

176-
def isAccessible(sym: Symbol): Boolean =
177-
(self.isContainedIn(sym) && (companionMirror || ctx.owner.isContainedIn(sym)))
178-
|| sym.is(Module) && isAccessible(sym.owner)
177+
def accessibleMessage(sym: Symbol): String =
178+
def inherits(sym: Symbol, scope: Symbol): Boolean =
179+
!scope.is(Package) && (scope.derivesFrom(sym) || inherits(sym, scope.owner))
180+
def isVisibleToParent(sym: Symbol): Boolean =
181+
self.isContainedIn(sym) || sym.is(Module) && isVisibleToParent(sym.owner)
182+
def isVisibleToScope(sym: Symbol): Boolean =
183+
def isReachable: Boolean = ctx.owner.isContainedIn(sym)
184+
def isMemberOfPrefix: Boolean =
185+
ownerScope.exists && inherits(sym, ownerScope)
186+
isReachable || isMemberOfPrefix || sym.is(Module) && isVisibleToScope(sym.owner)
187+
if !isVisibleToParent(sym) then i"to its parent $self"
188+
else if !companionMirror && !isVisibleToScope(sym) then i"to call site ${ctx.owner}"
189+
else ""
190+
end accessibleMessage
191+
192+
val childAccessible = accessibleMessage(child.owner)
179193

180194
if (child == self) "it has anonymous or inaccessible subclasses"
181-
else if (!isAccessible(child.owner)) i"its child $child is not accessible"
195+
else if (!childAccessible.isEmpty) i"its child $child is not accessible $childAccessible"
182196
else if (!child.isClass) "" // its a singleton enum value
183197
else {
184198
val s = child.whyNotGenericProduct
185199
if s.isEmpty then s
186200
else if child.is(Sealed) then
187-
val s = child.whyNotGenericSum
201+
val s = child.whyNotGenericSum(pre)
188202
if s.isEmpty then s
189203
else i"its child $child is not a generic sum because $s"
190204
else
@@ -195,7 +209,7 @@ object SymUtils:
195209
else children.map(problem).find(!_.isEmpty).getOrElse("")
196210
}
197211

198-
def isGenericSum(using Context): Boolean = whyNotGenericSum.isEmpty
212+
def isGenericSum(pre: Type)(using Context): Boolean = whyNotGenericSum(pre).isEmpty
199213

200214
/** If this is a constructor, its owner: otherwise this. */
201215
final def skipConstructor(using Context): Symbol =

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ import NullOpsDecorator._
1818

1919
object SyntheticMembers {
2020

21+
enum MirrorImpl:
22+
case OfProduct(pre: Type)
23+
case OfSum(childPres: List[Type])
24+
2125
/** Attachment marking an anonymous class as a singleton case that will extend from Mirror.Singleton */
2226
val ExtendsSingletonMirror: Property.StickyKey[Unit] = new Property.StickyKey
2327

2428
/** Attachment recording that an anonymous class should extend Mirror.Product */
25-
val ExtendsProductMirror: Property.StickyKey[Unit] = new Property.StickyKey
26-
27-
/** Attachment recording that an anonymous class should extend Mirror.Sum */
28-
val ExtendsSumMirror: Property.StickyKey[Unit] = new Property.StickyKey
29+
val ExtendsSumOrProductMirror: Property.StickyKey[MirrorImpl] = new Property.StickyKey
2930
}
3031

3132
/** Synthetic method implementations for case classes, case objects,
@@ -484,32 +485,41 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
484485
* type MirroredMonoType = C[?]
485486
* ```
486487
*/
487-
def fromProductBody(caseClass: Symbol, param: Tree)(using Context): Tree = {
488-
val (classRef, methTpe) =
489-
caseClass.primaryConstructor.info match {
488+
def fromProductBody(caseClass: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
489+
def extractParams(tpe: Type): List[Type] =
490+
tpe.asInstanceOf[MethodType].paramInfos
491+
492+
def computeFromCaseClass: (Type, List[Type]) =
493+
val (baseRef, baseInfo) =
494+
val rawRef = caseClass.typeRef
495+
val rawInfo = caseClass.primaryConstructor.info
496+
optInfo match
497+
case Some(info) =>
498+
(rawRef.asSeenFrom(info.pre, caseClass.owner), rawInfo.asSeenFrom(info.pre, caseClass.owner))
499+
case _ =>
500+
(rawRef, rawInfo)
501+
baseInfo match
490502
case tl: PolyType =>
491503
val (tl1, tpts) = constrained(tl, untpd.EmptyTree, alwaysAddTypeVars = true)
492504
val targs =
493505
for (tpt <- tpts) yield
494506
tpt.tpe match {
495507
case tvar: TypeVar => tvar.instantiate(fromBelow = false)
496508
}
497-
(caseClass.typeRef.appliedTo(targs), tl.instantiate(targs))
509+
(baseRef.appliedTo(targs), extractParams(tl.instantiate(targs)))
498510
case methTpe =>
499-
(caseClass.typeRef, methTpe)
500-
}
501-
methTpe match {
502-
case methTpe: MethodType =>
503-
val elems =
504-
for ((formal, idx) <- methTpe.paramInfos.zipWithIndex) yield {
505-
val elem =
506-
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
507-
.ensureConforms(formal.translateFromRepeated(toArray = false))
508-
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
509-
}
510-
New(classRef, elems)
511-
}
512-
}
511+
(baseRef, extractParams(methTpe))
512+
end computeFromCaseClass
513+
514+
val (classRefApplied, paramInfos) = computeFromCaseClass
515+
val elems =
516+
for ((formal, idx) <- paramInfos.zipWithIndex) yield
517+
val elem =
518+
param.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
519+
.ensureConforms(formal.translateFromRepeated(toArray = false))
520+
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
521+
New(classRefApplied, elems)
522+
end fromProductBody
513523

514524
/** For an enum T:
515525
*
@@ -527,24 +537,36 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
527537
* a wildcard for each type parameter. The normalized type of an object
528538
* O is O.type.
529539
*/
530-
def ordinalBody(cls: Symbol, param: Tree)(using Context): Tree =
531-
if (cls.is(Enum)) param.select(nme.ordinal).ensureApplied
532-
else {
540+
def ordinalBody(cls: Symbol, param: Tree, optInfo: Option[MirrorImpl.OfSum])(using Context): Tree =
541+
if cls.is(Enum) then
542+
param.select(nme.ordinal).ensureApplied
543+
else
544+
def computeChildTypes: List[Type] =
545+
def rawRef(child: Symbol): Type =
546+
if (child.isTerm) child.reachableTermRef else child.reachableRawTypeRef
547+
optInfo match
548+
case Some(info) => info
549+
.childPres
550+
.lazyZip(cls.children)
551+
.map((pre, child) => rawRef(child).asSeenFrom(pre, child.owner))
552+
case _ =>
553+
cls.children.map(rawRef)
554+
end computeChildTypes
555+
val childTypes = computeChildTypes
533556
val cases =
534-
for ((child, idx) <- cls.children.zipWithIndex) yield {
535-
val patType = if (child.isTerm) child.reachableTermRef else child.reachableRawTypeRef
557+
for (patType, idx) <- childTypes.zipWithIndex yield
536558
val pat = Typed(untpd.Ident(nme.WILDCARD).withType(patType), TypeTree(patType))
537559
CaseDef(pat, EmptyTree, Literal(Constant(idx)))
538-
}
560+
539561
Match(param.annotated(New(defn.UncheckedAnnot.typeRef, Nil)), cases)
540-
}
562+
end ordinalBody
541563

542564
/** - If `impl` is the companion of a generic sum, add `deriving.Mirror.Sum` parent
543565
* and `MirroredMonoType` and `ordinal` members.
544566
* - If `impl` is the companion of a generic product, add `deriving.Mirror.Product` parent
545567
* and `MirroredMonoType` and `fromProduct` members.
546-
* - If `impl` is marked with one of the attachments ExtendsSingletonMirror, ExtendsProductMirror,
547-
* or ExtendsSumMirror, remove the attachment and generate the corresponding mirror support,
568+
* - If `impl` is marked with one of the attachments ExtendsSingletonMirror or ExtendsSumOfProductMirror,
569+
* remove the attachment and generate the corresponding mirror support,
548570
* On this case the represented class or object is referred to in a pre-existing `MirroredMonoType`
549571
* member of the template.
550572
*/
@@ -581,30 +603,33 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
581603
}
582604
def makeSingletonMirror() =
583605
addParent(defn.Mirror_SingletonClass.typeRef)
584-
def makeProductMirror(cls: Symbol) = {
606+
def makeProductMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfProduct]) = {
585607
addParent(defn.Mirror_ProductClass.typeRef)
586608
addMethod(nme.fromProduct, MethodType(defn.ProductClass.typeRef :: Nil, monoType.typeRef), cls,
587-
fromProductBody(_, _).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
609+
fromProductBody(_, _, optInfo).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
588610
}
589-
def makeSumMirror(cls: Symbol) = {
611+
def makeSumMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfSum]) = {
590612
addParent(defn.Mirror_SumClass.typeRef)
591613
addMethod(nme.ordinal, MethodType(monoType.typeRef :: Nil, defn.IntType), cls,
592-
ordinalBody(_, _))
614+
ordinalBody(_, _, optInfo))
593615
}
594616

595617
if (clazz.is(Module)) {
596618
if (clazz.is(Case)) makeSingletonMirror()
597-
else if (linked.isGenericProduct) makeProductMirror(linked)
598-
else if (linked.isGenericSum) makeSumMirror(linked)
619+
else if (linked.isGenericProduct) makeProductMirror(linked, None)
620+
else if (linked.isGenericSum(NoType)) makeSumMirror(linked, None)
599621
else if (linked.is(Sealed))
600-
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum}")
622+
derive.println(i"$linked is not a sum because ${linked.whyNotGenericSum(NoType)}")
601623
}
602624
else if (impl.removeAttachment(ExtendsSingletonMirror).isDefined)
603625
makeSingletonMirror()
604-
else if (impl.removeAttachment(ExtendsProductMirror).isDefined)
605-
makeProductMirror(monoType.typeRef.dealias.classSymbol)
606-
else if (impl.removeAttachment(ExtendsSumMirror).isDefined)
607-
makeSumMirror(monoType.typeRef.dealias.classSymbol)
626+
else
627+
impl.removeAttachment(ExtendsSumOrProductMirror).match
628+
case Some(prodImpl: MirrorImpl.OfProduct) =>
629+
makeProductMirror(monoType.typeRef.dealias.classSymbol, Some(prodImpl))
630+
case Some(sumImpl: MirrorImpl.OfSum) =>
631+
makeSumMirror(monoType.typeRef.dealias.classSymbol, Some(sumImpl))
632+
case _ =>
608633

609634
cpy.Template(impl)(parents = newParents, body = newBody)
610635
}

0 commit comments

Comments
 (0)