Skip to content

Commit 7899ae0

Browse files
committed
Fix #6622: Add code interpolation
Allows to get string representations for code passed in the interpolated values ```scala inline def logged(p1: => Any) = { val c = code"code: $p1" val res = p1 (c, p1) } logged(indentity("foo")) ``` is equivalent to: ```scala ("code: indentity("foo")", indentity("foo")) ```
1 parent 9ca016e commit 7899ae0

File tree

12 files changed

+149
-15
lines changed

12 files changed

+149
-15
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ class Definitions {
234234
def Compiletime_constValue(implicit ctx: Context): Symbol = Compiletime_constValueR.symbol
235235
@threadUnsafe lazy val Compiletime_constValueOptR: TermRef = CompiletimePackageObjectRef.symbol.requiredMethodRef("constValueOpt")
236236
def Compiletime_constValueOpt(implicit ctx: Context): Symbol = Compiletime_constValueOptR.symbol
237+
@threadUnsafe lazy val Compiletime_codeR: TermRef = CompiletimePackageObjectRef.symbol.requiredMethodRef("code")
238+
def Compiletime_code(implicit ctx: Context): Symbol = Compiletime_codeR.symbol
237239

238240
/** The `scalaShadowing` package is used to safely modify classes and
239241
* objects in scala so that they can be used from dotty. They will

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
472472
val expansion = inliner.transform(rhsToInline)
473473

474474
def issueError() = callValueArgss match {
475-
case (msgArg :: rest) :: Nil =>
475+
case (msgArg :: Nil) :: Nil =>
476476
msgArg.tpe match {
477477
case ConstantType(Constant(msg: String)) =>
478478
// Usually `error` is called from within a rewrite method. In this
@@ -482,23 +482,49 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
482482
val callToReport = if (enclosingInlineds.nonEmpty) enclosingInlineds.last else call
483483
val ctxToReport = ctx.outersIterator.dropWhile(enclosingInlineds(_).nonEmpty).next
484484
def issueInCtx(implicit ctx: Context) = {
485-
def decompose(arg: Tree): String = arg match {
486-
case Typed(arg, _) => decompose(arg)
487-
case SeqLiteral(elems, _) => elems.map(decompose).mkString(", ")
488-
case arg =>
489-
arg.tpe.widenTermRefExpr match {
490-
case ConstantType(Constant(c)) => c.toString
491-
case _ => arg.show
492-
}
493-
}
494-
ctx.error(s"$msg${rest.map(decompose).mkString(", ")}", callToReport.sourcePos)
485+
ctx.error(msg, callToReport.sourcePos)
495486
}
496487
issueInCtx(ctxToReport)
497488
case _ =>
498489
}
499490
case _ =>
500491
}
501492

493+
def issueCode()(implicit ctx: Context): Literal = {
494+
def decompose(arg: Tree): String = arg match {
495+
case Typed(arg, _) => decompose(arg)
496+
case SeqLiteral(elems, _) => elems.map(decompose).mkString(", ")
497+
case Block(Nil, expr) => decompose(expr)
498+
case Inlined(_, Nil, expr) => decompose(expr)
499+
case arg =>
500+
arg.tpe.widenTermRefExpr match {
501+
case ConstantType(Constant(c)) => c.toString
502+
case _ => arg.show
503+
}
504+
}
505+
506+
def malformedString(): String = {
507+
ctx.error("Malformed part `code` string interpolator", call.sourcePos)
508+
""
509+
}
510+
511+
callValueArgss match {
512+
case List(List(Apply(_,List(Typed(SeqLiteral(Literal(headConst) :: parts,_),_)))), List(Typed(SeqLiteral(interpolatedParts,_),_)))
513+
if parts.size == interpolatedParts.size =>
514+
val constantParts = parts.map {
515+
case Literal(const) => const.stringValue
516+
case _ => malformedString()
517+
}
518+
val decomposedInterpolations = interpolatedParts.map(decompose)
519+
val constantString = decomposedInterpolations.zip(constantParts)
520+
.foldLeft(headConst.stringValue) { case (acc, (p1, p2)) => acc + p1 + p2 }
521+
522+
Literal(Constant(constantString)).withSpan(call.span)
523+
case _ =>
524+
Literal(Constant(malformedString()))
525+
}
526+
}
527+
502528
trace(i"inlining $call", inlining, show = true) {
503529

504530
// The normalized bindings collected in `bindingsBuf`
@@ -522,9 +548,13 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) {
522548

523549
if (inlinedMethod == defn.Compiletime_error) issueError()
524550

525-
// Take care that only argument bindings go into `bindings`, since positions are
526-
// different for bindings from arguments and bindings from body.
527-
tpd.Inlined(call, finalBindings, finalExpansion)
551+
if (inlinedMethod == defn.Compiletime_code) {
552+
issueCode()(ctx.fresh.setSetting(ctx.settings.color, "never"))
553+
} else {
554+
// Take care that only argument bindings go into `bindings`, since positions are
555+
// different for bindings from arguments and bindings from body.
556+
tpd.Inlined(call, finalBindings, finalExpansion)
557+
}
528558
}
529559
}
530560

library/src-3.x/scala/compiletime/package.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,23 @@ package object compiletime {
44

55
erased def erasedValue[T]: T = ???
66

7-
inline def error(inline msg: String, objs: Any*): Nothing = ???
7+
inline def error(inline msg: String): Nothing = ???
8+
9+
/** Returns the string representations for code passed in the interpolated values
10+
* ```scala
11+
* inline def logged(p1: => Any) = {
12+
* val c = code"code: $p1"
13+
* val res = p1
14+
* (c, p1)
15+
* }
16+
* logged(indentity("foo"))
17+
* ```
18+
* is equivalent to:
19+
* ```scala
20+
* ("code: indentity("foo")", indentity("foo"))
21+
* ```
22+
*/
23+
inline def (self: => StringContext) code (args: => Any*): String = ???
824

925
inline def constValueOpt[T]: Option[T] = ???
1026

tests/neg/i6622.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.compiletime._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(StringContext("abc ", "", "").code(println(34))) // error
7+
}
8+
9+
}

tests/neg/i6622a.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import scala.compiletime._
2+
3+
object Test {
4+
5+
def nonConstant: String = ""
6+
7+
def main(args: Array[String]): Unit = {
8+
println(StringContext("abc ", nonConstant).code(println(34))) // error
9+
}
10+
11+
}

tests/neg/i6622b.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.compiletime._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(StringContext("abc ").code(println(34), 34)) // error
7+
}
8+
9+
}

tests/neg/i6622c.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.compiletime._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(StringContext(Seq.empty[String]:_*).code(println(34))) // error
7+
}
8+
9+
}

tests/neg/i6622d.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.compiletime._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(StringContext("abc").code(Seq.empty[Any]:_*)) // error
7+
}
8+
9+
}

tests/neg/i6622e.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import scala.compiletime._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
println(StringContext(Seq.empty[String]:_*).code(Seq.empty[Any]:_*)) // error
7+
}
8+
9+
}

tests/neg/i6622f.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
-- Error: tests/neg/i6622f.scala:6:8 -----------------------------------------------------------------------------------
2+
6 | fail(println("foo")) // error
3+
| ^^^^^^^^^^^^^^^^^^^^
4+
| failed: println("foo") ...

tests/neg/i6622f.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
import scala.compiletime._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
fail(println("foo")) // error
7+
}
8+
9+
inline def fail(p1: => Any) = error(code"failed: $p1 ...")
10+
11+
}

tests/run/i6622.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import scala.compiletime._
2+
3+
object Test {
4+
5+
def main(args: Array[String]): Unit = {
6+
assert(code"abc ${println(34)} ..." == "abc println(34) ...")
7+
assert(code"abc ${println(34)}" == "abc println(34)")
8+
assert(code"${println(34)} ..." == "println(34) ...")
9+
assert(code"${println(34)}" == "println(34)")
10+
assert(code"..." == "...")
11+
assert(testConstant(code"") == "")
12+
}
13+
14+
inline def testConstant(inline msg: String): String = msg
15+
}

0 commit comments

Comments
 (0)