Skip to content

Enable macros in any transparent def #4823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bench/tests/power-macro/PowerMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ object PowerMacro {

def powerCode(n: Long, x: Expr[Double]): Expr[Double] =
if (n == 0) '(1.0)
else if (n % 2 == 0) '{ { val y = ~x * ~x; ~powerCode(n / 2, '(y)) } }
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode(n / 2, '(y)) }
else '{ ~x * ~powerCode(n - 1, x) }

}
39 changes: 22 additions & 17 deletions compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,16 @@ class ReifyQuotes extends MacroTransformWithImplicits {
/** The main transformer class
* @param inQuote we are within a `'(...)` context that is not shadowed by a nested `~(...)`
* @param outer the next outer reifier, null is this is the topmost transformer
* @param level the current level, where quotes add one and splices subtract one level
* @param level the current level, where quotes add one and splices subtract one level.
* The initial level is 0, a level `l` where `l > 0` implies code has been quoted `l` times
* and `l == -1` is code inside a top level splice (in an transparent method).
* @param levels a stacked map from symbols to the levels in which they were defined
* @param embedded a list of embedded quotes (if `inSplice = true`) or splices (if `inQuote = true`
*/
private class Reifier(inQuote: Boolean, val outer: Reifier, val level: Int, levels: LevelInfo,
val embedded: mutable.ListBuffer[Tree]) extends ImplicitsTransformer {
import levels._
assert(level >= 0)
assert(level >= -1)

/** A nested reifier for a quote (if `isQuote = true`) or a splice (if not) */
def nested(isQuote: Boolean): Reifier = {
Expand Down Expand Up @@ -205,7 +207,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
}

/** Enter staging level of symbol defined by `tree`, if applicable. */
def markDef(tree: Tree)(implicit ctx: Context) = tree match {
def markDef(tree: Tree)(implicit ctx: Context): Unit = tree match {
case tree: DefTree =>
val sym = tree.symbol
if ((sym.isClass || !sym.maybeOwner.isType) && !levelOf.contains(sym)) {
Expand All @@ -223,7 +225,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
def levelOK(sym: Symbol)(implicit ctx: Context): Boolean = levelOf.get(sym) match {
case Some(l) =>
l == level ||
l == 1 && level == 0 && isStage0Value(sym)
l == 0 && level == -1 && isStageNegOneValue(sym)
case None =>
!sym.is(Param) || levelOK(sym.owner)
}
Expand All @@ -239,7 +241,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
*/
def tryHeal(tp: Type, pos: Position)(implicit ctx: Context): Option[String] = tp match {
case tp: TypeRef =>
if (level == 0) {
if (level == -1) {
assert(ctx.owner.ownersIterator.exists(_.is(Transparent)))
None
} else {
Expand Down Expand Up @@ -357,7 +359,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
}
else body match {
case body: RefTree if isCaptured(body.symbol, level + 1) =>
if (isStage0Value(body.symbol)) {
if (isStageNegOneValue(body.symbol)) {
// Optimization: avoid the full conversion when capturing inlined `x`
// in '{ x } to '{ x$1.toExpr.unary_~ } and go directly to `x$1.toExpr`
liftInlineParamValue(capturers(body.symbol)(body))
Expand All @@ -368,7 +370,11 @@ class ReifyQuotes extends MacroTransformWithImplicits {
}
case _=>
val (body1, splices) = nested(isQuote = true).split(body)
pickledQuote(body1, splices, body.tpe, isType).withPos(quote.pos)
if (level >= 0) pickledQuote(body1, splices, body.tpe, isType).withPos(quote.pos)
else {
// In top-level splice in an transparent def. Keep the tree as it is, it will be transformed at inline site.
body
}
}
}

Expand Down Expand Up @@ -412,7 +418,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
val body1 = nested(isQuote = false).transform(splice.qualifier)
body1.select(splice.name)
}
else if (!inQuote && level == 0) {
else if (!inQuote && level == 0 && !ctx.owner.is(Transparent)) {
spliceOutsideQuotes(splice.pos)
splice
}
Expand Down Expand Up @@ -458,7 +464,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
val tpw = tree.tpe.widen
val argTpe =
if (tree.isType) defn.QuotedTypeType.appliedTo(tpw)
else if (isStage0Value(tree.symbol)) tpw
else if (isStageNegOneValue(tree.symbol)) tpw
else defn.QuotedExprType.appliedTo(tpw)
val selectArg = arg.select(nme.apply).appliedTo(Literal(Constant(i))).asInstance(argTpe)
val capturedArg = SyntheticValDef(UniqueName.fresh(tree.symbol.name.toTermName).toTermName, selectArg)
Expand Down Expand Up @@ -495,7 +501,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
private def isCaptured(sym: Symbol, level: Int)(implicit ctx: Context): Boolean = {
// Check phase consistency and presence of capturer
( (level == 1 && levelOf.get(sym).contains(1)) ||
(level == 0 && isStage0Value(sym))
(level == 0 && isStageNegOneValue(sym))
) && capturers.contains(sym)
}

Expand Down Expand Up @@ -537,7 +543,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
val capturer = capturers(tree.symbol)
def captureAndSplice(t: Tree) =
splice(t.select(if (tree.isTerm) nme.UNARY_~ else tpnme.UNARY_~))
if (!isStage0Value(tree.symbol)) captureAndSplice(capturer(tree))
if (!isStageNegOneValue(tree.symbol)) captureAndSplice(capturer(tree))
else if (level == 0) capturer(tree)
else captureAndSplice(liftInlineParamValue(capturer(tree)))
case Block(stats, _) =>
Expand All @@ -559,13 +565,12 @@ class ReifyQuotes extends MacroTransformWithImplicits {
case _: Import =>
tree
case tree: DefDef if tree.symbol.is(Macro) && level == 0 =>
if (enclosingInlineds.nonEmpty)
return EmptyTree // Already checked at definition site and already inlined
markDef(tree)
tree.rhs match {
case InlineSplice(_) =>
if (!tree.symbol.isStatic)
ctx.error("Transparent macro method must be a static method.", tree.pos)
markDef(tree)
val reifier = nested(isQuote = true)
reifier.transform(tree) // Ignore output, only check PCP
mapOverTree(enteredSyms) // Ignore output, only check PCP
cpy.DefDef(tree)(rhs = defaultValue(tree.rhs.tpe))
case _ =>
ctx.error(
Expand Down Expand Up @@ -602,7 +607,7 @@ class ReifyQuotes extends MacroTransformWithImplicits {
ref(lifter).select("toExpr".toTermName).appliedTo(tree)
}

private def isStage0Value(sym: Symbol)(implicit ctx: Context): Boolean =
private def isStageNegOneValue(sym: Symbol)(implicit ctx: Context): Boolean =
(sym.is(Transparent) && sym.owner.is(Transparent) && !defn.isFunctionType(sym.info)) ||
sym == defn.TastyTopLevelSplice_tastyContext // intrinsic value at stage 0

Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import scala.quoted._

class Foo {
transparent def foo: Unit = ~Foo.impl // error
transparent def foo: Unit = ~Foo.impl
object Bar {
transparent def foo: Unit = ~Foo.impl // error
transparent def foo: Unit = ~Foo.impl
}
}

object Foo {
class Baz {
transparent def foo: Unit = ~impl // error
transparent def foo: Unit = ~impl
}
object Quox {
transparent def foo: Unit = ~Foo.impl
Expand Down
8 changes: 8 additions & 0 deletions tests/run/i4803.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
1.0
1.5
2.25
7.59375
1.0
1.5
2.25
7.59375
20 changes: 20 additions & 0 deletions tests/run/i4803/App_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@

class Num2(x: Double) {
transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n)
}

object Test {
def main(args: Array[String]): Unit = {
val n = new Num(1.5)
println(n.power(0))
println(n.power(1))
println(n.power(2))
println(n.power(5))

val n2 = new Num2(1.5)
println(n.power(0))
println(n.power(1))
println(n.power(2))
println(n.power(5))
}
}
12 changes: 12 additions & 0 deletions tests/run/i4803/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import scala.quoted._

object PowerMacro {
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
if (n == 0) '(1.0)
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
else '{ ~x * ~powerCode(x, n - 1) }
}

class Num(x: Double) {
transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n)
}
4 changes: 4 additions & 0 deletions tests/run/i4803b.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
1.0
1.5
2.25
7.59375
18 changes: 18 additions & 0 deletions tests/run/i4803b/App_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@


class Nums {
class Num(x: Double) {
transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n)
}
}

object Test {
def main(args: Array[String]): Unit = {
val nums = new Nums
val n = new nums.Num(1.5)
println(n.power(0))
println(n.power(1))
println(n.power(2))
println(n.power(5))
}
}
8 changes: 8 additions & 0 deletions tests/run/i4803b/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import scala.quoted._

object PowerMacro {
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
if (n == 0) '(1.0)
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
else '{ ~x * ~powerCode(x, n - 1) }
}
8 changes: 8 additions & 0 deletions tests/run/i4803c.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
1.0
1.5
2.25
7.59375
1.0
1.5
2.25
7.59375
22 changes: 22 additions & 0 deletions tests/run/i4803c/App_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

object Test {
def main(args: Array[String]): Unit = {
class Num(x: Double) {
transparent def power(transparent n: Long) = ~PowerMacro.powerCode('(x), n)
}
val n = new Num(1.5)
println(n.power(0))
println(n.power(1))
println(n.power(2))
println(n.power(5))

transparent def power(x: Double, transparent n: Long) = ~PowerMacro.powerCode('(x), n)

val x: Double = 1.5

println(power(x, 0))
println(power(x, 1))
println(power(x, 2))
println(power(x, 5))
}
}
8 changes: 8 additions & 0 deletions tests/run/i4803c/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import scala.quoted._

object PowerMacro {
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
if (n == 0) '(1.0)
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
else '{ ~x * ~powerCode(x, n - 1) }
}
3 changes: 3 additions & 0 deletions tests/run/i4803d.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
0.0
2.25
12.25
17 changes: 17 additions & 0 deletions tests/run/i4803d/App_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@

object Test {
def main(args: Array[String]): Unit = {
val x1: Double = 0
val x2: Double = 1.5
val x3: Double = 3.5

println(power2(x1))
println(power2(x2))
println(power2(x3))
}

transparent def power2(x: Double) = {
transparent def power(x: Double, transparent n: Long) = ~PowerMacro.powerCode('(x), n)
power(x, 2)
}
}
8 changes: 8 additions & 0 deletions tests/run/i4803d/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import scala.quoted._

object PowerMacro {
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
if (n == 0) '(1.0)
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
else '{ ~x * ~powerCode(x, n - 1) }
}
14 changes: 14 additions & 0 deletions tests/run/i4803e/App_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

object Test {
def main(args: Array[String]): Unit = {
val x1: Double = 0
val x2: Double = 1.5
val x3: Double = 3.5

println(power2(x1))
println(power2(x2))
println(power2(x3))
}

transparent def power2(x: Double) = ~PowerMacro.power2('(x))
}
11 changes: 11 additions & 0 deletions tests/run/i4803e/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import scala.quoted._

object PowerMacro {
def power2(x: Expr[Double]) = '{
transparent def power(x: Double, n: Long): Double =
if (n == 0) 1.0
else if (n % 2 == 0) { val y = x * x; power(y, n / 2) }
else x * power(x, n - 1)
power(~x, 2)
}
}
14 changes: 14 additions & 0 deletions tests/run/i4803f/App_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@

object Test {
def main(args: Array[String]): Unit = {
val x1: Double = 0
val x2: Double = 1.5
val x3: Double = 3.5

println(power2(x1))
println(power2(x2))
println(power2(x3))
}

transparent def power2(x: Double) = ~PowerMacro.power2('(x))
}
13 changes: 13 additions & 0 deletions tests/run/i4803f/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import scala.quoted._

object PowerMacro {
def powerCode(x: Expr[Double], n: Long): Expr[Double] =
if (n == 0) '(1.0)
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode('(y), n / 2) }
else '{ ~x * ~powerCode(x, n - 1) }

def power2(x: Expr[Double]) = '{
transparent def power(x: Double): Double = ~powerCode('(x), 2)
power(~x)
}
}