Skip to content

Split macro body check from interpreter #6831

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
340 changes: 174 additions & 166 deletions compiler/src/dotty/tools/dotc/transform/Splicer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,85 @@ object Splicer {
*/
def checkValidMacroBody(tree: Tree)(implicit ctx: Context): Unit = tree match {
case Quoted(_) => // ok
case _ => (new CheckValidMacroBody).apply(tree)
case _ =>
def checkValidStat(tree: Tree): Unit = tree match {
case tree: ValDef if tree.symbol.is(Synthetic) =>
// Check val from `foo(j = x, i = y)` which it is expanded to
// `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
checkIfValidArgument(tree.rhs)
case _ =>
ctx.error("Macro should not have statements", tree.sourcePos)
}
def checkIfValidArgument(tree: Tree): Unit = tree match {
case Block(Nil, expr) => checkIfValidArgument(expr)
case Typed(expr, _) => checkIfValidArgument(expr)

case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
// OK

case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
// OK

case Literal(Constant(value)) =>
// OK

case _ if tree.symbol == defn.QuoteContext_macroContext =>
// OK

case Call(fn, args)
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
fn.symbol.is(Module) || fn.symbol.isStatic ||
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
args.foreach(_.foreach(checkIfValidArgument))

case NamedArg(_, arg) =>
checkIfValidArgument(arg)

case SeqLiteral(elems, _) =>
elems.foreach(checkIfValidArgument)

case tree: Ident if tree.symbol.is(Inline) || tree.symbol.is(Synthetic) =>
// OK

case _ =>
ctx.error(
"""Malformed macro parameter
|
|Parameters may be:
| * Quoted parameters or fields
| * References to inline parameters
| * Literal values of primitive types
|""".stripMargin, tree.sourcePos)
}
def checkIfValidStaticCall(tree: Tree): Unit = tree match {
case Block(stats, expr) =>
stats.foreach(checkValidStat)
checkIfValidStaticCall(expr)

case Typed(expr, _) =>
checkIfValidStaticCall(expr)

case Call(fn, args)
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
fn.symbol.is(Module) || fn.symbol.isStatic ||
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
args.flatten.foreach(checkIfValidArgument)

case _ =>
ctx.error(
"""Malformed macro.
|
|Expected the splice ${...} to contain a single call to a static method.
|""".stripMargin, tree.sourcePos)
}

checkIfValidStaticCall(tree)
}

/** Tree interpreter that evaluates the tree */
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) extends AbstractInterpreter {
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) {

def checking: Boolean = false

type Result = Object
type Env = Map[Name, Object]

/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
* Return Some of the result or None if some error happen during the interpretation.
Expand All @@ -93,22 +163,92 @@ object Splicer {
}
}

protected def interpretQuote(tree: Tree)(implicit env: Env): Object =
def interpretTree(tree: Tree)(implicit env: Env): Object = tree match {
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
val quoted1 = quoted match {
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
// inline proxy for by-name parameter
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
case Inlined(EmptyTree, _, quoted) => quoted
case _ => quoted
}
interpretQuote(quoted1)

case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
interpretTypeQuote(quoted)

case Literal(Constant(value)) =>
interpretLiteral(value)

case _ if tree.symbol == defn.QuoteContext_macroContext =>
interpretQuoteContext()

// TODO disallow interpreted method calls as arguments
case Call(fn, args) =>
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
interpretNew(fn.symbol, args.flatten.map(interpretTree))
} else if (fn.symbol.is(Module)) {
interpretModuleAccess(fn.symbol)
} else if (fn.symbol.isStatic) {
val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol)
staticMethodCall(args.flatten.map(interpretTree))
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol)
staticMethodCall(args.flatten.map(interpretTree))
} else if (env.contains(fn.name)) {
env(fn.name)
} else if (tree.symbol.is(InlineProxy)) {
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
} else {
unexpectedTree(tree)
}

// Interpret `foo(j = x, i = y)` which it is expanded to
// `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
case Block(stats, expr) => interpretBlock(stats, expr)
case NamedArg(_, arg) => interpretTree(arg)

case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)

case Typed(expr, _) =>
interpretTree(expr)

case SeqLiteral(elems, _) =>
interpretVarargs(elems.map(e => interpretTree(e)))

case _ =>
unexpectedTree(tree)
}

private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
var unexpected: Option[Object] = None
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
case stat: ValDef =>
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
case stat =>
if (unexpected.isEmpty)
unexpected = Some(unexpectedTree(stat))
accEnv
})
unexpected.getOrElse(interpretTree(expr)(newEnv))
}

private def interpretQuote(tree: Tree)(implicit env: Env): Object =
new scala.internal.quoted.TastyTreeExpr(Inlined(EmptyTree, Nil, tree).withSpan(tree.span))

protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
private def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
new scala.internal.quoted.TreeType(tree)

protected def interpretLiteral(value: Any)(implicit env: Env): Object =
private def interpretLiteral(value: Any)(implicit env: Env): Object =
value.asInstanceOf[Object]

protected def interpretVarargs(args: List[Object])(implicit env: Env): Object =
private def interpretVarargs(args: List[Object])(implicit env: Env): Object =
args.toSeq

protected def interpretQuoteContext()(implicit env: Env): Object =
private def interpretQuoteContext()(implicit env: Env): Object =
new scala.quoted.QuoteContext(ReflectionImpl(ctx, pos))

protected def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = {
val (inst, clazz) =
if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) {
(null, loadReplLineClass(moduleClass))
Expand All @@ -125,19 +265,20 @@ object Splicer {

val name = getDirectName(fn.info.finalResultType, fn.name.asTermName)
val method = getMethod(clazz, name, paramsSig(fn))
stopIfRuntimeException(method.invoke(inst, args: _*))

(args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*))
}

protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
loadModule(fn.moduleClass)

protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Object = {
private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
val clazz = loadClass(fn.owner.fullName.toString)
val constr = clazz.getConstructor(paramsSig(fn): _*)
constr.newInstance(args: _*).asInstanceOf[Object]
}

protected def unexpectedTree(tree: Tree)(implicit env: Env): Object =
private def unexpectedTree(tree: Tree)(implicit env: Env): Object =
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.sourcePos)

private def loadModule(sym: Symbol): Object = {
Expand Down Expand Up @@ -265,158 +406,25 @@ object Splicer {

}

/** Tree interpreter that tests if tree can be interpreted */
private class CheckValidMacroBody(implicit ctx: Context) extends AbstractInterpreter {
def checking: Boolean = true

type Result = Unit

def apply(tree: Tree): Unit = interpretTree(tree)(Map.empty)

protected def interpretQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
protected def interpretTypeQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
protected def interpretLiteral(value: Any)(implicit env: Env): Unit = ()
protected def interpretVarargs(args: List[Unit])(implicit env: Env): Unit = ()
protected def interpretQuoteContext()(implicit env: Env): Unit = ()
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Unit = ()
protected def interpretNew(fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)

def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Unit = {
// Assuming that top-level splices can only be in inline methods
// and splices are expanded at inline site, references to inline values
// will be known literal constant trees.
if (!tree.symbol.is(Inline))
ctx.error(
"""Malformed macro.
|
|Expected the splice ${...} to contain a single call to a static method.
|
|Where parameters may be:
| * Quoted paramers or fields
| * References to inline parameters
| * Literal values of primitive types
""".stripMargin, tree.sourcePos)
}
}

/** Abstract Tree interpreter that can interpret calls to static methods with quoted or inline arguments */
private abstract class AbstractInterpreter(implicit ctx: Context) {

def checking: Boolean

type Env = Map[Name, Result]
type Result

protected def interpretQuote(tree: Tree)(implicit env: Env): Result
protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Result
protected def interpretLiteral(value: Any)(implicit env: Env): Result
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
protected def interpretQuoteContext()(implicit env: Env): Result
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Result])(implicit env: Env): Result
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result

private final def removeErasedArguments(args: List[List[Tree]], fnTpe: Type): List[List[Tree]] =
fnTpe match {
case tp: TermRef => removeErasedArguments(args, tp.underlying)
case tp: PolyType => removeErasedArguments(args, tp.resType)
case tp: ExprType => removeErasedArguments(args, tp.resType)
case tp: MethodType =>
val tail = removeErasedArguments(args.tail, tp.resType)
if (tp.isErasedMethod) tail else args.head :: tail
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
val tail = removeErasedArguments(args.tail, tp.args.last)
if (defn.isErasedFunctionType(tp)) tail else args.head :: tail
case tp => assert(args.isEmpty, tp); Nil
}

protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match {
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
val quoted1 = quoted match {
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
// inline proxy for by-name parameter
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
case Inlined(EmptyTree, _, quoted) => quoted
case _ => quoted
}
interpretQuote(quoted1)

case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
interpretTypeQuote(quoted)

case Literal(Constant(value)) =>
interpretLiteral(value)

case _ if tree.symbol == defn.QuoteContext_macroContext =>
interpretQuoteContext()

case Call(fn, args) =>
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
interpretNew(fn.symbol, args.flatten.map(interpretTree))
} else if (fn.symbol.is(Module)) {
interpretModuleAccess(fn.symbol)
} else if (fn.symbol.isStatic) {
val module = fn.symbol.owner
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
val module = fn.qualifier.symbol.moduleClass
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
} else if (env.contains(fn.name)) {
env(fn.name)
} else if (tree.symbol.is(InlineProxy)) {
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
} else {
unexpectedTree(tree)
}

// Interpret `foo(j = x, i = y)` which it is expanded to
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
case Block(stats, expr) => interpretBlock(stats, expr)
case NamedArg(_, arg) => interpretTree(arg)

case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)

case Typed(expr, _) =>
interpretTree(expr)

case SeqLiteral(elems, _) =>
interpretVarargs(elems.map(e => interpretTree(e)))

case _ =>
unexpectedTree(tree)
}

private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
var unexpected: Option[Result] = None
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
case stat: ValDef if stat.symbol.is(Synthetic) || !checking =>
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
case stat =>
if (unexpected.isEmpty)
unexpected = Some(unexpectedTree(stat))
accEnv
})
unexpected.getOrElse(interpretTree(expr)(newEnv))
}

object Call {
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] =
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))

object Call0 {
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = arg match {
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
Some((fn, args))
case fn: RefTree => Some((fn, Nil))
case Apply(Call0(fn, args1), args2) => Some((fn, args2 :: args1))
case TypeApply(Call0(fn, args), _) => Some((fn, args))
case _ => None
}
object Call {
/** Matches an expression that is either a field access or an application
* It retruns a TermRef containing field accessed or a method reference and the arguments passed to it.
*/
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] =
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))

private object Call0 {
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] = arg match {
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
Some((fn, args))
case fn: RefTree => Some((fn, Nil))
case Apply(f @ Call0(fn, args1), args2) =>
if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1))
else Some((fn, args2 :: args1))
case TypeApply(Call0(fn, args), _) => Some((fn, args))
case _ => None
}
}
}

}
Loading