Skip to content

Commit 0006307

Browse files
committed
Fix #5941: implement GenLens macro
1 parent 29d4ccb commit 0006307

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
abstract class Lens[S, T] {
2+
def get(s: S): T
3+
def set(t: T, s: S) :S
4+
}
5+
6+
import scala.quoted._
7+
import scala.tasty._
8+
9+
object Lens {
10+
def apply[S, T](_get: S => T)(_set: T => S => S): Lens[S, T] = new Lens {
11+
def get(s: S): T = _get(s)
12+
def set(t: T, s: S): S = _set(t)(s)
13+
}
14+
15+
/** case class Address(streetNumber: Int, streetName: String)
16+
*
17+
* Lens.gen[Address, Int](_.streetNumber) ~~>
18+
*
19+
* Lens[Address, Int](_.streetNumber)(n => a => a.copy(streetNumber = n))
20+
*/
21+
inline def gen[S, T](get: S => T): Lens[S, T] = ~impl('(get))
22+
23+
def impl[S: Type, T: Type](getter: Expr[S => T])(implicit refl: Reflection): Expr[Lens[S, T]] = {
24+
import refl._
25+
import util._
26+
import quoted.Toolbox.Default._
27+
28+
// obj.copy(field = value)
29+
def setterBody(obj: Expr[S], value: Expr[T], field: String): Expr[S] =
30+
Term.Select.overloaded(obj.unseal, "copy", Nil, Term.NamedArg(field, value.unseal) :: Nil).seal[S]
31+
32+
getter.unseal.underlyingArgument match {
33+
case Term.Block(
34+
DefDef(_, Nil, (param :: Nil) :: Nil, _, Some(Term.Select(o, field))) :: Nil,
35+
Term.Lambda(meth, _)
36+
) =>
37+
'{
38+
val setter = (t: T) => (s: S) => ~setterBody('(s), '(t), field)
39+
apply(~getter)(setter)
40+
}
41+
}
42+
}
43+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
case class Address(streetNumber: Int, streetName: String)
2+
3+
object Test {
4+
def main(args: Array[String]): Unit = {
5+
// val len = Lens.gen[Address, Int](_.streetNumber)
6+
val len = Lens.gen[Address, Int]( (a: Address) => a.streetNumber)
7+
val address = Address(10, "High Street")
8+
assert(len.get(address) == 10)
9+
val addr2 = len.set(5, address)
10+
assert(len.get(addr2) == 5)
11+
}
12+
}

0 commit comments

Comments
 (0)