Skip to content

Commit 9b962b3

Browse files
committed
Separate macro & user code
1 parent c3fd6cd commit 9b962b3

File tree

2 files changed

+170
-168
lines changed

2 files changed

+170
-168
lines changed
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
package com.softwaremill.quicklens
2+
3+
import scala.quoted.*
4+
5+
object QuicklensMacros {
6+
def toPathModify[S: Type, A: Type](obj: Expr[S], f: Expr[(A => A) => S])(using Quotes): Expr[PathModify[S, A]] = '{ PathModify( ${obj}, ${f} ) }
7+
8+
def fromPathModify[S: Type, A: Type](pathModify: Expr[PathModify[S, A]])(using Quotes): Expr[(A => A) => S] = '{ ${pathModify}.f }
9+
10+
def to[T: Type, R: Type](f: Expr[T] => Expr[R])(using Quotes): Expr[T => R] = '{ (x: T) => ${ f('x) } }
11+
12+
def from[T: Type, R: Type](f: Expr[T => R])(using Quotes): Expr[T] => Expr[R] = (x: Expr[T]) => '{ $f($x) }
13+
14+
def modifyLensApplyImpl[T, U](path: Expr[T => U])(using Quotes, Type[T], Type[U]): Expr[PathLazyModify[T, U]] =
15+
'{PathLazyModify((t, mod) => ${modifyImpl('t, path)}.using(mod))}
16+
17+
def modifyAllLensApplyImpl[T, U](path1: Expr[T => U], paths: Expr[Seq[T => U]])(using Quotes, Type[T], Type[U]): Expr[PathLazyModify[T, U]] =
18+
'{PathLazyModify((t, mod) => ${modifyAllImpl('t, path1, paths)}.using(mod))}
19+
20+
def modifyAllImpl[S, A](obj: Expr[S], focus: Expr[S => A], focusesExpr: Expr[Seq[S => A]])(using qctx: Quotes, tpeS: Type[S], tpeA: Type[A]): Expr[PathModify[S, A]] = {
21+
import qctx.reflect.*
22+
23+
val focuses = focusesExpr match {
24+
case Varargs(args) => args
25+
}
26+
27+
val modF1 = fromPathModify(modifyImpl(obj, focus))
28+
val modF = to[(A => A), S] { (mod: Expr[A => A]) =>
29+
focuses.foldLeft(from[(A => A), S](modF1).apply(mod)) {
30+
case (objAcc, focus) =>
31+
val modCur = fromPathModify(modifyImpl(objAcc, focus))
32+
from[(A => A), S](modCur).apply(mod)
33+
}
34+
}
35+
36+
toPathModify(obj, modF)
37+
}
38+
39+
def modifyImpl[S, A](obj: Expr[S], focus: Expr[S => A])(using qctx: Quotes, tpeS: Type[S], tpeA: Type[A]): Expr[PathModify[S, A]] = {
40+
import qctx.reflect.*
41+
42+
def unsupportedShapeInfo(tree: Tree) = s"Unsupported path element. Path must have shape: _.field1.field2.each.field3.(...), got: $tree"
43+
44+
def methodSupported(method: String) =
45+
Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method)
46+
47+
enum PathSymbol:
48+
case Field(name: String)
49+
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term])
50+
51+
def toPath(tree: Tree): Seq[PathSymbol] = {
52+
tree match {
53+
/** Field access */
54+
case Select(deep, ident) =>
55+
toPath(deep) :+ PathSymbol.Field(ident)
56+
/** Method call with arguments and using clause */
57+
case Apply(Apply(Apply(TypeApply(Ident(s), typeTrees), idents), args), List(givn)) if methodSupported(s) =>
58+
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, args)
59+
/** Method call with no arguments and using clause */
60+
case Apply(Apply(TypeApply(Ident(s), typeTrees), idents), List(givn)) if methodSupported(s) =>
61+
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
62+
/** Method call with one type parameter and using clause */
63+
case a@Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
64+
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
65+
/** Field access */
66+
case Apply(deep, idents) =>
67+
toPath(deep) ++ idents.flatMap(toPath)
68+
/** Wild card from path */
69+
case i: Ident if i.name.startsWith("_") =>
70+
Seq.empty
71+
case _ =>
72+
report.throwError(unsupportedShapeInfo(tree))
73+
}
74+
}
75+
76+
def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
77+
term.tpe.typeSymbol.memberMethod(name).head
78+
}
79+
80+
def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = {
81+
val caseFields = term.tpe.typeSymbol.caseFields
82+
val idx = caseFields.map(_.name).indexOf(name)
83+
(caseFields.find(_.name == name).get, idx+1)
84+
}
85+
86+
def caseClassCopy(owner: Symbol, mod: Expr[A => A], obj: Term, field: PathSymbol.Field, tail: Seq[PathSymbol]): Term =
87+
val objSymbol = obj.tpe.typeSymbol
88+
if objSymbol.flags.is(Flags.Case) then
89+
val copy = termMethodByNameUnsafe(obj, "copy")
90+
val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name)
91+
val namedArg = NamedArg(field.name, mapToCopy(owner, mod, Select(obj, fieldMethod), tail))
92+
val fieldsIdxs = 1.to(obj.tpe.typeSymbol.caseFields.length)
93+
val args = fieldsIdxs.map { i =>
94+
if i == idx then namedArg
95+
else Select(obj, termMethodByNameUnsafe(obj, "copy$default$" + i.toString))
96+
}.toList
97+
98+
obj.tpe.widen match {
99+
// if the object's type is parametrised, we need to call .copy with the same type parameters
100+
case AppliedType(_, typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
101+
case _ => Apply(Select(obj, copy), args)
102+
}
103+
else if (objSymbol.flags.is(Flags.Sealed) && objSymbol.flags.is(Flags.Trait)) || objSymbol.flags.is(Flags.Enum) then
104+
// if the source is a sealed trait / enum, generating a pattern match with a .copy for each child (implementing case class)
105+
val cases = obj.tpe.typeSymbol.children.map { child =>
106+
val subtype = TypeIdent(child)
107+
val bind = Symbol.newBind(owner, "c", Flags.EmptyFlags, subtype.tpe)
108+
CaseDef(Bind(bind, Typed(Ref(bind), subtype)), None, caseClassCopy(owner, mod, Ref(bind), field, tail))
109+
}
110+
Match(obj, cases)
111+
112+
else
113+
report.throwError(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol")
114+
115+
def mapToCopy(owner: Symbol, mod: Expr[A => A], objTerm: Term, path: Seq[PathSymbol]): Term = path match
116+
case Nil =>
117+
val apply = termMethodByNameUnsafe(mod.asTerm, "apply")
118+
Apply(Select(mod.asTerm, apply), List(objTerm))
119+
case (field: PathSymbol.Field) :: tail =>
120+
caseClassCopy(owner, mod, objTerm, field, tail)
121+
122+
/**
123+
* For FunctionDelegate(method, givn, T, args)
124+
*
125+
* Generates:
126+
* `givn.method[T](obj, x => mapToCopy(...), ...args)`
127+
*
128+
*/
129+
case (f: PathSymbol.FunctionDelegate) :: tail =>
130+
val defdefSymbol = Symbol.newMethod(
131+
owner,
132+
"$anonfun",
133+
MethodType(List("x"))(_ => List(f.typeTree.tpe), _ => f.typeTree.tpe)
134+
)
135+
val fMethod = termMethodByNameUnsafe(f.givn, f.name)
136+
val fun = TypeApply(
137+
Select(f.givn, fMethod),
138+
List(f.typeTree)
139+
)
140+
val defdefStatements = DefDef(
141+
defdefSymbol, {
142+
case List(List(x)) => Some(mapToCopy(defdefSymbol, mod, x.asExpr.asTerm, tail))
143+
}
144+
)
145+
val closure = Closure(Ref(defdefSymbol), None)
146+
val block = Block(List(defdefStatements), closure)
147+
Apply(fun, List(objTerm, block) ++ f.args)
148+
149+
val focusTree: Tree = focus.asTerm
150+
val path = focusTree match {
151+
/** Single inlined path */
152+
case Inlined(_, _, Block(List(DefDef(_, _, _, Some(p))), _)) =>
153+
toPath(p)
154+
/** One of paths from modifyAll */
155+
case Block(List(DefDef(_, _, _, Some(p))), _) =>
156+
toPath(p)
157+
case _ =>
158+
report.throwError(unsupportedShapeInfo(focusTree))
159+
}
160+
161+
val objTree: Tree = obj.asTerm
162+
val objTerm: Term = objTree match {
163+
case Inlined(_, _, term) => term
164+
}
165+
166+
val res: (Expr[A => A] => Expr[S]) = (mod: Expr[A => A]) => mapToCopy(Symbol.spliceOwner, mod, objTerm, path).asExpr.asInstanceOf[Expr[S]]
167+
toPathModify(obj, to(res))
168+
}
169+
}

quicklens/src/main/scala-3/com/softwaremill/quicklens/package.scala

Lines changed: 1 addition & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ package com.softwaremill
22

33
import scala.collection.{Factory, SortedMap}
44
import scala.annotation.compileTimeOnly
5-
import scala.quoted.*
5+
import com.softwaremill.quicklens.QuicklensMacros._
66

77
package object quicklens {
8-
98
trait ModifyPimp {
109
extension [S, A](inline obj: S)
1110
/**
@@ -118,16 +117,10 @@ package object quicklens {
118117
inline def apply[U](inline path: T => U): PathLazyModify[T, U] = ${ modifyLensApplyImpl('path) }
119118
}
120119

121-
private def modifyLensApplyImpl[T, U](path: Expr[T => U])(using Quotes, Type[T], Type[U]): Expr[PathLazyModify[T, U]] =
122-
'{PathLazyModify((t, mod) => ${modifyImpl('t, path)}.using(mod))}
123-
124120
case class MultiLensHelper[T] private[quicklens] () {
125121
inline def apply[U](inline path1: T => U, inline paths: (T => U)*): PathLazyModify[T, U] = ${ modifyAllLensApplyImpl('path1, 'paths) }
126122
}
127123

128-
private def modifyAllLensApplyImpl[T, U](path1: Expr[T => U], paths: Expr[Seq[T => U]])(using Quotes, Type[T], Type[U]): Expr[PathLazyModify[T, U]] =
129-
'{PathLazyModify((t, mod) => ${modifyAllImpl('t, path1, paths)}.using(mod))}
130-
131124
case class PathLazyModify[T, U](doModify: (T, U => U) => T) { self =>
132125
/** see [[PathModify.using]] */
133126
def using(mod: U => U): T => T = obj => doModify(obj, mod)
@@ -283,164 +276,4 @@ package object quicklens {
283276
PathModify[T, V](t, vv => f1(t).f(u => f2(u).f(vv)))
284277

285278
private def canOnlyBeUsedInsideModify(method: String) = s"$method can only be used as a path component inside modify"
286-
287-
//
288-
289-
def toPathModify[S: Type, A: Type](obj: Expr[S], f: Expr[(A => A) => S])(using Quotes): Expr[PathModify[S, A]] = '{ PathModify( ${obj}, ${f} ) }
290-
291-
def fromPathModify[S: Type, A: Type](pathModify: Expr[PathModify[S, A]])(using Quotes): Expr[(A => A) => S] = '{ ${pathModify}.f }
292-
293-
def to[T: Type, R: Type](f: Expr[T] => Expr[R])(using Quotes): Expr[T => R] = '{ (x: T) => ${ f('x) } }
294-
295-
def from[T: Type, R: Type](f: Expr[T => R])(using Quotes): Expr[T] => Expr[R] = (x: Expr[T]) => '{ $f($x) }
296-
297-
def modifyAllImpl[S, A](obj: Expr[S], focus: Expr[S => A], focusesExpr: Expr[Seq[S => A]])(using qctx: Quotes, tpeS: Type[S], tpeA: Type[A]): Expr[PathModify[S, A]] = {
298-
import qctx.reflect.*
299-
300-
val focuses = focusesExpr match {
301-
case Varargs(args) => args
302-
}
303-
304-
val modF1 = fromPathModify(modifyImpl(obj, focus))
305-
val modF = to[(A => A), S] { (mod: Expr[A => A]) =>
306-
focuses.foldLeft(from[(A => A), S](modF1).apply(mod)) {
307-
case (objAcc, focus) =>
308-
val modCur = fromPathModify(modifyImpl(objAcc, focus))
309-
from[(A => A), S](modCur).apply(mod)
310-
}
311-
}
312-
313-
toPathModify(obj, modF)
314-
}
315-
316-
def modifyImpl[S, A](obj: Expr[S], focus: Expr[S => A])(using qctx: Quotes, tpeS: Type[S], tpeA: Type[A]): Expr[PathModify[S, A]] = {
317-
import qctx.reflect.*
318-
319-
def unsupportedShapeInfo(tree: Tree) = s"Unsupported path element. Path must have shape: _.field1.field2.each.field3.(...), got: $tree"
320-
321-
def methodSupported(method: String) =
322-
Seq("at", "each", "eachWhere", "eachRight", "eachLeft", "atOrElse", "index", "when").contains(method)
323-
324-
enum PathSymbol:
325-
case Field(name: String)
326-
case FunctionDelegate(name: String, givn: Term, typeTree: TypeTree, args: List[Term])
327-
328-
def toPath(tree: Tree): Seq[PathSymbol] = {
329-
tree match {
330-
/** Field access */
331-
case Select(deep, ident) =>
332-
toPath(deep) :+ PathSymbol.Field(ident)
333-
/** Method call with arguments and using clause */
334-
case Apply(Apply(Apply(TypeApply(Ident(s), typeTrees), idents), args), List(givn)) if methodSupported(s) =>
335-
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, args)
336-
/** Method call with no arguments and using clause */
337-
case Apply(Apply(TypeApply(Ident(s), typeTrees), idents), List(givn)) if methodSupported(s) =>
338-
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
339-
/** Method call with one type parameter and using clause */
340-
case a@Apply(TypeApply(Apply(TypeApply(Ident(s), _), idents), typeTrees), List(givn)) if methodSupported(s) =>
341-
idents.flatMap(toPath) :+ PathSymbol.FunctionDelegate(s, givn, typeTrees.last, List.empty)
342-
/** Field access */
343-
case Apply(deep, idents) =>
344-
toPath(deep) ++ idents.flatMap(toPath)
345-
/** Wild card from path */
346-
case i: Ident if i.name.startsWith("_") =>
347-
Seq.empty
348-
case _ =>
349-
report.throwError(unsupportedShapeInfo(tree))
350-
}
351-
}
352-
353-
def termMethodByNameUnsafe(term: Term, name: String): Symbol = {
354-
term.tpe.typeSymbol.memberMethod(name).head
355-
}
356-
357-
def termAccessorMethodByNameUnsafe(term: Term, name: String): (Symbol, Int) = {
358-
val caseFields = term.tpe.typeSymbol.caseFields
359-
val idx = caseFields.map(_.name).indexOf(name)
360-
(caseFields.find(_.name == name).get, idx+1)
361-
}
362-
363-
def caseClassCopy(owner: Symbol, mod: Expr[A => A], obj: Term, field: PathSymbol.Field, tail: Seq[PathSymbol]): Term =
364-
val objSymbol = obj.tpe.typeSymbol
365-
if objSymbol.flags.is(Flags.Case) then
366-
val copy = termMethodByNameUnsafe(obj, "copy")
367-
val (fieldMethod, idx) = termAccessorMethodByNameUnsafe(obj, field.name)
368-
val namedArg = NamedArg(field.name, mapToCopy(owner, mod, Select(obj, fieldMethod), tail))
369-
val fieldsIdxs = 1.to(obj.tpe.typeSymbol.caseFields.length)
370-
val args = fieldsIdxs.map { i =>
371-
if i == idx then namedArg
372-
else Select(obj, termMethodByNameUnsafe(obj, "copy$default$" + i.toString))
373-
}.toList
374-
375-
obj.tpe.widen match {
376-
// if the object's type is parametrised, we need to call .copy with the same type parameters
377-
case AppliedType(_, typeParams) => Apply(TypeApply(Select(obj, copy), typeParams.map(Inferred(_))), args)
378-
case _ => Apply(Select(obj, copy), args)
379-
}
380-
else if (objSymbol.flags.is(Flags.Sealed) && objSymbol.flags.is(Flags.Trait)) || objSymbol.flags.is(Flags.Enum) then
381-
// if the source is a sealed trait / enum, generating a pattern match with a .copy for each child (implementing case class)
382-
val cases = obj.tpe.typeSymbol.children.map { child =>
383-
val subtype = TypeIdent(child)
384-
val bind = Symbol.newBind(owner, "c", Flags.EmptyFlags, subtype.tpe)
385-
CaseDef(Bind(bind, Typed(Ref(bind), subtype)), None, caseClassCopy(owner, mod, Ref(bind), field, tail))
386-
}
387-
Match(obj, cases)
388-
389-
else
390-
report.throwError(s"Unsupported source object: must be a case class or sealed trait, but got: $objSymbol")
391-
392-
def mapToCopy(owner: Symbol, mod: Expr[A => A], objTerm: Term, path: Seq[PathSymbol]): Term = path match
393-
case Nil =>
394-
val apply = termMethodByNameUnsafe(mod.asTerm, "apply")
395-
Apply(Select(mod.asTerm, apply), List(objTerm))
396-
case (field: PathSymbol.Field) :: tail =>
397-
caseClassCopy(owner, mod, objTerm, field, tail)
398-
399-
/**
400-
* For FunctionDelegate(method, givn, T, args)
401-
*
402-
* Generates:
403-
* `givn.method[T](obj, x => mapToCopy(...), ...args)`
404-
*
405-
*/
406-
case (f: PathSymbol.FunctionDelegate) :: tail =>
407-
val defdefSymbol = Symbol.newMethod(
408-
owner,
409-
"$anonfun",
410-
MethodType(List("x"))(_ => List(f.typeTree.tpe), _ => f.typeTree.tpe)
411-
)
412-
val fMethod = termMethodByNameUnsafe(f.givn, f.name)
413-
val fun = TypeApply(
414-
Select(f.givn, fMethod),
415-
List(f.typeTree)
416-
)
417-
val defdefStatements = DefDef(
418-
defdefSymbol, {
419-
case List(List(x)) => Some(mapToCopy(defdefSymbol, mod, x.asExpr.asTerm, tail))
420-
}
421-
)
422-
val closure = Closure(Ref(defdefSymbol), None)
423-
val block = Block(List(defdefStatements), closure)
424-
Apply(fun, List(objTerm, block) ++ f.args)
425-
426-
val focusTree: Tree = focus.asTerm
427-
val path = focusTree match {
428-
/** Single inlined path */
429-
case Inlined(_, _, Block(List(DefDef(_, _, _, Some(p))), _)) =>
430-
toPath(p)
431-
/** One of paths from modifyAll */
432-
case Block(List(DefDef(_, _, _, Some(p))), _) =>
433-
toPath(p)
434-
case _ =>
435-
report.throwError(unsupportedShapeInfo(focusTree))
436-
}
437-
438-
val objTree: Tree = obj.asTerm
439-
val objTerm: Term = objTree match {
440-
case Inlined(_, _, term) => term
441-
}
442-
443-
val res: (Expr[A => A] => Expr[S]) = (mod: Expr[A => A]) => mapToCopy(Symbol.spliceOwner, mod, objTerm, path).asExpr.asInstanceOf[Expr[S]]
444-
toPathModify(obj, to(res))
445-
}
446279
}

0 commit comments

Comments
 (0)