Skip to content

Commit f7b993a

Browse files
Merge pull request #5299 from dotty-staging/fix-macro-with-implicit-fun-type
Allow macros to call method returning implicit functions
2 parents 5a76013 + 659e7ee commit f7b993a

File tree

8 files changed

+132
-62
lines changed

8 files changed

+132
-62
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 81 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,13 @@ import dotty.tools.dotc.core.Contexts._
1010
import dotty.tools.dotc.core.Decorators._
1111
import dotty.tools.dotc.core.Flags._
1212
import dotty.tools.dotc.core.NameKinds.FlatName
13-
import dotty.tools.dotc.core.Names.Name
13+
import dotty.tools.dotc.core.Names.{Name, TermName}
14+
import dotty.tools.dotc.core.StdNames.nme
1415
import dotty.tools.dotc.core.StdNames.str.MODULE_INSTANCE_FIELD
1516
import dotty.tools.dotc.core.quoted._
1617
import dotty.tools.dotc.core.Types._
1718
import dotty.tools.dotc.core.Symbols._
18-
import dotty.tools.dotc.core.TypeErasure
19+
import dotty.tools.dotc.core.{NameKinds, TypeErasure}
1920
import dotty.tools.dotc.core.Constants.Constant
2021
import dotty.tools.dotc.tastyreflect.TastyImpl
2122

@@ -105,23 +106,30 @@ object Splicer {
105106
protected def interpretVarargs(args: List[Object])(implicit env: Env): Object =
106107
args.toSeq
107108

108-
protected def interpretTastyContext()(implicit env: Env): Object =
109+
protected def interpretTastyContext()(implicit env: Env): Object = {
109110
new TastyImpl(ctx) {
110111
override def rootPosition: SourcePosition = pos
111112
}
113+
}
112114

113-
protected def interpretStaticMethodCall(fn: Tree, args: => List[Object])(implicit env: Env): Object = {
114-
val instance = loadModule(fn.symbol.owner)
115-
val method = getMethod(instance.getClass, fn.symbol.name, paramsSig(fn.symbol))
115+
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
116+
val instance = loadModule(fn.owner)
117+
def getDirectName(tp: Type, name: TermName): TermName = tp.widenDealias match {
118+
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
119+
getDirectName(tp.args.last, NameKinds.DirectMethodName(name))
120+
case _ => name
121+
}
122+
val name = getDirectName(fn.info.finalResultType, fn.name.asTermName)
123+
val method = getMethod(instance.getClass, name, paramsSig(fn))
116124
stopIfRuntimeException(method.invoke(instance, args: _*))
117125
}
118126

119-
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Object =
120-
loadModule(fn.symbol.moduleClass)
127+
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object =
128+
loadModule(fn.moduleClass)
121129

122-
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Object = {
123-
val clazz = loadClass(fn.symbol.owner.fullName)
124-
val constr = clazz.getConstructor(paramsSig(fn.symbol): _*)
130+
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Object = {
131+
val clazz = loadClass(fn.owner.fullName)
132+
val constr = clazz.getConstructor(paramsSig(fn): _*)
125133
constr.newInstance(args: _*).asInstanceOf[Object]
126134
}
127135

@@ -190,50 +198,58 @@ object Splicer {
190198

191199
/** List of classes of the parameters of the signature of `sym` */
192200
private def paramsSig(sym: Symbol): List[Class[_]] = {
193-
TypeErasure.erasure(sym.info) match {
194-
case meth: MethodType =>
195-
meth.paramInfos.map { param =>
196-
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
197-
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
198-
case _ => (tpe, depth)
199-
}
200-
def javaArraySig(tpe: Type): String = {
201-
val (elemType, depth) = arrayDepth(tpe, 0)
202-
val sym = elemType.classSymbol
203-
val suffix =
204-
if (sym == defn.BooleanClass) "Z"
205-
else if (sym == defn.ByteClass) "B"
206-
else if (sym == defn.ShortClass) "S"
207-
else if (sym == defn.IntClass) "I"
208-
else if (sym == defn.LongClass) "J"
209-
else if (sym == defn.FloatClass) "F"
210-
else if (sym == defn.DoubleClass) "D"
211-
else if (sym == defn.CharClass) "C"
212-
else "L" + javaSig(elemType) + ";"
213-
("[" * depth) + suffix
214-
}
215-
def javaSig(tpe: Type): String = tpe match {
216-
case tpe: JavaArrayType => javaArraySig(tpe)
217-
case _ =>
218-
// Take the flatten name of the class and the full package name
219-
val pack = tpe.classSymbol.topLevelClass.owner
220-
val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "."
221-
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
222-
}
223-
224-
val sym = param.classSymbol
225-
if (sym == defn.BooleanClass) classOf[Boolean]
226-
else if (sym == defn.ByteClass) classOf[Byte]
227-
else if (sym == defn.CharClass) classOf[Char]
228-
else if (sym == defn.ShortClass) classOf[Short]
229-
else if (sym == defn.IntClass) classOf[Int]
230-
else if (sym == defn.LongClass) classOf[Long]
231-
else if (sym == defn.FloatClass) classOf[Float]
232-
else if (sym == defn.DoubleClass) classOf[Double]
233-
else java.lang.Class.forName(javaSig(param), false, classLoader)
234-
}
201+
def paramClass(param: Type): Class[_] = {
202+
def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match {
203+
case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1)
204+
case _ => (tpe, depth)
205+
}
206+
def javaArraySig(tpe: Type): String = {
207+
val (elemType, depth) = arrayDepth(tpe, 0)
208+
val sym = elemType.classSymbol
209+
val suffix =
210+
if (sym == defn.BooleanClass) "Z"
211+
else if (sym == defn.ByteClass) "B"
212+
else if (sym == defn.ShortClass) "S"
213+
else if (sym == defn.IntClass) "I"
214+
else if (sym == defn.LongClass) "J"
215+
else if (sym == defn.FloatClass) "F"
216+
else if (sym == defn.DoubleClass) "D"
217+
else if (sym == defn.CharClass) "C"
218+
else "L" + javaSig(elemType) + ";"
219+
("[" * depth) + suffix
220+
}
221+
def javaSig(tpe: Type): String = tpe match {
222+
case tpe: JavaArrayType => javaArraySig(tpe)
223+
case _ =>
224+
// Take the flatten name of the class and the full package name
225+
val pack = tpe.classSymbol.topLevelClass.owner
226+
val packageName = if (pack == defn.EmptyPackageClass) "" else pack.fullName + "."
227+
packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString
228+
}
229+
230+
val sym = param.classSymbol
231+
if (sym == defn.BooleanClass) classOf[Boolean]
232+
else if (sym == defn.ByteClass) classOf[Byte]
233+
else if (sym == defn.CharClass) classOf[Char]
234+
else if (sym == defn.ShortClass) classOf[Short]
235+
else if (sym == defn.IntClass) classOf[Int]
236+
else if (sym == defn.LongClass) classOf[Long]
237+
else if (sym == defn.FloatClass) classOf[Float]
238+
else if (sym == defn.DoubleClass) classOf[Double]
239+
else java.lang.Class.forName(javaSig(param), false, classLoader)
240+
}
241+
def getExtraParams(tp: Type): List[Type] = tp.widenDealias match {
242+
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
243+
// Call implicit function type direct method
244+
tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last)
235245
case _ => Nil
236246
}
247+
val extraParams = getExtraParams(sym.info.finalResultType)
248+
val allParams = TypeErasure.erasure(sym.info) match {
249+
case meth: MethodType => meth.paramInfos ::: extraParams
250+
case _ => extraParams
251+
}
252+
allParams.map(paramClass)
237253
}
238254

239255
/** Exception that stops interpretation if some issue is found */
@@ -253,9 +269,10 @@ object Splicer {
253269
protected def interpretLiteral(value: Any)(implicit env: Env): Boolean = true
254270
protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
255271
protected def interpretTastyContext()(implicit env: Env): Boolean = true
256-
protected def interpretStaticMethodCall(fn: tpd.Tree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
257-
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Boolean = true
258-
protected def interpretNew(fn: RefTree, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
272+
protected def interpretQuoteContext()(implicit env: Env): Boolean = true
273+
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
274+
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Boolean = true
275+
protected def interpretNew(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
259276

260277
def unexpectedTree(tree: tpd.Tree)(implicit env: Env): Boolean = {
261278
// Assuming that top-level splices can only be in inline methods
@@ -275,9 +292,9 @@ object Splicer {
275292
protected def interpretLiteral(value: Any)(implicit env: Env): Result
276293
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
277294
protected def interpretTastyContext()(implicit env: Env): Result
278-
protected def interpretStaticMethodCall(fn: Tree, args: => List[Result])(implicit env: Env): Result
279-
protected def interpretModuleAccess(fn: Tree)(implicit env: Env): Result
280-
protected def interpretNew(fn: RefTree, args: => List[Result])(implicit env: Env): Result
295+
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Result])(implicit env: Env): Result
296+
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result
297+
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
281298
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
282299

283300
protected final def interpretTree(tree: Tree)(implicit env: Env): Result = tree match {
@@ -295,10 +312,10 @@ object Splicer {
295312

296313
case Call(fn, args) =>
297314
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
298-
interpretNew(fn, args.map(interpretTree))
315+
interpretNew(fn.symbol, args.map(interpretTree))
299316
} else if (fn.symbol.isStatic) {
300-
if (fn.symbol.is(Module)) interpretModuleAccess(fn)
301-
else interpretStaticMethodCall(fn, args.map(arg => interpretTree(arg)))
317+
if (fn.symbol.is(Module)) interpretModuleAccess(fn.symbol)
318+
else interpretStaticMethodCall(fn.symbol, args.map(arg => interpretTree(arg)))
302319
} else if (env.contains(fn.name)) {
303320
env(fn.name)
304321
} else {
@@ -330,6 +347,8 @@ object Splicer {
330347

331348
object Call {
332349
def unapply(arg: Tree): Option[(RefTree, List[Tree])] = arg match {
350+
case Select(Call(fn, args), nme.apply) if defn.isImplicitFunctionType(fn.tpe.widenDealias.finalResultType) =>
351+
Some((fn, args))
333352
case fn: RefTree => Some((fn, Nil))
334353
case Apply(Call(fn, args1), args2) => Some((fn, args1 ::: args2)) // TODO improve performance
335354
case TypeApply(Call(fn, args), _) => Some((fn, args))

compiler/test/dotc/run-test-pickling.blacklist

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,9 @@ tasty-extractors-constants-1
7272
tasty-extractors-owners
7373
tasty-extractors-types
7474
tasty-getfile
75+
tasty-getfile-implicit-fun-context
7576
tasty-indexed-map
77+
tasty-implicit-fun-context-2
7678
tasty-linenumber
7779
tasty-linenumber-2
7880
tasty-location
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
App_2.scala
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
println(SourceFiles.getThisFile)
5+
}
6+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import scala.quoted._
2+
import scala.tasty.Tasty
3+
4+
object SourceFiles {
5+
6+
type Macro[X] = implicit Tasty => Expr[X]
7+
def tastyContext(implicit ctx: Tasty): Tasty = ctx
8+
9+
implicit inline def getThisFile: String =
10+
~getThisFileImpl
11+
12+
def getThisFileImpl: Macro[String] = {
13+
val tasty = tastyContext
14+
import tasty._
15+
rootContext.source.getFileName.toString.toExpr
16+
}
17+
18+
19+
}
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
abc
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
println(Foo.foo)
5+
}
6+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import scala.quoted._
2+
import scala.tasty.Tasty
3+
4+
object Foo {
5+
6+
type Macro[X] = implicit Tasty => Expr[X]
7+
type Tastier[X] = implicit Tasty => X
8+
9+
implicit inline def foo: String =
10+
~fooImpl
11+
12+
def fooImpl(implicit tasty: Tasty): implicit Tasty => Tastier[implicit Tasty => Macro[String]] = {
13+
'("abc")
14+
}
15+
16+
}

0 commit comments

Comments
 (0)