Skip to content

Commit 73bf699

Browse files
committed
Split macro body checks from interpreter
1 parent 15b9cd8 commit 73bf699

File tree

2 files changed

+180
-175
lines changed

2 files changed

+180
-175
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 166 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,82 @@ object Splicer {
6464
*/
6565
def checkValidMacroBody(tree: Tree)(implicit ctx: Context): Unit = tree match {
6666
case Quoted(_) => // ok
67-
case _ => (new CheckValidMacroBody).apply(tree)
67+
case _ =>
68+
def checkValidStat(tree: Tree): Unit = tree match {
69+
case tree: ValDef if tree.symbol.is(Synthetic) =>
70+
// Check val from `foo(j = x, i = y)` which it is expanded to
71+
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
72+
checkIfValidArgument(tree.rhs)
73+
case _ =>
74+
ctx.error("Macro should not have statements", tree.sourcePos)
75+
}
76+
def checkIfValidArgument(tree: Tree): Unit = tree match {
77+
case Block(Nil, expr) => checkIfValidArgument(expr)
78+
case Typed(expr, _) => checkIfValidArgument(expr)
79+
80+
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
81+
// OK
82+
83+
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
84+
// OK
85+
86+
case Literal(Constant(value)) =>
87+
// OK
88+
89+
case _ if tree.symbol == defn.QuoteContext_macroContext =>
90+
// OK
91+
92+
case Call(fn, args)
93+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
94+
fn.symbol.is(Module) || fn.symbol.isStatic ||
95+
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
96+
args.foreach(_.foreach(checkIfValidArgument))
97+
98+
case NamedArg(_, arg) =>
99+
checkIfValidArgument(arg)
100+
101+
case SeqLiteral(elems, _) =>
102+
elems.foreach(checkIfValidArgument)
103+
104+
case tree: Ident if tree.symbol.is(Inline) || tree.symbol.is(Synthetic) =>
105+
// OK
106+
107+
case _ =>
108+
ctx.error(
109+
"""Malformed macro parameter
110+
|
111+
|Parameters may be:
112+
| * Quoted parameters or fields
113+
| * References to inline parameters
114+
| * Literal values of primitive types
115+
|""".stripMargin, tree.sourcePos)
116+
}
117+
def checkIfValidStaticCall(tree: Tree): Unit = tree match {
118+
case Block(stats, expr) =>
119+
stats.foreach(checkValidStat)
120+
checkIfValidStaticCall(expr)
121+
122+
case Typed(expr, _) => checkIfValidStaticCall(expr)
123+
case Call(fn, args)
124+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
125+
fn.symbol.is(Module) || fn.symbol.isStatic ||
126+
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
127+
args.flatten.foreach(checkIfValidArgument)
128+
case _ =>
129+
ctx.error(
130+
"""Malformed macro.
131+
|
132+
|Expected the splice ${...} to contain a single call to a static method.
133+
|""".stripMargin, tree.sourcePos)
134+
}
135+
136+
checkIfValidStaticCall(tree)
68137
}
69138

70139
/** Tree interpreter that evaluates the tree */
71-
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) extends AbstractInterpreter {
140+
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) {
72141

73-
def checking: Boolean = false
74-
75-
type Result = Object
142+
type Env = Map[Name, Object]
76143

77144
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
78145
* Return Some of the result or None if some error happen during the interpretation.
@@ -93,22 +160,92 @@ object Splicer {
93160
}
94161
}
95162

96-
protected def interpretQuote(tree: Tree)(implicit env: Env): Object =
163+
def interpretTree(tree: Tree)(implicit env: Env): Object = tree match {
164+
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
165+
val quoted1 = quoted match {
166+
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
167+
// inline proxy for by-name parameter
168+
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
169+
case Inlined(EmptyTree, _, quoted) => quoted
170+
case _ => quoted
171+
}
172+
interpretQuote(quoted1)
173+
174+
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
175+
interpretTypeQuote(quoted)
176+
177+
case Literal(Constant(value)) =>
178+
interpretLiteral(value)
179+
180+
case _ if tree.symbol == defn.QuoteContext_macroContext =>
181+
interpretQuoteContext()
182+
183+
// TODO disallow interpreted method calls as arguments
184+
case Call(fn, args) =>
185+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
186+
interpretNew(fn.symbol, args.flatten.map(interpretTree))
187+
} else if (fn.symbol.is(Module)) {
188+
interpretModuleAccess(fn.symbol)
189+
} else if (fn.symbol.isStatic) {
190+
val module = fn.symbol.owner
191+
interpretStaticMethodCall(module, fn.symbol, args.flatten.map(interpretTree))
192+
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
193+
val module = fn.qualifier.symbol.moduleClass
194+
interpretStaticMethodCall(module, fn.symbol, args.flatten.map(interpretTree))
195+
} else if (env.contains(fn.name)) {
196+
env(fn.name)
197+
} else if (tree.symbol.is(InlineProxy)) {
198+
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
199+
} else {
200+
unexpectedTree(tree)
201+
}
202+
203+
// Interpret `foo(j = x, i = y)` which it is expanded to
204+
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
205+
case Block(stats, expr) => interpretBlock(stats, expr)
206+
case NamedArg(_, arg) => interpretTree(arg)
207+
208+
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
209+
210+
case Typed(expr, _) =>
211+
interpretTree(expr)
212+
213+
case SeqLiteral(elems, _) =>
214+
interpretVarargs(elems.map(e => interpretTree(e)))
215+
216+
case _ =>
217+
unexpectedTree(tree)
218+
}
219+
220+
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
221+
var unexpected: Option[Object] = None
222+
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
223+
case stat: ValDef =>
224+
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
225+
case stat =>
226+
if (unexpected.isEmpty)
227+
unexpected = Some(unexpectedTree(stat))
228+
accEnv
229+
})
230+
unexpected.getOrElse(interpretTree(expr)(newEnv))
231+
}
232+
233+
private def interpretQuote(tree: Tree)(implicit env: Env): Object =
97234
new scala.internal.quoted.TastyTreeExpr(Inlined(EmptyTree, Nil, tree).withSpan(tree.span))
98235

99-
protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
236+
private def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
100237
new scala.internal.quoted.TreeType(tree)
101238

102-
protected def interpretLiteral(value: Any)(implicit env: Env): Object =
239+
private def interpretLiteral(value: Any)(implicit env: Env): Object =
103240
value.asInstanceOf[Object]
104241

105-
protected def interpretVarargs(args: List[Object])(implicit env: Env): Object =
242+
private def interpretVarargs(args: List[Object])(implicit env: Env): Object =
106243
args.toSeq
107244

108-
protected def interpretQuoteContext()(implicit env: Env): Object =
245+
private def interpretQuoteContext()(implicit env: Env): Object =
109246
new scala.quoted.QuoteContext(ReflectionImpl(ctx, pos))
110247

111-
protected def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
248+
private def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
112249
val (inst, clazz) =
113250
if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) {
114251
(null, loadReplLineClass(moduleClass))
@@ -128,16 +265,16 @@ object Splicer {
128265
stopIfRuntimeException(method.invoke(inst, args: _*))
129266
}
130267

131-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
268+
private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
132269
loadModule(fn.moduleClass)
133270

134-
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Object = {
271+
private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
135272
val clazz = loadClass(fn.owner.fullName.toString)
136273
val constr = clazz.getConstructor(paramsSig(fn): _*)
137274
constr.newInstance(args: _*).asInstanceOf[Object]
138275
}
139276

140-
protected def unexpectedTree(tree: Tree)(implicit env: Env): Object =
277+
private def unexpectedTree(tree: Tree)(implicit env: Env): Object =
141278
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.sourcePos)
142279

143280
private def loadModule(sym: Symbol): Object = {
@@ -265,158 +402,22 @@ object Splicer {
265402

266403
}
267404

268-
/** Tree interpreter that tests if tree can be interpreted */
269-
private class CheckValidMacroBody(implicit ctx: Context) extends AbstractInterpreter {
270-
def checking: Boolean = true
271-
272-
type Result = Unit
273-
274-
def apply(tree: Tree): Unit = interpretTree(tree)(Map.empty)
275-
276-
protected def interpretQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
277-
protected def interpretTypeQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
278-
protected def interpretLiteral(value: Any)(implicit env: Env): Unit = ()
279-
protected def interpretVarargs(args: List[Unit])(implicit env: Env): Unit = ()
280-
protected def interpretQuoteContext()(implicit env: Env): Unit = ()
281-
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)
282-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Unit = ()
283-
protected def interpretNew(fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)
284-
285-
def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Unit = {
286-
// Assuming that top-level splices can only be in inline methods
287-
// and splices are expanded at inline site, references to inline values
288-
// will be known literal constant trees.
289-
if (!tree.symbol.is(Inline))
290-
ctx.error(
291-
"""Malformed macro.
292-
|
293-
|Expected the splice ${...} to contain a single call to a static method.
294-
|
295-
|Where parameters may be:
296-
| * Quoted paramers or fields
297-
| * References to inline parameters
298-
| * Literal values of primitive types
299-
""".stripMargin, tree.sourcePos)
300-
}
301-
}
302-
303-
/** Abstract Tree interpreter that can interpret calls to static methods with quoted or inline arguments */
304-
private abstract class AbstractInterpreter(implicit ctx: Context) {
305-
306-
def checking: Boolean
307-
308-
type Env = Map[Name, Result]
309-
type Result
310-
311-
protected def interpretQuote(tree: Tree)(implicit env: Env): Result
312-
protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Result
313-
protected def interpretLiteral(value: Any)(implicit env: Env): Result
314-
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
315-
protected def interpretQuoteContext()(implicit env: Env): Result
316-
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Result])(implicit env: Env): Result
317-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result
318-
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
319-
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
320-
321-
private final def removeErasedArguments(args: List[List[Tree]], fnTpe: Type): List[List[Tree]] =
322-
fnTpe match {
323-
case tp: TermRef => removeErasedArguments(args, tp.underlying)
324-
case tp: PolyType => removeErasedArguments(args, tp.resType)
325-
case tp: ExprType => removeErasedArguments(args, tp.resType)
326-
case tp: MethodType =>
327-
val tail = removeErasedArguments(args.tail, tp.resType)
328-
if (tp.isErasedMethod) tail else args.head :: tail
329-
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
330-
val tail = removeErasedArguments(args.tail, tp.args.last)
331-
if (defn.isErasedFunctionType(tp)) tail else args.head :: tail
332-
case tp => assert(args.isEmpty, tp); Nil
333-
}
334-
335-
protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match {
336-
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
337-
val quoted1 = quoted match {
338-
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
339-
// inline proxy for by-name parameter
340-
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
341-
case Inlined(EmptyTree, _, quoted) => quoted
342-
case _ => quoted
343-
}
344-
interpretQuote(quoted1)
345-
346-
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
347-
interpretTypeQuote(quoted)
348-
349-
case Literal(Constant(value)) =>
350-
interpretLiteral(value)
351-
352-
case _ if tree.symbol == defn.QuoteContext_macroContext =>
353-
interpretQuoteContext()
354-
355-
case Call(fn, args) =>
356-
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
357-
interpretNew(fn.symbol, args.flatten.map(interpretTree))
358-
} else if (fn.symbol.is(Module)) {
359-
interpretModuleAccess(fn.symbol)
360-
} else if (fn.symbol.isStatic) {
361-
val module = fn.symbol.owner
362-
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
363-
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
364-
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
365-
val module = fn.qualifier.symbol.moduleClass
366-
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
367-
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
368-
} else if (env.contains(fn.name)) {
369-
env(fn.name)
370-
} else if (tree.symbol.is(InlineProxy)) {
371-
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
372-
} else {
373-
unexpectedTree(tree)
374-
}
375-
376-
// Interpret `foo(j = x, i = y)` which it is expanded to
377-
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
378-
case Block(stats, expr) => interpretBlock(stats, expr)
379-
case NamedArg(_, arg) => interpretTree(arg)
380-
381-
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
382-
383-
case Typed(expr, _) =>
384-
interpretTree(expr)
385-
386-
case SeqLiteral(elems, _) =>
387-
interpretVarargs(elems.map(e => interpretTree(e)))
388-
389-
case _ =>
390-
unexpectedTree(tree)
391-
}
392-
393-
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
394-
var unexpected: Option[Result] = None
395-
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
396-
case stat: ValDef if stat.symbol.is(Synthetic) || !checking =>
397-
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
398-
case stat =>
399-
if (unexpected.isEmpty)
400-
unexpected = Some(unexpectedTree(stat))
401-
accEnv
402-
})
403-
unexpected.getOrElse(interpretTree(expr)(newEnv))
404-
}
405-
406-
object Call {
407-
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] =
408-
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
409-
410-
object Call0 {
411-
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = arg match {
412-
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
413-
Some((fn, args))
414-
case fn: RefTree => Some((fn, Nil))
415-
case Apply(Call0(fn, args1), args2) => Some((fn, args2 :: args1))
416-
case TypeApply(Call0(fn, args), _) => Some((fn, args))
417-
case _ => None
418-
}
405+
object Call {
406+
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] =
407+
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
408+
409+
private object Call0 {
410+
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] = arg match {
411+
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
412+
Some((fn, args))
413+
case fn: RefTree => Some((fn, Nil))
414+
case Apply(f @ Call0(fn, args1), args2) =>
415+
if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1))
416+
else Some((fn, args2 :: args1))
417+
case TypeApply(Call0(fn, args), _) => Some((fn, args))
418+
case _ => None
419419
}
420420
}
421421
}
422+
422423
}

0 commit comments

Comments
 (0)