From 434ac1d1e11eb7ebc4e452b0887e66615690090e Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Fri, 19 Oct 2018 19:28:35 +0200 Subject: [PATCH 1/2] Allow macros to call method returning implicit functions --- .../dotty/tools/dotc/transform/Splicer.scala | 119 ++++++++++-------- .../test/dotc/run-test-pickling.blacklist | 1 + .../tasty-getfile-implicit-fun-context.check | 1 + .../App_2.scala | 6 + .../Macro_1.scala | 19 +++ 5 files changed, 95 insertions(+), 51 deletions(-) create mode 100644 tests/run/tasty-getfile-implicit-fun-context.check create mode 100644 tests/run/tasty-getfile-implicit-fun-context/App_2.scala create mode 100644 tests/run/tasty-getfile-implicit-fun-context/Macro_1.scala diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index 7900810e0185..122f21e1019b 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -15,7 +15,7 @@ import dotty.tools.dotc.core.StdNames.str.MODULE_INSTANCE_FIELD import dotty.tools.dotc.core.quoted._ import dotty.tools.dotc.core.Types._ import dotty.tools.dotc.core.Symbols._ -import dotty.tools.dotc.core.TypeErasure +import dotty.tools.dotc.core.{NameKinds, TypeErasure} import dotty.tools.dotc.core.Constants.Constant import dotty.tools.dotc.tastyreflect.TastyImpl @@ -105,14 +105,18 @@ object Splicer { protected def interpretVarargs(args: List[Object])(implicit env: Env): Object = args.toSeq - protected def interpretTastyContext()(implicit env: Env): Object = + protected def interpretTastyContext()(implicit env: Env): Object = { new TastyImpl(ctx) { override def rootPosition: SourcePosition = pos } + } - protected def interpretStaticMethodCall(fn: Tree, args: => List[Object])(implicit env: Env): Object = { - val instance = loadModule(fn.symbol.owner) - val method = getMethod(instance.getClass, fn.symbol.name, paramsSig(fn.symbol)) + protected def interpretStaticMethodCall(fn: Symbol, args: => List[Object])(implicit env: Env): Object = { + val instance = loadModule(fn.owner) + val name = + if (!defn.isImplicitFunctionType(fn.info.finalResultType)) fn.name + else NameKinds.DirectMethodName(fn.name.asTermName) // Call implicit function type direct method + val method = getMethod(instance.getClass, name, paramsSig(fn)) stopIfRuntimeException(method.invoke(instance, args: _*)) } @@ -190,50 +194,57 @@ object Splicer { /** List of classes of the parameters of the signature of `sym` */ private def paramsSig(sym: Symbol): List[Class[_]] = { - TypeErasure.erasure(sym.info) match { - case meth: MethodType => - meth.paramInfos.map { param => - def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match { - case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1) - case _ => (tpe, depth) - } - def javaArraySig(tpe: Type): String = { - val (elemType, depth) = arrayDepth(tpe, 0) - val sym = elemType.classSymbol - val suffix = - if (sym == defn.BooleanClass) "Z" - else if (sym == defn.ByteClass) "B" - else if (sym == defn.ShortClass) "S" - else if (sym == defn.IntClass) "I" - else if (sym == defn.LongClass) "J" - else if (sym == defn.FloatClass) "F" - else if (sym == defn.DoubleClass) "D" - else if (sym == defn.CharClass) "C" - else "L" + javaSig(elemType) + ";" - ("[" * depth) + suffix - } - def javaSig(tpe: Type): String = tpe match { - case tpe: JavaArrayType => javaArraySig(tpe) - case _ => - // Take the flatten name of the class and the full package name - val pack = tpe.classSymbol.topLevelClass.owner - val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "." - packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString - } - - val sym = param.classSymbol - if (sym == defn.BooleanClass) classOf[Boolean] - else if (sym == defn.ByteClass) classOf[Byte] - else if (sym == defn.CharClass) classOf[Char] - else if (sym == defn.ShortClass) classOf[Short] - else if (sym == defn.IntClass) classOf[Int] - else if (sym == defn.LongClass) classOf[Long] - else if (sym == defn.FloatClass) classOf[Float] - else if (sym == defn.DoubleClass) classOf[Double] - else java.lang.Class.forName(javaSig(param), false, classLoader) - } + def paramClass(param: Type): Class[_] = { + def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match { + case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1) + case _ => (tpe, depth) + } + def javaArraySig(tpe: Type): String = { + val (elemType, depth) = arrayDepth(tpe, 0) + val sym = elemType.classSymbol + val suffix = + if (sym == defn.BooleanClass) "Z" + else if (sym == defn.ByteClass) "B" + else if (sym == defn.ShortClass) "S" + else if (sym == defn.IntClass) "I" + else if (sym == defn.LongClass) "J" + else if (sym == defn.FloatClass) "F" + else if (sym == defn.DoubleClass) "D" + else if (sym == defn.CharClass) "C" + else "L" + javaSig(elemType) + ";" + ("[" * depth) + suffix + } + def javaSig(tpe: Type): String = tpe match { + case tpe: JavaArrayType => javaArraySig(tpe) + case _ => + // Take the flatten name of the class and the full package name + val pack = tpe.classSymbol.topLevelClass.owner + val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "." + packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString + } + + val sym = param.classSymbol + if (sym == defn.BooleanClass) classOf[Boolean] + else if (sym == defn.ByteClass) classOf[Byte] + else if (sym == defn.CharClass) classOf[Char] + else if (sym == defn.ShortClass) classOf[Short] + else if (sym == defn.IntClass) classOf[Int] + else if (sym == defn.LongClass) classOf[Long] + else if (sym == defn.FloatClass) classOf[Float] + else if (sym == defn.DoubleClass) classOf[Double] + else java.lang.Class.forName(javaSig(param), false, classLoader) + } + val extraParams = sym.info.finalResultType.widenDealias match { + case tp: AppliedType if defn.isImplicitFunctionType(tp) => + // Call implicit function type direct method + tp.args.init.map(arg => TypeErasure.erasure(arg)) case _ => Nil } + val allParams = TypeErasure.erasure(sym.info) match { + case meth: MethodType => (meth.paramInfos ::: extraParams) + case _ => extraParams + } + allParams.map(paramClass) } /** Exception that stops interpretation if some issue is found */ @@ -253,7 +264,8 @@ object Splicer { protected def interpretLiteral(value: Any)(implicit env: Env): Boolean = true protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity) protected def interpretTastyContext()(implicit env: Env): Boolean = true - protected def interpretStaticMethodCall(fn: tpd.Tree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity) + protected def interpretQuoteContext()(implicit env: Env): Boolean = true + protected def interpretStaticMethodCall(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity) protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Boolean = true protected def interpretNew(fn: RefTree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity) @@ -275,7 +287,7 @@ object Splicer { protected def interpretLiteral(value: Any)(implicit env: Env): Result protected def interpretVarargs(args: List[Result])(implicit env: Env): Result protected def interpretTastyContext()(implicit env: Env): Result - protected def interpretStaticMethodCall(fn: Tree, args: => List[Result])(implicit env: Env): Result + protected def interpretStaticMethodCall(fn: Symbol, args: => List[Result])(implicit env: Env): Result protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Result protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Result protected def unexpectedTree(tree: Tree)(implicit env: Env): Result @@ -298,11 +310,16 @@ object Splicer { interpretNew(fn, args.map(interpretTree)) } else if (fn.symbol.isStatic) { if (fn.symbol.is(Module)) interpretModuleAccess(fn) - else interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg))) + else interpretStaticMethodCall(fn.symbol, args.map(arg => interpretTree(arg))) } else if (env.contains(fn.name)) { env(fn.name) } else { - unexpectedTree(tree) + fn match { + case fn @ Select(Call(fn0, args0), _) if fn0.symbol.isStatic && fn.symbol.info.isImplicitMethod => + // Call implicit function type direct method + interpretStaticMethodCall(fn0.symbol, (args0 ::: args).map(arg => interpretTree(arg))) + case _ => unexpectedTree(tree) + } } // Interpret `foo(j = x, i = y)` which it is expanded to diff --git a/compiler/test/dotc/run-test-pickling.blacklist b/compiler/test/dotc/run-test-pickling.blacklist index a325fa8a995d..9aa550216e51 100644 --- a/compiler/test/dotc/run-test-pickling.blacklist +++ b/compiler/test/dotc/run-test-pickling.blacklist @@ -71,6 +71,7 @@ tasty-extractors-constants-1 tasty-extractors-owners tasty-extractors-types tasty-getfile +tasty-getfile-implicit-fun-context tasty-indexed-map tasty-linenumber tasty-linenumber-2 diff --git a/tests/run/tasty-getfile-implicit-fun-context.check b/tests/run/tasty-getfile-implicit-fun-context.check new file mode 100644 index 000000000000..f414a7f6b11c --- /dev/null +++ b/tests/run/tasty-getfile-implicit-fun-context.check @@ -0,0 +1 @@ +App_2.scala diff --git a/tests/run/tasty-getfile-implicit-fun-context/App_2.scala b/tests/run/tasty-getfile-implicit-fun-context/App_2.scala new file mode 100644 index 000000000000..ca5531badfc7 --- /dev/null +++ b/tests/run/tasty-getfile-implicit-fun-context/App_2.scala @@ -0,0 +1,6 @@ + +object Test { + def main(args: Array[String]): Unit = { + println(SourceFiles.getThisFile) + } +} diff --git a/tests/run/tasty-getfile-implicit-fun-context/Macro_1.scala b/tests/run/tasty-getfile-implicit-fun-context/Macro_1.scala new file mode 100644 index 000000000000..c588d45985e2 --- /dev/null +++ b/tests/run/tasty-getfile-implicit-fun-context/Macro_1.scala @@ -0,0 +1,19 @@ +import scala.quoted._ +import scala.tasty.Tasty + +object SourceFiles { + + type Macro[X] = implicit Tasty => Expr[X] + def tastyContext(implicit ctx: Tasty): Tasty = ctx + + implicit inline def getThisFile: String = + ~getThisFileImpl + + def getThisFileImpl: Macro[String] = { + val tasty = tastyContext + import tasty._ + rootContext.source.getFileName.toString.toExpr + } + + +} From 659e7ee608fe9f0bd2ae8ff201435ca2d950bc4a Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Sat, 20 Oct 2018 09:31:24 +0200 Subject: [PATCH 2/2] Support nested implicit function types --- .../dotty/tools/dotc/transform/Splicer.scala | 50 ++++++++++--------- .../test/dotc/run-test-pickling.blacklist | 1 + tests/run/tasty-implicit-fun-context-2.check | 1 + .../tasty-implicit-fun-context-2/App_2.scala | 6 +++ .../Macro_1.scala | 16 ++++++ 5 files changed, 50 insertions(+), 24 deletions(-) create mode 100644 tests/run/tasty-implicit-fun-context-2.check create mode 100644 tests/run/tasty-implicit-fun-context-2/App_2.scala create mode 100644 tests/run/tasty-implicit-fun-context-2/Macro_1.scala diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index 122f21e1019b..f3c9d5aa19b9 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -10,7 +10,8 @@ import dotty.tools.dotc.core.Contexts._ import dotty.tools.dotc.core.Decorators._ import dotty.tools.dotc.core.Flags._ import dotty.tools.dotc.core.NameKinds.FlatName -import dotty.tools.dotc.core.Names.Name +import dotty.tools.dotc.core.Names.{Name, TermName} +import dotty.tools.dotc.core.StdNames.nme import dotty.tools.dotc.core.StdNames.str.MODULE_INSTANCE_FIELD import dotty.tools.dotc.core.quoted._ import dotty.tools.dotc.core.Types._ @@ -113,19 +114,22 @@ object Splicer { protected def interpretStaticMethodCall(fn: Symbol, args: => List[Object])(implicit env: Env): Object = { val instance = loadModule(fn.owner) - val name = - if (!defn.isImplicitFunctionType(fn.info.finalResultType)) fn.name - else NameKinds.DirectMethodName(fn.name.asTermName) // Call implicit function type direct method + def getDirectName(tp: Type, name: TermName): TermName = tp.widenDealias match { + case tp: AppliedType if defn.isImplicitFunctionType(tp) => + getDirectName(tp.args.last, NameKinds.DirectMethodName(name)) + case _ => name + } + val name = getDirectName(fn.info.finalResultType, fn.name.asTermName) val method = getMethod(instance.getClass, name, paramsSig(fn)) stopIfRuntimeException(method.invoke(instance, args: _*)) } - protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Object = - loadModule(fn.symbol.moduleClass) + protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object = + loadModule(fn.moduleClass) - protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Object = { - val clazz = loadClass(fn.symbol.owner.fullName) - val constr = clazz.getConstructor(paramsSig(fn.symbol): _*) + protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Object = { + val clazz = loadClass(fn.owner.fullName) + val constr = clazz.getConstructor(paramsSig(fn): _*) constr.newInstance(args: _*).asInstanceOf[Object] } @@ -234,14 +238,15 @@ object Splicer { else if (sym == defn.DoubleClass) classOf[Double] else java.lang.Class.forName(javaSig(param), false, classLoader) } - val extraParams = sym.info.finalResultType.widenDealias match { + def getExtraParams(tp: Type): List[Type] = tp.widenDealias match { case tp: AppliedType if defn.isImplicitFunctionType(tp) => // Call implicit function type direct method - tp.args.init.map(arg => TypeErasure.erasure(arg)) + tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last) case _ => Nil } + val extraParams = getExtraParams(sym.info.finalResultType) val allParams = TypeErasure.erasure(sym.info) match { - case meth: MethodType => (meth.paramInfos ::: extraParams) + case meth: MethodType => meth.paramInfos ::: extraParams case _ => extraParams } allParams.map(paramClass) @@ -266,8 +271,8 @@ object Splicer { protected def interpretTastyContext()(implicit env: Env): Boolean = true protected def interpretQuoteContext()(implicit env: Env): Boolean = true protected def interpretStaticMethodCall(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity) - protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Boolean = true - protected def interpretNew(fn: RefTree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity) + protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Boolean = true + protected def interpretNew(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity) def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Boolean = { // Assuming that top-level splices can only be in inline methods @@ -288,8 +293,8 @@ object Splicer { protected def interpretVarargs(args: List[Result])(implicit env: Env): Result protected def interpretTastyContext()(implicit env: Env): Result protected def interpretStaticMethodCall(fn: Symbol, args: => List[Result])(implicit env: Env): Result - protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Result - protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Result + protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result + protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result protected def unexpectedTree(tree: Tree)(implicit env: Env): Result protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match { @@ -307,19 +312,14 @@ object Splicer { case Call(fn, args) => if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) { - interpretNew(fn, args.map(interpretTree)) + interpretNew(fn.symbol, args.map(interpretTree)) } else if (fn.symbol.isStatic) { - if (fn.symbol.is(Module)) interpretModuleAccess(fn) + if (fn.symbol.is(Module)) interpretModuleAccess(fn.symbol) else interpretStaticMethodCall(fn.symbol, args.map(arg => interpretTree(arg))) } else if (env.contains(fn.name)) { env(fn.name) } else { - fn match { - case fn @ Select(Call(fn0, args0), _) if fn0.symbol.isStatic && fn.symbol.info.isImplicitMethod => - // Call implicit function type direct method - interpretStaticMethodCall(fn0.symbol, (args0 ::: args).map(arg => interpretTree(arg))) - case _ => unexpectedTree(tree) - } + unexpectedTree(tree) } // Interpret `foo(j = x, i = y)` which it is expanded to @@ -347,6 +347,8 @@ object Splicer { object Call { def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match { + case Select(Call(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) => + Some((fn, args)) case fn: RefTree => Some((fn, Nil)) case Apply(Call(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance case TypeApply(Call(fn, args), _) => Some((fn, args)) diff --git a/compiler/test/dotc/run-test-pickling.blacklist b/compiler/test/dotc/run-test-pickling.blacklist index 9aa550216e51..7de50c8841ab 100644 --- a/compiler/test/dotc/run-test-pickling.blacklist +++ b/compiler/test/dotc/run-test-pickling.blacklist @@ -73,6 +73,7 @@ tasty-extractors-types tasty-getfile tasty-getfile-implicit-fun-context tasty-indexed-map +tasty-implicit-fun-context-2 tasty-linenumber tasty-linenumber-2 tasty-location diff --git a/tests/run/tasty-implicit-fun-context-2.check b/tests/run/tasty-implicit-fun-context-2.check new file mode 100644 index 000000000000..8baef1b4abc4 --- /dev/null +++ b/tests/run/tasty-implicit-fun-context-2.check @@ -0,0 +1 @@ +abc diff --git a/tests/run/tasty-implicit-fun-context-2/App_2.scala b/tests/run/tasty-implicit-fun-context-2/App_2.scala new file mode 100644 index 000000000000..3857b59ae684 --- /dev/null +++ b/tests/run/tasty-implicit-fun-context-2/App_2.scala @@ -0,0 +1,6 @@ + +object Test { + def main(args: Array[String]): Unit = { + println(Foo.foo) + } +} diff --git a/tests/run/tasty-implicit-fun-context-2/Macro_1.scala b/tests/run/tasty-implicit-fun-context-2/Macro_1.scala new file mode 100644 index 000000000000..4cafc46592bb --- /dev/null +++ b/tests/run/tasty-implicit-fun-context-2/Macro_1.scala @@ -0,0 +1,16 @@ +import scala.quoted._ +import scala.tasty.Tasty + +object Foo { + + type Macro[X] = implicit Tasty => Expr[X] + type Tastier[X] = implicit Tasty => X + + implicit inline def foo: String = + ~fooImpl + + def fooImpl(implicit tasty: Tasty): implicit Tasty => Tastier[implicit Tasty => Macro[String]] = { + '("abc") + } + +}