Skip to content

Commit 05c5ba1

Browse files
Merge pull request #4823 from dotty-staging/fix-#4803
Enable macros in any transparent def
2 parents e85a91e + 05a273c commit 05c5ba1

19 files changed

+214
-21
lines changed

bench/tests/power-macro/PowerMacro.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ object PowerMacro {
66

77
def powerCode(n: Long, x: Expr[Double]): Expr[Double] =
88
if (n == 0) '(1.0)
9-
else if (n % 2 == 0) '{ { val y = ~x * ~x; ~powerCode(n / 2, '(y)) } }
9+
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode(n / 2, '(y)) }
1010
else '{ ~x * ~powerCode(n - 1, x) }
1111

1212
}

compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,16 @@ class ReifyQuotes extends MacroTransformWithImplicits {
103103
/** The main transformer class
104104
* @param inQuote we are within a `'(...)` context that is not shadowed by a nested `~(...)`
105105
* @param outer the next outer reifier, null is this is the topmost transformer
106-
* @param level the current level, where quotes add one and splices subtract one level
106+
* @param level the current level, where quotes add one and splices subtract one level.
107+
* The initial level is 0, a level `l` where `l > 0` implies code has been quoted `l` times
108+
* and `l == -1` is code inside a top level splice (in an transparent method).
107109
* @param levels a stacked map from symbols to the levels in which they were defined
108110
* @param embedded a list of embedded quotes (if `inSplice = true`) or splices (if `inQuote = true`
109111
*/
110112
private class Reifier(inQuote: Boolean, val outer: Reifier, val level: Int, levels: LevelInfo,
111113
val embedded: mutable.ListBuffer[Tree]) extends ImplicitsTransformer {
112114
import levels._
113-
assert(level >= 0)
115+
assert(level >= -1)
114116

115117
/** A nested reifier for a quote (if `isQuote = true`) or a splice (if not) */
116118
def nested(isQuote: Boolean): Reifier = {
@@ -205,7 +207,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
205207
}
206208

207209
/** Enter staging level of symbol defined by `tree`, if applicable. */
208-
def markDef(tree: Tree)(implicit ctx: Context) = tree match {
210+
def markDef(tree: Tree)(implicit ctx: Context): Unit = tree match {
209211
case tree: DefTree =>
210212
val sym = tree.symbol
211213
if ((sym.isClass || !sym.maybeOwner.isType) && !levelOf.contains(sym)) {
@@ -223,7 +225,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
223225
def levelOK(sym: Symbol)(implicit ctx: Context): Boolean = levelOf.get(sym) match {
224226
case Some(l) =>
225227
l == level ||
226-
l == 1 && level == 0 && isStage0Value(sym)
228+
l == 0 && level == -1 && isStageNegOneValue(sym)
227229
case None =>
228230
!sym.is(Param) || levelOK(sym.owner)
229231
}
@@ -239,7 +241,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
239241
*/
240242
def tryHeal(tp: Type, pos: Position)(implicit ctx: Context): Option[String] = tp match {
241243
case tp: TypeRef =>
242-
if (level == 0) {
244+
if (level == -1) {
243245
assert(ctx.owner.ownersIterator.exists(_.is(Transparent)))
244246
None
245247
} else {
@@ -357,7 +359,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
357359
}
358360
else body match {
359361
case body: RefTree if isCaptured(body.symbol, level + 1) =>
360-
if (isStage0Value(body.symbol)) {
362+
if (isStageNegOneValue(body.symbol)) {
361363
// Optimization: avoid the full conversion when capturing inlined `x`
362364
// in '{ x } to '{ x$1.toExpr.unary_~ } and go directly to `x$1.toExpr`
363365
liftInlineParamValue(capturers(body.symbol)(body))
@@ -368,7 +370,11 @@ class ReifyQuotes extends MacroTransformWithImplicits {
368370
}
369371
case _=>
370372
val (body1, splices) = nested(isQuote = true).split(body)
371-
pickledQuote(body1, splices, body.tpe, isType).withPos(quote.pos)
373+
if (level >= 0) pickledQuote(body1, splices, body.tpe, isType).withPos(quote.pos)
374+
else {
375+
// In top-level splice in an transparent def. Keep the tree as it is, it will be transformed at inline site.
376+
body
377+
}
372378
}
373379
}
374380

@@ -412,7 +418,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
412418
val body1 = nested(isQuote = false).transform(splice.qualifier)
413419
body1.select(splice.name)
414420
}
415-
else if (!inQuote && level == 0) {
421+
else if (!inQuote && level == 0 && !ctx.owner.is(Transparent)) {
416422
spliceOutsideQuotes(splice.pos)
417423
splice
418424
}
@@ -458,7 +464,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
458464
val tpw = tree.tpe.widen
459465
val argTpe =
460466
if (tree.isType) defn.QuotedTypeType.appliedTo(tpw)
461-
else if (isStage0Value(tree.symbol)) tpw
467+
else if (isStageNegOneValue(tree.symbol)) tpw
462468
else defn.QuotedExprType.appliedTo(tpw)
463469
val selectArg = arg.select(nme.apply).appliedTo(Literal(Constant(i))).asInstance(argTpe)
464470
val capturedArg = SyntheticValDef(UniqueName.fresh(tree.symbol.name.toTermName).toTermName, selectArg)
@@ -495,7 +501,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
495501
private def isCaptured(sym: Symbol, level: Int)(implicit ctx: Context): Boolean = {
496502
// Check phase consistency and presence of capturer
497503
( (level == 1 && levelOf.get(sym).contains(1)) ||
498-
(level == 0 && isStage0Value(sym))
504+
(level == 0 && isStageNegOneValue(sym))
499505
) && capturers.contains(sym)
500506
}
501507

@@ -537,7 +543,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
537543
val capturer = capturers(tree.symbol)
538544
def captureAndSplice(t: Tree) =
539545
splice(t.select(if (tree.isTerm) nme.UNARY_~ else tpnme.UNARY_~))
540-
if (!isStage0Value(tree.symbol)) captureAndSplice(capturer(tree))
546+
if (!isStageNegOneValue(tree.symbol)) captureAndSplice(capturer(tree))
541547
else if (level == 0) capturer(tree)
542548
else captureAndSplice(liftInlineParamValue(capturer(tree)))
543549
case Block(stats, _) =>
@@ -559,13 +565,12 @@ class ReifyQuotes extends MacroTransformWithImplicits {
559565
case _: Import =>
560566
tree
561567
case tree: DefDef if tree.symbol.is(Macro) && level == 0 =>
568+
if (enclosingInlineds.nonEmpty)
569+
return EmptyTree // Already checked at definition site and already inlined
570+
markDef(tree)
562571
tree.rhs match {
563572
case InlineSplice(_) =>
564-
if (!tree.symbol.isStatic)
565-
ctx.error("Transparent macro method must be a static method.", tree.pos)
566-
markDef(tree)
567-
val reifier = nested(isQuote = true)
568-
reifier.transform(tree) // Ignore output, only check PCP
573+
mapOverTree(enteredSyms) // Ignore output, only check PCP
569574
cpy.DefDef(tree)(rhs = defaultValue(tree.rhs.tpe))
570575
case _ =>
571576
ctx.error(
@@ -602,7 +607,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
602607
ref(lifter).select("toExpr".toTermName).appliedTo(tree)
603608
}
604609

605-
private def isStage0Value(sym: Symbol)(implicit ctx: Context): Boolean =
610+
private def isStageNegOneValue(sym: Symbol)(implicit ctx: Context): Boolean =
606611
(sym.is(Transparent) && sym.owner.is(Transparent) && !defn.isFunctionType(sym.info)) ||
607612
sym == defn.TastyTopLevelSplice_tastyContext // intrinsic value at stage 0
608613

tests/neg/quote-non-static-macro.scala renamed to tests/pos/quote-non-static-macro.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import scala.quoted._
22

33
class Foo {
4-
transparent def foo: Unit = ~Foo.impl // error
4+
transparent def foo: Unit = ~Foo.impl
55
object Bar {
6-
transparent def foo: Unit = ~Foo.impl // error
6+
transparent def foo: Unit = ~Foo.impl
77
}
88
}
99

1010
object Foo {
1111
class Baz {
12-
transparent def foo: Unit = ~impl // error
12+
transparent def foo: Unit = ~impl
1313
}
1414
object Quox {
1515
transparent def foo: Unit = ~Foo.impl

tests/run/i4803.check

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
1.0
2+
1.5
3+
2.25
4+
7.59375
5+
1.0
6+
1.5
7+
2.25
8+
7.59375

tests/run/i4803/App_2.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
2+
class Num2(x: Double) {
3+
transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n)
4+
}
5+
6+
object Test {
7+
def main(args: Array[String]): Unit = {
8+
val n = new Num(1.5)
9+
println(n.power(0))
10+
println(n.power(1))
11+
println(n.power(2))
12+
println(n.power(5))
13+
14+
val n2 = new Num2(1.5)
15+
println(n.power(0))
16+
println(n.power(1))
17+
println(n.power(2))
18+
println(n.power(5))
19+
}
20+
}

tests/run/i4803/Macro_1.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import scala.quoted._
2+
3+
object PowerMacro {
4+
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
5+
if (n == 0) '(1.0)
6+
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
7+
else '{ ~x * ~powerCode(x, n - 1) }
8+
}
9+
10+
class Num(x: Double) {
11+
transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n)
12+
}

tests/run/i4803b.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
1.0
2+
1.5
3+
2.25
4+
7.59375

tests/run/i4803b/App_2.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
2+
3+
class Nums {
4+
class Num(x: Double) {
5+
transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n)
6+
}
7+
}
8+
9+
object Test {
10+
def main(args: Array[String]): Unit = {
11+
val nums = new Nums
12+
val n = new nums.Num(1.5)
13+
println(n.power(0))
14+
println(n.power(1))
15+
println(n.power(2))
16+
println(n.power(5))
17+
}
18+
}

tests/run/i4803b/Macro_1.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import scala.quoted._
2+
3+
object PowerMacro {
4+
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
5+
if (n == 0) '(1.0)
6+
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
7+
else '{ ~x * ~powerCode(x, n - 1) }
8+
}

tests/run/i4803c.check

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
1.0
2+
1.5
3+
2.25
4+
7.59375
5+
1.0
6+
1.5
7+
2.25
8+
7.59375

tests/run/i4803c/App_2.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
class Num(x: Double) {
5+
transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n)
6+
}
7+
val n = new Num(1.5)
8+
println(n.power(0))
9+
println(n.power(1))
10+
println(n.power(2))
11+
println(n.power(5))
12+
13+
transparent def power(x: Double, transparent n: Long) = ~PowerMacro.powerCode('(x), n)
14+
15+
val x: Double = 1.5
16+
17+
println(power(x, 0))
18+
println(power(x, 1))
19+
println(power(x, 2))
20+
println(power(x, 5))
21+
}
22+
}

tests/run/i4803c/Macro_1.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import scala.quoted._
2+
3+
object PowerMacro {
4+
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
5+
if (n == 0) '(1.0)
6+
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
7+
else '{ ~x * ~powerCode(x, n - 1) }
8+
}

tests/run/i4803d.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
0.0
2+
2.25
3+
12.25

tests/run/i4803d/App_2.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
val x1: Double = 0
5+
val x2: Double = 1.5
6+
val x3: Double = 3.5
7+
8+
println(power2(x1))
9+
println(power2(x2))
10+
println(power2(x3))
11+
}
12+
13+
transparent def power2(x: Double) = {
14+
transparent def power(x: Double, transparent n: Long) = ~PowerMacro.powerCode('(x), n)
15+
power(x, 2)
16+
}
17+
}

tests/run/i4803d/Macro_1.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import scala.quoted._
2+
3+
object PowerMacro {
4+
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
5+
if (n == 0) '(1.0)
6+
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
7+
else '{ ~x * ~powerCode(x, n - 1) }
8+
}

tests/run/i4803e/App_2.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
val x1: Double = 0
5+
val x2: Double = 1.5
6+
val x3: Double = 3.5
7+
8+
println(power2(x1))
9+
println(power2(x2))
10+
println(power2(x3))
11+
}
12+
13+
transparent def power2(x: Double) = ~PowerMacro.power2('(x))
14+
}

tests/run/i4803e/Macro_1.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import scala.quoted._
2+
3+
object PowerMacro {
4+
def power2(x: Expr[Double]) = '{
5+
transparent def power(x: Double, n: Long): Double =
6+
if (n == 0) 1.0
7+
else if (n % 2 == 0) { val y = x * x; power(y, n / 2) }
8+
else x * power(x, n - 1)
9+
power(~x, 2)
10+
}
11+
}

tests/run/i4803f/App_2.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
val x1: Double = 0
5+
val x2: Double = 1.5
6+
val x3: Double = 3.5
7+
8+
println(power2(x1))
9+
println(power2(x2))
10+
println(power2(x3))
11+
}
12+
13+
transparent def power2(x: Double) = ~PowerMacro.power2('(x))
14+
}

tests/run/i4803f/Macro_1.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.quoted._
2+
3+
object PowerMacro {
4+
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
5+
if (n == 0) '(1.0)
6+
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
7+
else '{ ~x * ~powerCode(x, n - 1) }
8+
9+
def power2(x: Expr[Double]) = '{
10+
transparent def power(x: Double): Double = ~powerCode('(x), 2)
11+
power(~x)
12+
}
13+
}

0 commit comments

Comments
 (0)