Skip to content

Commit 6dd8543

Browse files
committed
WIP: transform string based matches
1 parent c908b13 commit 6dd8543

File tree

2 files changed

+138
-17
lines changed

2 files changed

+138
-17
lines changed

compiler/src/dotty/tools/backend/jvm/BCodeBodyBuilder.scala

Lines changed: 116 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,11 @@ import dotty.tools.dotc.transform.SymUtils._
2121
import dotty.tools.dotc.util.Spans._
2222
import dotty.tools.dotc.core.Contexts._
2323
import dotty.tools.dotc.core.Phases._
24+
import dotty.tools.dotc.core.NameKinds.{PatMatResultName, PatMatAltsName, UniqueName, PatMatStdBinderName}
2425
import dotty.tools.dotc.report
2526

27+
import scala.collection.mutable
28+
2629
/*
2730
*
2831
* @author Miguel Garcia, http://lamp.epfl.ch/~magarcia/ScalaCompilerCornerReloaded/
@@ -67,6 +70,97 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
6770
}
6871
}
6972

73+
private def transformStringMatchLabelled(tree: Labeled): Tree =
74+
import dotty.tools.dotc.core.{Flags => flg}
75+
// report.echo(i"transforming $tree")
76+
def extract(tree: Labeled) =
77+
var inits: List[tpd.Tree] = Nil
78+
var litBuilder = mutable.ArrayBuilder.ofRef[(Literal, Literal)]
79+
var caseBuilder = mutable.ListBuffer.empty[CaseDef]
80+
var selector: Symbol = null
81+
var default: CaseDef = null
82+
def extractMatch(tree: Match) =
83+
selector = tree.selector.symbol
84+
var idx = 0
85+
val it = tree.cases.iterator
86+
while it.hasNext do
87+
val caseDef = it.next
88+
caseDef.pat match
89+
case lit: Literal =>
90+
val ord = Literal(Constant(idx))
91+
litBuilder += ((lit, ord))
92+
caseBuilder += CaseDef(ord, EmptyTree, caseDef.body)
93+
idx += 1
94+
case defaultIdent: Ident =>
95+
assert(default eq null, s"multiple default targets in a Match node, at ${tree.span}")
96+
default = CaseDef(Underscore(defn.IntType), EmptyTree, caseDef.body)
97+
case Alternative(alts) =>
98+
val indices = alts.indices
99+
val ords = indices.map(i => Literal(Constant(i + idx)))
100+
alts.lazyZip(ords) foreach {
101+
case pair @ (_: Literal, _) =>
102+
litBuilder += pair.asInstanceOf[(Literal, Literal)]
103+
case _ =>
104+
abort(s"Invalid alternative in alternative pattern in Match node: $tree at: ${tree.span}")
105+
}
106+
caseBuilder += CaseDef(Alternative(ords.toList), EmptyTree, caseDef.body)
107+
idx += indices.size
108+
case _ =>
109+
end while
110+
if default ne null then
111+
caseBuilder += default
112+
end extractMatch
113+
tree.expr match
114+
case Block(stats, tree: Match) =>
115+
inits = stats
116+
extractMatch(tree)
117+
case tree: Match => extractMatch(tree)
118+
(inits, ref(selector), litBuilder.result, caseBuilder.toList)
119+
end extract
120+
val (inits, selectorRef, lits, ordinalCases) = extract(tree)
121+
val labelOwner = tree.bind.symbol.owner
122+
val ordinal = UniqueName.fresh(nme.ordinal)
123+
val ordinalSym = newSymbol(labelOwner, ordinal, flg.Synthetic | flg.Local | flg.Mutable, defn.IntType)
124+
val ordinalRef = ref(ordinalSym)
125+
val ordinalDef = ValDef(ordinalSym, Literal(Constant(-1)))
126+
val ordinalLabel =
127+
newSymbol(labelOwner, PatMatResultName.fresh(), flg.Synthetic | flg.Label, defn.UnitType)
128+
val hash = PatMatStdBinderName.fresh()
129+
val hashSym = newSymbol(labelOwner, hash, flg.Synthetic | flg.Local | flg.Case, defn.IntType)
130+
val hashDef = ValDef(hashSym, If(
131+
cond = selectorRef.select(nme.eq).appliedTo(tpd.nullLiteral),
132+
thenp = Literal(Constant(0)),
133+
elsep = selectorRef.select(nme.hashCode_).ensureApplied))
134+
val hashCases = lits.groupBy(_._1.const.stringValue.hashCode).toList.sortBy(_._1).map((hash, pairing) =>
135+
CaseDef(Literal(Constant(hash)), EmptyTree, {
136+
val cases = pairing.toList.foldRight(tpd.unitLiteral: Tree) { (p, acc) =>
137+
val (str, idx) = p
138+
If(str.select(nme.equals_).appliedTo(selectorRef), Assign(ordinalRef, idx), acc)
139+
}
140+
Return(cases, ordinalLabel)
141+
})
142+
) ::: CaseDef(Underscore(defn.IntType), EmptyTree, Return(tpd.unitLiteral, ordinalLabel)) :: Nil
143+
val hashBlock = Labeled(Bind(ordinalLabel, EmptyTree), Block(
144+
hashDef :: Nil,
145+
Match(ref(hashSym), hashCases)
146+
))
147+
val ordinalScrut = PatMatStdBinderName.fresh()
148+
val ordinalScrutSym = newSymbol(labelOwner, ordinalScrut, flg.Synthetic | flg.Local | flg.Case, defn.IntType)
149+
val ordinalScrutDef = ValDef(ordinalScrutSym, ref(ordinalSym))
150+
val ordinalBlock = Labeled(tree.bind, Block(
151+
ordinalScrutDef :: Nil,
152+
Match(ref(ordinalScrutSym), ordinalCases)
153+
))
154+
val res = Block(
155+
inits :::
156+
ordinalDef ::
157+
hashBlock :: Nil,
158+
ordinalBlock
159+
)
160+
// report.echo(i"check res: $res")
161+
res
162+
end transformStringMatchLabelled
163+
70164
/*
71165
* Emits code that adds nothing to the operand stack.
72166
* Two main cases: `tree` is an assignment,
@@ -276,6 +370,25 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
276370

277371
lineNumber(tree)
278372

373+
def isStringMatchBlock(tree: Tree) =
374+
def isStringMatch(m: Match) = m.cases.exists { _.pat match {
375+
case Literal(Constant(_: String)) => true
376+
case Alternative(alts) => alts.exists {
377+
case Literal(Constant(_: String)) => true
378+
case _ => false
379+
}
380+
case _ => false
381+
}}
382+
tree match
383+
case Block(stats, expr) =>
384+
(expr :: stats).exists {
385+
case m: Match => isStringMatch(m)
386+
case _ => false
387+
}
388+
case m: Match => isStringMatch(m)
389+
case _ => false
390+
end isStringMatchBlock
391+
279392
tree match {
280393
case ValDef(nme.THIS, _, _) =>
281394
report.debuglog("skipping trivial assign to _$this: " + tree)
@@ -298,6 +411,9 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
298411
case t @ If(_, _, _) =>
299412
generatedType = genLoadIf(t, expectedType)
300413

414+
case t @ Labeled(label, expr) if (label.name.is(PatMatResultName) || label.name.is(PatMatAltsName)) && isStringMatchBlock(expr) =>
415+
genLoad(transformStringMatchLabelled(t))
416+
301417
case t @ Labeled(_, _) =>
302418
generatedType = genLabeled(t)
303419

@@ -549,7 +665,6 @@ trait BCodeBodyBuilder extends BCodeSkelBuilder {
549665

550666
private def genLabeled(tree: Labeled): BType = tree match {
551667
case Labeled(bind, expr) =>
552-
553668
val resKind = tpeTK(tree)
554669
genLoad(expr, resKind)
555670
markProgramPoint(programPoint(bind.symbol))

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -759,13 +759,18 @@ object PatternMatcher {
759759
(tpe isRef defn.IntClass) ||
760760
(tpe isRef defn.ByteClass) ||
761761
(tpe isRef defn.ShortClass) ||
762-
(tpe isRef defn.CharClass)
762+
(tpe isRef defn.CharClass) ||
763+
(tpe isRef defn.StringClass)
763764

764-
val seen = mutable.Set[Int]()
765+
val seenInt = mutable.Set[Int]()
766+
val seenString = mutable.Set[String]()
765767

766-
def isNewIntConst(tree: Tree) = tree match {
767-
case Literal(const) if const.isIntRange && !seen.contains(const.intValue) =>
768-
seen += const.intValue
768+
def isNewConst(tree: Tree) = tree match {
769+
case Literal(const) if const.isIntRange && !seenInt.contains(const.intValue) =>
770+
seenInt += const.intValue
771+
true
772+
case Literal(Constant(str: String)) if !seenString.contains(str) =>
773+
seenString += str
769774
true
770775
case _ =>
771776
false
@@ -780,7 +785,7 @@ object PatternMatcher {
780785
val alts = List.newBuilder[Tree]
781786
def rec(innerPlan: Plan): Boolean = innerPlan match {
782787
case SeqPlan(TestPlan(EqualTest(tree), scrut, _, ReturnPlan(`innerLabel`)), tail)
783-
if scrut === scrutinee && isNewIntConst(tree) =>
788+
if scrut === scrutinee && isNewConst(tree) =>
784789
alts += tree
785790
rec(tail)
786791
case ReturnPlan(`outerLabel`) =>
@@ -800,7 +805,7 @@ object PatternMatcher {
800805

801806
def recur(plan: Plan): List[(List[Tree], Plan)] = plan match {
802807
case SeqPlan(testPlan @ TestPlan(EqualTest(tree), scrut, _, ons), tail)
803-
if scrut === scrutinee && !canFallThrough(ons) && isNewIntConst(tree) =>
808+
if scrut === scrutinee && !canFallThrough(ons) && isNewConst(tree) =>
804809
(tree :: Nil, ons) :: recur(tail)
805810
case SeqPlan(AlternativesPlan(alts, ons), tail) =>
806811
(alts, ons) :: recur(tail)
@@ -823,29 +828,30 @@ object PatternMatcher {
823828

824829
/** Emit a switch-match */
825830
private def emitSwitchMatch(scrutinee: Tree, cases: List[(List[Tree], Plan)]): Match = {
826-
/* Make sure to adapt the scrutinee to Int, as well as all the alternatives
827-
* of all cases, so that only Matches on pritimive Ints survive this phase.
831+
/* Make sure to adapt the scrutinee to Int or String, as well as all the alternatives
832+
* of all cases, so that only Matches on pritimive Ints or String literals survive this phase.
828833
*/
829834

830-
val intScrutinee =
831-
if (scrutinee.tpe.widen.isRef(defn.IntClass)) scrutinee
835+
val scrutinee1 =
836+
if (scrutinee.tpe.widen.isRef(defn.StringClass)) scrutinee
837+
else if (scrutinee.tpe.widen.isRef(defn.IntClass)) scrutinee
832838
else scrutinee.select(nme.toInt)
833839

834-
def intLiteral(lit: Tree): Tree =
840+
def literal(lit: Tree): Tree =
835841
val Literal(constant) = lit
836-
if (constant.tag == Constants.IntTag) lit
842+
if constant.tag == Constants.IntTag || constant.tag == Constants.StringTag then lit
837843
else cpy.Literal(lit)(Constant(constant.intValue))
838844

839845
val caseDefs = cases.map { (alts, ons) =>
840846
val pat = alts match {
841-
case alt :: Nil => intLiteral(alt)
847+
case alt :: Nil => literal(alt)
842848
case Nil => Underscore(defn.IntType) // default case
843-
case _ => Alternative(alts.map(intLiteral))
849+
case _ => Alternative(alts.map(literal))
844850
}
845851
CaseDef(pat, EmptyTree, emit(ons))
846852
}
847853

848-
Match(intScrutinee, caseDefs)
854+
Match(scrutinee1, caseDefs)
849855
}
850856

851857
/** If selfCheck is `true`, used to check whether a tree gets generated twice */

0 commit comments

Comments
 (0)