Skip to content

Commit 1764e28

Browse files
committed
Split macro body checks from interpreter
1 parent 15b9cd8 commit 1764e28

File tree

2 files changed

+186
-167
lines changed

2 files changed

+186
-167
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 172 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,82 @@ object Splicer {
6464
*/
6565
def checkValidMacroBody(tree: Tree)(implicit ctx: Context): Unit = tree match {
6666
case Quoted(_) => // ok
67-
case _ => (new CheckValidMacroBody).apply(tree)
67+
case _ =>
68+
def checkValidStat(tree: Tree): Unit = tree match {
69+
case tree: ValDef if tree.symbol.is(Synthetic) =>
70+
// Check val from `foo(j = x, i = y)` which it is expanded to
71+
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
72+
checkIfValidArgument(tree.rhs)
73+
case _ =>
74+
ctx.error("Macro should not have statements", tree.sourcePos)
75+
}
76+
def checkIfValidArgument(tree: Tree): Unit = tree match {
77+
case Block(Nil, expr) => checkIfValidArgument(expr)
78+
case Typed(expr, _) => checkIfValidArgument(expr)
79+
80+
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
81+
// OK
82+
83+
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
84+
// OK
85+
86+
case Literal(Constant(value)) =>
87+
// OK
88+
89+
case _ if tree.symbol == defn.QuoteContext_macroContext =>
90+
// OK
91+
92+
case Call(fn, args)
93+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
94+
fn.symbol.is(Module) || fn.symbol.isStatic ||
95+
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
96+
args.flatten.foreach(checkIfValidArgument)
97+
98+
case NamedArg(_, arg) =>
99+
checkIfValidArgument(arg)
100+
101+
case SeqLiteral(elems, _) =>
102+
elems.foreach(checkIfValidArgument)
103+
104+
case tree: Ident if tree.symbol.is(Inline) || tree.symbol.is(Synthetic) =>
105+
// OK
106+
107+
case _ =>
108+
ctx.error(
109+
"""Malformed macro parameter
110+
|
111+
|Parameters may be:
112+
| * Quoted parameters or fields
113+
| * References to inline parameters
114+
| * Literal values of primitive types
115+
|""".stripMargin, tree.sourcePos)
116+
}
117+
def checkIfValidStaticCall(tree: Tree): Unit = tree match {
118+
case Block(stats, expr) =>
119+
stats.foreach(checkValidStat)
120+
checkIfValidStaticCall(expr)
121+
122+
case Typed(expr, _) => checkIfValidStaticCall(expr)
123+
case Call(fn, args)
124+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
125+
fn.symbol.is(Module) || fn.symbol.isStatic ||
126+
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
127+
args.flatten.foreach(checkIfValidArgument)
128+
case _ =>
129+
ctx.error(
130+
"""Malformed macro.
131+
|
132+
|Expected the splice ${...} to contain a single call to a static method.
133+
|""".stripMargin, tree.sourcePos)
134+
}
135+
136+
checkIfValidStaticCall(tree)
68137
}
69138

70139
/** Tree interpreter that evaluates the tree */
71-
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) extends AbstractInterpreter {
72-
73-
def checking: Boolean = false
140+
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) {
74141

75-
type Result = Object
142+
type Env = Map[Name, Object]
76143

77144
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
78145
* Return Some of the result or None if some error happen during the interpretation.
@@ -93,6 +160,92 @@ object Splicer {
93160
}
94161
}
95162

163+
protected final def interpretTree(tree: Tree)(implicit env: Env): Object = tree match {
164+
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
165+
val quoted1 = quoted match {
166+
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
167+
// inline proxy for by-name parameter
168+
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
169+
case Inlined(EmptyTree, _, quoted) => quoted
170+
case _ => quoted
171+
}
172+
interpretQuote(quoted1)
173+
174+
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
175+
interpretTypeQuote(quoted)
176+
177+
case Literal(Constant(value)) =>
178+
interpretLiteral(value)
179+
180+
case _ if tree.symbol == defn.QuoteContext_macroContext =>
181+
interpretQuoteContext()
182+
183+
// TODO disallow interpreted method calls as arguments
184+
case Call(fn, args) =>
185+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
186+
interpretNew(fn.symbol, args.flatten.map(interpretTree))
187+
} else if (fn.symbol.is(Module)) {
188+
interpretModuleAccess(fn.symbol)
189+
} else if (fn.symbol.isStatic) {
190+
val module = fn.symbol.owner
191+
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
192+
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
193+
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
194+
val module = fn.qualifier.symbol.moduleClass
195+
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
196+
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
197+
} else if (env.contains(fn.name)) {
198+
env(fn.name)
199+
} else if (tree.symbol.is(InlineProxy)) {
200+
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
201+
} else {
202+
unexpectedTree(tree)
203+
}
204+
205+
// Interpret `foo(j = x, i = y)` which it is expanded to
206+
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
207+
case Block(stats, expr) => interpretBlock(stats, expr)
208+
case NamedArg(_, arg) => interpretTree(arg)
209+
210+
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
211+
212+
case Typed(expr, _) =>
213+
interpretTree(expr)
214+
215+
case SeqLiteral(elems, _) =>
216+
interpretVarargs(elems.map(e => interpretTree(e)))
217+
218+
case _ =>
219+
unexpectedTree(tree)
220+
}
221+
222+
private final def removeErasedArguments(args: List[List[Tree]], fnTpe: Type): List[List[Tree]] =
223+
fnTpe match {
224+
case tp: TermRef => removeErasedArguments(args, tp.underlying)
225+
case tp: PolyType => removeErasedArguments(args, tp.resType)
226+
case tp: ExprType => removeErasedArguments(args, tp.resType)
227+
case tp: MethodType =>
228+
val tail = removeErasedArguments(args.tail, tp.resType)
229+
if (tp.isErasedMethod) tail else args.head :: tail
230+
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
231+
val tail = removeErasedArguments(args.tail, tp.args.last)
232+
if (defn.isErasedFunctionType(tp)) tail else args.head :: tail
233+
case tp => assert(args.isEmpty, tp); Nil
234+
}
235+
236+
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
237+
var unexpected: Option[Object] = None
238+
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
239+
case stat: ValDef =>
240+
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
241+
case stat =>
242+
if (unexpected.isEmpty)
243+
unexpected = Some(unexpectedTree(stat))
244+
accEnv
245+
})
246+
unexpected.getOrElse(interpretTree(expr)(newEnv))
247+
}
248+
96249
protected def interpretQuote(tree: Tree)(implicit env: Env): Object =
97250
new scala.internal.quoted.TastyTreeExpr(Inlined(EmptyTree, Nil, tree).withSpan(tree.span))
98251

@@ -131,7 +284,7 @@ object Splicer {
131284
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
132285
loadModule(fn.moduleClass)
133286

134-
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Object = {
287+
protected def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
135288
val clazz = loadClass(fn.owner.fullName.toString)
136289
val constr = clazz.getConstructor(paramsSig(fn): _*)
137290
constr.newInstance(args: _*).asInstanceOf[Object]
@@ -265,158 +418,20 @@ object Splicer {
265418

266419
}
267420

268-
/** Tree interpreter that tests if tree can be interpreted */
269-
private class CheckValidMacroBody(implicit ctx: Context) extends AbstractInterpreter {
270-
def checking: Boolean = true
271-
272-
type Result = Unit
273-
274-
def apply(tree: Tree): Unit = interpretTree(tree)(Map.empty)
275-
276-
protected def interpretQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
277-
protected def interpretTypeQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
278-
protected def interpretLiteral(value: Any)(implicit env: Env): Unit = ()
279-
protected def interpretVarargs(args: List[Unit])(implicit env: Env): Unit = ()
280-
protected def interpretQuoteContext()(implicit env: Env): Unit = ()
281-
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)
282-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Unit = ()
283-
protected def interpretNew(fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)
284-
285-
def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Unit = {
286-
// Assuming that top-level splices can only be in inline methods
287-
// and splices are expanded at inline site, references to inline values
288-
// will be known literal constant trees.
289-
if (!tree.symbol.is(Inline))
290-
ctx.error(
291-
"""Malformed macro.
292-
|
293-
|Expected the splice ${...} to contain a single call to a static method.
294-
|
295-
|Where parameters may be:
296-
| * Quoted paramers or fields
297-
| * References to inline parameters
298-
| * Literal values of primitive types
299-
""".stripMargin, tree.sourcePos)
300-
}
301-
}
302-
303-
/** Abstract Tree interpreter that can interpret calls to static methods with quoted or inline arguments */
304-
private abstract class AbstractInterpreter(implicit ctx: Context) {
305-
306-
def checking: Boolean
307-
308-
type Env = Map[Name, Result]
309-
type Result
310-
311-
protected def interpretQuote(tree: Tree)(implicit env: Env): Result
312-
protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Result
313-
protected def interpretLiteral(value: Any)(implicit env: Env): Result
314-
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
315-
protected def interpretQuoteContext()(implicit env: Env): Result
316-
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Result])(implicit env: Env): Result
317-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result
318-
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
319-
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
320-
321-
private final def removeErasedArguments(args: List[List[Tree]], fnTpe: Type): List[List[Tree]] =
322-
fnTpe match {
323-
case tp: TermRef => removeErasedArguments(args, tp.underlying)
324-
case tp: PolyType => removeErasedArguments(args, tp.resType)
325-
case tp: ExprType => removeErasedArguments(args, tp.resType)
326-
case tp: MethodType =>
327-
val tail = removeErasedArguments(args.tail, tp.resType)
328-
if (tp.isErasedMethod) tail else args.head :: tail
329-
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
330-
val tail = removeErasedArguments(args.tail, tp.args.last)
331-
if (defn.isErasedFunctionType(tp)) tail else args.head :: tail
332-
case tp => assert(args.isEmpty, tp); Nil
333-
}
334-
335-
protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match {
336-
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
337-
val quoted1 = quoted match {
338-
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
339-
// inline proxy for by-name parameter
340-
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
341-
case Inlined(EmptyTree, _, quoted) => quoted
342-
case _ => quoted
343-
}
344-
interpretQuote(quoted1)
345-
346-
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
347-
interpretTypeQuote(quoted)
348-
349-
case Literal(Constant(value)) =>
350-
interpretLiteral(value)
351-
352-
case _ if tree.symbol == defn.QuoteContext_macroContext =>
353-
interpretQuoteContext()
354-
355-
case Call(fn, args) =>
356-
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
357-
interpretNew(fn.symbol, args.flatten.map(interpretTree))
358-
} else if (fn.symbol.is(Module)) {
359-
interpretModuleAccess(fn.symbol)
360-
} else if (fn.symbol.isStatic) {
361-
val module = fn.symbol.owner
362-
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
363-
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
364-
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
365-
val module = fn.qualifier.symbol.moduleClass
366-
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
367-
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
368-
} else if (env.contains(fn.name)) {
369-
env(fn.name)
370-
} else if (tree.symbol.is(InlineProxy)) {
371-
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
372-
} else {
373-
unexpectedTree(tree)
374-
}
375-
376-
// Interpret `foo(j = x, i = y)` which it is expanded to
377-
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
378-
case Block(stats, expr) => interpretBlock(stats, expr)
379-
case NamedArg(_, arg) => interpretTree(arg)
380-
381-
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
382-
383-
case Typed(expr, _) =>
384-
interpretTree(expr)
385-
386-
case SeqLiteral(elems, _) =>
387-
interpretVarargs(elems.map(e => interpretTree(e)))
388-
389-
case _ =>
390-
unexpectedTree(tree)
391-
}
392-
393-
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
394-
var unexpected: Option[Result] = None
395-
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
396-
case stat: ValDef if stat.symbol.is(Synthetic) || !checking =>
397-
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
398-
case stat =>
399-
if (unexpected.isEmpty)
400-
unexpected = Some(unexpectedTree(stat))
401-
accEnv
402-
})
403-
unexpected.getOrElse(interpretTree(expr)(newEnv))
404-
}
405-
406-
object Call {
407-
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] =
408-
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
409-
410-
object Call0 {
411-
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = arg match {
412-
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
413-
Some((fn, args))
414-
case fn: RefTree => Some((fn, Nil))
415-
case Apply(Call0(fn, args1), args2) => Some((fn, args2 :: args1))
416-
case TypeApply(Call0(fn, args), _) => Some((fn, args))
417-
case _ => None
418-
}
421+
object Call {
422+
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] =
423+
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
424+
425+
object Call0 {
426+
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] = arg match {
427+
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
428+
Some((fn, args))
429+
case fn: RefTree => Some((fn, Nil))
430+
case Apply(Call0(fn, args1), args2) => Some((fn, args2 :: args1))
431+
case TypeApply(Call0(fn, args), _) => Some((fn, args))
432+
case _ => None
419433
}
420434
}
421435
}
436+
422437
}

tests/neg-macros/quote-complex-top-splice.scala

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,22 @@ object Test {
66

77
inline def foo1: Unit = ${
88
val x = 1 // error
9-
impl(x)
9+
impl(x) // error
1010
}
1111

12-
inline def foo2: Unit = ${ impl({
13-
val x = 1 // error
14-
x
15-
}) }
16-
17-
inline def foo3: Unit = ${ impl({
18-
println("foo3") // error
19-
3
20-
}) }
12+
inline def foo2: Unit = ${ impl(
13+
{ // error
14+
val x = 1
15+
x
16+
}
17+
) }
18+
19+
inline def foo3: Unit = ${ impl(
20+
{ // error
21+
println("foo3")
22+
3
23+
}
24+
) }
2125

2226
inline def foo4: Unit = ${
2327
println("foo4") // error

0 commit comments

Comments
 (0)