diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index e1cc8fa34062..d07d44c63057 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -112,8 +112,8 @@ object Splicer { } } - protected def interpretStaticMethodCall(fn: Symbol, args: => List[Object])(implicit env: Env): Object = { - val instance = loadModule(fn.owner) + protected def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = { + val instance = loadModule(moduleClass) def getDirectName(tp: Type, name: TermName): TermName = tp.widenDealias match { case tp: AppliedType if defn.isImplicitFunctionType(tp) => getDirectName(tp.args.last, NameKinds.DirectMethodName(name)) @@ -270,7 +270,7 @@ object Splicer { protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity) protected def interpretTastyContext()(implicit env: Env): Boolean = true protected def interpretQuoteContext()(implicit env: Env): Boolean = true - protected def interpretStaticMethodCall(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity) + protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity) protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Boolean = true protected def interpretNew(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity) @@ -292,7 +292,7 @@ object Splicer { protected def interpretLiteral(value: Any)(implicit env: Env): Result protected def interpretVarargs(args: List[Result])(implicit env: Env): Result protected def interpretTastyContext()(implicit env: Env): Result - protected def interpretStaticMethodCall(fn: Symbol, args: => List[Result])(implicit env: Env): Result + protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Result])(implicit env: Env): Result protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result protected def unexpectedTree(tree: Tree)(implicit env: Env): Result @@ -313,9 +313,14 @@ object Splicer { case Call(fn, args) => if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) { interpretNew(fn.symbol, args.map(interpretTree)) + } else if (fn.symbol.is(Module)) { + interpretModuleAccess(fn.symbol) } else if (fn.symbol.isStatic) { - if (fn.symbol.is(Module)) interpretModuleAccess(fn.symbol) - else interpretStaticMethodCall(fn.symbol, args.map(arg => interpretTree(arg))) + val module = fn.symbol.owner + interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg))) + } else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) { + val module = fn.qualifier.symbol.moduleClass + interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg))) } else if (env.contains(fn.name)) { env(fn.name) } else { diff --git a/tests/run/tasty-interpolation-1.check b/tests/run/tasty-interpolation-1.check new file mode 100644 index 000000000000..bdf79ca83b8a --- /dev/null +++ b/tests/run/tasty-interpolation-1.check @@ -0,0 +1,3 @@ +Hello world! +Hello world!\n +Hello foo! diff --git a/tests/run/tasty-interpolation-1/Macro.scala b/tests/run/tasty-interpolation-1/Macro.scala new file mode 100644 index 000000000000..c7db2734098c --- /dev/null +++ b/tests/run/tasty-interpolation-1/Macro.scala @@ -0,0 +1,95 @@ + +import scala.quoted._ +import scala.tasty.Reflection +import scala.language.implicitConversions +import scala.quoted.Exprs.LiftedExpr +import scala.quoted.Toolbox.Default._ + +object Macro { + + class StringContextOps(strCtx: => StringContext) { + inline def s2(args: Any*): String = ~SIntepolator('(strCtx), '(args)) + inline def raw2(args: Any*): String = ~RawIntepolator('(strCtx), '(args)) + inline def foo(args: Any*): String = ~FooIntepolator('(strCtx), '(args)) + } + implicit inline def SCOps(strCtx: => StringContext): StringContextOps = new StringContextOps(strCtx) +} + +object SIntepolator extends MacroStringInterpolator[String] { + protected def interpolate(strCtx: StringContext, args: List[Expr[Any]])(implicit reflect: Reflection): Expr[String] = + '((~strCtx.toExpr).s(~args.toExprOfList: _*)) +} + +object RawIntepolator extends MacroStringInterpolator[String] { + protected def interpolate(strCtx: StringContext, args: List[Expr[Any]])(implicit reflect: Reflection): Expr[String] = + '((~strCtx.toExpr).raw(~args.toExprOfList: _*)) +} + +object FooIntepolator extends MacroStringInterpolator[String] { + protected def interpolate(strCtx: StringContext, args: List[Expr[Any]])(implicit reflect: Reflection): Expr[String] = + '((~strCtx.toExpr).s(~args.map(_ => '("foo")).toExprOfList: _*)) +} + +// TODO put this class in the stdlib or separate project? +abstract class MacroStringInterpolator[T] { + + final def apply(strCtxExpr: Expr[StringContext], argsExpr: Expr[Seq[Any]])(implicit reflect: Reflection): Expr[T] = { + try interpolate(strCtxExpr, argsExpr) + catch { + case ex: NotStaticlyKnownError => + // TODO use ex.expr to recover the position + throw new QuoteError(ex.getMessage) + case ex: StringContextError => + // TODO use ex.idx to recover the position + throw new QuoteError(ex.getMessage) + case ex: ArgumentError => + // TODO use ex.idx to recover the position + throw new QuoteError(ex.getMessage) + } + } + + protected def interpolate(strCtxExpr: Expr[StringContext], argsExpr: Expr[Seq[Any]])(implicit reflect: Reflection): Expr[T] = + interpolate(getStaticStringContext(strCtxExpr), getArgsList(argsExpr)) + + protected def interpolate(strCtx: StringContext, argExprs: List[Expr[Any]])(implicit reflect: Reflection): Expr[T] + + protected def getStaticStringContext(strCtxExpr: Expr[StringContext])(implicit reflect: Reflection): StringContext = { + import reflect._ + strCtxExpr.unseal.underlyingArgument match { + case Term.Select(Term.Typed(Term.Apply(_, List(Term.Apply(_, List(Term.Typed(Term.Repeated(strCtxArgTrees), TypeTree.Inferred()))))), _), _) => + val strCtxArgs = strCtxArgTrees.map { + case Term.Literal(Constant.String(str)) => str + case tree => throw new NotStaticlyKnownError("Expected statically known StringContext", tree.seal[Any]) + } + StringContext(strCtxArgs: _*) + case tree => + throw new NotStaticlyKnownError("Expected statically known StringContext", tree.seal[Any]) + } + } + + protected def getArgsList(argsExpr: Expr[Seq[Any]])(implicit reflect: Reflection): List[Expr[Any]] = { + import reflect._ + argsExpr.unseal.underlyingArgument match { + case Term.Typed(Term.Repeated(args), _) => args.map(_.seal[Any]) + case tree => throw new NotStaticlyKnownError("Expected statically known argument list", tree.seal[Any]) + } + } + + protected implicit def StringContextIsLiftable: Liftable[StringContext] = new Liftable[StringContext] { + def toExpr(strCtx: StringContext): Expr[StringContext] = { + // TODO define in stdlib? + implicit def ListIsLiftable: Liftable[List[String]] = new Liftable[List[String]] { + override def toExpr(list: List[String]): Expr[List[String]] = list match { + case x :: xs => '(~x.toExpr :: ~toExpr(xs)) + case Nil => '(Nil) + } + } + '(StringContext(~strCtx.parts.toList.toExpr: _*)) + } + } + + protected class NotStaticlyKnownError(msg: String, expr: Expr[Any]) extends Exception(msg) + protected class StringContextError(msg: String, idx: Int, start: Int = -1, end: Int = -1) extends Exception(msg) + protected class ArgumentError(msg: String, idx: Int) extends Exception(msg) + +} diff --git a/tests/run/tasty-interpolation-1/Test_2.scala b/tests/run/tasty-interpolation-1/Test_2.scala new file mode 100644 index 000000000000..716adac21f45 --- /dev/null +++ b/tests/run/tasty-interpolation-1/Test_2.scala @@ -0,0 +1,10 @@ +import Macro._ + +object Test { + def main(args: Array[String]): Unit = { + val w = "world" + println(s2"Hello $w!") + println(raw2"Hello $w!\n") + println(foo"Hello $w!") + } +}