Skip to content

Commit 434ac1d

Browse files
committed
Allow macros to call method returning implicit functions
1 parent 2f64743 commit 434ac1d

File tree

5 files changed

+95
-51
lines changed

5 files changed

+95
-51
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 68 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import dotty.tools.dotc.core.StdNames.str.MODULE_INSTANCE_FIELD
1515
import dotty.tools.dotc.core.quoted._
1616
import dotty.tools.dotc.core.Types._
1717
import dotty.tools.dotc.core.Symbols._
18-
import dotty.tools.dotc.core.TypeErasure
18+
import dotty.tools.dotc.core.{NameKinds, TypeErasure}
1919
import dotty.tools.dotc.core.Constants.Constant
2020
import dotty.tools.dotc.tastyreflect.TastyImpl
2121

@@ -105,14 +105,18 @@ object Splicer {
105105
protected def interpretVarargs(args: List[Object])(implicit env: Env): Object =
106106
args.toSeq
107107

108-
protected def interpretTastyContext()(implicit env: Env): Object =
108+
protected def interpretTastyContext()(implicit env: Env): Object = {
109109
new TastyImpl(ctx) {
110110
override def rootPosition: SourcePosition = pos
111111
}
112+
}
112113

113-
protected def interpretStaticMethodCall(fn: Tree, args: => List[Object])(implicit env: Env): Object = {
114-
val instance = loadModule(fn.symbol.owner)
115-
val method = getMethod(instance.getClass, fn.symbol.name, paramsSig(fn.symbol))
114+
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
115+
val instance = loadModule(fn.owner)
116+
val name =
117+
if (!defn.isImplicitFunctionType(fn.info.finalResultType)) fn.name
118+
else NameKinds.DirectMethodName(fn.name.asTermName) // Call implicit function type direct method
119+
val method = getMethod(instance.getClass, name, paramsSig(fn))
116120
stopIfRuntimeException(method.invoke(instance, args: _*))
117121
}
118122

@@ -190,50 +194,57 @@ object Splicer {
190194

191195
/** List of classes of the parameters of the signature of `sym` */
192196
private def paramsSig(sym: Symbol): List[Class[_]] = {
193-
TypeErasure.erasure(sym.info) match {
194-
case meth: MethodType =>
195-
meth.paramInfos.map { param =>
196-
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
197-
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
198-
case _ => (tpe, depth)
199-
}
200-
def javaArraySig(tpe: Type): String = {
201-
val (elemType, depth) = arrayDepth(tpe, 0)
202-
val sym = elemType.classSymbol
203-
val suffix =
204-
if (sym == defn.BooleanClass) "Z"
205-
else if (sym == defn.ByteClass) "B"
206-
else if (sym == defn.ShortClass) "S"
207-
else if (sym == defn.IntClass) "I"
208-
else if (sym == defn.LongClass) "J"
209-
else if (sym == defn.FloatClass) "F"
210-
else if (sym == defn.DoubleClass) "D"
211-
else if (sym == defn.CharClass) "C"
212-
else "L" + javaSig(elemType) + ";"
213-
("[" * depth) + suffix
214-
}
215-
def javaSig(tpe: Type): String = tpe match {
216-
case tpe: JavaArrayType => javaArraySig(tpe)
217-
case _ =>
218-
// Take the flatten name of the class and the full package name
219-
val pack = tpe.classSymbol.topLevelClass.owner
220-
val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "."
221-
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
222-
}
223-
224-
val sym = param.classSymbol
225-
if (sym == defn.BooleanClass) classOf[Boolean]
226-
else if (sym == defn.ByteClass) classOf[Byte]
227-
else if (sym == defn.CharClass) classOf[Char]
228-
else if (sym == defn.ShortClass) classOf[Short]
229-
else if (sym == defn.IntClass) classOf[Int]
230-
else if (sym == defn.LongClass) classOf[Long]
231-
else if (sym == defn.FloatClass) classOf[Float]
232-
else if (sym == defn.DoubleClass) classOf[Double]
233-
else java.lang.Class.forName(javaSig(param), false, classLoader)
234-
}
197+
def paramClass(param: Type): Class[_] = {
198+
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
199+
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
200+
case _ => (tpe, depth)
201+
}
202+
def javaArraySig(tpe: Type): String = {
203+
val (elemType, depth) = arrayDepth(tpe, 0)
204+
val sym = elemType.classSymbol
205+
val suffix =
206+
if (sym == defn.BooleanClass) "Z"
207+
else if (sym == defn.ByteClass) "B"
208+
else if (sym == defn.ShortClass) "S"
209+
else if (sym == defn.IntClass) "I"
210+
else if (sym == defn.LongClass) "J"
211+
else if (sym == defn.FloatClass) "F"
212+
else if (sym == defn.DoubleClass) "D"
213+
else if (sym == defn.CharClass) "C"
214+
else "L" + javaSig(elemType) + ";"
215+
("[" * depth) + suffix
216+
}
217+
def javaSig(tpe: Type): String = tpe match {
218+
case tpe: JavaArrayType => javaArraySig(tpe)
219+
case _ =>
220+
// Take the flatten name of the class and the full package name
221+
val pack = tpe.classSymbol.topLevelClass.owner
222+
val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "."
223+
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
224+
}
225+
226+
val sym = param.classSymbol
227+
if (sym == defn.BooleanClass) classOf[Boolean]
228+
else if (sym == defn.ByteClass) classOf[Byte]
229+
else if (sym == defn.CharClass) classOf[Char]
230+
else if (sym == defn.ShortClass) classOf[Short]
231+
else if (sym == defn.IntClass) classOf[Int]
232+
else if (sym == defn.LongClass) classOf[Long]
233+
else if (sym == defn.FloatClass) classOf[Float]
234+
else if (sym == defn.DoubleClass) classOf[Double]
235+
else java.lang.Class.forName(javaSig(param), false, classLoader)
236+
}
237+
val extraParams = sym.info.finalResultType.widenDealias match {
238+
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
239+
// Call implicit function type direct method
240+
tp.args.init.map(arg => TypeErasure.erasure(arg))
235241
case _ => Nil
236242
}
243+
val allParams = TypeErasure.erasure(sym.info) match {
244+
case meth: MethodType => (meth.paramInfos ::: extraParams)
245+
case _ => extraParams
246+
}
247+
allParams.map(paramClass)
237248
}
238249

239250
/** Exception that stops interpretation if some issue is found */
@@ -253,7 +264,8 @@ object Splicer {
253264
protected def interpretLiteral(value: Any)(implicit env: Env): Boolean = true
254265
protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
255266
protected def interpretTastyContext()(implicit env: Env): Boolean = true
256-
protected def interpretStaticMethodCall(fn: tpd.Tree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
267+
protected def interpretQuoteContext()(implicit env: Env): Boolean = true
268+
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
257269
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Boolean = true
258270
protected def interpretNew(fn: RefTree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
259271

@@ -275,7 +287,7 @@ object Splicer {
275287
protected def interpretLiteral(value: Any)(implicit env: Env): Result
276288
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
277289
protected def interpretTastyContext()(implicit env: Env): Result
278-
protected def interpretStaticMethodCall(fn: Tree, args: => List[Result])(implicit env: Env): Result
290+
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Result])(implicit env: Env): Result
279291
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Result
280292
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Result
281293
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
@@ -298,11 +310,16 @@ object Splicer {
298310
interpretNew(fn, args.map(interpretTree))
299311
} else if (fn.symbol.isStatic) {
300312
if (fn.symbol.is(Module)) interpretModuleAccess(fn)
301-
else interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
313+
else interpretStaticMethodCall(fn.symbol, args.map(arg => interpretTree(arg)))
302314
} else if (env.contains(fn.name)) {
303315
env(fn.name)
304316
} else {
305-
unexpectedTree(tree)
317+
fn match {
318+
case fn @ Select(Call(fn0, args0), _) if fn0.symbol.isStatic && fn.symbol.info.isImplicitMethod =>
319+
// Call implicit function type direct method
320+
interpretStaticMethodCall(fn0.symbol, (args0 ::: args).map(arg => interpretTree(arg)))
321+
case _ => unexpectedTree(tree)
322+
}
306323
}
307324

308325
// Interpret `foo(j = x, i = y)` which it is expanded to

compiler/test/dotc/run-test-pickling.blacklist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ tasty-extractors-constants-1
7171
tasty-extractors-owners
7272
tasty-extractors-types
7373
tasty-getfile
74+
tasty-getfile-implicit-fun-context
7475
tasty-indexed-map
7576
tasty-linenumber
7677
tasty-linenumber-2
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
App_2.scala
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
println(SourceFiles.getThisFile)
5+
}
6+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import scala.quoted._
2+
import scala.tasty.Tasty
3+
4+
object SourceFiles {
5+
6+
type Macro[X] = implicit Tasty => Expr[X]
7+
def tastyContext(implicit ctx: Tasty): Tasty = ctx
8+
9+
implicit inline def getThisFile: String =
10+
~getThisFileImpl
11+
12+
def getThisFileImpl: Macro[String] = {
13+
val tasty = tastyContext
14+
import tasty._
15+
rootContext.source.getFileName.toString.toExpr
16+
}
17+
18+
19+
}

0 commit comments

Comments
 (0)