diff --git a/compiler/src/dotty/tools/dotc/transform/patmat/Space.scala b/compiler/src/dotty/tools/dotc/transform/patmat/Space.scala index 4810929822ea..0b7d063cdc9d 100644 --- a/compiler/src/dotty/tools/dotc/transform/patmat/Space.scala +++ b/compiler/src/dotty/tools/dotc/transform/patmat/Space.scala @@ -13,6 +13,8 @@ import Symbols._ import StdNames._ import NameOps._ import Constants._ +import typer.Applications._ +import transform.SymUtils._ import reporting.diagnostic.messages._ import config.Printers.{ exhaustivity => debug } @@ -28,9 +30,7 @@ import config.Printers.{ exhaustivity => debug } * 1. `Empty` is a space * 2. For a type T, `Typ(T)` is a space * 3. A union of spaces `S1 | S2 | ...` is a space - * 4. For a case class Kon(x1: T1, x2: T2, .., xn: Tn), if S1, S2, ..., Sn - * are spaces, then `Kon(S1, S2, ..., Sn)` is a space. - * 5. `Fun(S1, S2, ..., Sn)` is an extractor space. + * 4. `Prod(S1, S2, ..., Sn)` is a product space. * * For the problem of exhaustivity check, its formulation in terms of space is as follows: * @@ -62,11 +62,8 @@ case object Empty extends Space */ case class Typ(tp: Type, decomposed: Boolean) extends Space -/** Space representing a constructor pattern */ -case class Kon(tp: Type, params: List[Space]) extends Space - /** Space representing an extractor pattern */ -case class Fun(tp: Type, fun: Type, params: List[Space]) extends Space +case class Prod(tp: Type, unappTp: Type, unappSym: Symbol, params: List[Space], full: Boolean) extends Space /** Union of spaces */ case class Or(spaces: List[Space]) extends Space @@ -93,8 +90,8 @@ trait SpaceLogic { */ def canDecompose(tp: Type): Boolean - /** Return term parameter types of the case class `tp` */ - def signature(tp: Type): List[Type] + /** Return term parameter types of the extractor `unapp` */ + def signature(unapp: Type, unappSym: Symbol, argLen: Int): List[Type] /** Get components of decomposable types */ def decompose(tp: Type): List[Space] @@ -110,12 +107,8 @@ trait SpaceLogic { * This reduces noise in counterexamples. */ def simplify(space: Space, aggressive: Boolean = false): Space = space match { - case Kon(tp, spaces) => - val sp = Kon(tp, spaces.map(simplify(_))) - if (sp.params.contains(Empty)) Empty - else sp - case Fun(tp, fun, spaces) => - val sp = Fun(tp, fun, spaces.map(simplify(_))) + case Prod(tp, fun, sym, spaces, full) => + val sp = Prod(tp, fun, sym, spaces.map(simplify(_)), full) if (sp.params.contains(Empty)) Empty else sp case Or(spaces) => @@ -143,12 +136,14 @@ trait SpaceLogic { /** Flatten space to get rid of `Or` for pretty print */ def flatten(space: Space): List[Space] = space match { - case Kon(tp, spaces) => - val flats = spaces.map(flatten _) - - flats.foldLeft(List[Kon]()) { (acc, flat) => - if (acc.isEmpty) flat.map(s => Kon(tp, Nil :+ s)) - else for (Kon(tp, ss) <- acc; s <- flat) yield Kon(tp, ss :+ s) + case Prod(tp, fun, sym, spaces, full) => + spaces.map(flatten) match { + case Nil => Prod(tp, fun, sym, Nil, full) :: Nil + case ss => + ss.foldLeft(List[Prod]()) { (acc, flat) => + if (acc.isEmpty) flat.map(s => Prod(tp, fun, sym, s :: Nil, full)) + else for (Prod(tp, fun, sym, ss, full) <- acc; s <- flat) yield Prod(tp, fun, sym, ss :+ s, full) + } } case Or(spaces) => spaces.flatMap(flatten _) @@ -171,22 +166,13 @@ trait SpaceLogic { ss.exists(isSubspace(a, _)) || tryDecompose1(tp1) case (_, Or(_)) => simplify(minus(a, b)) == Empty - case (Typ(tp1, _), Kon(tp2, ss)) => - isSubType(tp1, tp2) && isSubspace(Kon(tp2, signature(tp2).map(Typ(_, false))), b) - case (Kon(tp1, ss), Typ(tp2, _)) => - isSubType(tp1, tp2) - case (Kon(tp1, ss1), Kon(tp2, ss2)) => - isEqualType(tp1, tp2) && ss1.zip(ss2).forall((isSubspace _).tupled) - case (Fun(tp1, fun, ss), Typ(tp2, _)) => + case (Prod(tp1, _, _, _, _), Typ(tp2, _)) => isSubType(tp1, tp2) - case (Typ(tp2, _), Fun(tp1, fun, ss)) => - false // approximation: assume a type can never be fully matched by an extractor - case (Kon(_, _), Fun(_, _, _)) => - false // approximation - case (Fun(_, _, _), Kon(_, _)) => - false // approximation - case (Fun(_, fun1, ss1), Fun(_, fun2, ss2)) => - isEqualType(fun1, fun2) && ss1.zip(ss2).forall((isSubspace _).tupled) + case (Typ(tp1, _), Prod(tp2, fun, sym, ss, full)) => + // approximation: a type can never be fully matched by a partial extractor + full && isSubType(tp1, tp2) && isSubspace(Prod(tp2, fun, sym, signature(fun, sym, ss.length).map(Typ(_, false)), full), b) + case (Prod(_, fun1, sym1, ss1, _), Prod(_, fun2, sym2, ss2, _)) => + sym1 == sym2 && isEqualType(fun1, fun2) && ss1.zip(ss2).forall((isSubspace _).tupled) } debug.println(s"${show(a)} < ${show(b)} = $res") @@ -209,38 +195,28 @@ trait SpaceLogic { else if (canDecompose(tp1)) tryDecompose1(tp1) else if (canDecompose(tp2)) tryDecompose2(tp2) else intersectUnrelatedAtomicTypes(tp1, tp2) - case (Typ(tp1, _), Kon(tp2, ss)) => + case (Typ(tp1, _), Prod(tp2, fun, _, ss, true)) => if (isSubType(tp2, tp1)) b else if (isSubType(tp1, tp2)) a // problematic corner case: inheriting a case class else if (canDecompose(tp1)) tryDecompose1(tp1) else Empty - case (Kon(tp1, ss), Typ(tp2, _)) => - if (isSubType(tp1, tp2)) a - else if (isSubType(tp2, tp1)) a // problematic corner case: inheriting a case class - else if (canDecompose(tp2)) tryDecompose2(tp2) - else Empty - case (Kon(tp1, ss1), Kon(tp2, ss2)) => - if (!isEqualType(tp1, tp2)) Empty - else if (ss1.zip(ss2).exists(p => simplify(intersect(p._1, p._2)) == Empty)) Empty - else Kon(tp1, ss1.zip(ss2).map((intersect _).tupled)) - case (Typ(tp1, _), Fun(tp2, _, _)) => + case (Typ(tp1, _), Prod(tp2, _, _, _, false)) => if (isSubType(tp1, tp2) || isSubType(tp2, tp1)) b // prefer extractor space for better approximation else if (canDecompose(tp1)) tryDecompose1(tp1) else Empty - case (Fun(tp1, _, _), Typ(tp2, _)) => - if (isSubType(tp1, tp2) || isSubType(tp2, tp1)) a + case (Prod(tp1, fun, _, ss, true), Typ(tp2, _)) => + if (isSubType(tp1, tp2)) a + else if (isSubType(tp2, tp1)) a // problematic corner case: inheriting a case class else if (canDecompose(tp2)) tryDecompose2(tp2) else Empty - case (Fun(tp1, _, _), Kon(tp2, _)) => + case (Prod(tp1, _, _, _, false), Typ(tp2, _)) => if (isSubType(tp1, tp2) || isSubType(tp2, tp1)) a + else if (canDecompose(tp2)) tryDecompose2(tp2) else Empty - case (Kon(tp1, _), Fun(tp2, _, _)) => - if (isSubType(tp1, tp2) || isSubType(tp2, tp1)) b - else Empty - case (Fun(tp1, fun1, ss1), Fun(tp2, fun2, ss2)) => - if (!isEqualType(fun1, fun2)) Empty + case (Prod(tp1, fun1, sym1, ss1, full), Prod(tp2, fun2, sym2, ss2, _)) => + if (sym1 != sym2 || !isEqualType(fun1, fun2)) Empty else if (ss1.zip(ss2).exists(p => simplify(intersect(p._1, p._2)) == Empty)) Empty - else Fun(tp1, fun1, ss1.zip(ss2).map((intersect _).tupled)) + else Prod(tp1, fun1, sym1, ss1.zip(ss2).map((intersect _).tupled), full) } debug.println(s"${show(a)} & ${show(b)} = ${show(res)}") @@ -261,48 +237,34 @@ trait SpaceLogic { else if (canDecompose(tp1)) tryDecompose1(tp1) else if (canDecompose(tp2)) tryDecompose2(tp2) else a - case (Typ(tp1, _), Kon(tp2, ss)) => - // corner case: inheriting a case class + case (Typ(tp1, _), Prod(tp2, fun, sym, ss, true)) => // rationale: every instance of `tp1` is covered by `tp2(_)` - if (isSubType(tp1, tp2)) minus(Kon(tp2, signature(tp2).map(Typ(_, false))), b) + if (isSubType(tp1, tp2)) minus(Prod(tp2, fun, sym, signature(fun, sym, ss.length).map(Typ(_, false)), true), b) else if (canDecompose(tp1)) tryDecompose1(tp1) else a case (_, Or(ss)) => ss.foldLeft(a)(minus) case (Or(ss), _) => Or(ss.map(minus(_, b))) - case (Kon(tp1, ss), Typ(tp2, _)) => + case (Prod(tp1, fun, _, ss, true), Typ(tp2, _)) => // uncovered corner case: tp2 :< tp1 if (isSubType(tp1, tp2)) Empty else if (simplify(a) == Empty) Empty else if (canDecompose(tp2)) tryDecompose2(tp2) else a - case (Kon(tp1, ss1), Kon(tp2, ss2)) => - if (!isEqualType(tp1, tp2)) a - else if (ss1.zip(ss2).exists(p => simplify(intersect(p._1, p._2)) == Empty)) a - else if (ss1.zip(ss2).forall((isSubspace _).tupled)) Empty - else - // `(_, _, _) - (Some, None, _)` becomes `(None, _, _) | (_, Some, _) | (_, _, Empty)` - Or(ss1.zip(ss2).map((minus _).tupled).zip(0 to ss2.length - 1).map { - case (ri, i) => Kon(tp1, ss1.updated(i, ri)) - }) - case (Fun(tp1, _, _), Typ(tp2, _)) => + case (Prod(tp1, _, _, _, false), Typ(tp2, _)) => if (isSubType(tp1, tp2)) Empty else a - case (Typ(tp1, _), Fun(tp2, _, _)) => + case (Typ(tp1, _), Prod(tp2, _, _, _, false)) => a // approximation - case (Fun(_, _, _), Kon(_, _)) => - a - case (Kon(_, _), Fun(_, _, _)) => - a - case (Fun(tp1, fun1, ss1), Fun(tp2, fun2, ss2)) => - if (!isEqualType(fun1, fun2)) a + case (Prod(tp1, fun1, sym1, ss1, full), Prod(tp2, fun2, sym2, ss2, _)) => + if (sym1 != sym2 || !isEqualType(fun1, fun2)) a else if (ss1.zip(ss2).exists(p => simplify(intersect(p._1, p._2)) == Empty)) a else if (ss1.zip(ss2).forall((isSubspace _).tupled)) Empty else // `(_, _, _) - (Some, None, _)` becomes `(None, _, _) | (_, Some, _) | (_, _, Empty)` Or(ss1.zip(ss2).map((minus _).tupled).zip(0 to ss2.length - 1).map { - case (ri, i) => Fun(tp1, fun1, ss1.updated(i, ri)) + case (ri, i) => Prod(tp1, fun1, sym1, ss1.updated(i, ri), full) }) } @@ -409,6 +371,13 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic { } } + /* Whether the extractor is irrefutable */ + def irrefutable(unapp: tpd.Tree): Boolean = { + // TODO: optionless patmat + unapp.tpe.widen.resultType.isRef(scalaSomeClass) || + (unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) + } + /** Return the space that represents the pattern `pat` */ def project(pat: Tree): Space = pat match { @@ -423,24 +392,24 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic { case Alternative(trees) => Or(trees.map(project(_))) case Bind(_, pat) => project(pat) case UnApply(fun, _, pats) => - if (pat.tpe.classSymbol.is(CaseClass)) - Kon(pat.tpe.stripAnnots, pats.map(pat => project(pat))) - else if (fun.symbol.owner == scalaSeqFactoryClass && fun.symbol.name == nme.unapplySeq) - projectList(pats) - else if (fun.symbol.info.finalResultType.isRef(scalaSomeClass)) - Kon(pat.tpe.stripAnnots, pats.map(pat => project(pat))) + if (fun.symbol.name == nme.unapplySeq) + if (fun.symbol.owner == scalaSeqFactoryClass) + projectSeq(pats) + else + Prod(pat.tpe.stripAnnots, fun.tpe.widen, fun.symbol, projectSeq(pats) :: Nil, irrefutable(fun)) else - Fun(pat.tpe.stripAnnots, fun.tpe, pats.map(pat => project(pat))) + Prod(pat.tpe.stripAnnots, fun.tpe.widen, fun.symbol, pats.map(project), irrefutable(fun)) case Typed(pat @ UnApply(_, _, _), _) => project(pat) case Typed(expr, _) => Typ(expr.tpe.stripAnnots, true) case _ => + debug.println(s"unknown pattern: $pat") Empty } - /** Space of the pattern: List(a, b, c: _*) + /** Space of the pattern: unapplySeq(a, b, c: _*) */ - def projectList(pats: List[Tree]): Space = { + def projectSeq(pats: List[Tree]): Space = { if (pats.isEmpty) return Typ(scalaNilType, false) val (items, zero) = if (pats.last.tpe.isRepeatedParam) @@ -449,7 +418,10 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic { (pats, Typ(scalaNilType, false)) items.foldRight[Space](zero) { (pat, acc) => - Kon(scalaConsType.appliedTo(pats.head.tpe.widen), project(pat) :: acc :: Nil) + val consTp = scalaConsType.appliedTo(pats.head.tpe.widen) + val unapplySym = consTp.classSymbol.linkedClass.info.member(nme.unapply).symbol + val unapplyTp = unapplySym.info.appliedTo(pats.head.tpe.widen) + Prod(consTp, unapplyTp, unapplySym, project(pat) :: acc :: Nil, true) } } @@ -494,18 +466,33 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic { def isEqualType(tp1: Type, tp2: Type): Boolean = tp1 =:= tp2 - /** Parameter types of the case class type `tp` */ - def signature(tp: Type): List[Type] = { - val ktor = tp.classSymbol.primaryConstructor.info - - val meth = ktor match { - case ktor: PolyType => - ktor.instantiate(tp.classSymbol.typeParams.map(_.typeRef)).asSeenFrom(tp, tp.classSymbol) - case _ => ktor + /** Parameter types of the case class type `tp`. Adapted from `unapplyPlan` in patternMatcher */ + def signature(unapp: Type, unappSym: Symbol, argLen: Int): List[Type] = { + def caseClass = unappSym.owner.linkedClass + lazy val caseAccessors = caseClass.caseAccessors.filter(_.is(Method)) + + def isSyntheticScala2Unapply(sym: Symbol) = + sym.is(SyntheticCase) && sym.owner.is(Scala2x) + + val mt @ MethodType(_) = unapp.widen + + if (isSyntheticScala2Unapply(unappSym) && caseAccessors.length == argLen) + caseAccessors.map(_.info.asSeenFrom(mt.paramInfos.head, caseClass).widen) + else if (mt.resultType.isRef(defn.BooleanClass)) + List() + else { + val isUnapplySeq = unappSym.name == nme.unapplySeq + if (isProductMatch(mt.resultType, argLen) && !isUnapplySeq) { + productSelectors(mt.resultType).take(argLen) + .map(_.info.asSeenFrom(mt.resultType, mt.resultType.classSymbol).widen) + } + else { + val resTp = mt.resultType.select(nme.get).resultType.widen + if (isUnapplySeq) scalaListType.appliedTo(resTp.argTypes.head) :: Nil + else if (argLen == 0) Nil + else productSelectors(resTp).map(_.info.asSeenFrom(resTp, resTp.classSymbol).widen) + } } - - // refine path-dependent type in params. refer to t9672 - meth.firstParamTypes.map(_.asSeenFrom(tp, tp.classSymbol)) } /** Decompose a type into subspaces -- assume the type can be decomposed */ @@ -641,6 +628,7 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic { /** Display spaces */ def show(s: Space): String = { + def params(tp: Type): List[Type] = tp.classSymbol.primaryConstructor.info.firstParamTypes /** does the companion object of the given symbol have custom unapply */ def hasCustomUnapply(sym: Symbol): Boolean = { @@ -657,17 +645,17 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic { val sym = tp.widen.classSymbol if (ctx.definitions.isTupleType(tp)) - signature(tp).map(_ => "_").mkString("(", ", ", ")") + params(tp).map(_ => "_").mkString("(", ", ", ")") else if (scalaListType.isRef(sym)) - if (mergeList) "_*" else "_: List" + if (mergeList) "_: _*" else "_: List" else if (scalaConsType.isRef(sym)) - if (mergeList) "_, _*" else "List(_, _*)" + if (mergeList) "_, _: _*" else "List(_, _: _*)" else if (tp.classSymbol.is(CaseClass) && !hasCustomUnapply(tp.classSymbol)) // use constructor syntax for case class - showType(tp) + signature(tp).map(_ => "_").mkString("(", ", ", ")") + showType(tp) + params(tp).map(_ => "_").mkString("(", ", ", ")") else if (decomposed) "_: " + showType(tp) else "_" - case Kon(tp, params) => + case Prod(tp, fun, _, params, _) => if (ctx.definitions.isTupleType(tp)) "(" + params.map(doShow(_)).mkString(", ") + ")" else if (tp.isRef(scalaConsType.symbol)) @@ -675,8 +663,6 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic { else params.map(doShow(_, true)).filter(_ != "Nil").mkString("List(", ", ", ")") else showType(tp) + params.map(doShow(_)).mkString("(", ", ", ")") - case Fun(tp, fun, params) => - showType(fun) + params.map(doShow(_)).mkString("(", ", ", ")") case Or(_) => throw new Exception("incorrect flatten result " + s) } diff --git a/tests/patmat/exhausting.check b/tests/patmat/exhausting.check index 51d2272add39..d17c5f36bc80 100644 --- a/tests/patmat/exhausting.check +++ b/tests/patmat/exhausting.check @@ -1,6 +1,6 @@ -21: Pattern Match Exhaustivity: List(_), List(_, _, _, _*) +21: Pattern Match Exhaustivity: List(_), List(_, _, _, _: _*) 27: Pattern Match Exhaustivity: Nil -32: Pattern Match Exhaustivity: List(_, _*) +32: Pattern Match Exhaustivity: List(_, _: _*) 39: Pattern Match Exhaustivity: Bar3 44: Pattern Match Exhaustivity: (Bar2, Bar2) 53: Pattern Match Exhaustivity: (Bar2, Bar2), (Bar2, Bar1), (Bar1, Bar3), (Bar1, Bar2) diff --git a/tests/patmat/i2363.check b/tests/patmat/i2363.check index 482144a87e72..5a6fdcbd378a 100644 --- a/tests/patmat/i2363.check +++ b/tests/patmat/i2363.check @@ -1,2 +1,2 @@ -15: Pattern Match Exhaustivity: List(_, _*) +15: Pattern Match Exhaustivity: List(_, _: _*) 21: Pattern Match Exhaustivity: _: Expr \ No newline at end of file diff --git a/tests/patmat/i3004.scala b/tests/patmat/i3004.scala new file mode 100644 index 000000000000..a0c95688077f --- /dev/null +++ b/tests/patmat/i3004.scala @@ -0,0 +1,12 @@ +object O { + sealed trait Fruit + object Apple extends Fruit + object Banana extends Fruit + sealed class C(f1: Fruit, f2: Fruit) + + object C { + def unapply(c: C): Some[Banana.type] = Some(Banana) + } + + def m(c: C) = c match { case C(b) => b } +} diff --git a/tests/patmat/t4408.check b/tests/patmat/t4408.check index c8f296e85367..3936743b977c 100644 --- a/tests/patmat/t4408.check +++ b/tests/patmat/t4408.check @@ -1 +1 @@ -2: Pattern Match Exhaustivity: List(_, _, _, _*) +2: Pattern Match Exhaustivity: List(_, _, _, _: _*) diff --git a/tests/patmat/t5440.check b/tests/patmat/t5440.check index 7c04baaa0d81..a72052e635c6 100644 --- a/tests/patmat/t5440.check +++ b/tests/patmat/t5440.check @@ -1 +1 @@ -2: Pattern Match Exhaustivity: (Nil, List(_, _*)), (List(_, _*), Nil) +2: Pattern Match Exhaustivity: (Nil, List(_, _: _*)), (List(_, _: _*), Nil) diff --git a/tests/patmat/t6420.check b/tests/patmat/t6420.check index a96c7b64bd0e..cacdd5c8c92e 100644 --- a/tests/patmat/t6420.check +++ b/tests/patmat/t6420.check @@ -1 +1 @@ -5: Pattern Match Exhaustivity: (_: List, Nil), (_: List, List(true, _*)), (_: List, List(false, _*)) +5: Pattern Match Exhaustivity: (_: List, Nil), (_: List, List(true, _: _*)), (_: List, List(false, _: _*)) diff --git a/tests/patmat/t7020.check b/tests/patmat/t7020.check index 410b921957fe..4a9c8c9baaaf 100644 --- a/tests/patmat/t7020.check +++ b/tests/patmat/t7020.check @@ -1,4 +1,4 @@ -3: Pattern Match Exhaustivity: List(_, _*) -10: Pattern Match Exhaustivity: List(_, _*) -17: Pattern Match Exhaustivity: List(_, _*) -24: Pattern Match Exhaustivity: List(_, _*) +3: Pattern Match Exhaustivity: List(_, _: _*) +10: Pattern Match Exhaustivity: List(_, _: _*) +17: Pattern Match Exhaustivity: List(_, _: _*) +24: Pattern Match Exhaustivity: List(_, _: _*) diff --git a/tests/patmat/t9232.check b/tests/patmat/t9232.check index e2372315727f..fdf9df06c5f3 100644 --- a/tests/patmat/t9232.check +++ b/tests/patmat/t9232.check @@ -1 +1,3 @@ -13: Pattern Match Exhaustivity: Node2(), Node1(Foo(_)) +13: Pattern Match Exhaustivity: Node2() +17: Pattern Match Exhaustivity: Node2(), Node1(Foo(Nil)), Node1(Foo(List(_, _: _*))) +21: Pattern Match Exhaustivity: Node2(), Node1(Foo(Nil)) diff --git a/tests/patmat/t9232.scala b/tests/patmat/t9232.scala index 975ec58db823..2ca677b81f52 100644 --- a/tests/patmat/t9232.scala +++ b/tests/patmat/t9232.scala @@ -11,6 +11,14 @@ case class Node2() extends Tree object Test { def transformTree(tree: Tree): Any = tree match { - case Node1(Foo(1)) => ??? + case Node1(Foo(_: _*)) => ??? + } + + def transformTree2(tree: Tree): Any = tree match { + case Node1(Foo(1, _: _*)) => ??? + } + + def transformTree3(tree: Tree): Any = tree match { + case Node1(Foo(x, _: _*)) => ??? } }