Skip to content

Commit e2130b9

Browse files
Merge pull request #6831 from dotty-staging/split-macro-body-check-from-interpreter
Split macro body check from interpreter
2 parents e94662a + 5f48cd4 commit e2130b9

File tree

2 files changed

+188
-176
lines changed

2 files changed

+188
-176
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 174 additions & 166 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,85 @@ object Splicer {
6464
*/
6565
def checkValidMacroBody(tree: Tree)(implicit ctx: Context): Unit = tree match {
6666
case Quoted(_) => // ok
67-
case _ => (new CheckValidMacroBody).apply(tree)
67+
case _ =>
68+
def checkValidStat(tree: Tree): Unit = tree match {
69+
case tree: ValDef if tree.symbol.is(Synthetic) =>
70+
// Check val from `foo(j = x, i = y)` which it is expanded to
71+
// `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
72+
checkIfValidArgument(tree.rhs)
73+
case _ =>
74+
ctx.error("Macro should not have statements", tree.sourcePos)
75+
}
76+
def checkIfValidArgument(tree: Tree): Unit = tree match {
77+
case Block(Nil, expr) => checkIfValidArgument(expr)
78+
case Typed(expr, _) => checkIfValidArgument(expr)
79+
80+
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
81+
// OK
82+
83+
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
84+
// OK
85+
86+
case Literal(Constant(value)) =>
87+
// OK
88+
89+
case _ if tree.symbol == defn.QuoteContext_macroContext =>
90+
// OK
91+
92+
case Call(fn, args)
93+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
94+
fn.symbol.is(Module) || fn.symbol.isStatic ||
95+
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
96+
args.foreach(_.foreach(checkIfValidArgument))
97+
98+
case NamedArg(_, arg) =>
99+
checkIfValidArgument(arg)
100+
101+
case SeqLiteral(elems, _) =>
102+
elems.foreach(checkIfValidArgument)
103+
104+
case tree: Ident if tree.symbol.is(Inline) || tree.symbol.is(Synthetic) =>
105+
// OK
106+
107+
case _ =>
108+
ctx.error(
109+
"""Malformed macro parameter
110+
|
111+
|Parameters may be:
112+
| * Quoted parameters or fields
113+
| * References to inline parameters
114+
| * Literal values of primitive types
115+
|""".stripMargin, tree.sourcePos)
116+
}
117+
def checkIfValidStaticCall(tree: Tree): Unit = tree match {
118+
case Block(stats, expr) =>
119+
stats.foreach(checkValidStat)
120+
checkIfValidStaticCall(expr)
121+
122+
case Typed(expr, _) =>
123+
checkIfValidStaticCall(expr)
124+
125+
case Call(fn, args)
126+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) ||
127+
fn.symbol.is(Module) || fn.symbol.isStatic ||
128+
(fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) =>
129+
args.flatten.foreach(checkIfValidArgument)
130+
131+
case _ =>
132+
ctx.error(
133+
"""Malformed macro.
134+
|
135+
|Expected the splice ${...} to contain a single call to a static method.
136+
|""".stripMargin, tree.sourcePos)
137+
}
138+
139+
checkIfValidStaticCall(tree)
68140
}
69141

70142
/** Tree interpreter that evaluates the tree */
71-
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) extends AbstractInterpreter {
143+
private class Interpreter(pos: SourcePosition, classLoader: ClassLoader)(implicit ctx: Context) {
72144

73-
def checking: Boolean = false
74-
75-
type Result = Object
145+
type Env = Map[Name, Object]
76146

77147
/** Returns the interpreted result of interpreting the code a call to the symbol with default arguments.
78148
* Return Some of the result or None if some error happen during the interpretation.
@@ -93,22 +163,92 @@ object Splicer {
93163
}
94164
}
95165

96-
protected def interpretQuote(tree: Tree)(implicit env: Env): Object =
166+
def interpretTree(tree: Tree)(implicit env: Env): Object = tree match {
167+
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
168+
val quoted1 = quoted match {
169+
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
170+
// inline proxy for by-name parameter
171+
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
172+
case Inlined(EmptyTree, _, quoted) => quoted
173+
case _ => quoted
174+
}
175+
interpretQuote(quoted1)
176+
177+
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
178+
interpretTypeQuote(quoted)
179+
180+
case Literal(Constant(value)) =>
181+
interpretLiteral(value)
182+
183+
case _ if tree.symbol == defn.QuoteContext_macroContext =>
184+
interpretQuoteContext()
185+
186+
// TODO disallow interpreted method calls as arguments
187+
case Call(fn, args) =>
188+
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
189+
interpretNew(fn.symbol, args.flatten.map(interpretTree))
190+
} else if (fn.symbol.is(Module)) {
191+
interpretModuleAccess(fn.symbol)
192+
} else if (fn.symbol.isStatic) {
193+
val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol)
194+
staticMethodCall(args.flatten.map(interpretTree))
195+
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
196+
val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol)
197+
staticMethodCall(args.flatten.map(interpretTree))
198+
} else if (env.contains(fn.name)) {
199+
env(fn.name)
200+
} else if (tree.symbol.is(InlineProxy)) {
201+
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
202+
} else {
203+
unexpectedTree(tree)
204+
}
205+
206+
// Interpret `foo(j = x, i = y)` which it is expanded to
207+
// `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)`
208+
case Block(stats, expr) => interpretBlock(stats, expr)
209+
case NamedArg(_, arg) => interpretTree(arg)
210+
211+
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
212+
213+
case Typed(expr, _) =>
214+
interpretTree(expr)
215+
216+
case SeqLiteral(elems, _) =>
217+
interpretVarargs(elems.map(e => interpretTree(e)))
218+
219+
case _ =>
220+
unexpectedTree(tree)
221+
}
222+
223+
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
224+
var unexpected: Option[Object] = None
225+
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
226+
case stat: ValDef =>
227+
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
228+
case stat =>
229+
if (unexpected.isEmpty)
230+
unexpected = Some(unexpectedTree(stat))
231+
accEnv
232+
})
233+
unexpected.getOrElse(interpretTree(expr)(newEnv))
234+
}
235+
236+
private def interpretQuote(tree: Tree)(implicit env: Env): Object =
97237
new scala.internal.quoted.TastyTreeExpr(Inlined(EmptyTree, Nil, tree).withSpan(tree.span))
98238

99-
protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
239+
private def interpretTypeQuote(tree: Tree)(implicit env: Env): Object =
100240
new scala.internal.quoted.TreeType(tree)
101241

102-
protected def interpretLiteral(value: Any)(implicit env: Env): Object =
242+
private def interpretLiteral(value: Any)(implicit env: Env): Object =
103243
value.asInstanceOf[Object]
104244

105-
protected def interpretVarargs(args: List[Object])(implicit env: Env): Object =
245+
private def interpretVarargs(args: List[Object])(implicit env: Env): Object =
106246
args.toSeq
107247

108-
protected def interpretQuoteContext()(implicit env: Env): Object =
248+
private def interpretQuoteContext()(implicit env: Env): Object =
109249
new scala.quoted.QuoteContext(ReflectionImpl(ctx, pos))
110250

111-
protected def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
251+
private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = {
112252
val (inst, clazz) =
113253
if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) {
114254
(null, loadReplLineClass(moduleClass))
@@ -125,19 +265,20 @@ object Splicer {
125265

126266
val name = getDirectName(fn.info.finalResultType, fn.name.asTermName)
127267
val method = getMethod(clazz, name, paramsSig(fn))
128-
stopIfRuntimeException(method.invoke(inst, args: _*))
268+
269+
(args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*))
129270
}
130271

131-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
272+
private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
132273
loadModule(fn.moduleClass)
133274

134-
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Object = {
275+
private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
135276
val clazz = loadClass(fn.owner.fullName.toString)
136277
val constr = clazz.getConstructor(paramsSig(fn): _*)
137278
constr.newInstance(args: _*).asInstanceOf[Object]
138279
}
139280

140-
protected def unexpectedTree(tree: Tree)(implicit env: Env): Object =
281+
private def unexpectedTree(tree: Tree)(implicit env: Env): Object =
141282
throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.sourcePos)
142283

143284
private def loadModule(sym: Symbol): Object = {
@@ -265,158 +406,25 @@ object Splicer {
265406

266407
}
267408

268-
/** Tree interpreter that tests if tree can be interpreted */
269-
private class CheckValidMacroBody(implicit ctx: Context) extends AbstractInterpreter {
270-
def checking: Boolean = true
271-
272-
type Result = Unit
273-
274-
def apply(tree: Tree): Unit = interpretTree(tree)(Map.empty)
275-
276-
protected def interpretQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
277-
protected def interpretTypeQuote(tree: tpd.Tree)(implicit env: Env): Unit = ()
278-
protected def interpretLiteral(value: Any)(implicit env: Env): Unit = ()
279-
protected def interpretVarargs(args: List[Unit])(implicit env: Env): Unit = ()
280-
protected def interpretQuoteContext()(implicit env: Env): Unit = ()
281-
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)
282-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Unit = ()
283-
protected def interpretNew(fn: Symbol, args: => List[Unit])(implicit env: Env): Unit = args.foreach(identity)
284-
285-
def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Unit = {
286-
// Assuming that top-level splices can only be in inline methods
287-
// and splices are expanded at inline site, references to inline values
288-
// will be known literal constant trees.
289-
if (!tree.symbol.is(Inline))
290-
ctx.error(
291-
"""Malformed macro.
292-
|
293-
|Expected the splice ${...} to contain a single call to a static method.
294-
|
295-
|Where parameters may be:
296-
| * Quoted paramers or fields
297-
| * References to inline parameters
298-
| * Literal values of primitive types
299-
""".stripMargin, tree.sourcePos)
300-
}
301-
}
302-
303-
/** Abstract Tree interpreter that can interpret calls to static methods with quoted or inline arguments */
304-
private abstract class AbstractInterpreter(implicit ctx: Context) {
305-
306-
def checking: Boolean
307-
308-
type Env = Map[Name, Result]
309-
type Result
310-
311-
protected def interpretQuote(tree: Tree)(implicit env: Env): Result
312-
protected def interpretTypeQuote(tree: Tree)(implicit env: Env): Result
313-
protected def interpretLiteral(value: Any)(implicit env: Env): Result
314-
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
315-
protected def interpretQuoteContext()(implicit env: Env): Result
316-
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Result])(implicit env: Env): Result
317-
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result
318-
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
319-
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
320-
321-
private final def removeErasedArguments(args: List[List[Tree]], fnTpe: Type): List[List[Tree]] =
322-
fnTpe match {
323-
case tp: TermRef => removeErasedArguments(args, tp.underlying)
324-
case tp: PolyType => removeErasedArguments(args, tp.resType)
325-
case tp: ExprType => removeErasedArguments(args, tp.resType)
326-
case tp: MethodType =>
327-
val tail = removeErasedArguments(args.tail, tp.resType)
328-
if (tp.isErasedMethod) tail else args.head :: tail
329-
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
330-
val tail = removeErasedArguments(args.tail, tp.args.last)
331-
if (defn.isErasedFunctionType(tp)) tail else args.head :: tail
332-
case tp => assert(args.isEmpty, tp); Nil
333-
}
334-
335-
protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match {
336-
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
337-
val quoted1 = quoted match {
338-
case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) =>
339-
// inline proxy for by-name parameter
340-
quoted.symbol.defTree.asInstanceOf[DefDef].rhs
341-
case Inlined(EmptyTree, _, quoted) => quoted
342-
case _ => quoted
343-
}
344-
interpretQuote(quoted1)
345-
346-
case TypeApply(fn, quoted :: Nil) if fn.symbol == defn.InternalQuoted_typeQuote =>
347-
interpretTypeQuote(quoted)
348-
349-
case Literal(Constant(value)) =>
350-
interpretLiteral(value)
351-
352-
case _ if tree.symbol == defn.QuoteContext_macroContext =>
353-
interpretQuoteContext()
354-
355-
case Call(fn, args) =>
356-
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
357-
interpretNew(fn.symbol, args.flatten.map(interpretTree))
358-
} else if (fn.symbol.is(Module)) {
359-
interpretModuleAccess(fn.symbol)
360-
} else if (fn.symbol.isStatic) {
361-
val module = fn.symbol.owner
362-
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
363-
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
364-
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
365-
val module = fn.qualifier.symbol.moduleClass
366-
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
367-
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
368-
} else if (env.contains(fn.name)) {
369-
env(fn.name)
370-
} else if (tree.symbol.is(InlineProxy)) {
371-
interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs)
372-
} else {
373-
unexpectedTree(tree)
374-
}
375-
376-
// Interpret `foo(j = x, i = y)` which it is expanded to
377-
// `val j$1 = x; val i$1 = y; foo(i = y, j = x)`
378-
case Block(stats, expr) => interpretBlock(stats, expr)
379-
case NamedArg(_, arg) => interpretTree(arg)
380-
381-
case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion)
382-
383-
case Typed(expr, _) =>
384-
interpretTree(expr)
385-
386-
case SeqLiteral(elems, _) =>
387-
interpretVarargs(elems.map(e => interpretTree(e)))
388-
389-
case _ =>
390-
unexpectedTree(tree)
391-
}
392-
393-
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
394-
var unexpected: Option[Result] = None
395-
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {
396-
case stat: ValDef if stat.symbol.is(Synthetic) || !checking =>
397-
accEnv.updated(stat.name, interpretTree(stat.rhs)(accEnv))
398-
case stat =>
399-
if (unexpected.isEmpty)
400-
unexpected = Some(unexpectedTree(stat))
401-
accEnv
402-
})
403-
unexpected.getOrElse(interpretTree(expr)(newEnv))
404-
}
405-
406-
object Call {
407-
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] =
408-
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
409-
410-
object Call0 {
411-
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = arg match {
412-
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
413-
Some((fn, args))
414-
case fn: RefTree => Some((fn, Nil))
415-
case Apply(Call0(fn, args1), args2) => Some((fn, args2 :: args1))
416-
case TypeApply(Call0(fn, args), _) => Some((fn, args))
417-
case _ => None
418-
}
409+
object Call {
410+
/** Matches an expression that is either a field access or an application
411+
* It retruns a TermRef containing field accessed or a method reference and the arguments passed to it.
412+
*/
413+
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] =
414+
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
415+
416+
private object Call0 {
417+
def unapply(arg: Tree)(implicit ctx: Context): Option[(RefTree, List[List[Tree]])] = arg match {
418+
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
419+
Some((fn, args))
420+
case fn: RefTree => Some((fn, Nil))
421+
case Apply(f @ Call0(fn, args1), args2) =>
422+
if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1))
423+
else Some((fn, args2 :: args1))
424+
case TypeApply(Call0(fn, args), _) => Some((fn, args))
425+
case _ => None
419426
}
420427
}
421428
}
429+
422430
}

0 commit comments

Comments
 (0)