diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index aecb5c0446ee..f37e1d9d1d31 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -64,15 +64,85 @@ object Splicer { */ def checkValidMacroBody(tree: Tree)(implicit ctx: Context): Unit = tree match { case Quoted(_) => // ok - case _ => (new CheckValidMacroBody).apply(tree) + case _ => + def checkValidStat(tree: Tree): Unit = tree match { + case tree: ValDef if tree.symbol.is(Synthetic) => + // Check val from `foo(j = x, i = y)` which it is expanded to + // `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)` + checkIfValidArgument(tree.rhs) + case _ => + ctx.error("Macro should not have statements", tree.sourcePos) + } + def checkIfValidArgument(tree: Tree): Unit = tree match { + case Block(Nil, expr) => checkIfValidArgument(expr) + case Typed(expr, _) => checkIfValidArgument(expr) + + case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote => + // OK + + case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote => + // OK + + case Literal(Constant(value)) => + // OK + + case _ if tree.symbol == defn.QuoteContext_macroContext => + // OK + + case Call(fn, args) + if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) || + fn.symbol.is(Module) || fn.symbol.isStatic || + (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) => + args.foreach(_.foreach(checkIfValidArgument)) + + case NamedArg(_, arg) => + checkIfValidArgument(arg) + + case SeqLiteral(elems, _) => + elems.foreach(checkIfValidArgument) + + case tree: Ident if tree.symbol.is(Inline) || tree.symbol.is(Synthetic) => + // OK + + case _ => + ctx.error( + """Malformed macro parameter + | + |Parameters may be: + | * Quoted parameters or fields + | * References to inline parameters + | * Literal values of primitive types + |""".stripMargin, tree.sourcePos) + } + def checkIfValidStaticCall(tree: Tree): Unit = tree match { + case Block(stats, expr) => + stats.foreach(checkValidStat) + checkIfValidStaticCall(expr) + + case Typed(expr, _) => + checkIfValidStaticCall(expr) + + case Call(fn, args) + if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) || + fn.symbol.is(Module) || fn.symbol.isStatic || + (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) => + args.flatten.foreach(checkIfValidArgument) + + case _ => + ctx.error( + """Malformed macro. + | + |Expected the splice ${...} to contain a single call to a static method. + |""".stripMargin, tree.sourcePos) + } + + checkIfValidStaticCall(tree) } /** Tree interpreter that evaluates the tree */ - private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) extends AbstractInterpreter { + private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) { - def checking: Boolean = false - - type Result = Object + type Env = Map[Name, Object] /** Returns the interpreted result of interpreting the code a call to the symbol with default arguments. * Return Some of the result or None if some error happen during the interpretation. @@ -93,22 +163,92 @@ object Splicer { } } - protected def interpretQuote(tree: Tree)(implicit env: Env): Object = + def interpretTree(tree: Tree)(implicit env: Env): Object = tree match { + case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote => + val quoted1 = quoted match { + case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) => + // inline proxy for by-name parameter + quoted.symbol.defTree.asInstanceOf[DefDef].rhs + case Inlined(EmptyTree, _, quoted) => quoted + case _ => quoted + } + interpretQuote(quoted1) + + case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote => + interpretTypeQuote(quoted) + + case Literal(Constant(value)) => + interpretLiteral(value) + + case _ if tree.symbol == defn.QuoteContext_macroContext => + interpretQuoteContext() + + // TODO disallow interpreted method calls as arguments + case Call(fn, args) => + if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) { + interpretNew(fn.symbol, args.flatten.map(interpretTree)) + } else if (fn.symbol.is(Module)) { + interpretModuleAccess(fn.symbol) + } else if (fn.symbol.isStatic) { + val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol) + staticMethodCall(args.flatten.map(interpretTree)) + } else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) { + val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol) + staticMethodCall(args.flatten.map(interpretTree)) + } else if (env.contains(fn.name)) { + env(fn.name) + } else if (tree.symbol.is(InlineProxy)) { + interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs) + } else { + unexpectedTree(tree) + } + + // Interpret `foo(j = x, i = y)` which it is expanded to + // `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)` + case Block(stats, expr) => interpretBlock(stats, expr) + case NamedArg(_, arg) => interpretTree(arg) + + case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion) + + case Typed(expr, _) => + interpretTree(expr) + + case SeqLiteral(elems, _) => + interpretVarargs(elems.map(e => interpretTree(e))) + + case _ => + unexpectedTree(tree) + } + + private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = { + var unexpected: Option[Object] = None + val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match { + case stat: ValDef => + accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv)) + case stat => + if (unexpected.isEmpty) + unexpected = Some(unexpectedTree(stat)) + accEnv + }) + unexpected.getOrElse(interpretTree(expr)(newEnv)) + } + + private def interpretQuote(tree: Tree)(implicit env: Env): Object = new scala.internal.quoted.TastyTreeExpr(Inlined(EmptyTree, Nil, tree).withSpan(tree.span)) - protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Object = + private def interpretTypeQuote(tree: Tree)(implicit env: Env): Object = new scala.internal.quoted.TreeType(tree) - protected def interpretLiteral(value: Any)(implicit env: Env): Object = + private def interpretLiteral(value: Any)(implicit env: Env): Object = value.asInstanceOf[Object] - protected def interpretVarargs(args: List[Object])(implicit env: Env): Object = + private def interpretVarargs(args: List[Object])(implicit env: Env): Object = args.toSeq - protected def interpretQuoteContext()(implicit env: Env): Object = + private def interpretQuoteContext()(implicit env: Env): Object = new scala.quoted.QuoteContext(ReflectionImpl(ctx, pos)) - protected def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = { + private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = { val (inst, clazz) = if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) { (null, loadReplLineClass(moduleClass)) @@ -125,19 +265,20 @@ object Splicer { val name = getDirectName(fn.info.finalResultType, fn.name.asTermName) val method = getMethod(clazz, name, paramsSig(fn)) - stopIfRuntimeException(method.invoke(inst, args: _*)) + + (args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*)) } - protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object = + private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object = loadModule(fn.moduleClass) - protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Object = { + private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = { val clazz = loadClass(fn.owner.fullName.toString) val constr = clazz.getConstructor(paramsSig(fn): _*) constr.newInstance(args: _*).asInstanceOf[Object] } - protected def unexpectedTree(tree: Tree)(implicit env: Env): Object = + private def unexpectedTree(tree: Tree)(implicit env: Env): Object = throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.sourcePos) private def loadModule(sym: Symbol): Object = { @@ -265,158 +406,25 @@ object Splicer { } - /** Tree interpreter that tests if tree can be interpreted */ - private class CheckValidMacroBody(implicit ctx: Context) extends AbstractInterpreter { - def checking: Boolean = true - - type Result = Unit - - def apply(tree: Tree): Unit = interpretTree(tree)(Map.empty) - - protected def interpretQuote(tree: tpd.Tree)(implicit env: Env): Unit = () - protected def interpretTypeQuote(tree: tpd.Tree)(implicit env: Env): Unit = () - protected def interpretLiteral(value: Any)(implicit env: Env): Unit = () - protected def interpretVarargs(args: List[Unit])(implicit env: Env): Unit = () - protected def interpretQuoteContext()(implicit env: Env): Unit = () - protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity) - protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Unit = () - protected def interpretNew(fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity) - - def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Unit = { - // Assuming that top-level splices can only be in inline methods - // and splices are expanded at inline site, references to inline values - // will be known literal constant trees. - if (!tree.symbol.is(Inline)) - ctx.error( - """Malformed macro. - | - |Expected the splice ${...} to contain a single call to a static method. - | - |Where parameters may be: - | * Quoted paramers or fields - | * References to inline parameters - | * Literal values of primitive types - """.stripMargin, tree.sourcePos) - } - } - - /** Abstract Tree interpreter that can interpret calls to static methods with quoted or inline arguments */ - private abstract class AbstractInterpreter(implicit ctx: Context) { - - def checking: Boolean - - type Env = Map[Name, Result] - type Result - - protected def interpretQuote(tree: Tree)(implicit env: Env): Result - protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Result - protected def interpretLiteral(value: Any)(implicit env: Env): Result - protected def interpretVarargs(args: List[Result])(implicit env: Env): Result - protected def interpretQuoteContext()(implicit env: Env): Result - protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Result])(implicit env: Env): Result - protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result - protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result - protected def unexpectedTree(tree: Tree)(implicit env: Env): Result - - private final def removeErasedArguments(args: List[List[Tree]], fnTpe: Type): List[List[Tree]] = - fnTpe match { - case tp: TermRef => removeErasedArguments(args, tp.underlying) - case tp: PolyType => removeErasedArguments(args, tp.resType) - case tp: ExprType => removeErasedArguments(args, tp.resType) - case tp: MethodType => - val tail = removeErasedArguments(args.tail, tp.resType) - if (tp.isErasedMethod) tail else args.head :: tail - case tp: AppliedType if defn.isImplicitFunctionType(tp) => - val tail = removeErasedArguments(args.tail, tp.args.last) - if (defn.isErasedFunctionType(tp)) tail else args.head :: tail - case tp => assert(args.isEmpty, tp); Nil - } - - protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match { - case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote => - val quoted1 = quoted match { - case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) => - // inline proxy for by-name parameter - quoted.symbol.defTree.asInstanceOf[DefDef].rhs - case Inlined(EmptyTree, _, quoted) => quoted - case _ => quoted - } - interpretQuote(quoted1) - - case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote => - interpretTypeQuote(quoted) - - case Literal(Constant(value)) => - interpretLiteral(value) - - case _ if tree.symbol == defn.QuoteContext_macroContext => - interpretQuoteContext() - - case Call(fn, args) => - if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) { - interpretNew(fn.symbol, args.flatten.map(interpretTree)) - } else if (fn.symbol.is(Module)) { - interpretModuleAccess(fn.symbol) - } else if (fn.symbol.isStatic) { - val module = fn.symbol.owner - def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree) - interpretStaticMethodCall(module, fn.symbol, interpretedArgs) - } else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) { - val module = fn.qualifier.symbol.moduleClass - def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree) - interpretStaticMethodCall(module, fn.symbol, interpretedArgs) - } else if (env.contains(fn.name)) { - env(fn.name) - } else if (tree.symbol.is(InlineProxy)) { - interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs) - } else { - unexpectedTree(tree) - } - - // Interpret `foo(j = x, i = y)` which it is expanded to - // `val j$1 = x; val i$1 = y; foo(i = y, j = x)` - case Block(stats, expr) => interpretBlock(stats, expr) - case NamedArg(_, arg) => interpretTree(arg) - - case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion) - - case Typed(expr, _) => - interpretTree(expr) - - case SeqLiteral(elems, _) => - interpretVarargs(elems.map(e => interpretTree(e))) - - case _ => - unexpectedTree(tree) - } - - private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = { - var unexpected: Option[Result] = None - val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match { - case stat: ValDef if stat.symbol.is(Synthetic) || !checking => - accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv)) - case stat => - if (unexpected.isEmpty) - unexpected = Some(unexpectedTree(stat)) - accEnv - }) - unexpected.getOrElse(interpretTree(expr)(newEnv)) - } - - object Call { - def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = - Call0.unapply(arg).map((fn, args) => (fn, args.reverse)) - - object Call0 { - def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = arg match { - case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) => - Some((fn, args)) - case fn: RefTree => Some((fn, Nil)) - case Apply(Call0(fn, args1), args2) => Some((fn, args2 :: args1)) - case TypeApply(Call0(fn, args), _) => Some((fn, args)) - case _ => None - } + object Call { + /** Matches an expression that is either a field access or an application + * It retruns a TermRef containing field accessed or a method reference and the arguments passed to it. + */ + def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] = + Call0.unapply(arg).map((fn, args) => (fn, args.reverse)) + + private object Call0 { + def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] = arg match { + case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) => + Some((fn, args)) + case fn: RefTree => Some((fn, Nil)) + case Apply(f @ Call0(fn, args1), args2) => + if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1)) + else Some((fn, args2 :: args1)) + case TypeApply(Call0(fn, args), _) => Some((fn, args)) + case _ => None } } } + } diff --git a/tests/neg-macros/quote-complex-top-splice.scala b/tests/neg-macros/quote-complex-top-splice.scala index f561f1d086b9..94d816ed90f8 100644 --- a/tests/neg-macros/quote-complex-top-splice.scala +++ b/tests/neg-macros/quote-complex-top-splice.scala @@ -6,18 +6,22 @@ object Test { inline def foo1: Unit = ${ val x = 1 // error - impl(x) + impl(x) // error } - inline def foo2: Unit = ${ impl({ - val x = 1 // error - x - }) } - - inline def foo3: Unit = ${ impl({ - println("foo3") // error - 3 - }) } + inline def foo2: Unit = ${ impl( + { // error + val x = 1 + x + } + ) } + + inline def foo3: Unit = ${ impl( + { // error + println("foo3") + 3 + } + ) } inline def foo4: Unit = ${ println("foo4") // error