Skip to content

Commit 638b00e

Browse files
committed
Generate MonoType and fromProduct for generic products
Generate MonoType and fromProduct members for generic products.
1 parent 2f295c4 commit 638b00e

File tree

10 files changed

+228
-52
lines changed

10 files changed

+228
-52
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import reporting.diagnostic.messages._
1515
import reporting.trace
1616
import annotation.constructorOnly
1717
import printing.Formatting.hl
18+
import config.Printers
1819

1920
import scala.annotation.internal.sharable
2021

@@ -51,7 +52,7 @@ object desugar {
5152
private type VarInfo = (NameTree, Tree)
5253

5354
/** Is `name` the name of a method that can be invalidated as a compiler-generated
54-
* case class method that clashes with a user-defined method?
55+
* case class method if it clashes with a user-defined method?
5556
*/
5657
def isRetractableCaseClassMethodName(name: Name)(implicit ctx: Context): Boolean = name match {
5758
case nme.apply | nme.unapply | nme.unapplySeq | nme.copy => true
@@ -394,6 +395,10 @@ object desugar {
394395
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
395396
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
396397

398+
/** The untyped analogue of SymUtils.isGenericProduct */
399+
val isGenericProduct =
400+
mods.is(Case, butNot = Abstract) && constr1.vparamss.length == 1 && !isValueClass
401+
397402
val originalTparams = constr1.tparams
398403
val originalVparamss = constr1.vparamss
399404
lazy val derivedEnumParams = enumClass.typeParams.map(derivedTypeParam)
@@ -585,11 +590,16 @@ object desugar {
585590
else Nil
586591
}
587592

593+
def mirrorMemberType(str: String) =
594+
Select(Select(scalaDot("deriving".toTermName), "Mirror".toTermName), str.toTypeName)
595+
588596
var parents1 = parents
589597
if (isEnumCase && parents.isEmpty)
590598
parents1 = enumClassTypeRef :: Nil
591-
if (isCaseClass | isCaseObject)
599+
if (isCaseClass)
592600
parents1 = parents1 :+ scalaDot(str.Product.toTypeName) :+ scalaDot(nme.Serializable.toTypeName)
601+
else if (isCaseObject)
602+
parents1 = parents1 :+ mirrorMemberType("Singleton") :+ scalaDot(nme.Serializable.toTypeName)
593603
else if (isObject)
594604
parents1 = parents1 :+ scalaDot(nme.Serializable.toTypeName)
595605
if (isEnum)
@@ -600,11 +610,11 @@ object desugar {
600610
if (mods.is(Module)) (impl.derived, Nil) else (Nil, impl.derived)
601611

602612
// The thicket which is the desugared version of the companion object
603-
// synthetic object C extends parentTpt derives class-derived { defs }
604-
def companionDefs(parentTpt: Tree, defs: List[Tree]) = {
613+
// synthetic object C extends parentTpts derives class-derived { defs }
614+
def companionDefs(parentTpts: List[Tree], defs: List[Tree]) = {
605615
val mdefs = moduleDef(
606616
ModuleDef(
607-
className.toTermName, Template(emptyConstructor, parentTpt :: Nil, companionDerived, EmptyValDef, defs))
617+
className.toTermName, Template(emptyConstructor, parentTpts, companionDerived, EmptyValDef, defs))
608618
.withMods(companionMods | Synthetic))
609619
.withSpan(cdef.span).toList
610620
if (companionDerived.nonEmpty)
@@ -627,6 +637,7 @@ object desugar {
627637
// For all other classes, the parent is AnyRef.
628638
val companions =
629639
if (isCaseClass) {
640+
630641
// The return type of the `apply` method, and an (empty or singleton) list
631642
// of widening coercions
632643
val (applyResultTpt, widenDefs) =
@@ -654,38 +665,53 @@ object desugar {
654665
// todo: also use anyRef if constructor has a dependent method type (or rule that out)!
655666
(constrVparamss :\ classTypeRef) (
656667
(vparams, restpe) => Function(vparams map (_.tpt), restpe))
668+
val companionParents =
669+
if (isGenericProduct) companionParent :: mirrorMemberType("Product") :: Nil
670+
else companionParent :: Nil
657671
def widenedCreatorExpr =
658672
(creatorExpr /: widenDefs)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
659673
val applyMeths =
660674
if (mods is Abstract) Nil
661675
else {
662-
val copiedFlagsMask = DefaultParameterized | (copiedAccessFlags & Private)
663-
val appMods = {
664-
val mods = Modifiers(Synthetic | constr1.mods.flags & copiedFlagsMask)
665-
if (restrictedAccess) mods.withPrivateWithin(constr1.mods.privateWithin)
666-
else mods
676+
def applyDef = {
677+
val copiedFlagsMask = DefaultParameterized | (copiedAccessFlags & Private)
678+
val appMods = {
679+
val mods = Modifiers(Synthetic | constr1.mods.flags & copiedFlagsMask)
680+
if (restrictedAccess) mods.withPrivateWithin(constr1.mods.privateWithin)
681+
else mods
682+
}
683+
DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, widenedCreatorExpr)
684+
.withMods(appMods)
667685
}
668-
val app = DefDef(nme.apply, derivedTparams, derivedVparamss, applyResultTpt, widenedCreatorExpr)
669-
.withMods(appMods)
670-
app :: widenDefs
686+
applyDef :: widenDefs
671687
}
688+
689+
val monoTypeDefs =
690+
if (isGenericProduct) {
691+
val monoType = appliedTypeTree(classTycon, constrTparams.map(_ => TypeBoundsTree(EmptyTree, EmptyTree)))
692+
TypeDef(tpnme.MonoType, monoType).withMods(synthetic) :: Nil
693+
}
694+
else Nil
695+
672696
val unapplyMeth = {
673697
val hasRepeatedParam = constrVparamss.head.exists {
674698
case ValDef(_, tpt, _) => isRepeated(tpt)
675699
}
676700
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
677-
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
678-
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
679-
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
701+
val param = makeSyntheticParameter(tpt = classTypeRef)
702+
val rhs = if (arity == 0) Literal(Constant(true)) else Ident(param.name)
703+
DefDef(methName, derivedTparams, (param :: Nil) :: Nil, TypeTree(), rhs)
680704
.withMods(synthetic)
681705
}
682-
companionDefs(companionParent, applyMeths ::: unapplyMeth :: companionMembers)
706+
companionDefs(
707+
companionParents,
708+
applyMeths ::: unapplyMeth :: monoTypeDefs ::: companionMembers)
683709
}
684710
else if (companionMembers.nonEmpty || companionDerived.nonEmpty || isEnum)
685-
companionDefs(anyRef, companionMembers)
711+
companionDefs(anyRef :: Nil, companionMembers)
686712
else if (isValueClass) {
687713
impl.constr.vparamss match {
688-
case (_ :: Nil) :: _ => companionDefs(anyRef, Nil)
714+
case (_ :: Nil) :: _ => companionDefs(anyRef :: Nil, Nil)
689715
case _ => Nil // error will be emitted in typer
690716
}
691717
}
@@ -765,7 +791,7 @@ object desugar {
765791
}
766792

767793
flatTree(cdef1 :: companions ::: implicitWrappers)
768-
}
794+
}.reporting(res => i"desugared: $res", Printers.desugar)
769795

770796
/** Expand
771797
*

compiler/src/dotty/tools/dotc/config/Printers.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ object Printers {
1919
val cyclicErrors: Printer = noPrinter
2020
val debug = noPrinter // no type annotation here to force inlining
2121
val derive: Printer = noPrinter
22+
val desugar: Printer = noPrinter
2223
val dottydoc: Printer = noPrinter
2324
val exhaustivity: Printer = noPrinter
2425
val gadts: Printer = noPrinter

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -688,6 +688,16 @@ class Definitions {
688688
lazy val ModuleSerializationProxyConstructor: TermSymbol =
689689
ModuleSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty)))
690690

691+
//lazy val MirrorType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror")
692+
lazy val Mirror_ProductType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Product")
693+
def Mirror_ProductClass(implicit ctx: Context): ClassSymbol = Mirror_ProductType.symbol.asClass
694+
695+
lazy val Mirror_Product_fromProductR: TermRef = Mirror_ProductClass.requiredMethodRef(nme.fromProduct)
696+
def Mirror_Product_fromProduct(implicit ctx: Context): Symbol = Mirror_Product_fromProductR.symbol
697+
698+
lazy val Mirror_SingletonType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Singleton")
699+
def Mirror_SingletonClass(implicit ctx: Context): ClassSymbol = Mirror_SingletonType.symbol.asClass
700+
691701
lazy val GenericType: TypeRef = ctx.requiredClassRef("scala.reflect.Generic")
692702
def GenericClass(implicit ctx: Context): ClassSymbol = GenericType.symbol.asClass
693703
lazy val ShapeType: TypeRef = ctx.requiredClassRef("scala.compiletime.Shape")

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ object StdNames {
340340
val longHash: N = "longHash"
341341
val MatchCase: N = "MatchCase"
342342
val Modifiers: N = "Modifiers"
343+
val MonoType: N = "MonoType"
343344
val NestedAnnotArg: N = "NestedAnnotArg"
344345
val NoFlags: N = "NoFlags"
345346
val NoPrefix: N = "NoPrefix"
@@ -432,6 +433,7 @@ object StdNames {
432433
val flagsFromBits : N = "flagsFromBits"
433434
val flatMap: N = "flatMap"
434435
val foreach: N = "foreach"
436+
val fromProduct: N = "fromProduct"
435437
val genericArrayOps: N = "genericArrayOps"
436438
val genericClass: N = "genericClass"
437439
val get: N = "get"

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import StdNames._
1212
import NameKinds._
1313
import Flags._
1414
import Annotations._
15+
import ValueClasses.isDerivedValueClass
1516

1617
import language.implicitConversions
1718
import scala.annotation.tailrec
@@ -59,10 +60,18 @@ class SymUtils(val self: Symbol) extends AnyVal {
5960

6061
def isSuperAccessor(implicit ctx: Context): Boolean = self.name.is(SuperAccessorName)
6162

62-
/** A type or term parameter or a term parameter accessor */
63+
/** Is this a type or term parameter or a term parameter accessor? */
6364
def isParamOrAccessor(implicit ctx: Context): Boolean =
6465
self.is(Param) || self.is(ParamAccessor)
6566

67+
/** Is this a case class for which a product mirror is generated?
68+
* Excluded are value classes, abstract classes and case classes with more than one
69+
* parameter section. See also: desugar.isGenericProduct */
70+
def isGenericProduct(implicit ctx: Context): Boolean =
71+
self.is(CaseClass, butNot = Abstract) &&
72+
self.primaryConstructor.info.paramInfoss.length == 1 &&
73+
!isDerivedValueClass(self)
74+
6675
/** If this is a constructor, its owner: otherwise this. */
6776
final def skipConstructor(implicit ctx: Context): Symbol =
6877
if (self.isConstructor) self.owner else self

compiler/src/dotty/tools/dotc/transform/SyntheticMethods.scala

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import DenotTransformers._
88
import Decorators._
99
import NameOps._
1010
import Annotations.Annotation
11+
import typer.ProtoTypes.constrained
12+
import ast.untpd
1113
import ValueClasses.isDerivedValueClass
14+
import SymUtils._
1215

1316
/** Synthetic method implementations for case classes, case objects,
1417
* and value classes.
@@ -38,18 +41,21 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
3841
private[this] var myValueSymbols: List[Symbol] = Nil
3942
private[this] var myCaseSymbols: List[Symbol] = Nil
4043
private[this] var myCaseModuleSymbols: List[Symbol] = Nil
44+
private[this] var myProductMirrorSymbols: List[Symbol] = Nil
4145

4246
private def initSymbols(implicit ctx: Context) =
4347
if (myValueSymbols.isEmpty) {
4448
myValueSymbols = List(defn.Any_hashCode, defn.Any_equals)
4549
myCaseSymbols = myValueSymbols ++ List(defn.Any_toString, defn.Product_canEqual,
4650
defn.Product_productArity, defn.Product_productPrefix, defn.Product_productElement)
4751
myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals)
52+
myProductMirrorSymbols = List(defn.Mirror_Product_fromProduct)
4853
}
4954

5055
def valueSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myValueSymbols }
5156
def caseSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseSymbols }
5257
def caseModuleSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myCaseModuleSymbols }
58+
def productMirrorSymbols(implicit ctx: Context): List[Symbol] = { initSymbols; myProductMirrorSymbols }
5359

5460
/** If this is a case or value class, return the appropriate additional methods,
5561
* otherwise return nothing.
@@ -66,6 +72,7 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
6672
else caseSymbols
6773
}
6874
else if (isDerivedValueClass(clazz)) valueSymbols
75+
else if (clazz.is(Module) && clazz.linkedClass.isGenericProduct) productMirrorSymbols
6976
else Nil
7077

7178
def syntheticDefIfMissing(sym: Symbol): List[Tree] = {
@@ -78,6 +85,7 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
7885
val synthetic = sym.copy(
7986
owner = clazz,
8087
flags = sym.flags &~ Deferred | Synthetic | Override,
88+
info = clazz.thisType.memberInfo(sym),
8189
coord = clazz.coord).enteredAfter(thisPhase).asTerm
8290

8391
def forwardToRuntime(vrefss: List[List[Tree]]): Tree =
@@ -95,6 +103,10 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
95103
case nme.productArity => vrefss => Literal(Constant(accessors.length))
96104
case nme.productPrefix => ownName
97105
case nme.productElement => vrefss => productElementBody(accessors.length, vrefss.head.head)
106+
case nme.fromProduct =>
107+
vrefss =>
108+
fromProductBody(accessors, vrefss.head.head)
109+
.ensureConforms(synthetic.info.finalResultType)
98110
}
99111
ctx.log(s"adding $synthetic to $clazz at ${ctx.phase}")
100112
DefDef(synthetic, syntheticRHS(ctx.withOwner(synthetic))).withSpan(ctx.owner.span.focus)
@@ -138,6 +150,53 @@ class SyntheticMethods(thisPhase: DenotTransformer) {
138150
Match(index, (cases :+ defaultCase).toList)
139151
}
140152

153+
/** The class
154+
*
155+
* ```
156+
* case class C[T <: U](x: T, y: String*)
157+
* ```
158+
*
159+
* gets the `fromProduct` method:
160+
*
161+
* ```
162+
* def fromProduct(x$0: Product): MonoType =
163+
* new C[U](
164+
* x$0.productElement(0).asInstanceOf[U],
165+
* x$0.productElement(1).asInstanceOf[Seq[String]]: _*)
166+
* ```
167+
* where
168+
* ```
169+
* type MonoType = C[_]
170+
* ```
171+
*/
172+
def fromProductBody(accessors: List[Symbol], prod: Tree)(implicit ctx: Context): Tree = {
173+
val caseClass = clazz.linkedClass
174+
val (classRef, methTpe) =
175+
caseClass.primaryConstructor.info match {
176+
case tl: PolyType =>
177+
val (tl1, tpts) = constrained(tl, untpd.EmptyTree, alwaysAddTypeVars = true)
178+
val targs =
179+
for (tpt <- tpts) yield
180+
tpt.tpe match {
181+
case tvar: TypeVar => tvar.instantiate(fromBelow = false)
182+
}
183+
(caseClass.typeRef.appliedTo(targs), tl.instantiate(targs))
184+
case methTpe =>
185+
(caseClass.typeRef, methTpe)
186+
}
187+
methTpe match {
188+
case methTpe: MethodType =>
189+
val elems =
190+
for ((formal, idx) <- methTpe.paramInfos.zipWithIndex) yield {
191+
val elem =
192+
prod.select(defn.Product_productElement).appliedTo(Literal(Constant(idx)))
193+
.ensureConforms(formal.underlyingIfRepeated(isJava = false))
194+
if (formal.isRepeatedParam) ctx.typer.seqToRepeated(elem) else elem
195+
}
196+
New(classRef, elems)
197+
}
198+
}
199+
141200
/** The class
142201
*
143202
* ```

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class Namer { typer: Typer =>
447447
case _ => tree
448448
}
449449

450-
/** For all class definitions `stat` in `xstats`: If the companion class if
450+
/** For all class definitions `stat` in `xstats`: If the companion class is
451451
* not also defined in `xstats`, invalidate it by setting its info to
452452
* NoType.
453453
*/
@@ -702,7 +702,7 @@ class Namer { typer: Typer =>
702702
// If a top-level object or class has no companion in the current run, we
703703
// enter a dummy companion (`denot.isAbsent` returns true) in scope. This
704704
// ensures that we never use a companion from a previous run or from the
705-
// classpath. See tests/pos/false-companion for an example where this
705+
// class path. See tests/pos/false-companion for an example where this
706706
// matters.
707707
if (ctx.owner.is(PackageClass)) {
708708
for (cdef @ TypeDef(moduleName, _) <- moduleDef.values) {

compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,9 @@ trait TypeAssigner {
159159
private def toRepeated(tree: Tree, from: ClassSymbol)(implicit ctx: Context): Tree =
160160
Typed(tree, TypeTree(tree.tpe.widen.translateParameterized(from, defn.RepeatedParamClass)))
161161

162-
def seqToRepeated(tree: Tree)(implicit ctx: Context): Tree = toRepeated(tree, defn.SeqClass)
162+
def seqToRepeated(tree: Tree)(implicit ctx: Context): Tree = toRepeated(tree, defn.SeqClass)
163163

164-
def arrayToRepeated(tree: Tree)(implicit ctx: Context): Tree = toRepeated(tree, defn.ArrayClass)
164+
def arrayToRepeated(tree: Tree)(implicit ctx: Context): Tree = toRepeated(tree, defn.ArrayClass)
165165

166166
/** A denotation exists really if it exists and does not point to a stale symbol. */
167167
final def reallyExists(denot: Denotation)(implicit ctx: Context): Boolean = try

0 commit comments

Comments
 (0)