Skip to content

Commit 15b9cd8

Browse files
Merge pull request #6830 from dotty-staging/support-erased-arguments-in-splicer
Support erased arguments in splicer
2 parents 9375804 + 9a431dc commit 15b9cd8

File tree

3 files changed

+64
-11
lines changed

3 files changed

+64
-11
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,20 @@ object Splicer {
318318
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
319319
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
320320

321+
private final def removeErasedArguments(args: List[List[Tree]], fnTpe: Type): List[List[Tree]] =
322+
fnTpe match {
323+
case tp: TermRef => removeErasedArguments(args, tp.underlying)
324+
case tp: PolyType => removeErasedArguments(args, tp.resType)
325+
case tp: ExprType => removeErasedArguments(args, tp.resType)
326+
case tp: MethodType =>
327+
val tail = removeErasedArguments(args.tail, tp.resType)
328+
if (tp.isErasedMethod) tail else args.head :: tail
329+
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
330+
val tail = removeErasedArguments(args.tail, tp.args.last)
331+
if (defn.isErasedFunctionType(tp)) tail else args.head :: tail
332+
case tp => assert(args.isEmpty, tp); Nil
333+
}
334+
321335
protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match {
322336
case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote =>
323337
val quoted1 = quoted match {
@@ -340,15 +354,17 @@ object Splicer {
340354

341355
case Call(fn, args) =>
342356
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
343-
interpretNew(fn.symbol, args.map(interpretTree))
357+
interpretNew(fn.symbol, args.flatten.map(interpretTree))
344358
} else if (fn.symbol.is(Module)) {
345359
interpretModuleAccess(fn.symbol)
346360
} else if (fn.symbol.isStatic) {
347361
val module = fn.symbol.owner
348-
interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg)))
362+
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
363+
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
349364
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
350365
val module = fn.qualifier.symbol.moduleClass
351-
interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg)))
366+
def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree)
367+
interpretStaticMethodCall(module, fn.symbol, interpretedArgs)
352368
} else if (env.contains(fn.name)) {
353369
env(fn.name)
354370
} else if (tree.symbol.is(InlineProxy)) {
@@ -388,15 +404,19 @@ object Splicer {
388404
}
389405

390406
object Call {
391-
def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match {
392-
case Select(Call(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
393-
Some((fn, args))
394-
case fn: RefTree => Some((fn, Nil))
395-
case Apply(Call(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
396-
case TypeApply(Call(fn, args), _) => Some((fn, args))
397-
case _ => None
407+
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] =
408+
Call0.unapply(arg).map((fn, args) => (fn, args.reverse))
409+
410+
object Call0 {
411+
def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = arg match {
412+
case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
413+
Some((fn, args))
414+
case fn: RefTree => Some((fn, Nil))
415+
case Apply(Call0(fn, args1), args2) => Some((fn, args2 :: args1))
416+
case TypeApply(Call0(fn, args), _) => Some((fn, args))
417+
case _ => None
418+
}
398419
}
399420
}
400421
}
401-
402422
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import scala.quoted._
2+
3+
object Macro {
4+
inline def foo1(i: Int) = $ { case1('{ i }) }
5+
inline def foo2(i: Int) = $ { case2(1)('{ i }) }
6+
inline def foo3(i: Int) = $ { case3('{ i })(1) }
7+
inline def foo4(i: Int) = $ { case4(1)('{ i }, '{ i }) }
8+
inline def foo5(i: Int) = $ { case5('{ i }, '{ i })(1) }
9+
inline def foo6(i: Int) = $ { case6(1)('{ i })('{ i }) }
10+
inline def foo7(i: Int) = $ { case7('{ i })(1)('{ i }) }
11+
inline def foo8(i: Int) = $ { case8('{ i })('{ i })(1) }
12+
13+
def case1 erased (i: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
14+
def case2 (i: Int) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
15+
def case3 erased (i: Expr[Int]) (j: Int) given (QuoteContext): Expr[Int] = '{ 0 }
16+
def case4 (h: Int) erased (i: Expr[Int], j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
17+
def case5 erased (i: Expr[Int], j: Expr[Int]) (h: Int) given (QuoteContext): Expr[Int] = '{ 0 }
18+
def case6 (h: Int) erased (i: Expr[Int]) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
19+
def case7 erased (i: Expr[Int]) (h: Int) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 }
20+
def case8 erased (i: Expr[Int]) erased (j: Expr[Int]) (h: Int) given (QuoteContext): Expr[Int] = '{ 0 }
21+
}
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
object Test {
2+
def main(args: Array[String]): Unit = {
3+
assert(Macro.foo1(1) == 0)
4+
assert(Macro.foo2(1) == 0)
5+
assert(Macro.foo3(1) == 0)
6+
assert(Macro.foo4(1) == 0)
7+
assert(Macro.foo5(1) == 0)
8+
assert(Macro.foo6(1) == 0)
9+
assert(Macro.foo7(1) == 0)
10+
assert(Macro.foo8(1) == 0)
11+
}
12+
}

0 commit comments

Comments
 (0)