Skip to content

Commit e5a789b

Browse files
Merge pull request #9811 from dotty-staging/fix-#9802
Fix #9802: Interpret by-name params as `() => interpret(arg)`
2 parents 6766ebd + 31e9d1f commit e5a789b

File tree

5 files changed

+54
-2
lines changed

5 files changed

+54
-2
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ object Splicer {
250250
interpretModuleAccess(fn.symbol)
251251
else if (fn.symbol.is(Method) && fn.symbol.isStatic) {
252252
val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol)
253-
staticMethodCall(args.flatten.map(interpretTree))
253+
staticMethodCall(interpretArgs(args, fn.symbol.info))
254254
}
255255
else if fn.symbol.isStatic then
256256
assert(args.isEmpty)
@@ -260,7 +260,7 @@ object Splicer {
260260
interpretModuleAccess(fn.qualifier.symbol)
261261
else {
262262
val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol)
263-
staticMethodCall(args.flatten.map(interpretTree))
263+
staticMethodCall(interpretArgs(args, fn.symbol.info))
264264
}
265265
else if (env.contains(fn.symbol))
266266
env(fn.symbol)
@@ -289,6 +289,32 @@ object Splicer {
289289
unexpectedTree(tree)
290290
}
291291

292+
private def interpretArgs(argss: List[List[Tree]], fnType: Type)(using Env): List[Object] = {
293+
def interpretArgsGroup(args: List[Tree], argTypes: List[Type]): List[Object] =
294+
assert(args.size == argTypes.size)
295+
val view =
296+
for (arg, info) <- args.lazyZip(argTypes) yield
297+
info match
298+
case _: ExprType => () => interpretTree(arg) // by-name argument
299+
case _ => interpretTree(arg) // by-value argument
300+
view.toList
301+
302+
fnType.dealias match
303+
case fnType: MethodType if fnType.isErasedMethod => interpretArgs(argss, fnType.resType)
304+
case fnType: MethodType =>
305+
val argTypes = fnType.paramInfos
306+
assert(argss.head.size == argTypes.size)
307+
interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, fnType.resType)
308+
case fnType: AppliedType if defn.isContextFunctionType(fnType) =>
309+
val argTypes :+ resType = fnType.args
310+
interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, resType)
311+
case fnType: PolyType => interpretArgs(argss, fnType.resType)
312+
case fnType: ExprType => interpretArgs(argss, fnType.resType)
313+
case _ =>
314+
assert(argss.isEmpty)
315+
Nil
316+
}
317+
292318
private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = {
293319
var unexpected: Option[Object] = None
294320
val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match {

tests/pos-macros/i9802/Macro_1.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
import scala.quoted._
2+
3+
inline def fun(inline prog: Double): Double = ${impl('prog)}
4+
5+
def impl(prog: => Expr[Double])(using QuoteContext) : Expr[Double] = '{ 42.0 }

tests/pos-macros/i9802/Test_2.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2+
def test: Unit = fun(4)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
object Test {
3+
def main(args: Array[String]): Unit = {
4+
println(SourceFiles.getThisFile)
5+
}
6+
}
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.quoted._
2+
3+
object SourceFiles {
4+
5+
type Macro[X] = (=> QuoteContext) ?=> Expr[X]
6+
7+
implicit inline def getThisFile: String =
8+
${getThisFileImpl}
9+
10+
def getThisFileImpl: Macro[String] =
11+
Expr(qctx.tasty.Source.path.getFileName.toString)
12+
13+
}

0 commit comments

Comments
 (0)