diff --git a/bench/tests/power-macro/PowerMacro.scala b/bench/tests/power-macro/PowerMacro.scala index cc8ac09b7d2a..3a18f97e980e 100644 --- a/bench/tests/power-macro/PowerMacro.scala +++ b/bench/tests/power-macro/PowerMacro.scala @@ -6,7 +6,7 @@ object PowerMacro { def powerCode(n: Long, x: Expr[Double]): Expr[Double] = if (n == 0) '(1.0) - else if (n % 2 == 0) '{ { val y = ~x * ~x; ~powerCode(n / 2, '(y)) } } + else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode(n / 2, '(y)) } else '{ ~x * ~powerCode(n - 1, x) } } diff --git a/compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala b/compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala index 72e6c97c10dc..ded1a950be50 100644 --- a/compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala +++ b/compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala @@ -103,14 +103,16 @@ class ReifyQuotes extends MacroTransformWithImplicits { /** The main transformer class * @param inQuote we are within a `'(...)` context that is not shadowed by a nested `~(...)` * @param outer the next outer reifier, null is this is the topmost transformer - * @param level the current level, where quotes add one and splices subtract one level + * @param level the current level, where quotes add one and splices subtract one level. + * The initial level is 0, a level `l` where `l > 0` implies code has been quoted `l` times + * and `l == -1` is code inside a top level splice (in an transparent method). * @param levels a stacked map from symbols to the levels in which they were defined * @param embedded a list of embedded quotes (if `inSplice = true`) or splices (if `inQuote = true` */ private class Reifier(inQuote: Boolean, val outer: Reifier, val level: Int, levels: LevelInfo, val embedded: mutable.ListBuffer[Tree]) extends ImplicitsTransformer { import levels._ - assert(level >= 0) + assert(level >= -1) /** A nested reifier for a quote (if `isQuote = true`) or a splice (if not) */ def nested(isQuote: Boolean): Reifier = { @@ -205,7 +207,7 @@ class ReifyQuotes extends MacroTransformWithImplicits { } /** Enter staging level of symbol defined by `tree`, if applicable. */ - def markDef(tree: Tree)(implicit ctx: Context) = tree match { + def markDef(tree: Tree)(implicit ctx: Context): Unit = tree match { case tree: DefTree => val sym = tree.symbol if ((sym.isClass || !sym.maybeOwner.isType) && !levelOf.contains(sym)) { @@ -223,7 +225,7 @@ class ReifyQuotes extends MacroTransformWithImplicits { def levelOK(sym: Symbol)(implicit ctx: Context): Boolean = levelOf.get(sym) match { case Some(l) => l == level || - l == 1 && level == 0 && isStage0Value(sym) + l == 0 && level == -1 && isStageNegOneValue(sym) case None => !sym.is(Param) || levelOK(sym.owner) } @@ -239,7 +241,7 @@ class ReifyQuotes extends MacroTransformWithImplicits { */ def tryHeal(tp: Type, pos: Position)(implicit ctx: Context): Option[String] = tp match { case tp: TypeRef => - if (level == 0) { + if (level == -1) { assert(ctx.owner.ownersIterator.exists(_.is(Transparent))) None } else { @@ -357,7 +359,7 @@ class ReifyQuotes extends MacroTransformWithImplicits { } else body match { case body: RefTree if isCaptured(body.symbol, level + 1) => - if (isStage0Value(body.symbol)) { + if (isStageNegOneValue(body.symbol)) { // Optimization: avoid the full conversion when capturing inlined `x` // in '{ x } to '{ x$1.toExpr.unary_~ } and go directly to `x$1.toExpr` liftInlineParamValue(capturers(body.symbol)(body)) @@ -368,7 +370,11 @@ class ReifyQuotes extends MacroTransformWithImplicits { } case _=> val (body1, splices) = nested(isQuote = true).split(body) - pickledQuote(body1, splices, body.tpe, isType).withPos(quote.pos) + if (level >= 0) pickledQuote(body1, splices, body.tpe, isType).withPos(quote.pos) + else { + // In top-level splice in an transparent def. Keep the tree as it is, it will be transformed at inline site. + body + } } } @@ -412,7 +418,7 @@ class ReifyQuotes extends MacroTransformWithImplicits { val body1 = nested(isQuote = false).transform(splice.qualifier) body1.select(splice.name) } - else if (!inQuote && level == 0) { + else if (!inQuote && level == 0 && !ctx.owner.is(Transparent)) { spliceOutsideQuotes(splice.pos) splice } @@ -458,7 +464,7 @@ class ReifyQuotes extends MacroTransformWithImplicits { val tpw = tree.tpe.widen val argTpe = if (tree.isType) defn.QuotedTypeType.appliedTo(tpw) - else if (isStage0Value(tree.symbol)) tpw + else if (isStageNegOneValue(tree.symbol)) tpw else defn.QuotedExprType.appliedTo(tpw) val selectArg = arg.select(nme.apply).appliedTo(Literal(Constant(i))).asInstance(argTpe) val capturedArg = SyntheticValDef(UniqueName.fresh(tree.symbol.name.toTermName).toTermName, selectArg) @@ -495,7 +501,7 @@ class ReifyQuotes extends MacroTransformWithImplicits { private def isCaptured(sym: Symbol, level: Int)(implicit ctx: Context): Boolean = { // Check phase consistency and presence of capturer ( (level == 1 && levelOf.get(sym).contains(1)) || - (level == 0 && isStage0Value(sym)) + (level == 0 && isStageNegOneValue(sym)) ) && capturers.contains(sym) } @@ -537,7 +543,7 @@ class ReifyQuotes extends MacroTransformWithImplicits { val capturer = capturers(tree.symbol) def captureAndSplice(t: Tree) = splice(t.select(if (tree.isTerm) nme.UNARY_~ else tpnme.UNARY_~)) - if (!isStage0Value(tree.symbol)) captureAndSplice(capturer(tree)) + if (!isStageNegOneValue(tree.symbol)) captureAndSplice(capturer(tree)) else if (level == 0) capturer(tree) else captureAndSplice(liftInlineParamValue(capturer(tree))) case Block(stats, _) => @@ -559,13 +565,12 @@ class ReifyQuotes extends MacroTransformWithImplicits { case _: Import => tree case tree: DefDef if tree.symbol.is(Macro) && level == 0 => + if (enclosingInlineds.nonEmpty) + return EmptyTree // Already checked at definition site and already inlined + markDef(tree) tree.rhs match { case InlineSplice(_) => - if (!tree.symbol.isStatic) - ctx.error("Transparent macro method must be a static method.", tree.pos) - markDef(tree) - val reifier = nested(isQuote = true) - reifier.transform(tree) // Ignore output, only check PCP + mapOverTree(enteredSyms) // Ignore output, only check PCP cpy.DefDef(tree)(rhs = defaultValue(tree.rhs.tpe)) case _ => ctx.error( @@ -602,7 +607,7 @@ class ReifyQuotes extends MacroTransformWithImplicits { ref(lifter).select("toExpr".toTermName).appliedTo(tree) } - private def isStage0Value(sym: Symbol)(implicit ctx: Context): Boolean = + private def isStageNegOneValue(sym: Symbol)(implicit ctx: Context): Boolean = (sym.is(Transparent) && sym.owner.is(Transparent) && !defn.isFunctionType(sym.info)) || sym == defn.TastyTopLevelSplice_tastyContext // intrinsic value at stage 0 diff --git a/tests/neg/quote-non-static-macro.scala b/tests/pos/quote-non-static-macro.scala similarity index 55% rename from tests/neg/quote-non-static-macro.scala rename to tests/pos/quote-non-static-macro.scala index 32600eae7bb7..00df6d9b1c3a 100644 --- a/tests/neg/quote-non-static-macro.scala +++ b/tests/pos/quote-non-static-macro.scala @@ -1,15 +1,15 @@ import scala.quoted._ class Foo { - transparent def foo: Unit = ~Foo.impl // error + transparent def foo: Unit = ~Foo.impl object Bar { - transparent def foo: Unit = ~Foo.impl // error + transparent def foo: Unit = ~Foo.impl } } object Foo { class Baz { - transparent def foo: Unit = ~impl // error + transparent def foo: Unit = ~impl } object Quox { transparent def foo: Unit = ~Foo.impl diff --git a/tests/run/i4803.check b/tests/run/i4803.check new file mode 100644 index 000000000000..927510ef8789 --- /dev/null +++ b/tests/run/i4803.check @@ -0,0 +1,8 @@ +1.0 +1.5 +2.25 +7.59375 +1.0 +1.5 +2.25 +7.59375 diff --git a/tests/run/i4803/App_2.scala b/tests/run/i4803/App_2.scala new file mode 100644 index 000000000000..fa550642c722 --- /dev/null +++ b/tests/run/i4803/App_2.scala @@ -0,0 +1,20 @@ + +class Num2(x: Double) { + transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n) +} + +object Test { + def main(args: Array[String]): Unit = { + val n = new Num(1.5) + println(n.power(0)) + println(n.power(1)) + println(n.power(2)) + println(n.power(5)) + + val n2 = new Num2(1.5) + println(n.power(0)) + println(n.power(1)) + println(n.power(2)) + println(n.power(5)) + } +} diff --git a/tests/run/i4803/Macro_1.scala b/tests/run/i4803/Macro_1.scala new file mode 100644 index 000000000000..deae9bead04e --- /dev/null +++ b/tests/run/i4803/Macro_1.scala @@ -0,0 +1,12 @@ +import scala.quoted._ + +object PowerMacro { + def powerCode(x: Expr[Double], n: Long): Expr[Double] = + if (n == 0) '(1.0) + else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) } + else '{ ~x * ~powerCode(x, n - 1) } +} + +class Num(x: Double) { + transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n) +} diff --git a/tests/run/i4803b.check b/tests/run/i4803b.check new file mode 100644 index 000000000000..5defea0d3920 --- /dev/null +++ b/tests/run/i4803b.check @@ -0,0 +1,4 @@ +1.0 +1.5 +2.25 +7.59375 diff --git a/tests/run/i4803b/App_2.scala b/tests/run/i4803b/App_2.scala new file mode 100644 index 000000000000..2e93422e36cf --- /dev/null +++ b/tests/run/i4803b/App_2.scala @@ -0,0 +1,18 @@ + + +class Nums { + class Num(x: Double) { + transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n) + } +} + +object Test { + def main(args: Array[String]): Unit = { + val nums = new Nums + val n = new nums.Num(1.5) + println(n.power(0)) + println(n.power(1)) + println(n.power(2)) + println(n.power(5)) + } +} diff --git a/tests/run/i4803b/Macro_1.scala b/tests/run/i4803b/Macro_1.scala new file mode 100644 index 000000000000..681f3b2fac63 --- /dev/null +++ b/tests/run/i4803b/Macro_1.scala @@ -0,0 +1,8 @@ +import scala.quoted._ + +object PowerMacro { + def powerCode(x: Expr[Double], n: Long): Expr[Double] = + if (n == 0) '(1.0) + else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) } + else '{ ~x * ~powerCode(x, n - 1) } +} diff --git a/tests/run/i4803c.check b/tests/run/i4803c.check new file mode 100644 index 000000000000..927510ef8789 --- /dev/null +++ b/tests/run/i4803c.check @@ -0,0 +1,8 @@ +1.0 +1.5 +2.25 +7.59375 +1.0 +1.5 +2.25 +7.59375 diff --git a/tests/run/i4803c/App_2.scala b/tests/run/i4803c/App_2.scala new file mode 100644 index 000000000000..f3b2655b8c2c --- /dev/null +++ b/tests/run/i4803c/App_2.scala @@ -0,0 +1,22 @@ + +object Test { + def main(args: Array[String]): Unit = { + class Num(x: Double) { + transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n) + } + val n = new Num(1.5) + println(n.power(0)) + println(n.power(1)) + println(n.power(2)) + println(n.power(5)) + + transparent def power(x: Double, transparent n: Long) = ~PowerMacro.powerCode('(x), n) + + val x: Double = 1.5 + + println(power(x, 0)) + println(power(x, 1)) + println(power(x, 2)) + println(power(x, 5)) + } +} diff --git a/tests/run/i4803c/Macro_1.scala b/tests/run/i4803c/Macro_1.scala new file mode 100644 index 000000000000..681f3b2fac63 --- /dev/null +++ b/tests/run/i4803c/Macro_1.scala @@ -0,0 +1,8 @@ +import scala.quoted._ + +object PowerMacro { + def powerCode(x: Expr[Double], n: Long): Expr[Double] = + if (n == 0) '(1.0) + else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) } + else '{ ~x * ~powerCode(x, n - 1) } +} diff --git a/tests/run/i4803d.check b/tests/run/i4803d.check new file mode 100644 index 000000000000..e0fd53acdb8e --- /dev/null +++ b/tests/run/i4803d.check @@ -0,0 +1,3 @@ +0.0 +2.25 +12.25 diff --git a/tests/run/i4803d/App_2.scala b/tests/run/i4803d/App_2.scala new file mode 100644 index 000000000000..293882746f0b --- /dev/null +++ b/tests/run/i4803d/App_2.scala @@ -0,0 +1,17 @@ + +object Test { + def main(args: Array[String]): Unit = { + val x1: Double = 0 + val x2: Double = 1.5 + val x3: Double = 3.5 + + println(power2(x1)) + println(power2(x2)) + println(power2(x3)) + } + + transparent def power2(x: Double) = { + transparent def power(x: Double, transparent n: Long) = ~PowerMacro.powerCode('(x), n) + power(x, 2) + } +} diff --git a/tests/run/i4803d/Macro_1.scala b/tests/run/i4803d/Macro_1.scala new file mode 100644 index 000000000000..681f3b2fac63 --- /dev/null +++ b/tests/run/i4803d/Macro_1.scala @@ -0,0 +1,8 @@ +import scala.quoted._ + +object PowerMacro { + def powerCode(x: Expr[Double], n: Long): Expr[Double] = + if (n == 0) '(1.0) + else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) } + else '{ ~x * ~powerCode(x, n - 1) } +} diff --git a/tests/run/i4803e/App_2.scala b/tests/run/i4803e/App_2.scala new file mode 100644 index 000000000000..33fe9a015e03 --- /dev/null +++ b/tests/run/i4803e/App_2.scala @@ -0,0 +1,14 @@ + +object Test { + def main(args: Array[String]): Unit = { + val x1: Double = 0 + val x2: Double = 1.5 + val x3: Double = 3.5 + + println(power2(x1)) + println(power2(x2)) + println(power2(x3)) + } + + transparent def power2(x: Double) = ~PowerMacro.power2('(x)) +} diff --git a/tests/run/i4803e/Macro_1.scala b/tests/run/i4803e/Macro_1.scala new file mode 100644 index 000000000000..13667220618d --- /dev/null +++ b/tests/run/i4803e/Macro_1.scala @@ -0,0 +1,11 @@ +import scala.quoted._ + +object PowerMacro { + def power2(x: Expr[Double]) = '{ + transparent def power(x: Double, n: Long): Double = + if (n == 0) 1.0 + else if (n % 2 == 0) { val y = x * x; power(y, n / 2) } + else x * power(x, n - 1) + power(~x, 2) + } +} diff --git a/tests/run/i4803f/App_2.scala b/tests/run/i4803f/App_2.scala new file mode 100644 index 000000000000..33fe9a015e03 --- /dev/null +++ b/tests/run/i4803f/App_2.scala @@ -0,0 +1,14 @@ + +object Test { + def main(args: Array[String]): Unit = { + val x1: Double = 0 + val x2: Double = 1.5 + val x3: Double = 3.5 + + println(power2(x1)) + println(power2(x2)) + println(power2(x3)) + } + + transparent def power2(x: Double) = ~PowerMacro.power2('(x)) +} diff --git a/tests/run/i4803f/Macro_1.scala b/tests/run/i4803f/Macro_1.scala new file mode 100644 index 000000000000..fe022306989c --- /dev/null +++ b/tests/run/i4803f/Macro_1.scala @@ -0,0 +1,13 @@ +import scala.quoted._ + +object PowerMacro { + def powerCode(x: Expr[Double], n: Long): Expr[Double] = + if (n == 0) '(1.0) + else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) } + else '{ ~x * ~powerCode(x, n - 1) } + + def power2(x: Expr[Double]) = '{ + transparent def power(x: Double): Double = ~powerCode('(x), 2) + power(~x) + } +}