Skip to content

Commit b3e2a91

Browse files
committed
wip merge tests
1 parent 9afb336 commit b3e2a91

File tree

1 file changed

+118
-3
lines changed

1 file changed

+118
-3
lines changed

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ object PatternMatcher {
103103
LabeledPlan(label, expr(ReturnPlan(label)), next)
104104
}
105105

106+
/** The plan `let l = labelled in body(l)` where `l` is a fresh label */
107+
private def labeledAbstract2(next: Plan)(expr: TermSymbol => Plan): Plan = {
108+
val label = ctx.newSymbol(ctx.owner, PatMatCaseName.fresh(), Synthetic | Label,
109+
defn.UnitType)
110+
LabeledPlan(label, expr(label), next)
111+
}
112+
106113
/** Test whether a type refers to a pattern-generated variable */
107114
private val refersToInternal = new TypeAccumulator[Boolean] {
108115
def apply(x: Boolean, tp: Type) =
@@ -141,7 +148,7 @@ object PatternMatcher {
141148

142149
case class LetPlan(sym: TermSymbol, var body: Plan) extends Plan
143150
case class LabeledPlan(sym: TermSymbol, var expr: Plan, var next: Plan) extends Plan
144-
case class ReturnPlan(label: TermSymbol) extends Plan
151+
case class ReturnPlan(var label: TermSymbol) extends Plan
145152
case class ResultPlan(var tree: Tree) extends Plan
146153

147154
object TestPlan {
@@ -459,6 +466,111 @@ object PatternMatcher {
459466
refCounter.count
460467
}
461468

469+
/** Merge identical tests from consecutive cases.
470+
*
471+
* When we have the following shape:
472+
*
473+
* caseM: {
474+
* if (testA) plan1 else plan2
475+
* }
476+
* caseN: {
477+
* if (testA) plan3 else plan4
478+
* }
479+
* nextPlan
480+
*
481+
* transform it to
482+
*
483+
* caseN: {
484+
* if (testA) {
485+
* case M: {
486+
* plan1
487+
* }
488+
* plan3
489+
* } else {
490+
* case M2: {
491+
* plan2[caseM2/caseM]
492+
* }
493+
* plan4
494+
* }
495+
* }
496+
* nextPlan
497+
*
498+
* where plan2[caseM2/caseM] means substituting caseM2 for caseM in plan2.
499+
*
500+
* We use some tricks to identify a let pointing to an unapply and the
501+
* NonEmptyTest that follows it as a single `UnappTest` test.
502+
*/
503+
def mergeTests(plan: Plan): Plan = {
504+
def isUnapply(sym: Symbol) = sym.name == nme.unapply || sym.name == nme.unapplySeq
505+
506+
/** A locally used test value that represents combos of
507+
*
508+
* let x = X.unapply(...) in if !x.isEmpty then ... else ...
509+
*/
510+
case object UnappTest extends Test
511+
512+
/** If `plan` is the NonEmptyTest part of an unapply, the corresponding UnappTest
513+
* otherwise the original plan
514+
*/
515+
def normalize(plan: TestPlan): TestPlan = plan.scrutinee match {
516+
case id: Ident
517+
if plan.test == NonEmptyTest &&
518+
isPatmatGenerated(id.symbol) &&
519+
isUnapply(initializer(id.symbol).symbol) =>
520+
TestPlan(UnappTest, initializer(id.symbol), plan.pos, plan.onSuccess, plan.onFailure)
521+
case _ =>
522+
plan
523+
}
524+
525+
/** Extractor for Let/NonEmptyTest combos that represent unapplies */
526+
object UnappTestPlan {
527+
def unapply(plan: Plan): Option[TestPlan] = plan match {
528+
case LetPlan(sym, body: TestPlan) =>
529+
val RHS = initializer(sym)
530+
normalize(body) match {
531+
case normPlan @ TestPlan(UnappTest, RHS, _, _, _) => Some(normPlan)
532+
case _ => None
533+
}
534+
case _ => None
535+
}
536+
}
537+
538+
class SubstituteLabel(from: TermSymbol, to: TermSymbol) extends PlanTransform {
539+
override def apply(plan: ReturnPlan): Plan = {
540+
if (plan.label == from)
541+
plan.label = to
542+
plan
543+
}
544+
}
545+
546+
class MergeTests extends PlanTransform {
547+
override def apply(plan: LabeledPlan): Plan = {
548+
plan.next = apply(plan.next)
549+
plan match {
550+
case LabeledPlan(label1, testPlan1: TestPlan, LabeledPlan(label2, testPlan2: TestPlan, nextNext)) =>
551+
val normTestPlan1 = normalize(testPlan1)
552+
val normTestPlan2 = normalize(testPlan2)
553+
if (normTestPlan1 == normTestPlan2) {
554+
val onFailure = labeledAbstract2(testPlan2.onFailure) { label12 =>
555+
new SubstituteLabel(label1, label12)(testPlan1.onFailure)
556+
}
557+
val onSuccess = LabeledPlan(label1, testPlan1.onSuccess, testPlan2.onSuccess)
558+
testPlan1.onSuccess = apply(onSuccess)
559+
testPlan1.onFailure = apply(onFailure)
560+
LabeledPlan(label2, testPlan1, nextNext)
561+
} else {
562+
plan.expr = apply(plan.expr)
563+
plan
564+
}
565+
case _ =>
566+
plan.expr = apply(plan.expr)
567+
plan
568+
}
569+
}
570+
}
571+
new MergeTests()(plan)
572+
}
573+
462574
/** Eliminate tests that are redundant (known to be true or false).
463575
* Two parts:
464576
*
@@ -949,6 +1061,7 @@ object PatternMatcher {
9491061
}
9501062

9511063
val optimizations: List[(String, Plan => Plan)] = List(
1064+
//"mergeTests" -> mergeTests
9521065
/*
9531066
"hoistLabels" -> hoistLabels,
9541067
"elimRedundantTests" -> elimRedundantTests,
@@ -961,11 +1074,13 @@ object PatternMatcher {
9611074
/** Translate pattern match to sequence of tests. */
9621075
def translateMatch(tree: Match): Tree = {
9631076
var plan = matchPlan(tree)
964-
patmatch.println(i"Plan for $tree: ${show(plan)}")
1077+
//patmatch.println(i"Plan for $tree: ${show(plan)}")
1078+
System.err.println(i"Plan for $tree: ${show(plan)}")
9651079
if (!ctx.settings.YnoPatmatOpt.value)
9661080
for ((title, optimization) <- optimizations) {
9671081
plan = optimization(plan)
968-
patmatch.println(s"After $title: ${show(plan)}")
1082+
//patmatch.println(s"After $title: ${show(plan)}")
1083+
System.err.println(s"After $title: ${show(plan)}")
9691084
}
9701085
val result = emit(plan)
9711086
//checkSwitch(tree, result)

0 commit comments

Comments
 (0)