Skip to content

Commit ccb813a

Browse files
committed
Fix pattern matching alternative/5402
1 parent dab02ed commit ccb813a

File tree

2 files changed

+66
-9
lines changed

2 files changed

+66
-9
lines changed

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -769,16 +769,45 @@ object PatternMatcher {
769769
}
770770

771771
/** Emit cases of a switch */
772-
private def emitSwitchCases(cases: List[(List[Tree], Plan)]): List[CaseDef] = (cases: @unchecked) match {
773-
case (alts, ons) :: cases1 =>
774-
val pat = alts match {
775-
case alt :: Nil => alt
776-
case Nil => Underscore(defn.IntType) // default case
777-
case _ => Alternative(alts)
772+
private def emitSwitchCases(cases: List[(List[Tree], Plan)]): List[CaseDef] = cases.foldLeft((List[CaseDef](), List[Tree]())) {
773+
case ((prev, collected), (alts, ons)) =>
774+
collectCases(collected, alts) match {
775+
case Some((pat, col)) => (CaseDef(pat, EmptyTree, emit(ons)) :: prev, col ::: collected)
776+
case None => (prev, collected)
778777
}
779-
CaseDef(pat, EmptyTree, emit(ons)) :: emitSwitchCases(cases1)
780-
case nil =>
781-
Nil
778+
}._1
779+
780+
/** Flattens the tree of patterns into a tree and collect all the alternative patterns in a list
781+
* returns None if the pattern is redundant
782+
*/
783+
private def collectCases(existingPatterns: List[Tree], alts: List[Tree]): Option[(Tree, List[Tree])] = {
784+
alts match {
785+
case Nil => Some((Underscore(defn.IntType), Nil))
786+
case _ => mapCases(removeRedundantCases(existingPatterns, alts))
787+
}
788+
}
789+
790+
private def mapCases(alts: List[Tree]): Option[(Tree, List[Tree])] = alts match {
791+
case alt :: Nil => Some((alt, alt :: Nil))
792+
case Nil => None
793+
case _ => Some((Alternative(alts), alts))
794+
}
795+
796+
/** Remove cases that already appear in the same pattern or in previous patterns */
797+
private def removeRedundantCases(previousCases: List[Tree], cases: List[Tree]): List[Tree] = cases.foldLeft(List[Tree]()) {
798+
case (cases, alt) =>
799+
if (cases.exists(_ === alt) || previousCases.exists(_ === alt)) {
800+
cases
801+
} else {
802+
alt :: cases
803+
}
804+
}
805+
806+
/** Flatten a list of patterns into a single tree */
807+
private def simplifyCases(alts: List[Tree]): Tree = alts match {
808+
case alt :: Nil => alt
809+
case Nil => Underscore(defn.IntType) // default case
810+
case _ => Alternative(alts)
782811
}
783812

784813
/** If selfCheck is `true`, used to check whether a tree gets generated twice */
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
2+
object Main {
3+
val a: Int = 4
4+
a match {
5+
case 1 => println("1")
6+
case 1 | 2 => println("1 or 2")
7+
}
8+
9+
a match {
10+
case 1 => 1
11+
case 0 | 0 => 0
12+
case 2 | 2 | 2 | 3 | 2 | 3 => 0
13+
case 4 | (_ @ 4) => 0
14+
case _ => -1
15+
}
16+
17+
a match {
18+
case 1 => 1
19+
case 0 | 0 => 0
20+
case 2 | 2 | 2 | 3 | 2 | 3 => 0
21+
case _ => -1
22+
}
23+
24+
a match {
25+
case 0 | 1 => 0
26+
case 1 => 1
27+
}
28+
}

0 commit comments

Comments
 (0)