|
| 1 | +import scala.deriving._ |
| 2 | +import scala.compiletime.{erasedValue, summonInline} |
| 3 | + |
| 4 | +inline def summonAll[T <: Tuple]: List[Eq[_]] = inline erasedValue[T] match { |
| 5 | + case _: EmptyTuple => Nil |
| 6 | + case _: (t *: ts) => summonInline[Eq[t]] :: summonAll[ts] |
| 7 | +} |
| 8 | + |
| 9 | +trait Eq[T] { |
| 10 | + def eqv(x: T, y: T): Boolean |
| 11 | +} |
| 12 | + |
| 13 | +object Eq { |
| 14 | + given Eq[Int] { |
| 15 | + def eqv(x: Int, y: Int) = x == y |
| 16 | + } |
| 17 | + |
| 18 | + def check(elem: Eq[_])(x: Any, y: Any): Boolean = |
| 19 | + elem.asInstanceOf[Eq[Any]].eqv(x, y) |
| 20 | + |
| 21 | + def iterator[T](p: T) = p.asInstanceOf[Product].productIterator |
| 22 | + |
| 23 | + def eqSum[T](s: Mirror.SumOf[T], elems: => List[Eq[_]]): Eq[T] = |
| 24 | + new Eq[T] { |
| 25 | + def eqv(x: T, y: T): Boolean = { |
| 26 | + val ordx = s.ordinal(x) |
| 27 | + (s.ordinal(y) == ordx) && check(elems(ordx))(x, y) |
| 28 | + } |
| 29 | + } |
| 30 | + |
| 31 | + def eqProduct[T](p: Mirror.ProductOf[T], elems: => List[Eq[_]]): Eq[T] = |
| 32 | + new Eq[T] { |
| 33 | + def eqv(x: T, y: T): Boolean = |
| 34 | + iterator(x).zip(iterator(y)).zip(elems.iterator).forall { |
| 35 | + case ((x, y), elem) => check(elem)(x, y) |
| 36 | + } |
| 37 | + } |
| 38 | + |
| 39 | + inline given derived[T](using m: Mirror.Of[T]) as Eq[T] = { |
| 40 | + lazy val elemInstances = summonAll[m.MirroredElemTypes] |
| 41 | + inline m match { |
| 42 | + case s: Mirror.SumOf[T] => eqSum(s, elemInstances) |
| 43 | + case p: Mirror.ProductOf[T] => eqProduct(p, elemInstances) |
| 44 | + } |
| 45 | + } |
| 46 | +} |
| 47 | + |
| 48 | +enum Tree[T] derives Eq { |
| 49 | + case Branch(left: Tree[T], right: Tree[T]) |
| 50 | + case Leaf(elem: T) |
| 51 | +} |
| 52 | + |
| 53 | +@main |
| 54 | +def Test = { |
| 55 | + import Tree._ |
| 56 | + |
| 57 | + val t1 = Branch(Leaf(1), Leaf(1)) |
| 58 | + assert(summon[Eq[Tree[Int]]].eqv(t1, t1)) |
| 59 | +} |
0 commit comments