Skip to content

Commit 1a777d6

Browse files
committed
Fix doc and add regression test
1 parent e6628bb commit 1a777d6

File tree

4 files changed

+59
-6
lines changed

4 files changed

+59
-6
lines changed

docs/docs/reference/metaprogramming/macros.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ In case all files are suspended due to cyclic dependencies the compilation will
618618

619619
It is possible to deconstruct or extract values out of `Expr` using pattern matching.
620620

621-
#### scala.quoted .matching
621+
#### scala.quoted.matching
622622

623623
In `scala.quoted.matching` contains object that can help extract values from `Expr`.
624624

@@ -661,19 +661,19 @@ optimize {
661661

662662
```scala
663663
def sum(args: =>Int*): Int = args.sum
664-
def optimize(arg: Int): Int = ${ optimizeExpr('arg) }
665-
def optimizeExpr(body: Expr[Int])(given QuoteContext): Expr[Int] = body match {
664+
inline def optimize(arg: Int): Int = ${ optimizeExpr('arg) }
665+
private def optimizeExpr(body: Expr[Int])(given QuoteContext): Expr[Int] = body match {
666666
// Match a call to sum without any arguments
667667
case '{ sum() } => Expr(0)
668668
// Match a call to sum with an argument $n of type Int. n will be the Expr[Int] representing the argument.
669669
case '{ sum($n) } => n
670670
// Match a call to sum and extracts all its args in an `Expr[Seq[Int]]`
671-
case '{ sum($args: _*) } => sumExpr(args)
671+
case '{ sum(${ExprSeq(args)}: _*) } => sumExpr(args)
672672
case body => body
673673
}
674-
def sumExpr(args: Expr[Seq[Int]])(given QuoteContext): Expr[Int] = args.underlyingArgument match {
674+
private def sumExpr(args1: Seq[Expr[Int]])(given QuoteContext): Expr[Int] = {
675675
def flatSumArgs(arg: Expr[Int]): Seq[Expr[Int]] = arg match {
676-
case '{ sum($subArgs: _*) } => subArgs.flatMap(listflatSumArgsSum)
676+
case '{ sum(${ExprSeq(subArgs)}: _*) } => subArgs.flatMap(flatSumArgs)
677677
case arg => Seq(arg)
678678
}
679679
val args2 = args1.flatMap(flatSumArgs)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
6
2+
6
3+
12.+(Test.a)
4+
17
5+
4.+(Macro_1$package.sum((Test.seq: scala.<repeated>[scala.Int])))
6+
13
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import scala.quoted._
2+
import scala.quoted.matching._
3+
4+
def sum(args: =>Int*): Int = args.sum
5+
6+
inline def showOptimize(arg: Int): String = ${ showOptimizeExpr('arg) }
7+
inline def optimize(arg: Int): Int = ${ optimizeExpr('arg) }
8+
9+
private def showOptimizeExpr(body: Expr[Int])(given QuoteContext): Expr[String] =
10+
Expr(optimizeExpr(body).show)
11+
12+
private def optimizeExpr(body: Expr[Int])(given QuoteContext): Expr[Int] = body match {
13+
// Match a call to sum without any arguments
14+
case '{ sum() } => Expr(0)
15+
// Match a call to sum with an argument $n of type Int. n will be the Expr[Int] representing the argument.
16+
case '{ sum($n) } => n
17+
// Match a call to sum and extracts all its args in an `Expr[Seq[Int]]`
18+
case '{ sum(${ExprSeq(args)}: _*) } => sumExpr(args)
19+
case body => body
20+
}
21+
22+
private def sumExpr(args1: Seq[Expr[Int]])(given QuoteContext): Expr[Int] = {
23+
def flatSumArgs(arg: Expr[Int]): Seq[Expr[Int]] = arg match {
24+
case '{ sum(${ExprSeq(subArgs)}: _*) } => subArgs.flatMap(flatSumArgs)
25+
case arg => Seq(arg)
26+
}
27+
val args2 = args1.flatMap(flatSumArgs)
28+
val staticSum: Int = args2.map {
29+
case Const(arg) => arg
30+
case _ => 0
31+
}.sum
32+
val dynamicSum: Seq[Expr[Int]] = args2.filter {
33+
case Const(_) => false
34+
case arg => true
35+
}
36+
dynamicSum.foldLeft(Expr(staticSum))((acc, arg) => '{ $acc + $arg })
37+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object Test extends App {
2+
println(showOptimize(sum(1, 2, 3)))
3+
println(optimize(sum(1, 2, 3)))
4+
val a: Int = 5
5+
println(showOptimize(sum(1, a, sum(1, 2, 3), 5)))
6+
println(optimize(sum(1, a, sum(1, 2, 3), 5)))
7+
val seq: Seq[Int] = Seq(1, 3, 5)
8+
println(showOptimize(sum(1, sum(seq: _*), 3)))
9+
println(optimize(sum(1, sum(seq: _*), 3)))
10+
}

0 commit comments

Comments
 (0)