diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index 6c117812cdfc..21cdee0f0eb6 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -321,6 +321,20 @@ object Splicer { protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result protected def unexpectedTree(tree: Tree)(implicit env: Env): Result + private final def removeErasedArguments(args: List[List[Tree]], fnTpe: Type): List[List[Tree]] = + fnTpe match { + case tp: TermRef => removeErasedArguments(args, tp.underlying) + case tp: PolyType => removeErasedArguments(args, tp.resType) + case tp: ExprType => removeErasedArguments(args, tp.resType) + case tp: MethodType => + val tail = removeErasedArguments(args.tail, tp.resType) + if (tp.isErasedMethod) tail else args.head :: tail + case tp: AppliedType if defn.isImplicitFunctionType(tp) => + val tail = removeErasedArguments(args.tail, tp.args.last) + if (defn.isErasedFunctionType(tp)) tail else args.head :: tail + case tp => assert(args.isEmpty, tp); Nil + } + protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match { case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote => val quoted1 = quoted match { @@ -346,15 +360,17 @@ object Splicer { case Call(fn, args) => if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) { - interpretNew(fn.symbol, args.map(interpretTree)) + interpretNew(fn.symbol, args.flatten.map(interpretTree)) } else if (fn.symbol.is(Module)) { interpretModuleAccess(fn.symbol) } else if (fn.symbol.isStatic) { val module = fn.symbol.owner - interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg))) + def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree) + interpretStaticMethodCall(module, fn.symbol, interpretedArgs) } else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) { val module = fn.qualifier.symbol.moduleClass - interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg))) + def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree) + interpretStaticMethodCall(module, fn.symbol, interpretedArgs) } else if (env.contains(fn.name)) { env(fn.name) } else if (tree.symbol.is(InlineProxy)) { @@ -394,15 +410,19 @@ object Splicer { } object Call { - def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match { - case Select(Call(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) => - Some((fn, args)) - case fn: RefTree => Some((fn, Nil)) - case Apply(Call(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance - case TypeApply(Call(fn, args), _) => Some((fn, args)) - case _ => None + def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = + Call0.unapply(arg).map((fn, args) => (fn, args.reverse)) + + object Call0 { + def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = arg match { + case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) => + Some((fn, args)) + case fn: RefTree => Some((fn, Nil)) + case Apply(Call0(fn, args1), args2) => Some((fn, args2 :: args1)) + case TypeApply(Call0(fn, args), _) => Some((fn, args)) + case _ => None + } } } } - } diff --git a/tests/run-macros/erased-arg-macro/1.scala b/tests/run-macros/erased-arg-macro/1.scala new file mode 100644 index 000000000000..427ca0b1b9d7 --- /dev/null +++ b/tests/run-macros/erased-arg-macro/1.scala @@ -0,0 +1,21 @@ +import scala.quoted._ + +object Macro { + inline def foo1(i: Int) = $ { case1('{ i }) } + inline def foo2(i: Int) = $ { case2(1)('{ i }) } + inline def foo3(i: Int) = $ { case3('{ i })(1) } + inline def foo4(i: Int) = $ { case4(1)('{ i }, '{ i }) } + inline def foo5(i: Int) = $ { case5('{ i }, '{ i })(1) } + inline def foo6(i: Int) = $ { case6(1)('{ i })('{ i }) } + inline def foo7(i: Int) = $ { case7('{ i })(1)('{ i }) } + inline def foo8(i: Int) = $ { case8('{ i })('{ i })(1) } + + def case1 erased (i: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case2 (i: Int) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case3 erased (i: Expr[Int]) (j: Int) given (QuoteContext): Expr[Int] = '{ 0 } + def case4 (h: Int) erased (i: Expr[Int], j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case5 erased (i: Expr[Int], j: Expr[Int]) (h: Int) given (QuoteContext): Expr[Int] = '{ 0 } + def case6 (h: Int) erased (i: Expr[Int]) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case7 erased (i: Expr[Int]) (h: Int) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case8 erased (i: Expr[Int]) erased (j: Expr[Int]) (h: Int) given (QuoteContext): Expr[Int] = '{ 0 } +} diff --git a/tests/run-macros/erased-arg-macro/2.scala b/tests/run-macros/erased-arg-macro/2.scala new file mode 100644 index 000000000000..1f7f8be436c7 --- /dev/null +++ b/tests/run-macros/erased-arg-macro/2.scala @@ -0,0 +1,12 @@ +object Test { + def main(args: Array[String]): Unit = { + assert(Macro.foo1(1) == 0) + assert(Macro.foo2(1) == 0) + assert(Macro.foo3(1) == 0) + assert(Macro.foo4(1) == 0) + assert(Macro.foo5(1) == 0) + assert(Macro.foo6(1) == 0) + assert(Macro.foo7(1) == 0) + assert(Macro.foo8(1) == 0) + } +}