Skip to content

Commit e6628bb

Browse files
committed
Add quote pattern matching docs and simpler underlyingArgument
1 parent 0a33ad3 commit e6628bb

File tree

7 files changed

+166
-2
lines changed

7 files changed

+166
-2
lines changed

docs/docs/reference/metaprogramming/macros.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,4 +614,81 @@ compilation of the suspended files using the output of the previous (partial) co
614614
In case all files are suspended due to cyclic dependencies the compilation will fail with an error.
615615

616616

617+
### Pattern matching on quoted expressions
618+
619+
It is possible to deconstruct or extract values out of `Expr` using pattern matching.
620+
621+
#### scala.quoted .matching
622+
623+
In `scala.quoted.matching` contains object that can help extract values from `Expr`.
624+
625+
* `scala.quoted.matching.Const`: matches an expression a literal value and returns the value.
626+
* `scala.quoted.matching.ExprSeq`: matches an explicit sequence of expresions and returns them. These sequences are useful to get individual `Expr[T]` out of a varargs expression of type `Expr[Seq[T]]`.
627+
* `scala.quoted.matching.ConstSeq`: matches an explicit sequence of literal values and returns them.
628+
629+
These could be used in the following way to optimize any call to `sum` that has statically known values.
630+
```scala
631+
inline def sum(args: =>Int*): Int = ${ sumExpr('args) }
632+
private def sumExpr(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = argsExpr.underlyingArgument match {
633+
case ConstSeq(args) => // args is of type Seq[Int]
634+
Expr(args.sum) // precompute result of sum
635+
case ExprSeq(argExprs) => // argExprs is of type Seq[Expr[Int]]
636+
val staticSum: Int = argExprs.map {
637+
case Const(arg) => arg
638+
case _ => 0
639+
}.sum
640+
val dynamicSum: Seq[Expr[Int]] = argExprs.filter {
641+
case Const(_) => false
642+
case arg => true
643+
}
644+
dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg })
645+
case _ =>
646+
'{ $argsExpr.sum }
647+
}
648+
```
649+
650+
#### Quoted patterns
651+
652+
Quoted pattens allow to deconstruct complex code that contains a precise structure, types or methods.
653+
Patterns `'{ ... }` can be placed in any location where Scala expects a pattern.
654+
655+
For example
656+
```scala
657+
optimize {
658+
sum(sum(1, a, 2), 3, b)
659+
} // should be optimized to 6 + a + b
660+
```
661+
662+
```scala
663+
def sum(args: =>Int*): Int = args.sum
664+
def optimize(arg: Int): Int = ${ optimizeExpr('arg) }
665+
def optimizeExpr(body: Expr[Int])(given QuoteContext): Expr[Int] = body match {
666+
// Match a call to sum without any arguments
667+
case '{ sum() } => Expr(0)
668+
// Match a call to sum with an argument $n of type Int. n will be the Expr[Int] representing the argument.
669+
case '{ sum($n) } => n
670+
// Match a call to sum and extracts all its args in an `Expr[Seq[Int]]`
671+
case '{ sum($args: _*) } => sumExpr(args)
672+
case body => body
673+
}
674+
def sumExpr(args: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = args.underlyingArgument match {
675+
def flatSumArgs(arg: Expr[Int]): Seq[Expr[Int]] = arg match {
676+
case '{ sum($subArgs: _*) } => subArgs.flatMap(listflatSumArgsSum)
677+
case arg => Seq(arg)
678+
}
679+
val args2 = args1.flatMap(flatSumArgs)
680+
val staticSum: Int = args2.map {
681+
case Const(arg) => arg
682+
case _ => 0
683+
}.sum
684+
val dynamicSum: Seq[Expr[Int]] = args2.filter {
685+
case Const(_) => false
686+
case arg => true
687+
}
688+
dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg })
689+
}
690+
```
691+
692+
693+
### More details
617694
[More details](./macros-spec.md)

library/src/scala/quoted/Expr.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,26 @@ package quoted {
3030
final def matches(that: Expr[Any])(given qctx: QuoteContext): Boolean =
3131
!scala.internal.quoted.Expr.unapply[Unit, Unit](this)(given that, false, qctx).isEmpty
3232

33+
/** Returns the undelying argument that was in the call before inlining.
34+
*
35+
* ```
36+
* inline foo(x: Int): Int = baz(x, x)
37+
* foo(bar())
38+
* ```
39+
* is inlined as
40+
* ```
41+
* val x = bar()
42+
* baz(x, x)
43+
* ```
44+
* in this case the undelying argument of `x` will be `bar()`.
45+
*
46+
* Warning: Using the undelying argument directly in the expansion of a macro may change the parameter
47+
* semantics from by-value to by-name.
48+
*/
49+
def underlyingArgument(given qctx: QuoteContext): Expr[T] = {
50+
import qctx.tasty.{given, _}
51+
this.unseal.underlyingArgument.seal.asInstanceOf[Expr[T]]
52+
}
3353
}
3454

3555
object Expr {

library/src/scala/quoted/matching/ConstSeq.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,18 @@ package matching
44
/** Literal sequence of literal constant value expressions */
55
object ConstSeq {
66

7-
/** Matches literal sequence of literal constant value expressions */
7+
/** Matches literal sequence of literal constant value expressions and return a sequence of values.
8+
*
9+
* Usage:
10+
* ```scala
11+
* inline def sum(args: Int*): Int = ${ sumExpr('args) }
12+
* def sumExpr(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = argsExpr match
13+
* case ConstSeq(args) =>
14+
* // args: Seq[Int]
15+
* ...
16+
* }
17+
* ```
18+
*/
819
def unapply[T](expr: Expr[Seq[T]])(given qctx: QuoteContext): Option[Seq[T]] = expr match {
920
case ExprSeq(elems) =>
1021
elems.foldRight(Option(List.empty[T])) { (elem, acc) =>

library/src/scala/quoted/matching/ExprSeq.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,18 @@ package matching
44
/** Literal sequence of expressions */
55
object ExprSeq {
66

7-
/** Matches a literal sequence of expressions */
7+
/** Matches a literal sequence of expressions and return a sequence of expressions.
8+
*
9+
* Usage:
10+
* ```scala
11+
* inline def sum(args: Int*): Int = ${ sumExpr('args) }
12+
* def sumExpr(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = argsExpr match
13+
* case ExprSeq(argExprs) =>
14+
* // argExprs: Seq[Expr[Int]]
15+
* ...
16+
* }
17+
* ```
18+
*/
819
def unapply[T](expr: Expr[Seq[T]])(given qctx: QuoteContext): Option[Seq[Expr[T]]] = {
920
import qctx.tasty.{_, given}
1021
def rec(tree: Term): Option[Seq[Expr[T]]] = tree match {
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
6
2+
6
3+
10.+(Test.a)
4+
15
5+
args.sum[scala.Int](scala.math.Numeric.IntIsIntegral)
6+
9
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
inline def sum(args: Int*): Int = ${ sumExpr('args) }
5+
6+
inline def sumShow(args: Int*): String = ${ sumExprShow('args) }
7+
8+
private def sumExprShow(argsExpr: Expr[Seq[Int]])(given QuoteContext): Expr[String] =
9+
Expr(sumExpr(argsExpr).show)
10+
11+
private def sumExpr(argsExpr: Expr[Seq[Int]])(given qctx: QuoteContext): Expr[Int] = {
12+
import qctx.tasty.{given, _}
13+
argsExpr.underlyingArgument match {
14+
case ConstSeq(args) => // args is of type Seq[Int]
15+
Expr(args.sum) // precompute result of sum
16+
case ExprSeq(argExprs) => // argExprs is of type Seq[Expr[Int]]
17+
val staticSum: Int = argExprs.map {
18+
case Const(arg) => arg
19+
case _ => 0
20+
}.sum
21+
val dynamicSum: Seq[Expr[Int]] = argExprs.filter {
22+
case Const(_) => false
23+
case arg => true
24+
}
25+
dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg })
26+
case _ =>
27+
'{ $argsExpr.sum }
28+
}
29+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object Test extends App {
2+
println(sumShow(1, 2, 3))
3+
println(sum(1, 2, 3))
4+
val a: Int = 5
5+
println(sumShow(1, a, 4, 5))
6+
println(sum(1, a, 4, 5))
7+
val seq: Seq[Int] = Seq(1, 3, 5)
8+
println(sumShow(seq: _*))
9+
println(sum(seq: _*))
10+
}

0 commit comments

Comments
 (0)