Skip to content

Commit a993761

Browse files
harpocratesolsdavis
authored andcommitted
Emit efficient code for switch over strings
The pattern matcher will now emit `Match` with `String` scrutinee as well as the existing `Int` scrutinee. The JVM backend handles this case by emitting bytecode that switches on the String's `hashCode` (this matches what Java does). The SJS already handles `String` matches. The approach is similar to scala/scala#8451 (see scala/bug#11740 too), except that instead of doing a transformation on the AST, we just emit the right bytecode straight away. This is desirable since it means that Scala.js (and any other backend) can choose their own optimised strategy for compiling a match on strings. Fixes scala#11923
1 parent 47cd35a commit a993761

File tree

8 files changed

+317
-75
lines changed

8 files changed

+317
-75
lines changed

compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala

Lines changed: 156 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package backend
33
package jvm
44

55
import scala.annotation.switch
6+
import scala.collection.mutable.SortedMap
67

78
import scala.tools.asm
89
import scala.tools.asm.{Handle, Label, Opcodes}
@@ -840,61 +841,170 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
840841
generatedType
841842
}
842843

843-
/*
844-
* A Match node contains one or more case clauses,
845-
* each case clause lists one or more Int values to use as keys, and a code block.
846-
* Except the "default" case clause which (if it exists) doesn't list any Int key.
847-
*
848-
* On a first pass over the case clauses, we flatten the keys and their targets (the latter represented with asm.Labels).
849-
* That representation allows JCodeMethodV to emit a lookupswitch or a tableswitch.
850-
*
851-
* On a second pass, we emit the switch blocks, one for each different target.
844+
/* A Match node contains one or more case clauses, each case clause lists one or more
845+
* Int/String values to use as keys, and a code block. The exception is the "default" case
846+
* clause which doesn't list any key (there is exactly one of these per match).
852847
*/
853848
private def genMatch(tree: Match): BType = tree match {
854849
case Match(selector, cases) =>
855850
lineNumber(tree)
856-
genLoad(selector, INT)
857851
val generatedType = tpeTK(tree)
852+
val postMatch = new asm.Label
858853

859-
var flatKeys: List[Int] = Nil
860-
var targets: List[asm.Label] = Nil
861-
var default: asm.Label = null
862-
var switchBlocks: List[(asm.Label, Tree)] = Nil
863-
864-
// collect switch blocks and their keys, but don't emit yet any switch-block.
865-
for (caze @ CaseDef(pat, guard, body) <- cases) {
866-
assert(guard == tpd.EmptyTree, guard)
867-
val switchBlockPoint = new asm.Label
868-
switchBlocks ::= (switchBlockPoint, body)
869-
pat match {
870-
case Literal(value) =>
871-
flatKeys ::= value.intValue
872-
targets ::= switchBlockPoint
873-
case Ident(nme.WILDCARD) =>
874-
assert(default == null, s"multiple default targets in a Match node, at ${tree.span}")
875-
default = switchBlockPoint
876-
case Alternative(alts) =>
877-
alts foreach {
878-
case Literal(value) =>
879-
flatKeys ::= value.intValue
880-
targets ::= switchBlockPoint
881-
case _ =>
882-
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
883-
}
884-
case _ =>
885-
abort(s"Invalid pattern in Match node: $tree at: ${tree.span}")
854+
// Only two possible selector types exist in `Match` trees at this point: Int and String
855+
if (tpeTK(selector) == INT) {
856+
857+
/* On a first pass over the case clauses, we flatten the keys and their
858+
* targets (the latter represented with asm.Labels). That representation
859+
* allows JCodeMethodV to emit a lookupswitch or a tableswitch.
860+
*
861+
* On a second pass, we emit the switch blocks, one for each different target.
862+
*/
863+
864+
var flatKeys: List[Int] = Nil
865+
var targets: List[asm.Label] = Nil
866+
var default: asm.Label = null
867+
var switchBlocks: List[(asm.Label, Tree)] = Nil
868+
869+
genLoad(selector, INT)
870+
871+
// collect switch blocks and their keys, but don't emit yet any switch-block.
872+
for (caze @ CaseDef(pat, guard, body) <- cases) {
873+
assert(guard == tpd.EmptyTree, guard)
874+
val switchBlockPoint = new asm.Label
875+
switchBlocks ::= (switchBlockPoint, body)
876+
pat match {
877+
case Literal(value) =>
878+
flatKeys ::= value.intValue
879+
targets ::= switchBlockPoint
880+
case Ident(nme.WILDCARD) =>
881+
assert(default == null, s"multiple default targets in a Match node, at ${tree.span}")
882+
default = switchBlockPoint
883+
case Alternative(alts) =>
884+
alts foreach {
885+
case Literal(value) =>
886+
flatKeys ::= value.intValue
887+
targets ::= switchBlockPoint
888+
case _ =>
889+
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
890+
}
891+
case _ =>
892+
abort(s"Invalid pattern in Match node: $tree at: ${tree.span}")
893+
}
886894
}
887-
}
888895

889-
bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY)
896+
bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY)
890897

891-
// emit switch-blocks.
892-
val postMatch = new asm.Label
893-
for (sb <- switchBlocks.reverse) {
894-
val (caseLabel, caseBody) = sb
895-
markProgramPoint(caseLabel)
896-
genLoad(caseBody, generatedType)
897-
bc goTo postMatch
898+
// emit switch-blocks.
899+
for (sb <- switchBlocks.reverse) {
900+
val (caseLabel, caseBody) = sb
901+
markProgramPoint(caseLabel)
902+
genLoad(caseBody, generatedType)
903+
bc goTo postMatch
904+
}
905+
} else {
906+
907+
/* Since the JVM doesn't have a way to switch on a string, we switch
908+
* on the `hashCode` of the string then do an `equals` check (with a
909+
* possible second set of jumps if blocks can be reach from multiple
910+
* string alternatives).
911+
*
912+
* This mirrors the way that Java compiles `switch` on Strings.
913+
*/
914+
915+
var default: asm.Label = null
916+
var indirectBlocks: List[(asm.Label, Tree)] = Nil
917+
918+
import scala.collection.mutable
919+
920+
// Cases grouped by their hashCode
921+
val casesByHash = SortedMap.empty[Int, List[(String, Either[asm.Label, Tree])]]
922+
var caseFallback: Tree = null
923+
924+
for (caze @ CaseDef(pat, guard, body) <- cases) {
925+
assert(guard == tpd.EmptyTree, guard)
926+
pat match {
927+
case Literal(value) =>
928+
val strValue = value.stringValue
929+
casesByHash.updateWith(strValue.##) { existingCasesOpt =>
930+
val newCase = (strValue, Right(body))
931+
Some(newCase :: existingCasesOpt.getOrElse(Nil))
932+
}
933+
case Ident(nme.WILDCARD) =>
934+
assert(default == null, s"multiple default targets in a Match node, at ${tree.span}")
935+
default = new asm.Label
936+
indirectBlocks ::= (default, body)
937+
case Alternative(alts) =>
938+
// We need an extra basic block since multiple strings can lead to this code
939+
val indirectCaseGroupLabel = new asm.Label
940+
indirectBlocks ::= (indirectCaseGroupLabel, body)
941+
alts foreach {
942+
case Literal(value) =>
943+
val strValue = value.stringValue
944+
casesByHash.updateWith(strValue.##) { existingCasesOpt =>
945+
val newCase = (strValue, Left(indirectCaseGroupLabel))
946+
Some(newCase :: existingCasesOpt.getOrElse(Nil))
947+
}
948+
case _ =>
949+
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
950+
}
951+
952+
case _ =>
953+
abort(s"Invalid pattern in Match node: $tree at: ${tree.span}")
954+
}
955+
}
956+
957+
// Organize the hashCode options into switch cases
958+
var flatKeys: List[Int] = Nil
959+
var targets: List[asm.Label] = Nil
960+
var hashBlocks: List[(asm.Label, List[(String, Either[asm.Label, Tree])])] = Nil
961+
for ((hashValue, hashCases) <- casesByHash) {
962+
val switchBlockPoint = new asm.Label
963+
hashBlocks ::= (switchBlockPoint, hashCases)
964+
flatKeys ::= hashValue
965+
targets ::= switchBlockPoint
966+
}
967+
968+
// Push the hashCode of the string (or `0` it is `null`) onto the stack and switch on it
969+
genLoadIf(
970+
If(
971+
tree.selector.select(defn.Any_==).appliedTo(nullLiteral),
972+
Literal(Constant(0)),
973+
tree.selector.select(defn.Any_hashCode).appliedToNone
974+
),
975+
INT
976+
)
977+
bc.emitSWITCH(mkArrayReverse(flatKeys), mkArrayL(targets.reverse), default, MIN_SWITCH_DENSITY)
978+
979+
// emit blocks for each hash case
980+
for ((hashLabel, caseAlternatives) <- hashBlocks.reverse) {
981+
markProgramPoint(hashLabel)
982+
for ((caseString, indirectLblOrBody) <- caseAlternatives) {
983+
val comparison = if (caseString == null) defn.Any_== else defn.Any_equals
984+
val condp = Literal(Constant(caseString)).select(defn.Any_==).appliedTo(tree.selector)
985+
val keepGoing = new asm.Label
986+
indirectLblOrBody match {
987+
case Left(jump) =>
988+
genCond(condp, jump, keepGoing, targetIfNoJump = keepGoing)
989+
990+
case Right(caseBody) =>
991+
val thisCaseMatches = new asm.Label
992+
genCond(condp, thisCaseMatches, keepGoing, targetIfNoJump = thisCaseMatches)
993+
markProgramPoint(thisCaseMatches)
994+
genLoad(caseBody, generatedType)
995+
bc goTo postMatch
996+
}
997+
markProgramPoint(keepGoing)
998+
}
999+
bc goTo default
1000+
}
1001+
1002+
// emit blocks for common patterns
1003+
for ((caseLabel, caseBody) <- indirectBlocks.reverse) {
1004+
markProgramPoint(caseLabel)
1005+
genLoad(caseBody, generatedType)
1006+
bc goTo postMatch
1007+
}
8981008
}
8991009

9001010
markProgramPoint(postMatch)

compiler/src/dotty/tools/backend/sjs/JSCodeGen.scala

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3130,12 +3130,6 @@ class JSCodeGen()(using genCtx: Context) {
31303130
def abortMatch(msg: String): Nothing =
31313131
throw new FatalError(s"$msg in switch-like pattern match at ${tree.span}: $tree")
31323132

3133-
/* Although GenBCode adapts the scrutinee and the cases to `int`, only
3134-
* true `int`s can reach the back-end, as asserted by the String-switch
3135-
* transformation in `cleanup`. Therefore, we do not adapt, preserving
3136-
* the `string`s and `null`s that come out of the pattern matching in
3137-
* Scala 2.13.2+.
3138-
*/
31393133
val genSelector = genExpr(selector)
31403134

31413135
// Sanity check: we can handle Ints and Strings (including `null`s), but nothing else
@@ -3192,11 +3186,6 @@ class JSCodeGen()(using genCtx: Context) {
31923186
* When no optimization applies, and any of the case values is not a
31933187
* literal int, we emit a series of `if..else` instead of a `js.Match`.
31943188
* This became necessary in 2.13.2 with strings and nulls.
3195-
*
3196-
* Note that dotc has not adopted String-switch-Matches yet, so these code
3197-
* paths are dead code at the moment. However, they already existed in the
3198-
* scalac, so were ported, to be immediately available and working when
3199-
* dotc starts emitting switch-Matches on Strings.
32003189
*/
32013190
def isInt(tree: js.Tree): Boolean = tree.tpe == jstpe.IntType
32023191

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import util.Property._
2020

2121
/** The pattern matching transform.
2222
* After this phase, the only Match nodes remaining in the code are simple switches
23-
* where every pattern is an integer constant
23+
* where every pattern is an integer or string constant
2424
*/
2525
class PatternMatcher extends MiniPhase {
2626
import ast.tpd._
@@ -768,13 +768,15 @@ object PatternMatcher {
768768
(tpe isRef defn.IntClass) ||
769769
(tpe isRef defn.ByteClass) ||
770770
(tpe isRef defn.ShortClass) ||
771-
(tpe isRef defn.CharClass)
771+
(tpe isRef defn.CharClass) ||
772+
(tpe isRef defn.StringClass)
772773

773-
val seen = mutable.Set[Int]()
774+
val seen = mutable.Set[Any]()
774775

775-
def isNewIntConst(tree: Tree) = tree match {
776-
case Literal(const) if const.isIntRange && !seen.contains(const.intValue) =>
777-
seen += const.intValue
776+
def isNewSwitchableConst(tree: Tree) = tree match {
777+
case Literal(const)
778+
if (const.isIntRange || const.tag == Constants.StringTag) && !seen.contains(const.value) =>
779+
seen += const.value
778780
true
779781
case _ =>
780782
false
@@ -789,7 +791,7 @@ object PatternMatcher {
789791
val alts = List.newBuilder[Tree]
790792
def rec(innerPlan: Plan): Boolean = innerPlan match {
791793
case SeqPlan(TestPlan(EqualTest(tree), scrut, _, ReturnPlan(`innerLabel`)), tail)
792-
if scrut === scrutinee && isNewIntConst(tree) =>
794+
if scrut === scrutinee && isNewSwitchableConst(tree) =>
793795
alts += tree
794796
rec(tail)
795797
case ReturnPlan(`outerLabel`) =>
@@ -809,7 +811,7 @@ object PatternMatcher {
809811

810812
def recur(plan: Plan): List[(List[Tree], Plan)] = plan match {
811813
case SeqPlan(testPlan @ TestPlan(EqualTest(tree), scrut, _, ons), tail)
812-
if scrut === scrutinee && !canFallThrough(ons) && isNewIntConst(tree) =>
814+
if scrut === scrutinee && !canFallThrough(ons) && isNewSwitchableConst(tree) =>
813815
(tree :: Nil, ons) :: recur(tail)
814816
case SeqPlan(AlternativesPlan(alts, ons), tail) =>
815817
(alts, ons) :: recur(tail)
@@ -832,29 +834,32 @@ object PatternMatcher {
832834

833835
/** Emit a switch-match */
834836
private def emitSwitchMatch(scrutinee: Tree, cases: List[(List[Tree], Plan)]): Match = {
835-
/* Make sure to adapt the scrutinee to Int, as well as all the alternatives
836-
* of all cases, so that only Matches on pritimive Ints survive this phase.
837+
/* Make sure to adapt the scrutinee to Int or String, as well as all the
838+
* alternatives, so that only Matches on pritimive Ints or Strings survive
839+
* this phase.
837840
*/
838841

839-
val intScrutinee =
840-
if (scrutinee.tpe.widen.isRef(defn.IntClass)) scrutinee
841-
else scrutinee.select(nme.toInt)
842+
val (primScrutinee, scrutineeTpe) =
843+
if (scrutinee.tpe.widen.isRef(defn.IntClass)) (scrutinee, defn.IntType)
844+
else if (scrutinee.tpe.widen.isRef(defn.StringClass)) (scrutinee, defn.StringType)
845+
else (scrutinee.select(nme.toInt), defn.IntType)
842846

843-
def intLiteral(lit: Tree): Tree =
847+
def primLiteral(lit: Tree): Tree =
844848
val Literal(constant) = lit
845849
if (constant.tag == Constants.IntTag) lit
850+
else if (constant.tag == Constants.StringTag) lit
846851
else cpy.Literal(lit)(Constant(constant.intValue))
847852

848853
val caseDefs = cases.map { (alts, ons) =>
849854
val pat = alts match {
850-
case alt :: Nil => intLiteral(alt)
851-
case Nil => Underscore(defn.IntType) // default case
852-
case _ => Alternative(alts.map(intLiteral))
855+
case alt :: Nil => primLiteral(alt)
856+
case Nil => Underscore(scrutineeTpe) // default case
857+
case _ => Alternative(alts.map(primLiteral))
853858
}
854859
CaseDef(pat, EmptyTree, emit(ons))
855860
}
856861

857-
Match(intScrutinee, caseDefs)
862+
Match(primScrutinee, caseDefs)
858863
}
859864

860865
/** If selfCheck is `true`, used to check whether a tree gets generated twice */

compiler/test/dotty/tools/backend/jvm/DottyBytecodeTests.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,28 @@ class TestBCode extends DottyBytecodeTest {
118118
}
119119
}
120120

121+
@Test def switchOnStrings = {
122+
val source =
123+
"""
124+
|object Foo {
125+
| import scala.annotation.switch
126+
| def foo(s: String) = s match {
127+
| case "AaAa" => println(3)
128+
| case "BBBB" | "c" => println(2)
129+
| case "D" | "E" => println(1)
130+
| case _ => println(0)
131+
| }
132+
|}
133+
""".stripMargin
134+
135+
checkBCode(source) { dir =>
136+
val moduleIn = dir.lookupName("Foo$.class", directory = false)
137+
val moduleNode = loadClassNode(moduleIn.input)
138+
val methodNode = getMethod(moduleNode, "foo")
139+
assert(verifySwitch(methodNode))
140+
}
141+
}
142+
121143
@Test def matchWithDefaultNoThrowMatchError = {
122144
val source =
123145
"""class Test {
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
2
2+
-1
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import annotation.switch
2+
3+
object Test {
4+
def test(s: String): Int = {
5+
(s : @switch) match {
6+
case "1" => 0
7+
case null => -1
8+
case _ => s.toInt
9+
}
10+
}
11+
12+
def main(args: Array[String]): Unit = {
13+
println(test("2"))
14+
println(test(null))
15+
}
16+
}

0 commit comments

Comments
 (0)