Skip to content

Commit 0a7038d

Browse files
committed
WIP
1 parent 3be5d70 commit 0a7038d

File tree

3 files changed

+37
-13
lines changed

3 files changed

+37
-13
lines changed

library/src/scala/quoted/Expr.scala

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,20 +175,39 @@ package quoted {
175175
}
176176

177177
// TODO generalize for any function arity (see Expr.betaReduce)
178-
def open[T, U, X](f: Expr[T => U])(content: (Expr[U], [t] => Expr[t] => Expr[T] => Expr[t]) => X)(given qctx: QuoteContext): X = {
178+
def open[T1, R, X](f: Expr[T1 => R])(content: (Expr[R], [t] => Expr[t] => Expr[T1] => Expr[t]) => X)(given qctx: QuoteContext): X = {
179179
import qctx.tasty.{given, _}
180-
f.unseal.etaExpand match
181-
case Block(List(DefDef("$anonfun", Nil, List(List(param)), _, Some(body))), Closure(Ident("$anonfun"), None)) =>
182-
val bodyExpr = body.seal.asInstanceOf[Expr[U]]
183-
def bodyFn[V](e: Expr[V])(v: Expr[T]): Expr[V] = {
184-
new TreeMap {
185-
override def transformTerm(tree: Term)(given ctx: Context): Term =
186-
super.transformTerm(tree) match
187-
case tree: Ident if tree.symbol == param.symbol => v.unseal
188-
case tree => tree
189-
}.transformTerm(e.unseal).seal.asInstanceOf[Expr[V]]
190-
}
191-
content(bodyExpr, [t] => (e: Expr[t]) => (v: Expr[T]) => bodyFn[t](e)(v))
180+
val (params, bodyExpr) = paramsAndBody(f)
181+
content(bodyExpr, [t] => (e: Expr[t]) => (v: Expr[T1]) => bodyFn[t](e.unseal, params, List(v.unseal)).seal.asInstanceOf[Expr[t]])
182+
}
183+
184+
def open[T1, T2, R, X](f: Expr[(T1, T2) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit): X = {
185+
import qctx.tasty.{given, _}
186+
val (params, bodyExpr) = paramsAndBody(f)
187+
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal)).seal.asInstanceOf[Expr[t]])
188+
}
189+
190+
def open[T1, T2, T3, R, X](f: Expr[(T1, T2) => R])(content: (Expr[R], [t] => Expr[t] => (Expr[T1], Expr[T2], Expr[T3]) => Expr[t]) => X)(given qctx: QuoteContext)(given DummyImplicit, DummyImplicit): X = {
191+
import qctx.tasty.{given, _}
192+
val (params, bodyExpr) = paramsAndBody(f)
193+
content(bodyExpr, [t] => (e: Expr[t]) => (v1: Expr[T1], v2: Expr[T2], v3: Expr[T3]) => bodyFn[t](e.unseal, params, List(v1.unseal, v2.unseal, v3.unseal)).seal.asInstanceOf[Expr[t]])
194+
}
195+
196+
private def paramsAndBody[R](given qctx: QuoteContext)(f: Expr[Any]) = {
197+
import qctx.tasty.{given, _}
198+
val Block(List(DefDef("$anonfun", Nil, List(params), _, Some(body))), Closure(Ident("$anonfun"), None)) = f.unseal.etaExpand
199+
(params, body.seal.asInstanceOf[Expr[R]])
200+
}
201+
202+
private def bodyFn[t](given qctx: QuoteContext)(e: qctx.tasty.Term, params: List[qctx.tasty.ValDef], args: List[qctx.tasty.Term]): qctx.tasty.Term = {
203+
import qctx.tasty.{given, _}
204+
val map = params.map(_.symbol).zip(args).toMap
205+
new TreeMap {
206+
override def transformTerm(tree: Term)(given ctx: Context): Term =
207+
super.transformTerm(tree) match
208+
case tree: Ident => map.getOrElse(tree.symbol, tree)
209+
case tree => tree
210+
}.transformTerm(e)
192211
}
193212

194213
}

tests/run-macros/quoted-pattern-open-expr/Macro_1.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,6 @@ private def testExpr(e: Expr[Int])(given QuoteContext): Expr[String] = {
77
case '{ val y: Int = 4; $body } => Expr("Matched closed\n" + body.show)
88
case '{ val y: Int = 4; ($body: Int => Int)(y) } => Expr("Matched open\n" + body.show)
99
case '{ val y: Int => Int = x => x + 1; ($body: (Int => Int) => Int)(y) } => Expr("Matched open\n" + body.show)
10+
case '{ def g(x: Int): Int = ($body: (Int => Int, Int) => Int)(g, x); 5 } => Expr("Matched open\n" + body.show)
1011
}
1112
}

tests/run-macros/quoted-pattern-open-expr/Test_2.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,9 @@ object Test {
44
println(test { val x: Int = 4; 6: Int })
55
println(test { val x: Int = 4; x * x })
66
println(test { val f: Int => Int = x => x + 1; f(3) })
7+
println(test { def g(x: Int): Int = x; 5 })
8+
println(test { def g(x: Int): Int = g(4); 5 })
9+
println(test { def g(x: Int): Int = g(x); 5 })
10+
711
}
812
}

0 commit comments

Comments
 (0)