Skip to content

[docs] Typeclass Derivation no longer recommends given derived[T]: TC[T] #17414

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 140 additions & 141 deletions docs/_docs/reference/contextual/derivation-macro.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@ title: "How to write a type class `derived` method using macros"
nightlyOf: https://docs.scala-lang.org/scala3/reference/contextual/derivation-macro.html
---

In the main [derivation](./derivation.md) documentation page, we explained the
details behind `Mirror`s and type class derivation. Here we demonstrate how to
implement a type class `derived` method using macros only. We follow the same
example of deriving `Eq` instances and for simplicity we support a `Product`
type e.g., a case class `Person`. The low-level method we will use to implement
the `derived` method exploits quotes, splices of both expressions and types and
the `scala.quoted.Expr.summon` method which is the equivalent of
`summonFrom`. The former is suitable for use in a quote context, used within
macros.
In the main [derivation](./derivation.md) documentation page, we explained the details behind `Mirror`s and type class derivation.
Here we demonstrate how to implement a type class `derived` method using macros only.
We follow the same example of deriving `Eq` instances and for simplicity we support a `Product` type e.g., a case class `Person`.
The low-level technique that we will use to implement the `derived` method exploits quotes, splices of both expressions and types and the `scala.quoted.Expr.summon` method which is the equivalent of `scala.compiletime.summonFrom`.
The former is suitable for use in a quote context, used within macros.

As in the original code, the type class definition is the same:

Expand All @@ -21,185 +17,188 @@ trait Eq[T]:
def eqv(x: T, y: T): Boolean
```

we need to implement a method `Eq.derived` on the companion object of `Eq` that
produces a quoted instance for `Eq[T]`. Here is a possible signature,
We need to implement an inline method `Eq.derived` on the companion object of `Eq` that calls into a macro to produce a quoted instance for `Eq[T]`.
Here is a possible signature:


```scala
given derived[T: Type](using Quotes): Expr[Eq[T]]
inline def derived[T]: Eq[T] = ${ derivedMacro[T] }

def derivedMacro[T: Type](using Quotes): Expr[Eq[T]] = ???
```

and for comparison reasons we give the same signature we had with `inline`:
Note, that since a type is used in a subsequent macro compilation stage it will need to be lifted to a `quoted.Type` by using the corresponding context bound (seen in `derivedMacro`).


For comparison, here is the signature of the inline `derived` method from the [main derivation page](./derivation.md):
```scala
inline given derived[T](using Mirror.Of[T]): Eq[T] = ???
inline def derived[T](using m: Mirror.Of[T]): Eq[T] = ???
```

Note, that since a type is used in a subsequent stage it will need to be lifted
to a `Type` by using the corresponding context bound. Also, note that we can
summon the quoted `Mirror` inside the body of the `derived` thus we can omit it
from the signature. The body of the `derived` method is shown below:
Note that the macro-based `derived` signature does not have a `Mirror` parameter.
This is because we can summon the `Mirror` inside the body of `derivedMacro` thus we can omit it from the signature.

One additional possibility with the body of `derivedMacro` here as opposed to the one with `inline` is that with macros it is simpler to create a fully optimised method body for `eqv`.

Let's say we wanted to derive an `Eq` instance for the following case class `Person`,
```scala
case class Person(name: String, age: Int) derives Eq
```

the equality check we are going to generate is the following:

```scala
given derived[T: Type](using Quotes): Expr[Eq[T]] =
import quotes.reflect.*
(x: Person, y: Person) =>
summon[Eq[String]].eqv(x.productElement(0), y.productElement(0))
&& summon[Eq[Int]].eqv(x.productElement(1), y.productElement(1))
```

> Note that it is possible, by using the [reflection API](../metaprogramming/reflection.md), to further optimise and directly reference the fields of `Person`, but for clear understanding we will only use quoted expressions.

The code to generates this body can be seen in the `eqProductBody` method, shown here as part of the definition for the `derivedMacro` method:


```scala
def derivedMacro[T: Type](using Quotes): Expr[Eq[T]] =

val ev: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].get

ev match
case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = elementTypes }} =>
val elemInstances = summonAll[elementTypes]
val elemInstances = summonInstances[T, elementTypes]
def eqProductBody(x: Expr[Product], y: Expr[Product])(using Quotes): Expr[Boolean] = {
elemInstances.zipWithIndex.foldLeft(Expr(true)) {
case (acc, ('{ $elem: Eq[t] }, index)) =>
val indexExpr = Expr(index)
val e1 = '{ $x.productElement($indexExpr).asInstanceOf[t] }
val e2 = '{ $y.productElement($indexExpr).asInstanceOf[t] }
'{ $acc && $elem.eqv($e1, $e2) }
}
if elemInstances.isEmpty then
Expr(true)
else
elemInstances.zipWithIndex.map {
case ('{ $elem: Eq[t] }, index) =>
val indexExpr = Expr(index)
val e1 = '{ $x.productElement($indexExpr).asInstanceOf[t] }
val e2 = '{ $y.productElement($indexExpr).asInstanceOf[t] }
'{ $elem.eqv($e1, $e2) }
}.reduce((acc, elem) => '{ $acc && $elem })
end if
}
'{ eqProduct((x: T, y: T) => ${eqProductBody('x.asExprOf[Product], 'y.asExprOf[Product])}) }

// case for Mirror.ProductOf[T]
// ...
// case for Mirror.SumOf[T] ...
```

Note, that in the `inline` case we can merely write
`summonAll[m.MirroredElemTypes]` inside the inline method but here, since
`Expr.summon` is required, we can extract the element types in a macro fashion.
Being inside a macro, our first reaction would be to write the code below. Since
the path inside the type argument is not stable this cannot be used:
Note, that in the version without macros, we can merely write `summonInstances[T, m.MirroredElemTypes]` inside the inline method but here, since `Expr.summon` is required, we can extract the element types in a macro fashion.
Being inside a macro, our first reaction would be to write the code below:

```scala
'{
summonAll[$m.MirroredElemTypes]
summonInstances[T, $m.MirroredElemTypes]
}
```

Instead we extract the tuple-type for element types using pattern matching over
quotes and more specifically of the refined type:
However, since the path inside the type argument is not stable this cannot be used.
Instead we extract the tuple-type for element types using pattern matching over quotes and more specifically of the refined type:

```scala
case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = elementTypes }} => ...
```

Shown below is the implementation of `summonAll` as a macro. We assume that
given instances for our primitive types exist.

```scala
def summonAll[T: Type](using Quotes): List[Expr[Eq[_]]] =
Type.of[T] match
case '[String *: tpes] => '{ summon[Eq[String]] } :: summonAll[tpes]
case '[Int *: tpes] => '{ summon[Eq[Int]] } :: summonAll[tpes]
case '[tpe *: tpes] => derived[tpe] :: summonAll[tpes]
case '[EmptyTuple] => Nil
```

One additional difference with the body of `derived` here as opposed to the one
with `inline` is that with macros we need to synthesize the body of the code during the
macro-expansion time. That is the rationale behind the `eqProductBody` function.
Assuming that we calculate the equality of two `Person`s defined with a case
class that holds a name of type [`String`](https://scala-lang.org/api/3.x/scala/Predef$.html#String-0)
and an age of type `Int`, the equality check we want to generate is the following:

```scala
true
&& Eq[String].eqv(x.productElement(0),y.productElement(0))
&& Eq[Int].eqv(x.productElement(1), y.productElement(1))
```

## Calling the derived method inside the macro
Shown below is the implementation of `summonInstances` as a macro, which for each type `elem` in the tuple type, calls
`deriveOrSummon[T, elem]`.

Following the rules in [Macros](../metaprogramming/metaprogramming.md) we create two methods.
One that hosts the top-level splice `eqv` and one that is the implementation.
Alternatively and what is shown below is that we can call the `eqv` method
directly. The `eqGen` can trigger the derivation.
To understand `deriveOrSummon`, consider that if `elem` derives from the parent `T` type, then it is a recursive derivation.
Recursive derivation usually happens for types such as `scala.collection.immutable.::`. If `elem` does not derive from `T`, then there must exist a contextual `Eq[elem]` instance.

```scala
extension [T](inline x: T)
inline def === (inline y: T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)

inline given eqGen[T]: Eq[T] = ${ Eq.derived[T] }
```

Note, that we use inline method syntax and we can compare instance such as
`Sm(Person("Test", 23)) === Sm(Person("Test", 24))` for e.g., the following two
types:

```scala
case class Person(name: String, age: Int)

enum Opt[+T]:
case Sm(t: T)
case Nn
def summonInstances[T: Type, Elems: Type](using Quotes): List[Expr[Eq[?]]] =
Type.of[Elems] match
case '[elem *: elems] => deriveOrSummon[T, elem] :: summonInstances[T, elems]
case '[EmptyTuple] => Nil

def deriveOrSummon[T: Type, Elem: Type](using Quotes): Expr[Eq[Elem]] =
Type.of[Elem] match
case '[T] => deriveRec[T, Elem]
case _ => '{ summonInline[Eq[Elem]] }

def deriveRec[T: Type, Elem: Type](using Quotes): Expr[Eq[Elem]] =
Type.of[T] match
case '[Elem] => '{ error("infinite recursive derivation") }
case _ => derivedMacro[Elem] // recursive derivation
```

The full code is shown below:

```scala
import compiletime.*
import scala.deriving.*
import scala.quoted.*


trait Eq[T]:
def eqv(x: T, y: T): Boolean
def eqv(x: T, y: T): Boolean

object Eq:
given Eq[String] with
def eqv(x: String, y: String) = x == y

given Eq[Int] with
def eqv(x: Int, y: Int) = x == y

def eqProduct[T](body: (T, T) => Boolean): Eq[T] =
new Eq[T]:
def eqv(x: T, y: T): Boolean = body(x, y)

def eqSum[T](body: (T, T) => Boolean): Eq[T] =
new Eq[T]:
def eqv(x: T, y: T): Boolean = body(x, y)

def summonAll[T: Type](using Quotes): List[Expr[Eq[_]]] =
Type.of[T] match
case '[String *: tpes] => '{ summon[Eq[String]] } :: summonAll[tpes]
case '[Int *: tpes] => '{ summon[Eq[Int]] } :: summonAll[tpes]
case '[tpe *: tpes] => derived[tpe] :: summonAll[tpes]
case '[EmptyTuple] => Nil

given derived[T: Type](using q: Quotes): Expr[Eq[T]] =
import quotes.reflect.*

val ev: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].get

ev match
case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = elementTypes }} =>
val elemInstances = summonAll[elementTypes]
val eqProductBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) =>
elemInstances.zipWithIndex.foldLeft(Expr(true: Boolean)) {
case (acc, (elem, index)) =>
val e1 = '{$x.asInstanceOf[Product].productElement(${Expr(index)})}
val e2 = '{$y.asInstanceOf[Product].productElement(${Expr(index)})}

'{ $acc && $elem.asInstanceOf[Eq[Any]].eqv($e1, $e2) }
}
'{ eqProduct((x: T, y: T) => ${eqProductBody('x, 'y)}) }

case '{ $m: Mirror.SumOf[T] { type MirroredElemTypes = elementTypes }} =>
val elemInstances = summonAll[elementTypes]
val eqSumBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) =>
val ordx = '{ $m.ordinal($x) }
val ordy = '{ $m.ordinal($y) }

val elements = Expr.ofList(elemInstances)
'{ $ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y) }

'{ eqSum((x: T, y: T) => ${eqSumBody('x, 'y)}) }
end derived
given Eq[String] with
def eqv(x: String, y: String) = x == y

given Eq[Int] with
def eqv(x: Int, y: Int) = x == y

def eqProduct[T](body: (T, T) => Boolean): Eq[T] =
new Eq[T]:
def eqv(x: T, y: T): Boolean = body(x, y)

def eqSum[T](body: (T, T) => Boolean): Eq[T] =
new Eq[T]:
def eqv(x: T, y: T): Boolean = body(x, y)

def summonInstances[T: Type, Elems: Type](using Quotes): List[Expr[Eq[?]]] =
Type.of[Elems] match
case '[elem *: elems] => deriveOrSummon[T, elem] :: summonInstances[T, elems]
case '[EmptyTuple] => Nil

def deriveOrSummon[T: Type, Elem: Type](using Quotes): Expr[Eq[Elem]] =
Type.of[Elem] match
case '[T] => deriveRec[T, Elem]
case _ => '{ summonInline[Eq[Elem]] }

def deriveRec[T: Type, Elem: Type](using Quotes): Expr[Eq[Elem]] =
import quotes.reflect.*
Type.of[T] match
case '[Elem] => report.errorAndAbort("infinite recursive derivation")
case _ => derivedMacro[Elem] // recursive derivation

inline def derived[T]: Eq[T] = ${ derivedMacro[T] }

def derivedMacro[T: Type](using Quotes): Expr[Eq[T]] =

val ev: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].get

ev match
case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = elementTypes }} =>
val elemInstances = summonInstances[T, elementTypes]
def eqProductBody(x: Expr[Product], y: Expr[Product])(using Quotes): Expr[Boolean] = {
if elemInstances.isEmpty then
Expr(true)
else
elemInstances.zipWithIndex.map {
case ('{ $elem: Eq[t] }, index) =>
val indexExpr = Expr(index)
val e1 = '{ $x.productElement($indexExpr).asInstanceOf[t] }
val e2 = '{ $y.productElement($indexExpr).asInstanceOf[t] }
'{ $elem.eqv($e1, $e2) }
}.reduce((acc, elem) => '{ $acc && $elem })
end if
}
'{ eqProduct((x: T, y: T) => ${eqProductBody('x.asExprOf[Product], 'y.asExprOf[Product])}) }

case '{ $m: Mirror.SumOf[T] { type MirroredElemTypes = elementTypes }} =>
val elemInstances = summonInstances[T, elementTypes]
val elements = Expr.ofList(elemInstances)

def eqSumBody(x: Expr[T], y: Expr[T])(using Quotes): Expr[Boolean] =
val ordx = '{ $m.ordinal($x) }
val ordy = '{ $m.ordinal($y) }
'{ $ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y) }

'{ eqSum((x: T, y: T) => ${eqSumBody('x, 'y)}) }
end derivedMacro
end Eq

object Macro3:
extension [T](inline x: T)
inline def === (inline y: T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)

inline given eqGen[T]: Eq[T] = ${ Eq.derived[T] }
```
Loading