From 70c85a65526579e4cdc3038467c58b84d8efdf19 Mon Sep 17 00:00:00 2001 From: Olivier Blanvillain Date: Tue, 9 Jul 2019 17:42:51 +0200 Subject: [PATCH 1/2] Support erased arguments in splicer --- .../dotty/tools/dotc/transform/Splicer.scala | 28 +++++++++++++++++-- tests/run-macros/erased-arg-macro/1.scala | 21 ++++++++++++++ tests/run-macros/erased-arg-macro/2.scala | 10 +++++++ 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 tests/run-macros/erased-arg-macro/1.scala create mode 100644 tests/run-macros/erased-arg-macro/2.scala diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index 6c117812cdfc..275c72ca6092 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -321,6 +321,28 @@ object Splicer { protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result protected def unexpectedTree(tree: Tree)(implicit env: Env): Result + private final def removeEraisedArguments(args: List[Tree], fnTpe: Type): List[Tree] = { + var result = args + var index = 0 + def loop(tp: Type): Unit = tp match { + case tp: TermRef => loop(tp.underlying) + case tp: PolyType => loop(tp.resType) + case tp: MethodType if tp.isErasedMethod => + tp.paramInfos.foreach { _ => + result = result.updated(index, null) + index += 1 + } + loop(tp.resType) + case tp: MethodType => + index += tp.paramInfos.size + loop(tp.resType) + case _ => () + } + loop(fnTpe) + assert(index == args.size) + result.filterNot(null.eq) + } + protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match { case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote => val quoted1 = quoted match { @@ -351,10 +373,12 @@ object Splicer { interpretModuleAccess(fn.symbol) } else if (fn.symbol.isStatic) { val module = fn.symbol.owner - interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg))) + def interpretedArgs = removeEraisedArguments(args, fn.tpe).map(arg => interpretTree(arg)) + interpretStaticMethodCall(module, fn.symbol, interpretedArgs) } else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) { val module = fn.qualifier.symbol.moduleClass - interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg))) + def interpretedArgs = removeEraisedArguments(args, fn.tpe).map(arg => interpretTree(arg)) + interpretStaticMethodCall(module, fn.symbol, interpretedArgs) } else if (env.contains(fn.name)) { env(fn.name) } else if (tree.symbol.is(InlineProxy)) { diff --git a/tests/run-macros/erased-arg-macro/1.scala b/tests/run-macros/erased-arg-macro/1.scala new file mode 100644 index 000000000000..427ca0b1b9d7 --- /dev/null +++ b/tests/run-macros/erased-arg-macro/1.scala @@ -0,0 +1,21 @@ +import scala.quoted._ + +object Macro { + inline def foo1(i: Int) = $ { case1('{ i }) } + inline def foo2(i: Int) = $ { case2(1)('{ i }) } + inline def foo3(i: Int) = $ { case3('{ i })(1) } + inline def foo4(i: Int) = $ { case4(1)('{ i }, '{ i }) } + inline def foo5(i: Int) = $ { case5('{ i }, '{ i })(1) } + inline def foo6(i: Int) = $ { case6(1)('{ i })('{ i }) } + inline def foo7(i: Int) = $ { case7('{ i })(1)('{ i }) } + inline def foo8(i: Int) = $ { case8('{ i })('{ i })(1) } + + def case1 erased (i: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case2 (i: Int) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case3 erased (i: Expr[Int]) (j: Int) given (QuoteContext): Expr[Int] = '{ 0 } + def case4 (h: Int) erased (i: Expr[Int], j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case5 erased (i: Expr[Int], j: Expr[Int]) (h: Int) given (QuoteContext): Expr[Int] = '{ 0 } + def case6 (h: Int) erased (i: Expr[Int]) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case7 erased (i: Expr[Int]) (h: Int) erased (j: Expr[Int]) given (QuoteContext): Expr[Int] = '{ 0 } + def case8 erased (i: Expr[Int]) erased (j: Expr[Int]) (h: Int) given (QuoteContext): Expr[Int] = '{ 0 } +} diff --git a/tests/run-macros/erased-arg-macro/2.scala b/tests/run-macros/erased-arg-macro/2.scala new file mode 100644 index 000000000000..7abfef78bcb0 --- /dev/null +++ b/tests/run-macros/erased-arg-macro/2.scala @@ -0,0 +1,10 @@ +object Test { + assert(Macro.foo1(1) == 0) + assert(Macro.foo2(1) == 0) + assert(Macro.foo3(1) == 0) + assert(Macro.foo4(1) == 0) + assert(Macro.foo5(1) == 0) + assert(Macro.foo6(1) == 0) + assert(Macro.foo7(1) == 0) + assert(Macro.foo8(1) == 0) +} From 9a431dc0aa0ab8bc0cf103de6c35f3a2464b8158 Mon Sep 17 00:00:00 2001 From: Olivier Blanvillain Date: Wed, 10 Jul 2019 17:41:37 +0200 Subject: [PATCH 2/2] Address review --- .../dotty/tools/dotc/transform/Splicer.scala | 56 +++++++++---------- tests/run-macros/erased-arg-macro/2.scala | 18 +++--- 2 files changed, 36 insertions(+), 38 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index 275c72ca6092..21cdee0f0eb6 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -321,27 +321,19 @@ object Splicer { protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result protected def unexpectedTree(tree: Tree)(implicit env: Env): Result - private final def removeEraisedArguments(args: List[Tree], fnTpe: Type): List[Tree] = { - var result = args - var index = 0 - def loop(tp: Type): Unit = tp match { - case tp: TermRef => loop(tp.underlying) - case tp: PolyType => loop(tp.resType) - case tp: MethodType if tp.isErasedMethod => - tp.paramInfos.foreach { _ => - result = result.updated(index, null) - index += 1 - } - loop(tp.resType) + private final def removeErasedArguments(args: List[List[Tree]], fnTpe: Type): List[List[Tree]] = + fnTpe match { + case tp: TermRef => removeErasedArguments(args, tp.underlying) + case tp: PolyType => removeErasedArguments(args, tp.resType) + case tp: ExprType => removeErasedArguments(args, tp.resType) case tp: MethodType => - index += tp.paramInfos.size - loop(tp.resType) - case _ => () + val tail = removeErasedArguments(args.tail, tp.resType) + if (tp.isErasedMethod) tail else args.head :: tail + case tp: AppliedType if defn.isImplicitFunctionType(tp) => + val tail = removeErasedArguments(args.tail, tp.args.last) + if (defn.isErasedFunctionType(tp)) tail else args.head :: tail + case tp => assert(args.isEmpty, tp); Nil } - loop(fnTpe) - assert(index == args.size) - result.filterNot(null.eq) - } protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match { case Apply(TypeApply(fn, _), quoted :: Nil) if fn.symbol == defn.InternalQuoted_exprQuote => @@ -368,16 +360,16 @@ object Splicer { case Call(fn, args) => if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) { - interpretNew(fn.symbol, args.map(interpretTree)) + interpretNew(fn.symbol, args.flatten.map(interpretTree)) } else if (fn.symbol.is(Module)) { interpretModuleAccess(fn.symbol) } else if (fn.symbol.isStatic) { val module = fn.symbol.owner - def interpretedArgs = removeEraisedArguments(args, fn.tpe).map(arg => interpretTree(arg)) + def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree) interpretStaticMethodCall(module, fn.symbol, interpretedArgs) } else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) { val module = fn.qualifier.symbol.moduleClass - def interpretedArgs = removeEraisedArguments(args, fn.tpe).map(arg => interpretTree(arg)) + def interpretedArgs = removeErasedArguments(args, fn.tpe).flatten.map(interpretTree) interpretStaticMethodCall(module, fn.symbol, interpretedArgs) } else if (env.contains(fn.name)) { env(fn.name) @@ -418,15 +410,19 @@ object Splicer { } object Call { - def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match { - case Select(Call(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) => - Some((fn, args)) - case fn: RefTree => Some((fn, Nil)) - case Apply(Call(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance - case TypeApply(Call(fn, args), _) => Some((fn, args)) - case _ => None + def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = + Call0.unapply(arg).map((fn, args) => (fn, args.reverse)) + + object Call0 { + def unapply(arg: Tree): Option[(RefTree, List[List[Tree]])] = arg match { + case Select(Call0(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) => + Some((fn, args)) + case fn: RefTree => Some((fn, Nil)) + case Apply(Call0(fn, args1), args2) => Some((fn, args2 :: args1)) + case TypeApply(Call0(fn, args), _) => Some((fn, args)) + case _ => None + } } } } - } diff --git a/tests/run-macros/erased-arg-macro/2.scala b/tests/run-macros/erased-arg-macro/2.scala index 7abfef78bcb0..1f7f8be436c7 100644 --- a/tests/run-macros/erased-arg-macro/2.scala +++ b/tests/run-macros/erased-arg-macro/2.scala @@ -1,10 +1,12 @@ object Test { - assert(Macro.foo1(1) == 0) - assert(Macro.foo2(1) == 0) - assert(Macro.foo3(1) == 0) - assert(Macro.foo4(1) == 0) - assert(Macro.foo5(1) == 0) - assert(Macro.foo6(1) == 0) - assert(Macro.foo7(1) == 0) - assert(Macro.foo8(1) == 0) + def main(args: Array[String]): Unit = { + assert(Macro.foo1(1) == 0) + assert(Macro.foo2(1) == 0) + assert(Macro.foo3(1) == 0) + assert(Macro.foo4(1) == 0) + assert(Macro.foo5(1) == 0) + assert(Macro.foo6(1) == 0) + assert(Macro.foo7(1) == 0) + assert(Macro.foo8(1) == 0) + } }