Skip to content

Commit f31a876

Browse files
authored
Merge pull request #5839 from dotty-staging/change-derive
Improvements to Typeclass Derivation
2 parents cf6e7a9 + 654909f commit f31a876

File tree

5 files changed

+122
-26
lines changed

5 files changed

+122
-26
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,7 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
666666
* Pre: `sym` must have a position.
667667
*/
668668
def defPath(sym: Symbol, root: Tree)(implicit ctx: Context): List[Tree] = trace.onDebug(s"defpath($sym with position ${sym.span}, ${root.show})") {
669-
require(sym.span.exists)
669+
require(sym.span.exists, sym)
670670
object accum extends TreeAccumulator[List[Tree]] {
671671
def apply(x: List[Tree], tree: Tree)(implicit ctx: Context): List[Tree] = {
672672
if (tree.span.contains(sym.span))

compiler/src/dotty/tools/dotc/typer/Deriving.scala

Lines changed: 74 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ trait Deriving { this: Typer =>
3232
/** A buffer for synthesized symbols */
3333
private var synthetics = new mutable.ListBuffer[Symbol]
3434

35+
private var derivesGeneric = false
36+
3537
/** the children of `cls` ordered by textual occurrence */
3638
lazy val children: List[Symbol] = cls.children
3739

@@ -159,33 +161,67 @@ trait Deriving { this: Typer =>
159161
*
160162
* implicit def derived$D(implicit ev_1: D[T_1], ..., ev_n: D[T_n]): D[C[Ts]] = D.derived
161163
*
162-
* See test run/typeclass-derivation2 for examples that spell out what would be generated.
163-
* Note that the name of the derived method containd the name in the derives clause, not
164+
* See the body of this method for how to generalize this to typeclasses with more
165+
* or less than one type parameter.
166+
*
167+
* See test run/typeclass-derivation2 and run/derive-multi
168+
* for examples that spell out what would be generated.
169+
*
170+
* Note that the name of the derived method contains the name in the derives clause, not
164171
* the underlying class name. This allows one to disambiguate derivations of type classes
165172
* that have the same name but different prefixes through selective aliasing.
166173
*/
167174
private def processDerivedInstance(derived: untpd.Tree): Unit = {
168175
val originalType = typedAheadType(derived, AnyTypeConstructorProto).tpe
169176
val underlyingType = underlyingClassRef(originalType)
170177
val derivedType = checkClassType(underlyingType, derived.sourcePos, traitReq = false, stablePrefixReq = true)
171-
val nparams = derivedType.classSymbol.typeParams.length
178+
val typeClass = derivedType.classSymbol
179+
val nparams = typeClass.typeParams.length
172180
if (derivedType.isRef(defn.GenericClass))
173-
() // do nothing, a Generic instance will be created anyway by `addGeneric`
174-
else if (nparams == 1) {
175-
val typeClass = derivedType.classSymbol
176-
val firstKindedParams = cls.typeParams.filterNot(_.info.isLambdaSub)
181+
derivesGeneric = true
182+
else {
183+
// A matrix of all parameter combinations of current class parameters
184+
// and derived typeclass parameters.
185+
// Rows: parameters of current class
186+
// Columns: parameters of typeclass
187+
188+
// Running example: typeclass: class TC[X, Y, Z], deriving class: class A[T, U]
189+
// clsParamss =
190+
// T_X T_Y T_Z
191+
// U_X U_Y U_Z
192+
val clsParamss: List[List[TypeSymbol]] = cls.typeParams.map { tparam =>
193+
if (nparams == 0) Nil
194+
else if (nparams == 1) tparam :: Nil
195+
else typeClass.typeParams.map(tcparam =>
196+
tparam.copy(name = s"${tparam.name}_${tcparam.name}".toTypeName)
197+
.asInstanceOf[TypeSymbol])
198+
}
199+
val firstKindedParamss = clsParamss.filter {
200+
case param :: _ => !param.info.isLambdaSub
201+
case nil => false
202+
}
203+
204+
// The types of the required evidence parameters. In the running example:
205+
// TC[T_X, T_Y, T_Z], TC[U_X, U_Y, U_Z]
177206
val evidenceParamInfos =
178-
for (param <- firstKindedParams) yield derivedType.appliedTo(param.typeRef)
179-
val resultType = derivedType.appliedTo(cls.appliedRef)
207+
for (row <- firstKindedParamss)
208+
yield derivedType.appliedTo(row.map(_.typeRef))
209+
210+
// The class instances in the result type. Running example:
211+
// A[T_X, U_X], A[T_Y, U_Y], A[T_Z, U_Z]
212+
val resultInstances =
213+
for (n <- List.range(0, nparams))
214+
yield cls.typeRef.appliedTo(clsParamss.map(row => row(n).typeRef))
215+
216+
// TC[A[T_X, U_X], A[T_Y, U_Y], A[T_Z, U_Z]]
217+
val resultType = derivedType.appliedTo(resultInstances)
218+
219+
val clsParams: List[TypeSymbol] = clsParamss.flatten
180220
val instanceInfo =
181-
if (cls.typeParams.isEmpty) ExprType(resultType)
182-
else PolyType.fromParams(cls.typeParams, ImplicitMethodType(evidenceParamInfos, resultType))
221+
if (clsParams.isEmpty) ExprType(resultType)
222+
else PolyType.fromParams(clsParams, ImplicitMethodType(evidenceParamInfos, resultType))
183223
addDerivedInstance(originalType.typeSymbol.name, instanceInfo, derived.sourcePos, reportErrors = true)
184224
}
185-
else
186-
ctx.error(
187-
i"derived class $derivedType should have one type paramater but has $nparams",
188-
derived.sourcePos)
189225
}
190226

191227
/** Add value corresponding to `val genericClass = new GenericClass(...)`
@@ -210,14 +246,33 @@ trait Deriving { this: Typer =>
210246
addDerivedInstance(defn.GenericType.name, genericCompleter, codePos, reportErrors = false)
211247
}
212248

249+
/** If any of the instances has a companion with a `derived` member
250+
* that refers to `scala.reflect.Generic`, add an implied instance
251+
* of `Generic`. Note: this is just an optimization to avoid possible
252+
* code duplication. Generic instances are created on the fly if they
253+
* are missing from the companion.
254+
*/
255+
private def maybeAddGeneric(): Unit = {
256+
val genericCls = defn.GenericClass
257+
def refersToGeneric(sym: Symbol): Boolean = {
258+
val companion = sym.info.finalResultType.classSymbol.companionModule
259+
val derivd = companion.info.member(nme.derived)
260+
derivd.hasAltWith(sd => sd.info.existsPart(p => p.typeSymbol == genericCls))
261+
}
262+
if (derivesGeneric || synthetics.exists(refersToGeneric)) {
263+
derive.println(i"add generic infrastructure for $cls")
264+
addGeneric()
265+
addGenericClass()
266+
}
267+
}
268+
213269
/** Create symbols for derived instances and infrastructure,
214-
* append them to `synthetics` buffer,
215-
* and enter them into class scope.
270+
* append them to `synthetics` buffer, and enter them into class scope.
271+
* Also, add generic instances if needed.
216272
*/
217273
def enterDerived(derived: List[untpd.Tree]) = {
218274
derived.foreach(processDerivedInstance(_))
219-
addGeneric()
220-
addGenericClass()
275+
maybeAddGeneric()
221276
}
222277

223278
private def tupleElems(tp: Type): List[Type] = tp match {

docs/docs/reference/contextual/derivation.md

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,7 @@ The generated typeclass instances are placed in the companion objects `Labelled`
3838

3939
### Derivable Types
4040

41-
A trait or class can appear in a `derives` clause if
42-
43-
- it has a single type parameter, and
44-
- its companion object defines a method named `derived`.
45-
46-
These two conditions ensure that the synthesized derived instances for the trait are well-formed. The type and implementation of a `derived` method are arbitrary, but typically it has a definition like this:
41+
A trait or class can appear in a `derives` clause if its companion object defines a method named `derived`. The type and implementation of a `derived` method are arbitrary, but typically it has a definition like this:
4742
```scala
4843
def derived[T] with Generic[T] = ...
4944
```

tests/run/derive-multi.check

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
derived: A
2+
derived: B[One, Two]
3+
derived: B
4+
derived: B
5+
derived: B

tests/run/derive-multi.scala

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
class A
2+
object A {
3+
def derived: A = {
4+
println("derived: A")
5+
new A
6+
}
7+
}
8+
9+
class B[X, Y]
10+
object B {
11+
def derived[X, Y]: B[X, Y] = {
12+
println("derived: B")
13+
new B[X, Y]
14+
}
15+
}
16+
17+
case class One() derives A, B
18+
case class Two() derives A, B
19+
20+
implied for B[One, Two] {
21+
println("derived: B[One, Two]")
22+
}
23+
24+
enum Lst[T] derives A, B {
25+
case Cons(x: T, xs: Lst[T])
26+
case Nil()
27+
}
28+
29+
case class Triple[S, T, U] derives A, B
30+
31+
object Test1 {
32+
import Lst._
33+
implicitly[A]
34+
}
35+
36+
object Test extends App {
37+
Test1
38+
implicitly[B[Lst[Lst[One]], Lst[Lst[Two]]]]
39+
implicitly[B[Triple[One, One, One],
40+
Triple[Two, Two, Two]]]
41+
}

0 commit comments

Comments
 (0)