Skip to content

Allow macros calling inherited methods on modules #5575

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions compiler/src/dotty/tools/dotc/transform/Splicer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ object Splicer {
}
}

protected def interpretStaticMethodCall(fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
val instance = loadModule(fn.owner)
protected def interpretStaticMethodCall(moduleClass: Symbol, fn: Symbol, args: => List[Object])(implicit env: Env): Object = {
val instance = loadModule(moduleClass)
def getDirectName(tp: Type, name: TermName): TermName = tp.widenDealias match {
case tp: AppliedType if defn.isImplicitFunctionType(tp) =>
getDirectName(tp.args.last, NameKinds.DirectMethodName(name))
Expand Down Expand Up @@ -270,7 +270,7 @@ object Splicer {
protected def interpretVarargs(args: List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretTastyContext()(implicit env: Env): Boolean = true
protected def interpretQuoteContext()(implicit env: Env): Boolean = true
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Boolean = true
protected def interpretNew(fn: Symbol, args: => List[Boolean])(implicit env: Env): Boolean = args.forall(identity)

Expand All @@ -292,7 +292,7 @@ object Splicer {
protected def interpretLiteral(value: Any)(implicit env: Env): Result
protected def interpretVarargs(args: List[Result])(implicit env: Env): Result
protected def interpretTastyContext()(implicit env: Env): Result
protected def interpretStaticMethodCall(fn: Symbol, args: => List[Result])(implicit env: Env): Result
protected def interpretStaticMethodCall(module: Symbol, fn: Symbol, args: => List[Result])(implicit env: Env): Result
protected def interpretModuleAccess(fn: Symbol)(implicit env: Env): Result
protected def interpretNew(fn: Symbol, args: => List[Result])(implicit env: Env): Result
protected def unexpectedTree(tree: Tree)(implicit env: Env): Result
Expand All @@ -313,9 +313,14 @@ object Splicer {
case Call(fn, args) =>
if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) {
interpretNew(fn.symbol, args.map(interpretTree))
} else if (fn.symbol.is(Module)) {
interpretModuleAccess(fn.symbol)
} else if (fn.symbol.isStatic) {
if (fn.symbol.is(Module)) interpretModuleAccess(fn.symbol)
else interpretStaticMethodCall(fn.symbol, args.map(arg => interpretTree(arg)))
val module = fn.symbol.owner
interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg)))
} else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) {
val module = fn.qualifier.symbol.moduleClass
interpretStaticMethodCall(module, fn.symbol, args.map(arg => interpretTree(arg)))
} else if (env.contains(fn.name)) {
env(fn.name)
} else {
Expand Down
3 changes: 3 additions & 0 deletions tests/run/tasty-interpolation-1.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Hello world!
Hello world!\n
Hello foo!
95 changes: 95 additions & 0 deletions tests/run/tasty-interpolation-1/Macro.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@

import scala.quoted._
import scala.tasty.Reflection
import scala.language.implicitConversions
import scala.quoted.Exprs.LiftedExpr
import scala.quoted.Toolbox.Default._

object Macro {

class StringContextOps(strCtx: => StringContext) {
inline def s2(args: Any*): String = ~SIntepolator('(strCtx), '(args))
inline def raw2(args: Any*): String = ~RawIntepolator('(strCtx), '(args))
inline def foo(args: Any*): String = ~FooIntepolator('(strCtx), '(args))
}
implicit inline def SCOps(strCtx: => StringContext): StringContextOps = new StringContextOps(strCtx)
}

object SIntepolator extends MacroStringInterpolator[String] {
protected def interpolate(strCtx: StringContext, args: List[Expr[Any]])(implicit reflect: Reflection): Expr[String] =
'((~strCtx.toExpr).s(~args.toExprOfList: _*))
}

object RawIntepolator extends MacroStringInterpolator[String] {
protected def interpolate(strCtx: StringContext, args: List[Expr[Any]])(implicit reflect: Reflection): Expr[String] =
'((~strCtx.toExpr).raw(~args.toExprOfList: _*))
}

object FooIntepolator extends MacroStringInterpolator[String] {
protected def interpolate(strCtx: StringContext, args: List[Expr[Any]])(implicit reflect: Reflection): Expr[String] =
'((~strCtx.toExpr).s(~args.map(_ => '("foo")).toExprOfList: _*))
}

// TODO put this class in the stdlib or separate project?
abstract class MacroStringInterpolator[T] {

final def apply(strCtxExpr: Expr[StringContext], argsExpr: Expr[Seq[Any]])(implicit reflect: Reflection): Expr[T] = {
try interpolate(strCtxExpr, argsExpr)
catch {
case ex: NotStaticlyKnownError =>
// TODO use ex.expr to recover the position
throw new QuoteError(ex.getMessage)
case ex: StringContextError =>
// TODO use ex.idx to recover the position
throw new QuoteError(ex.getMessage)
case ex: ArgumentError =>
// TODO use ex.idx to recover the position
throw new QuoteError(ex.getMessage)
}
}

protected def interpolate(strCtxExpr: Expr[StringContext], argsExpr: Expr[Seq[Any]])(implicit reflect: Reflection): Expr[T] =
interpolate(getStaticStringContext(strCtxExpr), getArgsList(argsExpr))

protected def interpolate(strCtx: StringContext, argExprs: List[Expr[Any]])(implicit reflect: Reflection): Expr[T]

protected def getStaticStringContext(strCtxExpr: Expr[StringContext])(implicit reflect: Reflection): StringContext = {
import reflect._
strCtxExpr.unseal.underlyingArgument match {
case Term.Select(Term.Typed(Term.Apply(_, List(Term.Apply(_, List(Term.Typed(Term.Repeated(strCtxArgTrees), TypeTree.Inferred()))))), _), _) =>
val strCtxArgs = strCtxArgTrees.map {
case Term.Literal(Constant.String(str)) => str
case tree => throw new NotStaticlyKnownError("Expected statically known StringContext", tree.seal[Any])
}
StringContext(strCtxArgs: _*)
case tree =>
throw new NotStaticlyKnownError("Expected statically known StringContext", tree.seal[Any])
}
}

protected def getArgsList(argsExpr: Expr[Seq[Any]])(implicit reflect: Reflection): List[Expr[Any]] = {
import reflect._
argsExpr.unseal.underlyingArgument match {
case Term.Typed(Term.Repeated(args), _) => args.map(_.seal[Any])
case tree => throw new NotStaticlyKnownError("Expected statically known argument list", tree.seal[Any])
}
}

protected implicit def StringContextIsLiftable: Liftable[StringContext] = new Liftable[StringContext] {
def toExpr(strCtx: StringContext): Expr[StringContext] = {
// TODO define in stdlib?
implicit def ListIsLiftable: Liftable[List[String]] = new Liftable[List[String]] {
override def toExpr(list: List[String]): Expr[List[String]] = list match {
case x :: xs => '(~x.toExpr :: ~toExpr(xs))
case Nil => '(Nil)
}
}
'(StringContext(~strCtx.parts.toList.toExpr: _*))
}
}

protected class NotStaticlyKnownError(msg: String, expr: Expr[Any]) extends Exception(msg)
protected class StringContextError(msg: String, idx: Int, start: Int = -1, end: Int = -1) extends Exception(msg)
protected class ArgumentError(msg: String, idx: Int) extends Exception(msg)

}
10 changes: 10 additions & 0 deletions tests/run/tasty-interpolation-1/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import Macro._

object Test {
def main(args: Array[String]): Unit = {
val w = "world"
println(s2"Hello $w!")
println(raw2"Hello $w!\n")
println(foo"Hello $w!")
}
}