diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 8ea64cf62f40..94779bca680c 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -7,16 +7,15 @@ import Periods._ import Symbols._ import Types._ import Scopes._ -import typer.{FrontEnd, Typer, ImportInfo, RefChecks} -import reporting.{Reporter, ConsoleReporter} +import typer.{FrontEnd, ImportInfo, RefChecks, Typer} +import reporting.{ConsoleReporter, Reporter} import Phases.Phase import transform._ import util.FreshNameCreator import core.DenotTransformers.DenotTransformer import core.Denotations.SingleDenotation - -import dotty.tools.backend.jvm.{LabelDefs, GenBCode, CollectSuperCalls} -import dotty.tools.dotc.transform.localopt.Simplify +import dotty.tools.backend.jvm.{CollectSuperCalls, GenBCode, LabelDefs} +import dotty.tools.dotc.transform.localopt.{Simplify, StringInterpolatorOpt} /** The central class of the dotc compiler. The job of a compiler is to create * runs, which process given `phases` in a given `rootContext`. @@ -78,6 +77,7 @@ class Compiler { new PatternMatcher, // Compile pattern matches new ExplicitOuter, // Add accessors to outer classes from nested ones. new ExplicitSelf, // Make references to non-trivial self types explicit as casts + new StringInterpolatorOpt, // Optimizes raw and s string interpolators by rewriting them to string concatentations new CrossCastAnd, // Normalize selections involving intersection types. new Splitter) :: // Expand selections involving union types into conditionals List(new ErasedDecls, // Removes all erased defs and vals decls (except for parameters) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 95c5aa94dd37..2376ad02addc 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -583,6 +583,16 @@ class Definitions { lazy val StringAdd_plusR = StringAddClass.requiredMethodRef(nme.raw.PLUS) def StringAdd_+(implicit ctx: Context) = StringAdd_plusR.symbol + lazy val StringContextType: TypeRef = ctx.requiredClassRef("scala.StringContext") + def StringContextClass(implicit ctx: Context) = StringContextType.symbol.asClass + lazy val StringContextSR = StringContextClass.requiredMethodRef(nme.s) + def StringContextS(implicit ctx: Context) = StringContextSR.symbol + lazy val StringContextRawR = StringContextClass.requiredMethodRef(nme.raw_) + def StringContextRaw(implicit ctx: Context) = StringContextRawR.symbol + def StringContextModule(implicit ctx: Context) = StringContextClass.companionModule + lazy val StringContextModule_applyR = StringContextModule.requiredMethodRef(nme.apply) + def StringContextModule_apply(implicit ctx: Context) = StringContextModule_applyR.symbol + lazy val PartialFunctionType: TypeRef = ctx.requiredClassRef("scala.PartialFunction") def PartialFunctionClass(implicit ctx: Context) = PartialFunctionType.symbol.asClass lazy val AbstractPartialFunctionType: TypeRef = ctx.requiredClassRef("scala.runtime.AbstractPartialFunction") diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index a6543572b23f..215c1ac32ab3 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -486,6 +486,7 @@ object StdNames { val productElement: N = "productElement" val productIterator: N = "productIterator" val productPrefix: N = "productPrefix" + val raw_ : N = "raw" val readResolve: N = "readResolve" val reflect : N = "reflect" val reify : N = "reify" @@ -495,6 +496,7 @@ object StdNames { val runtime: N = "runtime" val runtimeClass: N = "runtimeClass" val runtimeMirror: N = "runtimeMirror" + val s: N = "s" val sameElements: N = "sameElements" val scala_ : N = "scala" val scalaShadowing : N = "scalaShadowing" diff --git a/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala new file mode 100644 index 000000000000..89681da9aeaa --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/localopt/StringInterpolatorOpt.scala @@ -0,0 +1,102 @@ +package dotty.tools.dotc.transform.localopt + +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Constants.Constant +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.StdNames._ +import dotty.tools.dotc.core.Symbols._ +import dotty.tools.dotc.transform.MegaPhase.MiniPhase + +/** + * MiniPhase to transform s and raw string interpolators from using StringContext to string + * concatenation. Since string concatenation uses the Java String builder, we get a performance + * improvement in terms of these two interpolators. + * + * More info here: + * https://medium.com/@dkomanov/scala-string-interpolation-performance-21dc85e83afd + */ +class StringInterpolatorOpt extends MiniPhase { + import tpd._ + + override def phaseName: String = "stringInterpolatorOpt" + + /** Matches a list of constant literals */ + private object Literals { + def unapply(tree: SeqLiteral)(implicit ctx: Context): Option[List[Literal]] = { + tree.elems match { + case literals if literals.forall(_.isInstanceOf[Literal]) => + Some(literals.map(_.asInstanceOf[Literal])) + case _ => None + } + } + } + + private object StringContextApply { + def unapply(tree: Select)(implicit ctx: Context): Boolean = { + tree.symbol.eq(defn.StringContextModule_apply) && { + val qualifier = tree.qualifier + qualifier.isInstanceOf[Ident] && qualifier.symbol.eq(defn.StringContextModule) + } + } + } + + /** Matches an s or raw string interpolator */ + private object SOrRawInterpolator { + def unapply(tree: Tree)(implicit ctx: Context): Option[(List[Literal], List[Tree])] = { + if (tree.symbol.eq(defn.StringContextRaw) || tree.symbol.eq(defn.StringContextS)) { + tree match { + case Apply(Select(Apply(StringContextApply(), List(Literals(strs))), _), + List(SeqLiteral(elems, _))) if elems.length == strs.length - 1 => + Some(strs, elems) + case _ => None + } + } else None + } + } + + /** + * Match trees that resemble s and raw string interpolations. In the case of the s + * interpolator, escapes the string constants. Exposes the string constants as well as + * the variable references. + */ + private object StringContextIntrinsic { + def unapply(tree: Apply)(implicit ctx: Context): Option[(List[Literal], List[Tree])] = { + tree match { + case SOrRawInterpolator(strs, elems) => + if (tree.symbol == defn.StringContextRaw) Some(strs, elems) + else { // tree.symbol == defn.StringContextS + try { + val escapedStrs = strs.map { str => + val escapedValue = StringContext.processEscapes(str.const.stringValue) + cpy.Literal(str)(Constant(escapedValue)) + } + Some(escapedStrs, elems) + } catch { + case _: StringContext.InvalidEscapeException => None + } + } + case _ => None + } + } + } + + override def transformApply(tree: Apply)(implicit ctx: Context): Tree = { + tree match { + case StringContextIntrinsic(strs: List[Literal], elems: List[Tree]) => + val stri = strs.iterator + val elemi = elems.iterator + var result: Tree = stri.next + def concat(tree: Tree): Unit = { + result = result.select(defn.String_+).appliedTo(tree) + } + while (elemi.hasNext) { + concat(elemi.next) + val str = stri.next + if (!str.const.stringValue.isEmpty) concat(str) + } + result + case _ => tree + } + } +} diff --git a/compiler/test/dotty/tools/backend/jvm/StringInterpolatorOptTest.scala b/compiler/test/dotty/tools/backend/jvm/StringInterpolatorOptTest.scala new file mode 100644 index 000000000000..14bcf0d83e45 --- /dev/null +++ b/compiler/test/dotty/tools/backend/jvm/StringInterpolatorOptTest.scala @@ -0,0 +1,64 @@ +package dotty.tools.backend.jvm + +import org.junit.Assert._ +import org.junit.Test + +class StringInterpolatorOptTest extends DottyBytecodeTest { + import ASMConverters._ + + @Test def testRawInterpolator = { + val source = + """ + |class Foo { + | val one = 1 + | val two = "two" + | val three = 3.0 + | + | def meth1: String = raw"$one plus $two$three\n" + | def meth2: String = "" + one + " plus " + two + three + "\\n" + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + "the `` string interpolator incorrectly converts to string concatenation\n" + + diffInstructions(instructions1, instructions2)) + } + } + + @Test def testSInterpolator = { + val source = + """ + |class Foo { + | val one = 1 + | val two = "two" + | val three = 3.0 + | + | def meth1: String = s"$one plus $two$three\n" + | def meth2: String = "" + one + " plus " + two + three + "\n" + |} + """.stripMargin + + checkBCode(source) { dir => + val clsIn = dir.lookupName("Foo.class", directory = false).input + val clsNode = loadClassNode(clsIn) + val meth1 = getMethod(clsNode, "meth1") + val meth2 = getMethod(clsNode, "meth2") + + val instructions1 = instructionsFromMethod(meth1) + val instructions2 = instructionsFromMethod(meth2) + + assert(instructions1 == instructions2, + "the `s` string interpolator incorrectly converts to string concatenation\n" + + diffInstructions(instructions1, instructions2)) + } + } +} diff --git a/tests/run/interpolation-opt.check b/tests/run/interpolation-opt.check new file mode 100644 index 000000000000..1330eb8ac2e7 --- /dev/null +++ b/tests/run/interpolation-opt.check @@ -0,0 +1,12 @@ +1 plus two\nis 3.0 +1 plus two +is 3.0 +a1two3.0b +a1two3.0b + + +Hello World +Side effect! +Foo Bar +Side effect n2! +Titi Toto diff --git a/tests/run/interpolation-opt.scala b/tests/run/interpolation-opt.scala new file mode 100644 index 000000000000..e7ff07ec61ac --- /dev/null +++ b/tests/run/interpolation-opt.scala @@ -0,0 +1,29 @@ +object Test extends App { + + val one = 1 + val two = "two" + val three = 3.0 + + // Test escaping + println(raw"$one plus $two\nis $three") + println(s"$one plus $two\nis $three") + + // Test empty strings between elements + println(raw"a$one$two${three}b") + println(s"a$one$two${three}b") + + // Test empty string interpolators + println(raw"") + println(s"") + + // Make sure that StringContext still works with idents + val foo = "Hello" + val bar = "World" + println(StringContext(foo, bar).s(" ")) + + def myStringContext= { println("Side effect!"); StringContext } + println(myStringContext("Foo", "Bar").s(" ")) // this shouldn't be optimised away + + // this shouldn't be optimised away + println({ println("Side effect n2!"); StringContext }.apply("Titi", "Toto").s(" ")) +}