|
| 1 | +--- |
| 2 | +layout: doc-page |
| 3 | +title: How to write a type class `derived` method using macros |
| 4 | +--- |
| 5 | + |
| 6 | +In the main [derivation](./derivation.md) documentation page, we explained the |
| 7 | +details behind `Mirror`s and type class derivation. Here we demonstrate how to |
| 8 | +implement a type class `derived` method using macros only. We follow the same |
| 9 | +example of deriving `Eq` instances and for simplicity we support a `Product` |
| 10 | +type e.g., a case class `Person`. The low-level method we will use to implement |
| 11 | +the `derived` method exploits quotes, splices of both expressions and types and |
| 12 | +the `scala.quoted.matching.summonExpr` method which is the equivalent of |
| 13 | +`summonFrom`. The former is suitable for use in a quote context, used within |
| 14 | +macros. |
| 15 | + |
| 16 | +As in the original code, the type class definition is the same: |
| 17 | + |
| 18 | +```scala |
| 19 | +trait Eq[T] { |
| 20 | + def eqv(x: T, y: T): Boolean |
| 21 | +} |
| 22 | +``` |
| 23 | + |
| 24 | +we need to implement a method `Eq.derived` on the companion object of `Eq` that |
| 25 | +produces a quoted instance for `Eq[T]`. Here is a possible signature, |
| 26 | + |
| 27 | +```scala |
| 28 | +given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]] |
| 29 | +``` |
| 30 | + |
| 31 | +and for comparison reasons we give the same signature we had with `inline`: |
| 32 | + |
| 33 | +```scala |
| 34 | +inline given derived[T]: (m: Mirror.Of[T]) => Eq[T] = ??? |
| 35 | +``` |
| 36 | + |
| 37 | +Note, that since a type is used in a subsequent stage it will need to be lifted |
| 38 | +to a `Type` by using the corresponding context bound. Also, not that we can |
| 39 | +summon the quoted `Mirror` inside the body of the `derived` this we can omit it |
| 40 | +from the signature. The body of the `derived` method is shown below: |
| 41 | + |
| 42 | + |
| 43 | +```scala |
| 44 | +given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]] = { |
| 45 | + import qctx.tasty.{_, given} |
| 46 | + |
| 47 | + val ev: Expr[Mirror.Of[T]] = summonExpr(given '[Mirror.Of[T]]).get |
| 48 | + |
| 49 | + ev match { |
| 50 | + case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes }} => |
| 51 | + val elemInstances = summonAll(elementTypes) |
| 52 | + val eqProductBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => { |
| 53 | + elemInstances.zipWithIndex.foldLeft(Expr(true: Boolean)) { |
| 54 | + case (acc, (elem, index)) => |
| 55 | + val e1 = '{$x.asInstanceOf[Product].productElement(${Expr(index)})} |
| 56 | + val e2 = '{$y.asInstanceOf[Product].productElement(${Expr(index)})} |
| 57 | + |
| 58 | + '{ $acc && $elem.asInstanceOf[Eq[Any]].eqv($e1, $e2) } |
| 59 | + } |
| 60 | + } |
| 61 | + '{ |
| 62 | + eqProduct((x: T, y: T) => ${eqProductBody('x, 'y)}) |
| 63 | + } |
| 64 | + |
| 65 | + // case for Mirror.ProductOf[T] |
| 66 | + // ... |
| 67 | + } |
| 68 | +} |
| 69 | +``` |
| 70 | + |
| 71 | +Note, that in the `inline` case we can merely write |
| 72 | +`summonAll[m.MirroredElemTypes]` inside the inline method but here, since |
| 73 | +`summonExpr` is required, we can extract the element types in a macro fashion. |
| 74 | +Being inside a macro, our first reaction would be to write the code below. Since |
| 75 | +the path inside the type argument is not stable this cannot be used: |
| 76 | + |
| 77 | +```scala |
| 78 | +'{ |
| 79 | + summonAll[$m.MirroredElemTypes] |
| 80 | +} |
| 81 | +``` |
| 82 | + |
| 83 | +Instead we extract the tuple-type for element types using pattern matching over |
| 84 | +quotes and more specifically of the refined type: |
| 85 | + |
| 86 | +```scala |
| 87 | + case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes } } => ... |
| 88 | +``` |
| 89 | + |
| 90 | +The implementation of `summonAll` as a macro can be show below assuming that we |
| 91 | +have the given instances for our primitive types: |
| 92 | + |
| 93 | +```scala |
| 94 | + def summonAll[T](t: Type[T])(given qctx: QuoteContext): List[Expr[Eq[_]]] = t match { |
| 95 | + case '[String *: $tpes] => '{ summon[Eq[String]] } :: summonAll(tpes) |
| 96 | + case '[Int *: $tpes] => '{ summon[Eq[Int]] } :: summonAll(tpes) |
| 97 | + case '[$tpe *: $tpes] => derived(given tpe, qctx) :: summonAll(tpes) |
| 98 | + case '[Unit] => Nil |
| 99 | + } |
| 100 | +``` |
| 101 | + |
| 102 | +One additional difference with the body of `derived` here as opposed to the one |
| 103 | +with `inline` is that with macros we need to synthesize the body of the code during the |
| 104 | +macro-expansion time. That is the rationale behind the `eqProductBody` function. |
| 105 | +Assuming that we calculate the equality of two `Person`s defined with a case |
| 106 | +class that holds a name of type `String` and an age of type `Int`, the equality |
| 107 | +check we want to generate is the following: |
| 108 | + |
| 109 | +```scala |
| 110 | +true |
| 111 | + && Eq[String].eqv(x.productElement(0),y.productElement(0)) |
| 112 | + && Eq[Int].eqv(x.productElement(1), y.productElement(1)) |
| 113 | +``` |
| 114 | + |
| 115 | +### Calling the derived method inside the macro |
| 116 | + |
| 117 | +Following the rules in [Macros](../metaprogramming.md) we create two methods. |
| 118 | +One that hosts the top-level splice `eqv` and one that is the implementation. |
| 119 | +Alternatively and what is shown below is that we can call the `eqv` method |
| 120 | +directly. The `eqGen` can trigger the derivation. |
| 121 | + |
| 122 | +```scala |
| 123 | +inline def [T](x: =>T) === (y: =>T)(given eq: Eq[T]): Boolean = eq.eqv(x, y) |
| 124 | + |
| 125 | +implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] } |
| 126 | +``` |
| 127 | + |
| 128 | +Note, that we use inline method syntax and we can compare instance such as |
| 129 | +`Sm(Person("Test", 23)) === Sm(Person("Test", 24))` for e.g., the following two |
| 130 | +types: |
| 131 | + |
| 132 | +```scala |
| 133 | +case class Person(name: String, age: Int) |
| 134 | + |
| 135 | +enum Opt[+T] { |
| 136 | + case Sm(t: T) |
| 137 | + case Nn |
| 138 | +} |
| 139 | +``` |
| 140 | + |
| 141 | +The full code is shown below: |
| 142 | + |
| 143 | +```scala |
| 144 | +import scala.deriving._ |
| 145 | +import scala.quoted._ |
| 146 | +import scala.quoted.matching._ |
| 147 | + |
| 148 | +trait Eq[T] { |
| 149 | + def eqv(x: T, y: T): Boolean |
| 150 | +} |
| 151 | + |
| 152 | +object Eq { |
| 153 | + given Eq[String] { |
| 154 | + def eqv(x: String, y: String) = x == y |
| 155 | + } |
| 156 | + |
| 157 | + given Eq[Int] { |
| 158 | + def eqv(x: Int, y: Int) = x == y |
| 159 | + } |
| 160 | + |
| 161 | + def eqProduct[T](body: (T, T) => Boolean): Eq[T] = |
| 162 | + new Eq[T] { |
| 163 | + def eqv(x: T, y: T): Boolean = body(x, y) |
| 164 | + } |
| 165 | + |
| 166 | + def eqSum[T](body: (T, T) => Boolean): Eq[T] = |
| 167 | + new Eq[T] { |
| 168 | + def eqv(x: T, y: T): Boolean = body(x, y) |
| 169 | + } |
| 170 | + |
| 171 | + def summonAll[T](t: Type[T])(given qctx: QuoteContext): List[Expr[Eq[_]]] = t match { |
| 172 | + case '[String *: $tpes] => '{ summon[Eq[String]] } :: summonAll(tpes) |
| 173 | + case '[Int *: $tpes] => '{ summon[Eq[Int]] } :: summonAll(tpes) |
| 174 | + case '[$tpe *: $tpes] => derived(given tpe, qctx) :: summonAll(tpes) |
| 175 | + case '[Unit] => Nil |
| 176 | + } |
| 177 | + |
| 178 | + given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]] = { |
| 179 | + import qctx.tasty.{_, given} |
| 180 | + |
| 181 | + val ev: Expr[Mirror.Of[T]] = summonExpr(given '[Mirror.Of[T]]).get |
| 182 | + |
| 183 | + ev match { |
| 184 | + case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes }} => |
| 185 | + val elemInstances = summonAll(elementTypes) |
| 186 | + val eqProductBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => { |
| 187 | + elemInstances.zipWithIndex.foldLeft(Expr(true: Boolean)) { |
| 188 | + case (acc, (elem, index)) => |
| 189 | + val e1 = '{$x.asInstanceOf[Product].productElement(${Expr(index)})} |
| 190 | + val e2 = '{$y.asInstanceOf[Product].productElement(${Expr(index)})} |
| 191 | + |
| 192 | + '{ $acc && $elem.asInstanceOf[Eq[Any]].eqv($e1, $e2) } |
| 193 | + } |
| 194 | + } |
| 195 | + '{ |
| 196 | + eqProduct((x: T, y: T) => ${eqProductBody('x, 'y)}) |
| 197 | + } |
| 198 | + |
| 199 | + case '{ $m: Mirror.SumOf[T] { type MirroredElemTypes = $elementTypes }} => |
| 200 | + val elemInstances = summonAll(elementTypes) |
| 201 | + val eqSumBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => { |
| 202 | + val ordx = '{ $m.ordinal($x) } |
| 203 | + val ordy = '{ $m.ordinal($y) } |
| 204 | + |
| 205 | + val elements = Expr.ofList(elemInstances) |
| 206 | + '{ |
| 207 | + $ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y) |
| 208 | + } |
| 209 | + } |
| 210 | + |
| 211 | + '{ |
| 212 | + eqSum((x: T, y: T) => ${eqSumBody('x, 'y)}) |
| 213 | + } |
| 214 | + } |
| 215 | + } |
| 216 | +} |
| 217 | + |
| 218 | +object Macro3 { |
| 219 | + inline def [T](x: =>T) === (y: =>T)(given eq: Eq[T]): Boolean = eq.eqv(x, y) |
| 220 | + |
| 221 | + implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] } |
| 222 | +} |
| 223 | +``` |
0 commit comments