Skip to content

Commit d44fcef

Browse files
committed
remove 'given derived' from inline typeclass derivation
1 parent 13c10a9 commit d44fcef

File tree

1 file changed

+79
-61
lines changed

1 file changed

+79
-61
lines changed

docs/_docs/reference/contextual/derivation.md

Lines changed: 79 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -309,87 +309,99 @@ worked out example of such a library, see [Shapeless 3](https://github.com/miles
309309

310310
## How to write a type class `derived` method using low level mechanisms
311311

312-
The low-level method we will use to implement a type class `derived` method in this example exploits three new
313-
type-level constructs in Scala 3: inline methods, inline matches, and implicit searches via `summonInline` or `summonFrom`. Given this definition of the
314-
`Eq` type class,
312+
The low-level method we will use to implement a type class `derived` method in this example exploits three new type-level constructs in Scala 3: inline methods, inline matches, and implicit searches via `summonInline` or `summonFrom`.
313+
Given this definition of the `Eq` type class,
315314

316315
```scala
317316
trait Eq[T]:
318317
def eqv(x: T, y: T): Boolean
319318
```
320319

321320
we need to implement a method `Eq.derived` on the companion object of `Eq` that produces a given instance for `Eq[T]` given
322-
a `Mirror[T]`. Here is a possible implementation,
321+
a `Mirror[T]`.
322+
Here is a possible implementation,
323323

324324
```scala
325325
import scala.deriving.Mirror
326326

327-
inline given derived[T](using m: Mirror.Of[T]): Eq[T] =
328-
val elemInstances = summonAll[m.MirroredElemTypes] // (1)
329-
inline m match // (2)
327+
inline def derived[T](using m: Mirror.Of[T]): Eq[T] =
328+
lazy val elemInstances = summonInstances[T, m.MirroredElemTypes] // (1)
329+
inline m match // (2)
330330
case s: Mirror.SumOf[T] => eqSum(s, elemInstances)
331331
case p: Mirror.ProductOf[T] => eqProduct(p, elemInstances)
332332
```
333333

334-
Note that `derived` is defined as an `inline` given. This means that the method will be expanded at
335-
call sites (for instance the compiler generated instance definitions in the companion objects of ADTs which have a
336-
`derived Eq` clause), and also that it can be used recursively if necessary, to compute instances for children.
334+
Note that `derived` is defined as an `inline` given.
335+
This means that the method will be expanded at call sites (for instance the compiler generated instance definitions in the companion objects of ADTs which have a `derived Eq` clause), and also that it can be used recursively if necessary, to compute instances for children.
337336

338-
The body of this method (1) first materializes the `Eq` instances for all the child types of type the instance is
339-
being derived for. This is either all the branches of a sum type or all the fields of a product type. The
340-
implementation of `summonAll` is `inline` and uses Scala 3's `summonInline` construct to collect the instances as a
341-
`List`,
337+
The body of this method (1) first materializes the `Eq` instances for all the child types of type the instance is being derived for.
338+
This is either all the branches of a sum type or all the fields of a product type.
339+
The implementation of `summonInstances` is `inline` and uses Scala 3's `summonInline` construct to collect the instances as a `List`,
342340

343341
```scala
344-
inline def summonAll[T <: Tuple]: List[Eq[_]] =
345-
inline erasedValue[T] match
342+
inline def summonInstances[T, Elems <: Tuple]: List[Eq[?]] =
343+
inline erasedValue[Elems] match
344+
case _: (elem *: elems) => deriveOrSummon[T, elem] :: summonInstances[T, elems]
346345
case _: EmptyTuple => Nil
347-
case _: (t *: ts) => summonInline[Eq[t]] :: summonAll[ts]
346+
347+
inline def deriveOrSummon[T, Elem]: Eq[Elem] =
348+
inline erasedValue[Elem] match
349+
case _: T => deriveRec[T, Elem]
350+
case _ => summonInline[Eq[Elem]]
351+
352+
inline def deriveRec[T, Elem]: Eq[Elem] =
353+
inline erasedValue[T] match
354+
case _: Elem => error("infinite recursive derivation")
355+
case _ => Eq.derived[Elem](using summonInline[Mirror.Of[Elem]]) // recursive derivation
348356
```
349357

350358
with the instances for children in hand the `derived` method uses an `inline match` to dispatch to methods which can
351-
construct instances for either sums or products (2). Note that because `derived` is `inline` the match will be
352-
resolved at compile-time and only the left-hand side of the matching case will be inlined into the generated code with
353-
types refined as revealed by the match.
359+
construct instances for either sums or products (2).
360+
Note that because `derived` is `inline` the match will be resolved at compile-time and only the right-hand side of the matching case will be inlined into the generated code with types refined as revealed by the match.
354361

355-
In the sum case, `eqSum`, we use the runtime `ordinal` values of the arguments to `eqv` to first check if the two
356-
values are of the same subtype of the ADT (3) and then, if they are, to further test for equality based on the `Eq`
357-
instance for the appropriate ADT subtype using the auxiliary method `check` (4).
362+
In the sum case, `eqSum`, we use the runtime `ordinal` values of the arguments to `eqv` to first check if the two values are of the same subtype of the ADT (3) and then, if they are, to further test for equality based on the `Eq` instance for the appropriate ADT subtype using the auxiliary method `check` (4).
358363

359364
```scala
360365
import scala.deriving.Mirror
361366

362-
def eqSum[T](s: Mirror.SumOf[T], elems: List[Eq[_]]): Eq[T] =
367+
def eqSum[T](s: Mirror.SumOf[T], elems: => List[Eq[?]]): Eq[T] =
363368
new Eq[T]:
364369
def eqv(x: T, y: T): Boolean =
365370
val ordx = s.ordinal(x) // (3)
366-
(s.ordinal(y) == ordx) && check(elems(ordx))(x, y) // (4)
371+
(s.ordinal(y) == ordx) && check(x, y, elems(ordx)) // (4)
367372
```
368373

369-
In the product case, `eqProduct` we test the runtime values of the arguments to `eqv` for equality as products based
370-
on the `Eq` instances for the fields of the data type (5),
374+
In the product case, `eqProduct` we test the runtime values of the arguments to `eqv` for equality as products based on the `Eq` instances for the fields of the data type (5),
371375

372376
```scala
373377
import scala.deriving.Mirror
374378

375-
def eqProduct[T](p: Mirror.ProductOf[T], elems: List[Eq[_]]): Eq[T] =
379+
def eqProduct[T](p: Mirror.ProductOf[T], elems: => List[Eq[?]]): Eq[T] =
376380
new Eq[T]:
377381
def eqv(x: T, y: T): Boolean =
378-
iterator(x).zip(iterator(y)).zip(elems.iterator).forall { // (5)
379-
case ((x, y), elem) => check(elem)(x, y)
380-
}
382+
iterable(x).lazyZip(iterable(y)).lazyZip(elems).forall(check)
381383
```
382384

383385
Pulling this all together we have the following complete implementation,
384386

385387
```scala
386388
import scala.deriving.*
387-
import scala.compiletime.{erasedValue, summonInline}
389+
import scala.compiletime.{error, erasedValue, summonInline}
388390

389-
inline def summonAll[T <: Tuple]: List[Eq[_]] =
390-
inline erasedValue[T] match
391+
inline def summonInstances[T, Elems <: Tuple]: List[Eq[?]] =
392+
inline erasedValue[Elems] match
393+
case _: (elem *: elems) => deriveOrSummon[T, elem] :: summonInstances[T, elems]
391394
case _: EmptyTuple => Nil
392-
case _: (t *: ts) => summonInline[Eq[t]] :: summonAll[ts]
395+
396+
inline def deriveOrSummon[T, Elem]: Eq[Elem] =
397+
inline erasedValue[Elem] match
398+
case _: T => deriveRec[T, Elem]
399+
case _ => summonInline[Eq[Elem]]
400+
401+
inline def deriveRec[T, Elem]: Eq[Elem] =
402+
inline erasedValue[T] match
403+
case _: Elem => error("infinite recursive derivation")
404+
case _ => Eq.derived[Elem](using summonInline[Mirror.Of[Elem]]) // recursive derivation
393405

394406
trait Eq[T]:
395407
def eqv(x: T, y: T): Boolean
@@ -398,26 +410,25 @@ object Eq:
398410
given Eq[Int] with
399411
def eqv(x: Int, y: Int) = x == y
400412

401-
def check(elem: Eq[_])(x: Any, y: Any): Boolean =
413+
def check(x: Any, y: Any, elem: Eq[?]): Boolean =
402414
elem.asInstanceOf[Eq[Any]].eqv(x, y)
403415

404-
def iterator[T](p: T) = p.asInstanceOf[Product].productIterator
416+
def iterable[T](p: T): Iterable[Any] = new AbstractIterable[Any]:
417+
def iterator: Iterator[Any] = p.asInstanceOf[Product].productIterator
405418

406-
def eqSum[T](s: Mirror.SumOf[T], elems: => List[Eq[_]]): Eq[T] =
419+
def eqSum[T](s: Mirror.SumOf[T], elems: => List[Eq[?]]): Eq[T] =
407420
new Eq[T]:
408421
def eqv(x: T, y: T): Boolean =
409422
val ordx = s.ordinal(x)
410-
(s.ordinal(y) == ordx) && check(elems(ordx))(x, y)
423+
(s.ordinal(y) == ordx) && check(x, y, elems(ordx))
411424

412-
def eqProduct[T](p: Mirror.ProductOf[T], elems: => List[Eq[_]]): Eq[T] =
425+
def eqProduct[T](p: Mirror.ProductOf[T], elems: => List[Eq[?]]): Eq[T] =
413426
new Eq[T]:
414427
def eqv(x: T, y: T): Boolean =
415-
iterator(x).zip(iterator(y)).zip(elems.iterator).forall {
416-
case ((x, y), elem) => check(elem)(x, y)
417-
}
428+
iterable(x).lazyZip(iterable(y)).lazyZip(elems).forall(check)
418429

419-
inline given derived[T](using m: Mirror.Of[T]): Eq[T] =
420-
lazy val elemInstances = summonAll[m.MirroredElemTypes]
430+
inline def derived[T](using m: Mirror.Of[T]): Eq[T] =
431+
lazy val elemInstances = summonInstances[T, m.MirroredElemTypes]
421432
inline m match
422433
case s: Mirror.SumOf[T] => eqSum(s, elemInstances)
423434
case p: Mirror.ProductOf[T] => eqProduct(p, elemInstances)
@@ -427,32 +438,39 @@ end Eq
427438
we can test this relative to a simple ADT like so,
428439

429440
```scala
430-
enum Opt[+T] derives Eq:
431-
case Sm(t: T)
432-
case Nn
441+
enum Lst[+T] derives Eq:
442+
case Cns(t: T, ts: Lst[T])
443+
case Nl
444+
445+
extension [T](t: T) def ::(ts: Lst[T]): Lst[T] = Lst.Cns(t, ts)
433446

434447
@main def test(): Unit =
435-
import Opt.*
436-
val eqoi = summon[Eq[Opt[Int]]]
437-
assert(eqoi.eqv(Sm(23), Sm(23)))
438-
assert(!eqoi.eqv(Sm(23), Sm(13)))
439-
assert(!eqoi.eqv(Sm(23), Nn))
448+
import Lst.*
449+
val eqoi = summon[Eq[Lst[Int]]]
450+
assert(eqoi.eqv(23 :: 47 :: Nl, 23 :: 47 :: Nl))
451+
assert(!eqoi.eqv(23 :: Nl, 7 :: Nl))
452+
assert(!eqoi.eqv(23 :: Nl, Nl))
440453
```
441454

442-
In this case the code that is generated by the inline expansion for the derived `Eq` instance for `Opt` looks like the
455+
In this case the code that is generated by the inline expansion for the derived `Eq` instance for `Lst` looks like the
443456
following, after a little polishing,
444457

445458
```scala
446-
given derived$Eq[T](using eqT: Eq[T]): Eq[Opt[T]] =
447-
eqSum(
448-
summon[Mirror[Opt[T]]],
459+
given derived$Eq[T](using eqT: Eq[T]): Eq[Lst[T]] =
460+
eqSum(summon[Mirror.Of[Lst[T]]], {/* cached lazily */
449461
List(
450-
eqProduct(summon[Mirror[Sm[T]]], List(summon[Eq[T]])),
451-
eqProduct(summon[Mirror[Nn.type]], Nil)
462+
eqProduct(summon[Mirror.Of[Cns[T]]], {/* cached lazily */
463+
List(summon[Eq[T]], summon[Eq[Lst[T]]])
464+
}),
465+
eqProduct(summon[Mirror.Of[Nl.type]], {/* cached lazily */
466+
Nil
467+
})
452468
)
453-
)
469+
})
454470
```
455471

472+
The `lazy` modifier on `elemInstances` is necessary for preventing infinite recursion in the derived instance for recursive types such as `Lst`.
473+
456474
Alternative approaches can be taken to the way that `derived` methods can be defined. For example, more aggressively
457475
inlined variants using Scala 3 macros, whilst being more involved for type class authors to write than the example
458476
above, can produce code for type classes like `Eq` which eliminate all the abstraction artefacts (eg. the `Lists` of
@@ -466,7 +484,7 @@ given eqSum[A](using inst: => K0.CoproductInstances[Eq, A]): Eq[A] with
466484
[t] => (eqt: Eq[t], t0: t, t1: t) => eqt.eqv(t0, t1)
467485
)
468486

469-
given eqProduct[A](using inst: K0.ProductInstances[Eq, A]): Eq[A] with
487+
given eqProduct[A](using inst: => K0.ProductInstances[Eq, A]): Eq[A] with
470488
def eqv(x: A, y: A): Boolean = inst.foldLeft2(x, y)(true: Boolean)(
471489
[t] => (acc: Boolean, eqt: Eq[t], t0: t, t1: t) =>
472490
Complete(!eqt.eqv(t0, t1))(false)(true)

0 commit comments

Comments
 (0)