Skip to content

Commit 303bfcc

Browse files
Merge pull request #5575 from dotty-staging/add-macro-interpolation-core
Allow macros calling inherited methods on modules
2 parents f256a73 + 4d178b9 commit 303bfcc

File tree

4 files changed

+119
-6
lines changed

4 files changed

+119
-6
lines changed

compiler/src/dotty/tools/dotc/transform/Splicer.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,8 @@ object Splicer {
112112
}
113113
}
114114

115-
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
116-
val instance = loadModule(fn.owner)
115+
protected def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
116+
val instance = loadModule(moduleClass)
117117
def getDirectName(tp: Type, name: TermName): TermName = tp.widenDealias match {
118118
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
119119
getDirectName(tp.args.last, NameKinds.DirectMethodName(name))
@@ -270,7 +270,7 @@ object Splicer {
270270
protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
271271
protected def interpretTastyContext()(implicit env: Env): Boolean = true
272272
protected def interpretQuoteContext()(implicit env: Env): Boolean = true
273-
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
273+
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
274274
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Boolean = true
275275
protected def interpretNew(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
276276

@@ -292,7 +292,7 @@ object Splicer {
292292
protected def interpretLiteral(value: Any)(implicit env: Env): Result
293293
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
294294
protected def interpretTastyContext()(implicit env: Env): Result
295-
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Result])(implicit env: Env): Result
295+
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Result])(implicit env: Env): Result
296296
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result
297297
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
298298
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
@@ -313,9 +313,14 @@ object Splicer {
313313
case Call(fn, args) =>
314314
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
315315
interpretNew(fn.symbol, args.map(interpretTree))
316+
} else if (fn.symbol.is(Module)) {
317+
interpretModuleAccess(fn.symbol)
316318
} else if (fn.symbol.isStatic) {
317-
if (fn.symbol.is(Module)) interpretModuleAccess(fn.symbol)
318-
else interpretStaticMethodCall(fn.symbol, args.map(arg => interpretTree(arg)))
319+
val module = fn.symbol.owner
320+
interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg)))
321+
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
322+
val module = fn.qualifier.symbol.moduleClass
323+
interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg)))
319324
} else if (env.contains(fn.name)) {
320325
env(fn.name)
321326
} else {

tests/run/tasty-interpolation-1.check

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Hello world!
2+
Hello world!\n
3+
Hello foo!
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
2+
import scala.quoted._
3+
import scala.tasty.Reflection
4+
import scala.language.implicitConversions
5+
import scala.quoted.Exprs.LiftedExpr
6+
import scala.quoted.Toolbox.Default._
7+
8+
object Macro {
9+
10+
class StringContextOps(strCtx: => StringContext) {
11+
inline def s2(args: Any*): String = ~SIntepolator('(strCtx), '(args))
12+
inline def raw2(args: Any*): String = ~RawIntepolator('(strCtx), '(args))
13+
inline def foo(args: Any*): String = ~FooIntepolator('(strCtx), '(args))
14+
}
15+
implicit inline def SCOps(strCtx: => StringContext): StringContextOps = new StringContextOps(strCtx)
16+
}
17+
18+
object SIntepolator extends MacroStringInterpolator[String] {
19+
protected def interpolate(strCtx: StringContext, args: List[Expr[Any]])(implicit reflect: Reflection): Expr[String] =
20+
'((~strCtx.toExpr).s(~args.toExprOfList: _*))
21+
}
22+
23+
object RawIntepolator extends MacroStringInterpolator[String] {
24+
protected def interpolate(strCtx: StringContext, args: List[Expr[Any]])(implicit reflect: Reflection): Expr[String] =
25+
'((~strCtx.toExpr).raw(~args.toExprOfList: _*))
26+
}
27+
28+
object FooIntepolator extends MacroStringInterpolator[String] {
29+
protected def interpolate(strCtx: StringContext, args: List[Expr[Any]])(implicit reflect: Reflection): Expr[String] =
30+
'((~strCtx.toExpr).s(~args.map(_ => '("foo")).toExprOfList: _*))
31+
}
32+
33+
// TODO put this class in the stdlib or separate project?
34+
abstract class MacroStringInterpolator[T] {
35+
36+
final def apply(strCtxExpr: Expr[StringContext], argsExpr: Expr[Seq[Any]])(implicit reflect: Reflection): Expr[T] = {
37+
try interpolate(strCtxExpr, argsExpr)
38+
catch {
39+
case ex: NotStaticlyKnownError =>
40+
// TODO use ex.expr to recover the position
41+
throw new QuoteError(ex.getMessage)
42+
case ex: StringContextError =>
43+
// TODO use ex.idx to recover the position
44+
throw new QuoteError(ex.getMessage)
45+
case ex: ArgumentError =>
46+
// TODO use ex.idx to recover the position
47+
throw new QuoteError(ex.getMessage)
48+
}
49+
}
50+
51+
protected def interpolate(strCtxExpr: Expr[StringContext], argsExpr: Expr[Seq[Any]])(implicit reflect: Reflection): Expr[T] =
52+
interpolate(getStaticStringContext(strCtxExpr), getArgsList(argsExpr))
53+
54+
protected def interpolate(strCtx: StringContext, argExprs: List[Expr[Any]])(implicit reflect: Reflection): Expr[T]
55+
56+
protected def getStaticStringContext(strCtxExpr: Expr[StringContext])(implicit reflect: Reflection): StringContext = {
57+
import reflect._
58+
strCtxExpr.unseal.underlyingArgument match {
59+
case Term.Select(Term.Typed(Term.Apply(_, List(Term.Apply(_, List(Term.Typed(Term.Repeated(strCtxArgTrees), TypeTree.Inferred()))))), _), _) =>
60+
val strCtxArgs = strCtxArgTrees.map {
61+
case Term.Literal(Constant.String(str)) => str
62+
case tree => throw new NotStaticlyKnownError("Expected statically known StringContext", tree.seal[Any])
63+
}
64+
StringContext(strCtxArgs: _*)
65+
case tree =>
66+
throw new NotStaticlyKnownError("Expected statically known StringContext", tree.seal[Any])
67+
}
68+
}
69+
70+
protected def getArgsList(argsExpr: Expr[Seq[Any]])(implicit reflect: Reflection): List[Expr[Any]] = {
71+
import reflect._
72+
argsExpr.unseal.underlyingArgument match {
73+
case Term.Typed(Term.Repeated(args), _) => args.map(_.seal[Any])
74+
case tree => throw new NotStaticlyKnownError("Expected statically known argument list", tree.seal[Any])
75+
}
76+
}
77+
78+
protected implicit def StringContextIsLiftable: Liftable[StringContext] = new Liftable[StringContext] {
79+
def toExpr(strCtx: StringContext): Expr[StringContext] = {
80+
// TODO define in stdlib?
81+
implicit def ListIsLiftable: Liftable[List[String]] = new Liftable[List[String]] {
82+
override def toExpr(list: List[String]): Expr[List[String]] = list match {
83+
case x :: xs => '(~x.toExpr :: ~toExpr(xs))
84+
case Nil => '(Nil)
85+
}
86+
}
87+
'(StringContext(~strCtx.parts.toList.toExpr: _*))
88+
}
89+
}
90+
91+
protected class NotStaticlyKnownError(msg: String, expr: Expr[Any]) extends Exception(msg)
92+
protected class StringContextError(msg: String, idx: Int, start: Int = -1, end: Int = -1) extends Exception(msg)
93+
protected class ArgumentError(msg: String, idx: Int) extends Exception(msg)
94+
95+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
import Macro._
2+
3+
object Test {
4+
def main(args: Array[String]): Unit = {
5+
val w = "world"
6+
println(s2"Hello $w!")
7+
println(raw2"Hello $w!\n")
8+
println(foo"Hello $w!")
9+
}
10+
}

0 commit comments

Comments
 (0)