Skip to content

Allow macros to call method returning implicit functions #5299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 81 additions & 62 deletions compiler/src/dotty/tools/dotc/transform/Splicer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import dotty.tools.dotc.core.Contexts._
import dotty.tools.dotc.core.Decorators._
import dotty.tools.dotc.core.Flags._
import dotty.tools.dotc.core.NameKinds.FlatName
import dotty.tools.dotc.core.Names.Name
import dotty.tools.dotc.core.Names.{Name, TermName}
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.StdNames.str.MODULE_INSTANCE_FIELD
import dotty.tools.dotc.core.quoted._
import dotty.tools.dotc.core.Types._
import dotty.tools.dotc.core.Symbols._
import dotty.tools.dotc.core.TypeErasure
import dotty.tools.dotc.core.{NameKinds, TypeErasure}
import dotty.tools.dotc.core.Constants.Constant
import dotty.tools.dotc.tastyreflect.TastyImpl

Expand Down Expand Up @@ -105,23 +106,30 @@ object Splicer {
protected def interpretVarargs(args: List[Object])(implicit env: Env): Object =
args.toSeq

protected def interpretTastyContext()(implicit env: Env): Object =
protected def interpretTastyContext()(implicit env: Env): Object = {
new TastyImpl(ctx) {
override def rootPosition: SourcePosition = pos
}
}

protected def interpretStaticMethodCall(fn: Tree, args: => List[Object])(implicit env: Env): Object = {
val instance = loadModule(fn.symbol.owner)
val method = getMethod(instance.getClass, fn.symbol.name, paramsSig(fn.symbol))
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
val instance = loadModule(fn.owner)
def getDirectName(tp: Type, name: TermName): TermName = tp.widenDealias match {
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
getDirectName(tp.args.last, NameKinds.DirectMethodName(name))
case _ => name
}
val name = getDirectName(fn.info.finalResultType, fn.name.asTermName)
val method = getMethod(instance.getClass, name, paramsSig(fn))
stopIfRuntimeException(method.invoke(instance, args: _*))
}

protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Object =
loadModule(fn.symbol.moduleClass)
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
loadModule(fn.moduleClass)

protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Object = {
val clazz = loadClass(fn.symbol.owner.fullName)
val constr = clazz.getConstructor(paramsSig(fn.symbol): _*)
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Object = {
val clazz = loadClass(fn.owner.fullName)
val constr = clazz.getConstructor(paramsSig(fn): _*)
constr.newInstance(args: _*).asInstanceOf[Object]
}

Expand Down Expand Up @@ -190,50 +198,58 @@ object Splicer {

/** List of classes of the parameters of the signature of `sym` */
private def paramsSig(sym: Symbol): List[Class[_]] = {
TypeErasure.erasure(sym.info) match {
case meth: MethodType =>
meth.paramInfos.map { param =>
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
case _ => (tpe, depth)
}
def javaArraySig(tpe: Type): String = {
val (elemType, depth) = arrayDepth(tpe, 0)
val sym = elemType.classSymbol
val suffix =
if (sym == defn.BooleanClass) "Z"
else if (sym == defn.ByteClass) "B"
else if (sym == defn.ShortClass) "S"
else if (sym == defn.IntClass) "I"
else if (sym == defn.LongClass) "J"
else if (sym == defn.FloatClass) "F"
else if (sym == defn.DoubleClass) "D"
else if (sym == defn.CharClass) "C"
else "L" + javaSig(elemType) + ";"
("[" * depth) + suffix
}
def javaSig(tpe: Type): String = tpe match {
case tpe: JavaArrayType => javaArraySig(tpe)
case _ =>
// Take the flatten name of the class and the full package name
val pack = tpe.classSymbol.topLevelClass.owner
val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "."
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
}

val sym = param.classSymbol
if (sym == defn.BooleanClass) classOf[Boolean]
else if (sym == defn.ByteClass) classOf[Byte]
else if (sym == defn.CharClass) classOf[Char]
else if (sym == defn.ShortClass) classOf[Short]
else if (sym == defn.IntClass) classOf[Int]
else if (sym == defn.LongClass) classOf[Long]
else if (sym == defn.FloatClass) classOf[Float]
else if (sym == defn.DoubleClass) classOf[Double]
else java.lang.Class.forName(javaSig(param), false, classLoader)
}
def paramClass(param: Type): Class[_] = {
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
case _ => (tpe, depth)
}
def javaArraySig(tpe: Type): String = {
val (elemType, depth) = arrayDepth(tpe, 0)
val sym = elemType.classSymbol
val suffix =
if (sym == defn.BooleanClass) "Z"
else if (sym == defn.ByteClass) "B"
else if (sym == defn.ShortClass) "S"
else if (sym == defn.IntClass) "I"
else if (sym == defn.LongClass) "J"
else if (sym == defn.FloatClass) "F"
else if (sym == defn.DoubleClass) "D"
else if (sym == defn.CharClass) "C"
else "L" + javaSig(elemType) + ";"
("[" * depth) + suffix
}
def javaSig(tpe: Type): String = tpe match {
case tpe: JavaArrayType => javaArraySig(tpe)
case _ =>
// Take the flatten name of the class and the full package name
val pack = tpe.classSymbol.topLevelClass.owner
val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "."
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
}

val sym = param.classSymbol
if (sym == defn.BooleanClass) classOf[Boolean]
else if (sym == defn.ByteClass) classOf[Byte]
else if (sym == defn.CharClass) classOf[Char]
else if (sym == defn.ShortClass) classOf[Short]
else if (sym == defn.IntClass) classOf[Int]
else if (sym == defn.LongClass) classOf[Long]
else if (sym == defn.FloatClass) classOf[Float]
else if (sym == defn.DoubleClass) classOf[Double]
else java.lang.Class.forName(javaSig(param), false, classLoader)
}
def getExtraParams(tp: Type): List[Type] = tp.widenDealias match {
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
// Call implicit function type direct method
tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last)
case _ => Nil
}
val extraParams = getExtraParams(sym.info.finalResultType)
val allParams = TypeErasure.erasure(sym.info) match {
case meth: MethodType => meth.paramInfos ::: extraParams
case _ => extraParams
}
allParams.map(paramClass)
}

/** Exception that stops interpretation if some issue is found */
Expand All @@ -253,9 +269,10 @@ object Splicer {
protected def interpretLiteral(value: Any)(implicit env: Env): Boolean = true
protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretTastyContext()(implicit env: Env): Boolean = true
protected def interpretStaticMethodCall(fn: tpd.Tree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Boolean = true
protected def interpretNew(fn: RefTree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretQuoteContext()(implicit env: Env): Boolean = true
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Boolean = true
protected def interpretNew(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)

def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Boolean = {
// Assuming that top-level splices can only be in inline methods
Expand All @@ -275,9 +292,9 @@ object Splicer {
protected def interpretLiteral(value: Any)(implicit env: Env): Result
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
protected def interpretTastyContext()(implicit env: Env): Result
protected def interpretStaticMethodCall(fn: Tree, args: => List[Result])(implicit env: Env): Result
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Result
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Result
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Result])(implicit env: Env): Result
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result

protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match {
Expand All @@ -295,10 +312,10 @@ object Splicer {

case Call(fn, args) =>
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
interpretNew(fn, args.map(interpretTree))
interpretNew(fn.symbol, args.map(interpretTree))
} else if (fn.symbol.isStatic) {
if (fn.symbol.is(Module)) interpretModuleAccess(fn)
else interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
if (fn.symbol.is(Module)) interpretModuleAccess(fn.symbol)
else interpretStaticMethodCall(fn.symbol, args.map(arg => interpretTree(arg)))
} else if (env.contains(fn.name)) {
env(fn.name)
} else {
Expand Down Expand Up @@ -330,6 +347,8 @@ object Splicer {

object Call {
def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match {
case Select(Call(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was missing this guard.

Some((fn, args))
case fn: RefTree => Some((fn, Nil))
case Apply(Call(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
case TypeApply(Call(fn, args), _) => Some((fn, args))
Expand Down
2 changes: 2 additions & 0 deletions compiler/test/dotc/run-test-pickling.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ tasty-extractors-constants-1
tasty-extractors-owners
tasty-extractors-types
tasty-getfile
tasty-getfile-implicit-fun-context
tasty-indexed-map
tasty-implicit-fun-context-2
tasty-linenumber
tasty-linenumber-2
tasty-location
Expand Down
1 change: 1 addition & 0 deletions tests/run/tasty-getfile-implicit-fun-context.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
App_2.scala
6 changes: 6 additions & 0 deletions tests/run/tasty-getfile-implicit-fun-context/App_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

object Test {
def main(args: Array[String]): Unit = {
println(SourceFiles.getThisFile)
}
}
19 changes: 19 additions & 0 deletions tests/run/tasty-getfile-implicit-fun-context/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import scala.quoted._
import scala.tasty.Tasty

object SourceFiles {

type Macro[X] = implicit Tasty => Expr[X]
def tastyContext(implicit ctx: Tasty): Tasty = ctx

implicit inline def getThisFile: String =
~getThisFileImpl

def getThisFileImpl: Macro[String] = {
val tasty = tastyContext
import tasty._
rootContext.source.getFileName.toString.toExpr
}


}
1 change: 1 addition & 0 deletions tests/run/tasty-implicit-fun-context-2.check
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
abc
6 changes: 6 additions & 0 deletions tests/run/tasty-implicit-fun-context-2/App_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@

object Test {
def main(args: Array[String]): Unit = {
println(Foo.foo)
}
}
16 changes: 16 additions & 0 deletions tests/run/tasty-implicit-fun-context-2/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import scala.quoted._
import scala.tasty.Tasty

object Foo {

type Macro[X] = implicit Tasty => Expr[X]
type Tastier[X] = implicit Tasty => X

implicit inline def foo: String =
~fooImpl

def fooImpl(implicit tasty: Tasty): implicit Tasty => Tastier[implicit Tasty => Macro[String]] = {
'("abc")
}

}