Skip to content

Commit 031e6f2

Browse files
authored
Merge pull request #8011 from dotty-staging/fix-#8007
Fix #8007: Add regression and show type class derivation with macros
2 parents 9982f0d + 4033d71 commit 031e6f2

File tree

8 files changed

+453
-1
lines changed

8 files changed

+453
-1
lines changed
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
---
2+
layout: doc-page
3+
title: How to write a type class `derived` method using macros
4+
---
5+
6+
In the main [derivation](./derivation.md) documentation page, we explained the
7+
details behind `Mirror`s and type class derivation. Here we demonstrate how to
8+
implement a type class `derived` method using macros only. We follow the same
9+
example of deriving `Eq` instances and for simplicity we support a `Product`
10+
type e.g., a case class `Person`. The low-level method we will use to implement
11+
the `derived` method exploits quotes, splices of both expressions and types and
12+
the `scala.quoted.matching.summonExpr` method which is the equivalent of
13+
`summonFrom`. The former is suitable for use in a quote context, used within
14+
macros.
15+
16+
As in the original code, the type class definition is the same:
17+
18+
```scala
19+
trait Eq[T] {
20+
def eqv(x: T, y: T): Boolean
21+
}
22+
```
23+
24+
we need to implement a method `Eq.derived` on the companion object of `Eq` that
25+
produces a quoted instance for `Eq[T]`. Here is a possible signature,
26+
27+
```scala
28+
given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]]
29+
```
30+
31+
and for comparison reasons we give the same signature we had with `inline`:
32+
33+
```scala
34+
inline given derived[T]: (m: Mirror.Of[T]) => Eq[T] = ???
35+
```
36+
37+
Note, that since a type is used in a subsequent stage it will need to be lifted
38+
to a `Type` by using the corresponding context bound. Also, not that we can
39+
summon the quoted `Mirror` inside the body of the `derived` this we can omit it
40+
from the signature. The body of the `derived` method is shown below:
41+
42+
43+
```scala
44+
given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]] = {
45+
import qctx.tasty.{_, given}
46+
47+
val ev: Expr[Mirror.Of[T]] = summonExpr(given '[Mirror.Of[T]]).get
48+
49+
ev match {
50+
case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes }} =>
51+
val elemInstances = summonAll(elementTypes)
52+
val eqProductBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => {
53+
elemInstances.zipWithIndex.foldLeft(Expr(true: Boolean)) {
54+
case (acc, (elem, index)) =>
55+
val e1 = '{$x.asInstanceOf[Product].productElement(${Expr(index)})}
56+
val e2 = '{$y.asInstanceOf[Product].productElement(${Expr(index)})}
57+
58+
'{ $acc && $elem.asInstanceOf[Eq[Any]].eqv($e1, $e2) }
59+
}
60+
}
61+
'{
62+
eqProduct((x: T, y: T) => ${eqProductBody('x, 'y)})
63+
}
64+
65+
// case for Mirror.ProductOf[T]
66+
// ...
67+
}
68+
}
69+
```
70+
71+
Note, that in the `inline` case we can merely write
72+
`summonAll[m.MirroredElemTypes]` inside the inline method but here, since
73+
`summonExpr` is required, we can extract the element types in a macro fashion.
74+
Being inside a macro, our first reaction would be to write the code below. Since
75+
the path inside the type argument is not stable this cannot be used:
76+
77+
```scala
78+
'{
79+
summonAll[$m.MirroredElemTypes]
80+
}
81+
```
82+
83+
Instead we extract the tuple-type for element types using pattern matching over
84+
quotes and more specifically of the refined type:
85+
86+
```scala
87+
case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes } } => ...
88+
```
89+
90+
The implementation of `summonAll` as a macro can be show below assuming that we
91+
have the given instances for our primitive types:
92+
93+
```scala
94+
def summonAll[T](t: Type[T])(given qctx: QuoteContext): List[Expr[Eq[_]]] = t match {
95+
case '[String *: $tpes] => '{ summon[Eq[String]] } :: summonAll(tpes)
96+
case '[Int *: $tpes] => '{ summon[Eq[Int]] } :: summonAll(tpes)
97+
case '[$tpe *: $tpes] => derived(given tpe, qctx) :: summonAll(tpes)
98+
case '[Unit] => Nil
99+
}
100+
```
101+
102+
One additional difference with the body of `derived` here as opposed to the one
103+
with `inline` is that with macros we need to synthesize the body of the code during the
104+
macro-expansion time. That is the rationale behind the `eqProductBody` function.
105+
Assuming that we calculate the equality of two `Person`s defined with a case
106+
class that holds a name of type `String` and an age of type `Int`, the equality
107+
check we want to generate is the following:
108+
109+
```scala
110+
true
111+
&& Eq[String].eqv(x.productElement(0),y.productElement(0))
112+
&& Eq[Int].eqv(x.productElement(1), y.productElement(1))
113+
```
114+
115+
### Calling the derived method inside the macro
116+
117+
Following the rules in [Macros](../metaprogramming.md) we create two methods.
118+
One that hosts the top-level splice `eqv` and one that is the implementation.
119+
Alternatively and what is shown below is that we can call the `eqv` method
120+
directly. The `eqGen` can trigger the derivation.
121+
122+
```scala
123+
inline def [T](x: =>T) === (y: =>T)(given eq: Eq[T]): Boolean = eq.eqv(x, y)
124+
125+
implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] }
126+
```
127+
128+
Note, that we use inline method syntax and we can compare instance such as
129+
`Sm(Person("Test", 23)) === Sm(Person("Test", 24))` for e.g., the following two
130+
types:
131+
132+
```scala
133+
case class Person(name: String, age: Int)
134+
135+
enum Opt[+T] {
136+
case Sm(t: T)
137+
case Nn
138+
}
139+
```
140+
141+
The full code is shown below:
142+
143+
```scala
144+
import scala.deriving._
145+
import scala.quoted._
146+
import scala.quoted.matching._
147+
148+
trait Eq[T] {
149+
def eqv(x: T, y: T): Boolean
150+
}
151+
152+
object Eq {
153+
given Eq[String] {
154+
def eqv(x: String, y: String) = x == y
155+
}
156+
157+
given Eq[Int] {
158+
def eqv(x: Int, y: Int) = x == y
159+
}
160+
161+
def eqProduct[T](body: (T, T) => Boolean): Eq[T] =
162+
new Eq[T] {
163+
def eqv(x: T, y: T): Boolean = body(x, y)
164+
}
165+
166+
def eqSum[T](body: (T, T) => Boolean): Eq[T] =
167+
new Eq[T] {
168+
def eqv(x: T, y: T): Boolean = body(x, y)
169+
}
170+
171+
def summonAll[T](t: Type[T])(given qctx: QuoteContext): List[Expr[Eq[_]]] = t match {
172+
case '[String *: $tpes] => '{ summon[Eq[String]] } :: summonAll(tpes)
173+
case '[Int *: $tpes] => '{ summon[Eq[Int]] } :: summonAll(tpes)
174+
case '[$tpe *: $tpes] => derived(given tpe, qctx) :: summonAll(tpes)
175+
case '[Unit] => Nil
176+
}
177+
178+
given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]] = {
179+
import qctx.tasty.{_, given}
180+
181+
val ev: Expr[Mirror.Of[T]] = summonExpr(given '[Mirror.Of[T]]).get
182+
183+
ev match {
184+
case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes }} =>
185+
val elemInstances = summonAll(elementTypes)
186+
val eqProductBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => {
187+
elemInstances.zipWithIndex.foldLeft(Expr(true: Boolean)) {
188+
case (acc, (elem, index)) =>
189+
val e1 = '{$x.asInstanceOf[Product].productElement(${Expr(index)})}
190+
val e2 = '{$y.asInstanceOf[Product].productElement(${Expr(index)})}
191+
192+
'{ $acc && $elem.asInstanceOf[Eq[Any]].eqv($e1, $e2) }
193+
}
194+
}
195+
'{
196+
eqProduct((x: T, y: T) => ${eqProductBody('x, 'y)})
197+
}
198+
199+
case '{ $m: Mirror.SumOf[T] { type MirroredElemTypes = $elementTypes }} =>
200+
val elemInstances = summonAll(elementTypes)
201+
val eqSumBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => {
202+
val ordx = '{ $m.ordinal($x) }
203+
val ordy = '{ $m.ordinal($y) }
204+
205+
val elements = Expr.ofList(elemInstances)
206+
'{
207+
$ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y)
208+
}
209+
}
210+
211+
'{
212+
eqSum((x: T, y: T) => ${eqSumBody('x, 'y)})
213+
}
214+
}
215+
}
216+
}
217+
218+
object Macro3 {
219+
inline def [T](x: =>T) === (y: =>T)(given eq: Eq[T]): Boolean = eq.eqv(x, y)
220+
221+
implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] }
222+
}
223+
```

docs/docs/reference/contextual/derivation.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,10 @@ inline def derived[A](given gen: K0.Generic[A]): Eq[A] = gen.derive(eqSum, eqPro
350350

351351
The framework described here enables all three of these approaches without mandating any of them.
352352

353+
For a brief discussion on how to use macros to write a type class `derived`
354+
method please read more at [How to write a type class `derived` method using
355+
macros](./derivation-macro.md).
356+
353357
### Deriving instances elsewhere
354358

355359
Sometimes one would like to derive a type class instance for an ADT after the ADT is defined, without being able to

docs/docs/reference/metaprogramming/macros.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -565,7 +565,7 @@ sum
565565
### Find implicits within a macro
566566

567567
Similarly to the `summonFrom` construct, it is possible to make implicit search available
568-
in a quote context. For this we simply provide `scala.quoted.matching.summonExpr:
568+
in a quote context. For this we simply provide `scala.quoted.matching.summonExpr`:
569569

570570
```scala
571571
inline def setFor[T]: Set[T] = ${ setForExpr[T] }

tests/run-macros/i8007.check

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
List("name", "age")
2+
3+
Test 23
4+
()
5+
6+
true
7+
8+
false
9+
10+
true
11+
12+
true
13+
14+
false
15+

tests/run-macros/i8007/Macro_1.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import scala.deriving._
2+
import scala.quoted._
3+
import scala.quoted.matching._
4+
5+
object Macro1 {
6+
7+
def mirrorFields[T](t: Type[T])(given qctx: QuoteContext): List[String] =
8+
t match {
9+
case '[$field *: $fields] => field.show :: mirrorFields(fields)
10+
case '[Unit] => Nil
11+
}
12+
13+
// Demonstrates the use of quoted pattern matching
14+
// over a refined type extracting the tuple type
15+
// for e.g., MirroredElemLabels
16+
inline def test1[T](value: =>T): List[String] =
17+
${ test1Impl('value) }
18+
19+
def test1Impl[T: Type](value: Expr[T])(given qctx: QuoteContext): Expr[List[String]] = {
20+
import qctx.tasty.{_, given}
21+
22+
val mirrorTpe = '[Mirror.Of[T]]
23+
24+
summonExpr(given mirrorTpe).get match {
25+
case '{ $m: Mirror.ProductOf[T]{ type MirroredElemLabels = $t } } => {
26+
Expr(mirrorFields(t))
27+
}
28+
}
29+
}
30+
}

tests/run-macros/i8007/Macro_2.scala

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import scala.deriving._
2+
import scala.quoted._
3+
import scala.quoted.matching._
4+
5+
object Macro2 {
6+
7+
def mirrorFields[T](t: Type[T])(given qctx: QuoteContext): List[String] =
8+
t match {
9+
case '[$field *: $fields] => field.show.substring(1, field.show.length-1) :: mirrorFields(fields)
10+
case '[Unit] => Nil
11+
}
12+
13+
trait JsonEncoder[T] {
14+
def encode(elem: T): String
15+
}
16+
17+
object JsonEncoder {
18+
def emitJsonEncoder[T](body: T => String): JsonEncoder[T]=
19+
new JsonEncoder[T] {
20+
def encode(elem: T): String = body(elem)
21+
}
22+
23+
def derived[T: Type](ev: Expr[Mirror.Of[T]])(given qctx: QuoteContext): Expr[JsonEncoder[T]] = {
24+
import qctx.tasty.{_, given}
25+
26+
val fields = ev match {
27+
case '{ $m: Mirror.ProductOf[T] { type MirroredElemLabels = $t } } =>
28+
mirrorFields(t)
29+
}
30+
31+
val body: Expr[T] => Expr[String] = elem =>
32+
fields.reverse.foldLeft(Expr("")){ (acc, field) =>
33+
val res = Select.unique(elem.unseal, field).seal
34+
'{ $res.toString + " " + $acc }
35+
}
36+
37+
'{
38+
emitJsonEncoder((x: T) => ${body('x)})
39+
}
40+
}
41+
}
42+
43+
inline def test2[T](value: =>T): Unit = ${ test2Impl('value) }
44+
45+
def test2Impl[T: Type](value: Expr[T])(given qctx: QuoteContext): Expr[Unit] = {
46+
import qctx.tasty.{_, given}
47+
48+
val mirrorTpe = '[Mirror.Of[T]]
49+
val mirrorExpr = summonExpr(given mirrorTpe).get
50+
val derivedInstance = JsonEncoder.derived(mirrorExpr)
51+
52+
'{
53+
val res = $derivedInstance.encode($value)
54+
println(res)
55+
}
56+
}
57+
}

0 commit comments

Comments
 (0)