diff --git a/docs/docs/reference/contextual/derivation-macro.md b/docs/docs/reference/contextual/derivation-macro.md new file mode 100644 index 000000000000..89e794037d97 --- /dev/null +++ b/docs/docs/reference/contextual/derivation-macro.md @@ -0,0 +1,223 @@ +--- +layout: doc-page +title: How to write a type class `derived` method using macros +--- + +In the main [derivation](./derivation.md) documentation page, we explained the +details behind `Mirror`s and type class derivation. Here we demonstrate how to +implement a type class `derived` method using macros only. We follow the same +example of deriving `Eq` instances and for simplicity we support a `Product` +type e.g., a case class `Person`. The low-level method we will use to implement +the `derived` method exploits quotes, splices of both expressions and types and +the `scala.quoted.matching.summonExpr` method which is the equivalent of +`summonFrom`. The former is suitable for use in a quote context, used within +macros. + +As in the original code, the type class definition is the same: + +```scala +trait Eq[T] { + def eqv(x: T, y: T): Boolean +} +``` + +we need to implement a method `Eq.derived` on the companion object of `Eq` that +produces a quoted instance for `Eq[T]`. Here is a possible signature, + +```scala +given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]] +``` + +and for comparison reasons we give the same signature we had with `inline`: + +```scala +inline given derived[T]: (m: Mirror.Of[T]) => Eq[T] = ??? +``` + +Note, that since a type is used in a subsequent stage it will need to be lifted +to a `Type` by using the corresponding context bound. Also, not that we can +summon the quoted `Mirror` inside the body of the `derived` this we can omit it +from the signature. The body of the `derived` method is shown below: + + +```scala +given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]] = { + import qctx.tasty.{_, given} + + val ev: Expr[Mirror.Of[T]] = summonExpr(given '[Mirror.Of[T]]).get + + ev match { + case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes }} => + val elemInstances = summonAll(elementTypes) + val eqProductBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => { + elemInstances.zipWithIndex.foldLeft(Expr(true: Boolean)) { + case (acc, (elem, index)) => + val e1 = '{$x.asInstanceOf[Product].productElement(${Expr(index)})} + val e2 = '{$y.asInstanceOf[Product].productElement(${Expr(index)})} + + '{ $acc && $elem.asInstanceOf[Eq[Any]].eqv($e1, $e2) } + } + } + '{ + eqProduct((x: T, y: T) => ${eqProductBody('x, 'y)}) + } + + // case for Mirror.ProductOf[T] + // ... + } +} +``` + +Note, that in the `inline` case we can merely write +`summonAll[m.MirroredElemTypes]` inside the inline method but here, since +`summonExpr` is required, we can extract the element types in a macro fashion. +Being inside a macro, our first reaction would be to write the code below. Since +the path inside the type argument is not stable this cannot be used: + +```scala +'{ + summonAll[$m.MirroredElemTypes] +} +``` + +Instead we extract the tuple-type for element types using pattern matching over +quotes and more specifically of the refined type: + +```scala + case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes } } => ... +``` + +The implementation of `summonAll` as a macro can be show below assuming that we +have the given instances for our primitive types: + +```scala + def summonAll[T](t: Type[T])(given qctx: QuoteContext): List[Expr[Eq[_]]] = t match { + case '[String *: $tpes] => '{ summon[Eq[String]] } :: summonAll(tpes) + case '[Int *: $tpes] => '{ summon[Eq[Int]] } :: summonAll(tpes) + case '[$tpe *: $tpes] => derived(given tpe, qctx) :: summonAll(tpes) + case '[Unit] => Nil + } +``` + +One additional difference with the body of `derived` here as opposed to the one +with `inline` is that with macros we need to synthesize the body of the code during the +macro-expansion time. That is the rationale behind the `eqProductBody` function. +Assuming that we calculate the equality of two `Person`s defined with a case +class that holds a name of type `String` and an age of type `Int`, the equality +check we want to generate is the following: + +```scala +true + && Eq[String].eqv(x.productElement(0),y.productElement(0)) + && Eq[Int].eqv(x.productElement(1), y.productElement(1)) +``` + +### Calling the derived method inside the macro + +Following the rules in [Macros](../metaprogramming.md) we create two methods. +One that hosts the top-level splice `eqv` and one that is the implementation. +Alternatively and what is shown below is that we can call the `eqv` method +directly. The `eqGen` can trigger the derivation. + +```scala +inline def [T](x: =>T) === (y: =>T)(given eq: Eq[T]): Boolean = eq.eqv(x, y) + +implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] } +``` + +Note, that we use inline method syntax and we can compare instance such as +`Sm(Person("Test", 23)) === Sm(Person("Test", 24))` for e.g., the following two +types: + +```scala +case class Person(name: String, age: Int) + +enum Opt[+T] { + case Sm(t: T) + case Nn +} +``` + +The full code is shown below: + +```scala +import scala.deriving._ +import scala.quoted._ +import scala.quoted.matching._ + +trait Eq[T] { + def eqv(x: T, y: T): Boolean +} + +object Eq { + given Eq[String] { + def eqv(x: String, y: String) = x == y + } + + given Eq[Int] { + def eqv(x: Int, y: Int) = x == y + } + + def eqProduct[T](body: (T, T) => Boolean): Eq[T] = + new Eq[T] { + def eqv(x: T, y: T): Boolean = body(x, y) + } + + def eqSum[T](body: (T, T) => Boolean): Eq[T] = + new Eq[T] { + def eqv(x: T, y: T): Boolean = body(x, y) + } + + def summonAll[T](t: Type[T])(given qctx: QuoteContext): List[Expr[Eq[_]]] = t match { + case '[String *: $tpes] => '{ summon[Eq[String]] } :: summonAll(tpes) + case '[Int *: $tpes] => '{ summon[Eq[Int]] } :: summonAll(tpes) + case '[$tpe *: $tpes] => derived(given tpe, qctx) :: summonAll(tpes) + case '[Unit] => Nil + } + + given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]] = { + import qctx.tasty.{_, given} + + val ev: Expr[Mirror.Of[T]] = summonExpr(given '[Mirror.Of[T]]).get + + ev match { + case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes }} => + val elemInstances = summonAll(elementTypes) + val eqProductBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => { + elemInstances.zipWithIndex.foldLeft(Expr(true: Boolean)) { + case (acc, (elem, index)) => + val e1 = '{$x.asInstanceOf[Product].productElement(${Expr(index)})} + val e2 = '{$y.asInstanceOf[Product].productElement(${Expr(index)})} + + '{ $acc && $elem.asInstanceOf[Eq[Any]].eqv($e1, $e2) } + } + } + '{ + eqProduct((x: T, y: T) => ${eqProductBody('x, 'y)}) + } + + case '{ $m: Mirror.SumOf[T] { type MirroredElemTypes = $elementTypes }} => + val elemInstances = summonAll(elementTypes) + val eqSumBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => { + val ordx = '{ $m.ordinal($x) } + val ordy = '{ $m.ordinal($y) } + + val elements = Expr.ofList(elemInstances) + '{ + $ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y) + } + } + + '{ + eqSum((x: T, y: T) => ${eqSumBody('x, 'y)}) + } + } + } +} + +object Macro3 { + inline def [T](x: =>T) === (y: =>T)(given eq: Eq[T]): Boolean = eq.eqv(x, y) + + implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] } +} +``` \ No newline at end of file diff --git a/docs/docs/reference/contextual/derivation.md b/docs/docs/reference/contextual/derivation.md index 96883ce983e6..795b4297ebd1 100644 --- a/docs/docs/reference/contextual/derivation.md +++ b/docs/docs/reference/contextual/derivation.md @@ -347,6 +347,10 @@ inline def derived[A](given gen: K0.Generic[A]): Eq[A] = gen.derive(eqSum, eqPro The framework described here enables all three of these approaches without mandating any of them. +For a brief discussion on how to use macros to write a type class `derived` +method please read more at [How to write a type class `derived` method using +macros](./derivation-macro.md). + ### Deriving instances elsewhere Sometimes one would like to derive a type class instance for an ADT after the ADT is defined, without being able to diff --git a/docs/docs/reference/metaprogramming/macros.md b/docs/docs/reference/metaprogramming/macros.md index 50286136db65..9e0214304c64 100644 --- a/docs/docs/reference/metaprogramming/macros.md +++ b/docs/docs/reference/metaprogramming/macros.md @@ -569,7 +569,7 @@ sum ### Find implicits within a macro Similarly to the `summonFrom` construct, it is possible to make implicit search available -in a quote context. For this we simply provide `scala.quoted.matching.summonExpr: +in a quote context. For this we simply provide `scala.quoted.matching.summonExpr`: ```scala inline def setFor[T]: Set[T] = ${ setForExpr[T] } diff --git a/tests/run-macros/i8007.check b/tests/run-macros/i8007.check new file mode 100644 index 000000000000..0ccbe496ef31 --- /dev/null +++ b/tests/run-macros/i8007.check @@ -0,0 +1,15 @@ +List("name", "age") + +Test 23 +() + +true + +false + +true + +true + +false + diff --git a/tests/run-macros/i8007/Macro_1.scala b/tests/run-macros/i8007/Macro_1.scala new file mode 100644 index 000000000000..20e6f5997dbf --- /dev/null +++ b/tests/run-macros/i8007/Macro_1.scala @@ -0,0 +1,30 @@ +import scala.deriving._ +import scala.quoted._ +import scala.quoted.matching._ + +object Macro1 { + + def mirrorFields[T](t: Type[T])(given qctx: QuoteContext): List[String] = + t match { + case '[$field *: $fields] => field.show :: mirrorFields(fields) + case '[Unit] => Nil + } + + // Demonstrates the use of quoted pattern matching + // over a refined type extracting the tuple type + // for e.g., MirroredElemLabels + inline def test1[T](value: =>T): List[String] = + ${ test1Impl('value) } + + def test1Impl[T: Type](value: Expr[T])(given qctx: QuoteContext): Expr[List[String]] = { + import qctx.tasty.{_, given} + + val mirrorTpe = '[Mirror.Of[T]] + + summonExpr(given mirrorTpe).get match { + case '{ $m: Mirror.ProductOf[T]{ type MirroredElemLabels = $t } } => { + Expr(mirrorFields(t)) + } + } + } +} \ No newline at end of file diff --git a/tests/run-macros/i8007/Macro_2.scala b/tests/run-macros/i8007/Macro_2.scala new file mode 100644 index 000000000000..ed65f04bf4a8 --- /dev/null +++ b/tests/run-macros/i8007/Macro_2.scala @@ -0,0 +1,57 @@ +import scala.deriving._ +import scala.quoted._ +import scala.quoted.matching._ + +object Macro2 { + + def mirrorFields[T](t: Type[T])(given qctx: QuoteContext): List[String] = + t match { + case '[$field *: $fields] => field.show.substring(1, field.show.length-1) :: mirrorFields(fields) + case '[Unit] => Nil + } + + trait JsonEncoder[T] { + def encode(elem: T): String + } + + object JsonEncoder { + def emitJsonEncoder[T](body: T => String): JsonEncoder[T]= + new JsonEncoder[T] { + def encode(elem: T): String = body(elem) + } + + def derived[T: Type](ev: Expr[Mirror.Of[T]])(given qctx: QuoteContext): Expr[JsonEncoder[T]] = { + import qctx.tasty.{_, given} + + val fields = ev match { + case '{ $m: Mirror.ProductOf[T] { type MirroredElemLabels = $t } } => + mirrorFields(t) + } + + val body: Expr[T] => Expr[String] = elem => + fields.reverse.foldLeft(Expr("")){ (acc, field) => + val res = Select.unique(elem.unseal, field).seal + '{ $res.toString + " " + $acc } + } + + '{ + emitJsonEncoder((x: T) => ${body('x)}) + } + } + } + + inline def test2[T](value: =>T): Unit = ${ test2Impl('value) } + + def test2Impl[T: Type](value: Expr[T])(given qctx: QuoteContext): Expr[Unit] = { + import qctx.tasty.{_, given} + + val mirrorTpe = '[Mirror.Of[T]] + val mirrorExpr = summonExpr(given mirrorTpe).get + val derivedInstance = JsonEncoder.derived(mirrorExpr) + + '{ + val res = $derivedInstance.encode($value) + println(res) + } + } +} \ No newline at end of file diff --git a/tests/run-macros/i8007/Macro_3.scala b/tests/run-macros/i8007/Macro_3.scala new file mode 100644 index 000000000000..d530f86427b6 --- /dev/null +++ b/tests/run-macros/i8007/Macro_3.scala @@ -0,0 +1,79 @@ +import scala.deriving._ +import scala.quoted._ +import scala.quoted.matching._ + +trait Eq[T] { + def eqv(x: T, y: T): Boolean +} + +object Eq { + given Eq[String] { + def eqv(x: String, y: String) = x == y + } + + given Eq[Int] { + def eqv(x: Int, y: Int) = x == y + } + + def eqProduct[T](body: (T, T) => Boolean): Eq[T] = + new Eq[T] { + def eqv(x: T, y: T): Boolean = body(x, y) + } + + def eqSum[T](body: (T, T) => Boolean): Eq[T] = + new Eq[T] { + def eqv(x: T, y: T): Boolean = body(x, y) + } + + def summonAll[T](t: Type[T])(given qctx: QuoteContext): List[Expr[Eq[_]]] = t match { + case '[String *: $tpes] => '{ summon[Eq[String]] } :: summonAll(tpes) + case '[Int *: $tpes] => '{ summon[Eq[Int]] } :: summonAll(tpes) + case '[$tpe *: $tpes] => derived(given tpe, qctx) :: summonAll(tpes) + case '[Unit] => Nil + } + + given derived[T: Type](given qctx: QuoteContext): Expr[Eq[T]] = { + import qctx.tasty.{_, given} + + val ev: Expr[Mirror.Of[T]] = summonExpr(given '[Mirror.Of[T]]).get + + ev match { + case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = $elementTypes }} => + val elemInstances = summonAll(elementTypes) + val eqProductBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => { + elemInstances.zipWithIndex.foldLeft(Expr(true: Boolean)) { + case (acc, (elem, index)) => + val e1 = '{$x.asInstanceOf[Product].productElement(${Expr(index)})} + val e2 = '{$y.asInstanceOf[Product].productElement(${Expr(index)})} + + '{ $acc && $elem.asInstanceOf[Eq[Any]].eqv($e1, $e2) } + } + } + '{ + eqProduct((x: T, y: T) => ${eqProductBody('x, 'y)}) + } + + case '{ $m: Mirror.SumOf[T] { type MirroredElemTypes = $elementTypes }} => + val elemInstances = summonAll(elementTypes) + val eqSumBody: (Expr[T], Expr[T]) => Expr[Boolean] = (x, y) => { + val ordx = '{ $m.ordinal($x) } + val ordy = '{ $m.ordinal($y) } + + val elements = Expr.ofList(elemInstances) + '{ + $ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y) + } + } + + '{ + eqSum((x: T, y: T) => ${eqSumBody('x, 'y)}) + } + } + } +} + +object Macro3 { + inline def [T](x: =>T) === (y: =>T)(given eq: Eq[T]): Boolean = eq.eqv(x, y) + + implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] } +} \ No newline at end of file diff --git a/tests/run-macros/i8007/Test_4.scala b/tests/run-macros/i8007/Test_4.scala new file mode 100644 index 000000000000..1809d2e023eb --- /dev/null +++ b/tests/run-macros/i8007/Test_4.scala @@ -0,0 +1,44 @@ +import Macro1._ +import Macro2._ +import Macro3._ +import Macro3.eqGen + +case class Person(name: String, age: Int) + +enum Opt[+T] { + case Sm(t: T) + case Nn +} + +@main def Test() = { + import Opt._ + import Eq.{given, _} + + val t1 = test1(Person("Test", 23)) + println(t1) + println + + val t2 = test2(Person("Test", 23)) + println(t2) + println + + val t3 = Person("Test", 23) === Person("Test", 23) + println(t3) // true + println + + val t4 = Person("Test", 23) === Person("Test", 24) + println(t4) // false + println + + val t5 = Sm(23) === Sm(23) + println(t5) // true + println + + val t6 = Sm(Person("Test", 23)) === Sm(Person("Test", 23)) + println(t6) // true + println + + val t7 = Sm(Person("Test", 23)) === Sm(Person("Test", 24)) + println(t7) // false + println +} \ No newline at end of file