Skip to content

Commit 63cfe8c

Browse files
committed
Fix #4801: Allow multiple splices in macros
1 parent 45641b7 commit 63cfe8c

File tree

7 files changed

+82
-70
lines changed

7 files changed

+82
-70
lines changed

compiler/src/dotty/tools/dotc/transform/MacrosSplitter.scala

Lines changed: 26 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -57,53 +57,38 @@ class MacrosSplitter extends ReifyQuotes {
5757
case tree: Template => super.transform(tree)
5858
case macroDefTree: DefDef if macroDefTree.symbol.is(Inline) =>
5959

60-
val reifier = new Reifier(true, null, 1, new LevelInfo, new mutable.ListBuffer[Tree])
60+
val reifier = new Reifier(true, false, null, 1, new LevelInfo, new mutable.ListBuffer[Tree])
6161

6262
val transformedTree = reifier.transform(macroDefTree) // Ignore output, we only need the its embedding
6363

6464
if (reifier.embedded.isEmpty) macroDefTree // Not a macro
65-
else {
66-
if (!macroDefTree.symbol.isStatic) // TODO remove restriction (issue #4803)
67-
ctx.error("Inline macro method must be a static method.", macroDefTree.pos)
68-
if (InlineSplice.unapply(macroDefTree.rhs).isEmpty) { // TODO allow multiple splices (issue #4801)
69-
ctx.error(
70-
"""Malformed inline macro.
71-
|
72-
|Expected the ~ to be at the top of the RHS:
73-
| inline def foo(...): Int = ~impl(...)
74-
|or
75-
| inline def foo(...): Int = ~{
76-
| val x = 1
77-
| impl(... x ...)
78-
| }
79-
""".stripMargin, macroDefTree.rhs.pos)
80-
EmptyTree
81-
}
82-
else {
83-
val splicers = List.newBuilder[DefDef]
84-
val transformer = new TreeMap() {
85-
override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
86-
case Hole(idx, args) =>
87-
val targs = args.filter(_.isType).map(_.tpe)
88-
89-
val Block((lambdaDef: tpd.DefDef) :: Nil, _) = reifier.embedded(idx)
90-
91-
val splicer = genSplicer(macroDefTree, lambdaDef, targs)
92-
splicers += splicer
93-
94-
val liftedArgs = args.map { arg =>
95-
if (arg.symbol.is(Inline) || arg.symbol == defn.TastyTopLevelSplice_tastyContext) arg
96-
else if (arg.isType) ref(defn.QuotedType_apply).appliedToType(arg.tpe)
97-
else ref(defn.QuotedExpr_apply).appliedToType(arg.tpe.widen).appliedTo(arg)
98-
}
99-
ref(splicer.symbol).appliedToTypes(targs).appliedToArgs(liftedArgs).select(nme.UNARY_~)
100-
case _ => super.transform(tree)
101-
}
65+
else if (!macroDefTree.symbol.isStatic) { // TODO remove restriction (issue #4803)
66+
ctx.error("Inline macro method must be a static method.", macroDefTree.pos)
67+
EmptyTree
68+
} else {
69+
val splicers = List.newBuilder[DefDef]
70+
val transformer = new TreeMap() {
71+
override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = tree match {
72+
case Hole(idx, args) =>
73+
val targs = args.filter(_.isType).map(_.tpe)
74+
75+
val Block((lambdaDef: tpd.DefDef) :: Nil, _) = reifier.embedded(idx)
76+
77+
val splicer = genSplicer(macroDefTree, lambdaDef, targs)
78+
splicers += splicer
79+
80+
val liftedArgs = args.map { arg =>
81+
if (arg.symbol.is(Inline) || arg.symbol == defn.TastyTopLevelSplice_tastyContext) arg
82+
else if (arg.isType) ref(defn.QuotedType_apply).appliedToType(arg.tpe)
83+
else ref(defn.QuotedExpr_apply).appliedToType(arg.tpe.widen).appliedTo(arg)
84+
}
85+
ref(splicer.symbol).appliedToTypes(targs).appliedToArgs(liftedArgs).select(nme.UNARY_~)
86+
case _ => super.transform(tree)
10287
}
103-
val newRhs = transformer.transform(transformedTree.asInstanceOf[DefDef].rhs)
104-
val newDef = cpy.DefDef(macroDefTree)(rhs = newRhs)
105-
Thicket(newDef :: splicers.result())
10688
}
89+
val newRhs = transformer.transform(transformedTree.asInstanceOf[DefDef].rhs)
90+
val newDef = cpy.DefDef(macroDefTree)(rhs = newRhs)
91+
Thicket(newDef :: splicers.result())
10792
}
10893
case _ =>
10994
tree
@@ -171,18 +156,4 @@ class MacrosSplitter extends ReifyQuotes {
171156
polyDefDef(splicerSym, tparams => vparamss => treeTypeMap(tparams, vparamss.head).transform(body))
172157
}
173158

174-
/** InlineSplice is used to detect cases where the expansion
175-
* consists of a (possibly multiple & nested) block or a sole expression.
176-
*/
177-
object InlineSplice {
178-
def unapply(tree: Tree)(implicit ctx: Context): Option[Select] = {
179-
tree match {
180-
case expansion: Select if expansion.symbol.isSplice => Some(expansion)
181-
case Block(List(stat), Literal(Constant(()))) => unapply(stat)
182-
case Block(Nil, expr) => unapply(expr)
183-
case _ => None
184-
}
185-
}
186-
}
187-
188159
}

compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
8484
if (ctx.compilationUnit.containsQuotesOrSplices) super.run
8585

8686
protected def newTransformer(implicit ctx: Context): Transformer =
87-
new Reifier(inQuote = false, null, 0, new LevelInfo, new mutable.ListBuffer[Tree])
87+
new Reifier(inQuote = false, inlined = false, null, 0, new LevelInfo, new mutable.ListBuffer[Tree])
8888

8989
protected class LevelInfo {
9090
/** A map from locally defined symbols to the staging levels of their definitions */
@@ -113,17 +113,20 @@ class ReifyQuotes extends MacroTransformWithImplicits {
113113
* @param levels a stacked map from symbols to the levels in which they were defined
114114
* @param embedded a list of embedded quotes (if `inSplice = true`) or splices (if `inQuote = true`
115115
*/
116-
protected class Reifier(inQuote: Boolean, val outer: Reifier, val level: Int, levels: LevelInfo,
116+
protected class Reifier(inQuote: Boolean, inlined: Boolean, val outer: Reifier, val level: Int, levels: LevelInfo,
117117
val embedded: mutable.ListBuffer[Tree]) extends ImplicitsTransformer {
118118
import levels._
119119
assert(level >= 0)
120120

121121
/** A nested reifier for a quote (if `isQuote = true`) or a splice (if not) */
122122
def nested(isQuote: Boolean): Reifier = {
123123
val nestedEmbedded = if (level > 1 || (level == 1 && isQuote)) embedded else new mutable.ListBuffer[Tree]
124-
new Reifier(isQuote, this, if (isQuote) level + 1 else level - 1, levels, nestedEmbedded)
124+
new Reifier(isQuote, inlined, this, if (isQuote) level + 1 else level - 1, levels, nestedEmbedded)
125125
}
126126

127+
def inlinedReifier: Reifier =
128+
new Reifier(inQuote, inlined = true, this, level, levels, embedded)
129+
127130
/** We are in a `~(...)` context that is not shadowed by a nested `'(...)` */
128131
def inSplice: Boolean = outer != null && !inQuote
129132

@@ -418,6 +421,11 @@ class ReifyQuotes extends MacroTransformWithImplicits {
418421
val body1 = nested(isQuote = false).transform(splice.qualifier)
419422
body1.select(splice.name)
420423
}
424+
else if (inlined && level == 0) {
425+
val spliced = Splicer.splice(splice.qualifier, splice.pos, macroClassLoader).withPos(splice.pos)
426+
if (ctx.reporter.hasErrors) EmptyTree
427+
else transform(spliced)
428+
}
421429
else if (!inQuote && level == 0) {
422430
spliceOutsideQuotes(splice.pos)
423431
splice
@@ -561,16 +569,12 @@ class ReifyQuotes extends MacroTransformWithImplicits {
561569
val last = enteredSyms
562570
stats.foreach(markDef)
563571
mapOverTree(last)
564-
case Inlined(call, bindings, InlineSplice(expansion @ Select(body, name))) =>
565-
assert(call.symbol.is(Inline))
572+
case Inlined(call, bindings, expansion) if !inlined =>
566573
val tree2 =
567574
if (level == 0) {
568-
// Simplification of the call done in PostTyper for non-macros can also be performed now
569-
// see PostTyper `case Inlined(...) =>` for description of the simplification
570-
val call2 = Ident(call.symbol.topLevelClass.typeRef).withPos(call.pos)
571-
val spliced = Splicer.splice(body, call, bindings, tree.pos, macroClassLoader).withPos(tree.pos)
575+
val spliced = inlinedReifier.transform(expansion)
572576
if (ctx.reporter.hasErrors) EmptyTree
573-
else transform(cpy.Inlined(tree)(call2, bindings, spliced))
577+
else cpy.Inlined(tree)(call, bindings, spliced)
574578
}
575579
else super.transform(tree)
576580

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ object Splicer {
3434
*
3535
* See: `ReifyQuotes`
3636
*/
37-
def splice(tree: Tree, call: Tree, bindings: List[Tree], pos: Position, classLoader: ClassLoader)(implicit ctx: Context): Tree = tree match {
37+
def splice(tree: Tree, pos: Position, classLoader: ClassLoader)(implicit ctx: Context): Tree = tree match {
3838
case Quoted(quotedTree) => quotedTree
3939
case _ =>
4040
val interpreter = new Interpreter(pos, classLoader)

tests/neg/quote-macro-splice.scala renamed to tests/pos/quote-macro-splice.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@ import scala.quoted._
22

33
object Test {
44

5-
inline def foo1: Int = { // error
5+
inline def foo1: Int = {
66
println()
77
~impl(1.toExpr)
88
}
99

10-
inline def foo2: Int = { // error
10+
inline def foo2: Int = {
1111
~impl(1.toExpr)
1212
~impl(2.toExpr)
1313
}
1414

15-
inline def foo3: Int = { // error
15+
inline def foo3: Int = {
1616
val a = 1
1717
~impl('(a))
1818
}
1919

20-
inline def foo4: Int = { // error
20+
inline def foo4: Int = {
2121
~impl('(1))
2222
1
2323
}

tests/run/quote-multi-splice.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
5.0
2+
0.2
3+
5.0
4+
0.2
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
2+
import scala.quoted.Expr
3+
4+
object PowerMacro {
5+
6+
inline def powerV1(inline n: Long, x: Double) = {
7+
val y = ~powerCode(Math.max(n, -n), '(x))
8+
if (n < 0) 1 / y else y
9+
}
10+
11+
inline def powerV2(inline n: Long, x: Double) = {
12+
val a = ~powerCode(n, '(x))
13+
val b = 1 / ~powerCode(-n, '(x))
14+
assert(a == b)
15+
a
16+
}
17+
18+
def powerCode(n: Long, x: Expr[Double]): Expr[Double] =
19+
if (n < 0) '(1.0 / ~powerCode(-n, x))
20+
else if (n == 0) '(1.0)
21+
else if (n % 2 == 0) '{ { val y = ~x * ~x; ~powerCode(n / 2, '(y)) } }
22+
else '{ ~x * ~powerCode(n - 1, x) }
23+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
object Test {
2+
import PowerMacro._
3+
4+
def main(args: Array[String]): Unit = {
5+
println(powerV1(1, 5.0))
6+
println(powerV1(-1, 5.0))
7+
println(powerV2(1, 5.0))
8+
println(powerV2(-1, 5.0))
9+
}
10+
}

0 commit comments

Comments
 (0)