Skip to content

Commit be1a063

Browse files
oderskymilessabin
authored andcommitted
Synthesize implicits for product and sum mirrors
1 parent 934c4ad commit be1a063

File tree

8 files changed

+136
-10
lines changed

8 files changed

+136
-10
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,9 @@ class Definitions {
690690
lazy val ModuleSerializationProxyConstructor: TermSymbol =
691691
ModuleSerializationProxyClass.requiredMethod(nme.CONSTRUCTOR, List(ClassType(TypeBounds.empty)))
692692

693-
//lazy val MirrorType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror")
693+
lazy val MirrorType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror")
694+
def MirrorClass(implicit ctx: Context): ClassSymbol = MirrorType.symbol.asClass
695+
694696
lazy val Mirror_ProductType: TypeRef = ctx.requiredClassRef("scala.deriving.Mirror.Product")
695697
def Mirror_ProductClass(implicit ctx: Context): ClassSymbol = Mirror_ProductType.symbol.asClass
696698

@@ -711,7 +713,7 @@ class Definitions {
711713
def ShapeCaseClass(implicit ctx: Context): ClassSymbol = ShapeCaseType.symbol.asClass
712714
lazy val ShapeCasesType: TypeRef = ctx.requiredClassRef("scala.compiletime.Shape.Cases")
713715
def ShapeCasesClass(implicit ctx: Context): ClassSymbol = ShapeCasesType.symbol.asClass
714-
lazy val MirrorType: TypeRef = ctx.requiredClassRef("scala.reflect.Mirror")
716+
lazy val ReflectMirrorType: TypeRef = ctx.requiredClassRef("scala.reflect.Mirror")
715717
lazy val GenericClassType: TypeRef = ctx.requiredClassRef("scala.reflect.GenericClass")
716718

717719
lazy val LanguageModuleRef: TermSymbol = ctx.requiredModule("scala.language")

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,10 +326,13 @@ object StdNames {
326326
val AnnotatedType: N = "AnnotatedType"
327327
val AppliedTypeTree: N = "AppliedTypeTree"
328328
val ArrayAnnotArg: N = "ArrayAnnotArg"
329+
val CaseLabel: N = "CaseLabel"
329330
val CAP: N = "CAP"
330331
val Constant: N = "Constant"
331332
val ConstantType: N = "ConstantType"
332333
val doubleHash: N = "doubleHash"
334+
val ElemLabels: N = "ElemLabels"
335+
val ElemTypes: N = "ElemTypes"
333336
val ExistentialTypeTree: N = "ExistentialTypeTree"
334337
val Flag : N = "Flag"
335338
val floatHash: N = "floatHash"

compiler/src/dotty/tools/dotc/transform/TypeUtils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import TypeErasure.ErasedValueType
77
import Types._
88
import Contexts._
99
import Symbols._
10+
import Names.Name
1011

1112
object TypeUtils {
1213
/** A decorator that provides methods on types
@@ -63,5 +64,7 @@ object TypeUtils {
6364
}
6465
extractAlias(lo)
6566
}
67+
68+
def refinedWith(name: Name, info: Type)(implicit ctx: Context) = RefinedType(self, name, info)
6669
}
6770
}

compiler/src/dotty/tools/dotc/typer/Deriving.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ trait Deriving { this: Typer =>
372372
TypeDef(shapeAlias)
373373
}
374374
val reflectMethod: DefDef = {
375-
val meth = newMethod(nme.reflect, MethodType(clsArg :: Nil, defn.MirrorType)).entered
375+
val meth = newMethod(nme.reflect, MethodType(clsArg :: Nil, defn.ReflectMirrorType)).entered
376376
def rhs(paramRef: Tree)(implicit ctx: Context): Tree = {
377377
def reflectCase(scrut: Tree, idx: Int, elems: List[Type]): Tree = {
378378
val ordinal = Literal(Constant(idx))
@@ -401,7 +401,7 @@ trait Deriving { this: Typer =>
401401
}
402402

403403
val reifyMethod: DefDef = {
404-
val meth = newMethod(nme.reify, MethodType(defn.MirrorType :: Nil, clsArg)).entered
404+
val meth = newMethod(nme.reify, MethodType(defn.ReflectMirrorType :: Nil, clsArg)).entered
405405
def rhs(paramRef: Tree)(implicit ctx: Context): Tree = {
406406
def reifyCase(caseType: Type, elems: List[Type]): Tree = caseType match {
407407
case caseType: TermRef =>

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 100 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ import ErrorReporting._
2727
import reporting.diagnostic.Message
2828
import Inferencing.fullyDefinedType
2929
import Trees._
30+
import transform.SymUtils._
31+
import transform.TypeUtils._
3032
import Hashable._
3133
import util.{Property, SourceFile, NoSource}
3234
import config.Config
@@ -856,18 +858,104 @@ trait Implicits { self: Typer =>
856858
EmptyTree
857859
}
858860

861+
lazy val synthesizedProductMirror: SpecialHandler =
862+
(formal: Type, span: Span) => implicit (ctx: Context) => {
863+
formal.member(tpnme.MonoType).info match {
864+
case monoAlias @ TypeAlias(monoType) =>
865+
if (monoType.termSymbol.is(CaseVal)) {
866+
val modul = monoType.termSymbol
867+
val caseLabel = ConstantType(Constant(modul.name.toString))
868+
val mirrorType = defn.Mirror_SingletonType
869+
.refinedWith(tpnme.MonoType, monoAlias)
870+
.refinedWith(tpnme.CaseLabel, TypeAlias(caseLabel))
871+
ref(modul).withSpan(span).cast(mirrorType)
872+
}
873+
else if (monoType.classSymbol.isGenericProduct) {
874+
val cls = monoType.classSymbol
875+
val accessors = cls.caseAccessors.filterNot(_.is(PrivateLocal))
876+
val elemTypes = accessors.map(monoType.memberInfo(_))
877+
val caseLabel = ConstantType(Constant(cls.name.toString))
878+
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
879+
val mirrorType =
880+
defn.Mirror_ProductType
881+
.refinedWith(tpnme.MonoType, monoAlias)
882+
.refinedWith(tpnme.ElemTypes, TypeAlias(TypeOps.nestedPairs(elemTypes)))
883+
.refinedWith(tpnme.CaseLabel, TypeAlias(caseLabel))
884+
.refinedWith(tpnme.ElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
885+
val modul = cls.linkedClass.sourceModule
886+
assert(modul.is(Module))
887+
ref(modul).withSpan(span).cast(mirrorType)
888+
}
889+
else EmptyTree
890+
case _ => EmptyTree
891+
}
892+
}
893+
894+
lazy val synthesizedSumMirror: SpecialHandler =
895+
(formal: Type, span: Span) => implicit (ctx: Context) =>
896+
formal.member(tpnme.MonoType).info match {
897+
case monoAlias @ TypeAlias(monoType) if monoType.classSymbol.isGenericSum =>
898+
val cls = monoType.classSymbol
899+
val elemTypes = cls.children.map {
900+
case caseClass: ClassSymbol =>
901+
assert(caseClass.is(Case))
902+
if (caseClass.is(Module))
903+
caseClass.sourceModule.termRef
904+
else caseClass.primaryConstructor.info match {
905+
case info: PolyType =>
906+
def instantiate(implicit ctx: Context) = {
907+
val poly = constrained(info, untpd.EmptyTree)._1
908+
val mono @ MethodType(_) = poly.resultType
909+
val resType = mono.finalResultType
910+
resType <:< cls.appliedRef
911+
val tparams = poly.paramRefs
912+
val variances = caseClass.typeParams.map(_.paramVariance)
913+
val instanceTypes = (tparams, variances).zipped.map((tparam, variance) =>
914+
ctx.typeComparer.instanceType(tparam, fromBelow = variance < 0))
915+
resType.substParams(poly, instanceTypes)
916+
}
917+
instantiate(ctx.fresh.setExploreTyperState().setOwner(caseClass))
918+
case _ =>
919+
caseClass.typeRef
920+
}
921+
case child => child.termRef
922+
}
923+
val mirrorType =
924+
defn.Mirror_SumType
925+
.refinedWith(tpnme.MonoType, monoAlias)
926+
.refinedWith(tpnme.ElemTypes, TypeAlias(TypeOps.nestedPairs(elemTypes)))
927+
var modul = cls.linkedClass.sourceModule
928+
if (!modul.exists) ???
929+
ref(modul).withSpan(span).cast(mirrorType)
930+
case _ =>
931+
EmptyTree
932+
}
933+
934+
lazy val synthesizedMirror: SpecialHandler =
935+
(formal: Type, span: Span) => implicit (ctx: Context) =>
936+
formal.member(tpnme.MonoType).info match {
937+
case monoAlias @ TypeAlias(monoType) =>
938+
if (monoType.termSymbol.is(CaseVal) || monoType.classSymbol.isGenericProduct)
939+
synthesizedProductMirror(formal, span)(ctx)
940+
else
941+
synthesizedSumMirror(formal, span)(ctx)
942+
}
943+
859944
private var mySpecialHandlers: SpecialHandlers = null
860945

861946
private def specialHandlers(implicit ctx: Context) = {
862947
if (mySpecialHandlers == null)
863948
mySpecialHandlers = List(
864-
defn.ClassTagClass -> synthesizedClassTag,
865-
defn.QuotedTypeClass -> synthesizedTypeTag,
866-
defn.GenericClass -> synthesizedGeneric,
949+
defn.ClassTagClass -> synthesizedClassTag,
950+
defn.QuotedTypeClass -> synthesizedTypeTag,
951+
defn.GenericClass -> synthesizedGeneric,
867952
defn.TastyReflectionClass -> synthesizedTastyContext,
868-
defn.EqlClass -> synthesizedEq,
953+
defn.EqlClass -> synthesizedEq,
869954
defn.TupledFunctionClass -> synthesizedTupleFunction,
870-
defn.ValueOfClass -> synthesizedValueOf
955+
defn.ValueOfClass -> synthesizedValueOf,
956+
defn.Mirror_ProductClass -> synthesizedProductMirror,
957+
defn.Mirror_SumClass -> synthesizedSumMirror,
958+
defn.MirrorClass -> synthesizedMirror
871959
)
872960
mySpecialHandlers
873961
}
@@ -881,7 +969,13 @@ trait Implicits { self: Typer =>
881969
case fail @ SearchFailure(failed) =>
882970
def trySpecialCases(handlers: SpecialHandlers): Tree = handlers match {
883971
case (cls, handler) :: rest =>
884-
val base = formal.baseType(cls)
972+
def baseWithRefinements(tp: Type): Type = tp.dealias match {
973+
case tp @ RefinedType(parent, rname, rinfo) =>
974+
tp.derivedRefinedType(baseWithRefinements(parent), rname, rinfo)
975+
case _ =>
976+
tp.baseType(cls)
977+
}
978+
val base = baseWithRefinements(formal)
885979
val result =
886980
if (base <:< formal) {
887981
// With the subtype test we enforce that the searched type `formal` is of the right form

library/src/scala/deriving.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ object deriving {
3939

4040
trait Singleton extends Product {
4141
type MonoType = this.type
42+
type ElemTypes = Unit
43+
type ElemLabels = Unit
44+
4245
def fromProduct(p: scala.Product) = this
4346

4447
def productElement(n: Int): Any = throw new IndexOutOfBoundsException(n.toString)

tests/run/deriving.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
A(1,2)
2+
A(1,2)
3+
B
4+
1

tests/run/deriving.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,23 @@ case class A(x: Int, y: Int) extends T
55
case object B extends T
66

77
object Test extends App {
8+
import deriving.{Mirror, EmptyProduct}
89

910
case class AA[X >: Null <: AnyRef](x: X, y: X, z: String)
11+
12+
println(the[Mirror.ProductOf[A]].fromProduct(A(1, 2)))
13+
assert(the[Mirror.SumOf[T]].ordinal(A(1, 2)) == 0)
14+
assert(the[Mirror.Sum { type MonoType = T }].ordinal(B) == 1)
15+
the[Mirror.Of[A]] match {
16+
case m: Mirror.Product =>
17+
println(m.fromProduct(A(1, 2)))
18+
}
19+
the[Mirror.Of[B.type]] match {
20+
case m: Mirror.Product =>
21+
println(m.fromProduct(EmptyProduct))
22+
}
23+
the[Mirror.Of[T]] match {
24+
case m: Mirror.SumOf[T] =>
25+
println(m.ordinal(B))
26+
}
1027
}

0 commit comments

Comments
 (0)