Skip to content

Allow top-level splices not directly in the RHS #4826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bench/tests/power-macro/PowerMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ object PowerMacro {

def powerCode(n: Long, x: Expr[Double]): Expr[Double] =
if (n == 0) '(1.0)
else if (n % 2 == 0) '{ { val y = ~x * ~x; ~powerCode(n / 2, '(y)) } }
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode(n / 2, '(y)) }
else '{ ~x * ~powerCode(n - 1, x) }

}
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/tastyreflect/TastyImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,9 @@ class TastyImpl(val rootContext: Contexts.Context) extends scala.tasty.Tasty { s
}

object Inlined extends InlinedExtractor {
def unapply(x: Term)(implicit ctx: Context): Option[(Term, List[Statement], Term)] = x match {
def unapply(x: Term)(implicit ctx: Context): Option[(Option[Term], List[Statement], Term)] = x match {
case x: tpd.Inlined @unchecked =>
Some((x.call, x.bindings, x.expansion))
Some((optional(x.call), x.bindings, x.expansion))
case _ => None
}
}
Expand Down
7 changes: 1 addition & 6 deletions compiler/src/dotty/tools/dotc/transform/PostTyper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,7 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase
// be duplicated
// 2. To enable correct pickling (calls can share symbols with the inlined code, which
// would trigger an assertion when pickling).
// In the case of macros we keep the call to be able to reconstruct the parameters that
// are passed to the macro. This same simplification is applied in ReifiedQuotes when the
// macro splices are evaluated.
val callTrace =
if (call.symbol.is(Macro)) call
else Ident(call.symbol.topLevelClass.typeRef).withPos(call.pos)
val callTrace = Ident(call.symbol.topLevelClass.typeRef).withPos(call.pos)
cpy.Inlined(tree)(callTrace, transformSub(bindings), transform(expansion)(inlineContext(call)))
case tree: Template =>
withNoCheckNews(tree.parents.flatMap(newPart)) {
Expand Down
190 changes: 46 additions & 144 deletions compiler/src/dotty/tools/dotc/transform/ReifyQuotes.scala

Large diffs are not rendered by default.

224 changes: 140 additions & 84 deletions compiler/src/dotty/tools/dotc/transform/Splicer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,25 @@ package dotty.tools.dotc
package transform

import java.io.{PrintWriter, StringWriter}
import java.lang.reflect.Method
import java.lang.reflect.{InvocationTargetException, Method}

import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.ast.Trees._
import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags.Package
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameKinds.FlatName
import dotty.tools.dotc.core.Names.Name
import dotty.tools.dotc.core.StdNames.str.MODULE_INSTANCE_FIELD
import dotty.tools.dotc.core.quoted._
import dotty.tools.dotc.core.Types._
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.TypeErasure
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.tastyreflect.TastyImpl

import scala.util.control.NonFatal
import dotty.tools.dotc.util.Positions.Position
import dotty.tools.dotc.util.SourcePosition

import scala.reflect.ClassTag

Expand All @@ -32,89 +34,52 @@ object Splicer {
*
* See: `ReifyQuotes`
*/
def splice(tree: Tree, call: Tree, bindings: List[Tree], pos: Position, classLoader: ClassLoader)(implicit ctx: Context): Tree = tree match {
def splice(tree: Tree, pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context): Tree = tree match {
case Quoted(quotedTree) => quotedTree
case _ =>
val liftedArgs = getLiftedArgs(call, bindings)
val interpreter = new Interpreter(pos, classLoader)
val interpreted = interpreter.interpretCallToSymbol[Seq[Any] => Object](call.symbol)
val tctx = new TastyImpl(ctx)
evaluateMacro(pos) {
try {
// Some parts of the macro are evaluated during the unpickling performed in quotedExprToTree
val evaluated = interpreted.map(lambda => lambda(tctx :: liftedArgs).asInstanceOf[scala.quoted.Expr[Nothing]])
evaluated.fold(tree)(PickledQuotes.quotedExprToTree)
val interpreted = interpreter.interpret[scala.quoted.Expr[Any]](tree)
interpreted.fold(tree)(x => PickledQuotes.quotedExprToTree(x))
}
catch {
case ex: scala.quoted.QuoteError =>
ctx.error(ex.getMessage, pos)
EmptyTree
case NonFatal(ex) =>
val msg =
s"""Failed to evaluate macro.
| Caused by ${ex.getClass}: ${if (ex.getMessage == null) "" else ex.getMessage}
| ${ex.getStackTrace.takeWhile(_.getClassName != "dotty.tools.dotc.transform.Splicer$").init.mkString("\n ")}
""".stripMargin
ctx.error(msg, pos)
EmptyTree
}
}

/** Given the inline code and bindings, compute the lifted arguments that will be used to execute the macro
* - Type parameters are lifted to quoted.Types.TreeType
* - Inline parameters are listed as their value
* - Other parameters are lifted to quoted.Types.TreeExpr (may reference a binding)
*/
private def getLiftedArgs(call: Tree, bindings: List[Tree])(implicit ctx: Context): List[Any] = {
val bindMap = bindings.collect {
case vdef: ValDef => (vdef.rhs, ref(vdef.symbol).withPos(vdef.rhs.pos))
}.toMap
def allArgs(call: Tree, acc: List[List[Tree]]): List[List[Tree]] = call match {
case call: Apply => allArgs(call.fun, call.args :: acc)
case call: TypeApply => allArgs(call.fun, call.args :: acc)
case _ => acc
}
def liftArgs(tpe: Type, args: List[List[Tree]]): List[Any] = tpe match {
case tp: MethodType =>
val args1 = args.head.zip(tp.paramInfos).map {
case (arg: Literal, tp) if tp.hasAnnotation(defn.TransparentParamAnnot) => arg.const.value
case (arg, tp) =>
assert(!tp.hasAnnotation(defn.TransparentParamAnnot))
// Replace argument by its binding
val arg1 = bindMap.getOrElse(arg, arg)
new scala.quoted.Exprs.TastyTreeExpr(arg1)
}
args1 ::: liftArgs(tp.resType, args.tail)
case tp: PolyType =>
val args1 = args.head.map(tp => new scala.quoted.Types.TreeType(tp))
args1 ::: liftArgs(tp.resType, args.tail)
case _ => Nil
}

liftArgs(call.symbol.info, allArgs(call, Nil))
/** Check that the Tree can be spliced. `~'(xyz)` becomes `xyz`
* and for `~xyz` the tree of `xyz` is interpreted for which the
* resulting expression is returned as a `Tree`
*
* See: `ReifyQuotes`
*/
def canBeSpliced(tree: Tree)(implicit ctx: Context): Boolean = tree match {
case Quoted(_) => true
case _ => (new CanBeInterpreted).apply(tree)
}

/* Evaluate the code in the macro and handle exceptions durring evaluation */
private def evaluateMacro(pos: Position)(code: => Tree)(implicit ctx: Context): Tree = {
try code
catch {
case ex: scala.quoted.QuoteError =>
ctx.error(ex.getMessage, pos)
EmptyTree
case NonFatal(ex) =>
val msg =
s"""Failed to evaluate inlined quote.
| Caused by ${ex.getClass}: ${if (ex.getMessage == null) "" else ex.getMessage}
| ${ex.getStackTrace.takeWhile(_.getClassName != "dotty.tools.dotc.transform.Splicer$").init.mkString("\n ")}
""".stripMargin
ctx.error(msg, pos)
EmptyTree
}
}
/** Tree interpreter that evaluates the tree */
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) extends AbstractInterpreter {

/** Tree interpreter that can interpret calls to static methods with it's default arguments
*
* The interpreter assumes that all calls in the trees are to code that was
* previously compiled and is present in the classpath of the current context.
*/
private class Interpreter(pos: Position, classLoader: ClassLoader)(implicit ctx: Context) {
type Res = Object

/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
* Return Some of the result or None if some error happen during the interpretation.
*/
def interpretCallToSymbol[T](sym: Symbol)(implicit ct: ClassTag[T]): Option[T] = {
def interpret[T](tree: Tree)(implicit ct: ClassTag[T]): Option[T] = {
try {
val (clazz, instance) = loadModule(sym.owner)
val paramClasses = paramsSig(sym)
val interpretedArgs = paramClasses.map(defaultValue)
val method = getMethod(clazz, sym.name, paramClasses)
stopIfRuntimeException(method.invoke(instance, interpretedArgs: _*)) match {
interpretTree(tree)(Map.empty) match {
case obj: T => Some(obj)
case obj =>
// TODO upgrade to a full type tag check or something similar
Expand All @@ -128,6 +93,27 @@ object Splicer {
}
}

protected def interpretQuote(tree: Tree)(implicit env: Env): Object =
new scala.quoted.Exprs.TastyTreeExpr(tree)

protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
new scala.quoted.Types.TreeType(tree)

protected def interpretLiteral(value: Any)(implicit env: Env): Object =
value.asInstanceOf[Object]

protected def interpretTastyContext()(implicit env: Env): Object =
new TastyImpl(ctx)

protected def interpretStaticMethodCall(fn: Tree, args: => List[Object])(implicit env: Env): Object = {
val (clazz, instance) = loadModule(fn.symbol.owner)
val method = getMethod(clazz, fn.symbol.name, paramsSig(fn.symbol))
stopIfRuntimeException(method.invoke(instance, args: _*))
}

protected def unexpectedTree(tree: Tree)(implicit env: Env): Object =
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.pos)

private def loadModule(sym: Symbol): (Class[_], Object) = {
if (sym.owner.is(Package)) {
// is top level object
Expand Down Expand Up @@ -172,6 +158,15 @@ object Splicer {
ex.printStackTrace(new PrintWriter(sw))
sw.write("\n")
throw new StopInterpretation(sw.toString, pos)
case ex: InvocationTargetException =>
val sw = new StringWriter()
sw.write("An exception occurred while executing macro expansion\n")
sw.write(ex.getTargetException.getMessage)
sw.write("\n")
ex.getTargetException.printStackTrace(new PrintWriter(sw))
sw.write("\n")
throw new StopInterpretation(sw.toString, pos)

}
}

Expand Down Expand Up @@ -223,22 +218,83 @@ object Splicer {
}
}

/** Get the default value for the given class */
private def defaultValue(clazz: Class[_]): Object = {
if (clazz == classOf[Boolean]) false.asInstanceOf[Object]
else if (clazz == classOf[Byte]) 0.toByte.asInstanceOf[Object]
else if (clazz == classOf[Char]) 0.toChar.asInstanceOf[Object]
else if (clazz == classOf[Short]) 0.asInstanceOf[Object]
else if (clazz == classOf[Int]) 0.asInstanceOf[Object]
else if (clazz == classOf[Long]) 0L.asInstanceOf[Object]
else if (clazz == classOf[Float]) 0f.asInstanceOf[Object]
else if (clazz == classOf[Double]) 0d.asInstanceOf[Object]
else null
/** Exception that stops interpretation if some issue is found */
private class StopInterpretation(val msg: String, val pos: SourcePosition) extends Exception

}

/** Tree interpreter that tests if tree can be interpreted */
private class CanBeInterpreted(implicit ctx: Context) extends AbstractInterpreter {

type Res = Boolean

def apply(tree: Tree): Boolean = interpretTree(tree)(Map.empty)

def interpretQuote(tree: tpd.Tree)(implicit env: Env): Boolean = true
def interpretTypeQuote(tree: tpd.Tree)(implicit env: Env): Boolean = true
def interpretLiteral(value: Any)(implicit env: Env): Boolean = true
def interpretTastyContext()(implicit env: Env): Boolean = true
def interpretStaticMethodCall(fn: tpd.Tree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)

def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Boolean = {
// Assuming that top-level splices can only be in transparent methods
// and splices are expanded at inline site, references to transparent values
// will be know literal constant trees.
tree.symbol.is(Transparent)
}
}

/** Exception that stops interpretation if some issue is found */
private class StopInterpretation(val msg: String, val pos: Position) extends Exception
/** Abstract Tree interpreter that can interpret calls to static methods with quoted or transparent arguments */
private abstract class AbstractInterpreter(implicit ctx: Context) {
type Env = Map[Name, Res]
type Res

protected def interpretQuote(tree: Tree)(implicit env: Env): Res
protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Res
protected def interpretLiteral(value: Any)(implicit env: Env): Res
protected def interpretTastyContext()(implicit env: Env): Res
protected def interpretStaticMethodCall(fn: Tree, args: => List[Res])(implicit env: Env): Res
protected def unexpectedTree(tree: Tree)(implicit env: Env): Res

protected final def interpretTree(tree: Tree)(implicit env: Env): Res = tree match {
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.QuotedExpr_apply =>
interpretQuote(quoted)

case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.QuotedType_apply =>
interpretTypeQuote(quoted)

case Literal(Constant(value)) =>
interpretLiteral(value)

case _ if tree.symbol == defn.TastyTopLevelSplice_tastyContext =>
interpretTastyContext()

case StaticMethodCall(fn, args) =>
interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))

// Interpret `foo(j = x, i = y)` which it is expanded to
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
case Block(stats, expr) =>
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
case stat: ValDef if stat.symbol.is(Synthetic) =>
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
case stat => return unexpectedTree(stat)
})
interpretTree(expr)(newEnv)
case NamedArg(_, arg) => interpretTree(arg)
case Ident(name) if env.contains(name) => env(name)

case _ => unexpectedTree(tree)
}

object StaticMethodCall {
def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match {
case fn: RefTree if fn.symbol.isStatic => Some((fn, Nil))
case Apply(StaticMethodCall(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
case TypeApply(StaticMethodCall(fn, args), _) => Some((fn, args))
case _ => None
}
}
}

}
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/EtaExpansion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ abstract class Lifter {
// don't instantiate here, as the type params could be further constrained, see tests/pos/pickleinf.scala
var liftedType = expr.tpe.widen
if (liftedFlags.is(Method)) liftedType = ExprType(liftedType)
val lifted = ctx.newSymbol(ctx.owner, name, liftedFlags, liftedType, coord = positionCoord(expr.pos))
val lifted = ctx.newSymbol(ctx.owner, name, liftedFlags | Synthetic, liftedType, coord = positionCoord(expr.pos))
defs += liftedDef(lifted, expr).withPos(expr.pos)
ref(lifted.termRef).withPos(expr.pos.focus)
}
Expand Down
2 changes: 0 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -952,8 +952,6 @@ class RefChecks extends MiniPhase { thisPhase =>

override def transformDefDef(tree: DefDef)(implicit ctx: Context) = {
checkDeprecatedOvers(tree)
if (tree.symbol.is(Macro))
tree.symbol.resetFlag(Macro)
tree
}

Expand Down
5 changes: 3 additions & 2 deletions docs/docs/reference/principled-meta-programming.md
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ statically known exponent:
private def powerCode(n: Int, x: Expr[Double]): Expr[Double] =
if (n == 0) '(1.0)
else if (n == 1) x
else if (n % 2 == 0) '{ { val y = ~x * ~x; ~powerCode(n / 2, '(y)) } }
else if (n % 2 == 0) '{ val y = ~x * ~x; ~powerCode(n / 2, '(y)) }
else '{ ~x * ~powerCode(n - 1, x) }

The reference to `n` as an argument in `~powerCode(n, '(x))` is not
Expand Down Expand Up @@ -436,7 +436,8 @@ we currently impose the following restrictions on the use of splices.
1. A top-level splice must appear in a transparent function (turning that function
into a macro)

2. The splice must call a previously compiled method.
2. The splice must call a previously compiled (previous to the call of the transparent definition)
static method passing quoted arguments, constant arguments or transparent arguments.

3. Splices inside splices (but no intervening quotes) are not allowed.

Expand Down
2 changes: 1 addition & 1 deletion library/src/scala/tasty/Tasty.scala
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ abstract class Tasty { tasty =>

val Inlined: InlinedExtractor
abstract class InlinedExtractor {
def unapply(x: Term)(implicit ctx: Context): Option[(Term, List[Definition], Term)]
def unapply(x: Term)(implicit ctx: Context): Option[(Option[Term], List[Definition], Term)]
}

val SelectOuter: SelectOuterExtractor
Expand Down
5 changes: 3 additions & 2 deletions tests/neg-with-compiler/quote-run-in-macro-1/quoted_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import scala.quoted._
import dotty.tools.dotc.quoted.Toolbox._

object Macros {
transparent def foo(i: => Int): Int = ~{
val y: Int = ('(i)).run
transparent def foo(i: => Int): Int = ~fooImpl('(i))
def fooImpl(i: Expr[Int]): Expr[Int] = {
val y: Int = i.run
y.toExpr
}
}
5 changes: 3 additions & 2 deletions tests/neg-with-compiler/quote-run-in-macro-2/quoted_1.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ import scala.quoted._
import dotty.tools.dotc.quoted.Toolbox._

object Macros {
transparent def foo(i: => Int): Int = ~{
val y: Int = ('(i)).run
transparent def foo(i: => Int): Int = ~fooImpl('(i))
def fooImpl(i: Expr[Int]): Expr[Int] = {
val y: Int = i.run
y.toExpr
}
}
Loading