Skip to content

Commit efe3a1c

Browse files
authored
Merge pull request #12059 from dotty-staging/fix-typetest
Fix TypeTest exhaustivity check
2 parents 4f83ada + ed8c657 commit efe3a1c

File tree

12 files changed

+118
-79
lines changed

12 files changed

+118
-79
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,7 @@ class Definitions {
768768
@tu lazy val ClassTagModule_apply: Symbol = ClassTagModule.requiredMethod(nme.apply)
769769

770770
@tu lazy val TypeTestClass: ClassSymbol = requiredClass("scala.reflect.TypeTest")
771+
@tu lazy val TypeTest_unapply: Symbol = TypeTestClass.requiredMethod(nme.unapply)
771772
@tu lazy val TypeTestModule_identity: Symbol = TypeTestClass.companionModule.requiredMethod(nme.identity)
772773

773774
@tu lazy val QuotedExprClass: ClassSymbol = requiredClass("scala.quoted.Expr")

compiler/src/dotty/tools/dotc/transform/patmat/Space.scala

Lines changed: 75 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import ProtoTypes._
2020
import transform.SymUtils._
2121
import reporting._
2222
import config.Printers.{exhaustivity => debug}
23-
import util.SrcPos
23+
import util.{SrcPos, NoSourcePosition}
2424
import NullOpsDecorator._
2525
import collection.mutable
2626

@@ -69,7 +69,7 @@ case object Empty extends Space
6969
case class Typ(tp: Type, decomposed: Boolean = true) extends Space
7070

7171
/** Space representing an extractor pattern */
72-
case class Prod(tp: Type, unappTp: TermRef, params: List[Space], full: Boolean) extends Space
72+
case class Prod(tp: Type, unappTp: TermRef, params: List[Space]) extends Space
7373

7474
/** Union of spaces */
7575
case class Or(spaces: List[Space]) extends Space
@@ -107,6 +107,9 @@ trait SpaceLogic {
107107
/** Get components of decomposable types */
108108
def decompose(tp: Type): List[Space]
109109

110+
/** Whether the extractor covers the given type */
111+
def covers(unapp: TermRef, scrutineeTp: Type): Boolean
112+
110113
/** Display space in string format */
111114
def show(sp: Space): String
112115

@@ -118,8 +121,8 @@ trait SpaceLogic {
118121
* This reduces noise in counterexamples.
119122
*/
120123
def simplify(space: Space, aggressive: Boolean = false)(using Context): Space = trace(s"simplify ${show(space)}, aggressive = $aggressive --> ", debug, x => show(x.asInstanceOf[Space]))(space match {
121-
case Prod(tp, fun, spaces, full) =>
122-
val sp = Prod(tp, fun, spaces.map(simplify(_)), full)
124+
case Prod(tp, fun, spaces) =>
125+
val sp = Prod(tp, fun, spaces.map(simplify(_)))
123126
if (sp.params.contains(Empty)) Empty
124127
else if (canDecompose(tp) && decompose(tp).isEmpty) Empty
125128
else sp
@@ -151,14 +154,14 @@ trait SpaceLogic {
151154

152155
/** Flatten space to get rid of `Or` for pretty print */
153156
def flatten(space: Space)(using Context): Seq[Space] = space match {
154-
case Prod(tp, fun, spaces, full) =>
157+
case Prod(tp, fun, spaces) =>
155158
val ss = LazyList(spaces: _*).map(flatten)
156159

157160
ss.foldLeft(LazyList(Nil : List[Space])) { (acc, flat) =>
158161
for { sps <- acc; s <- flat }
159162
yield sps :+ s
160163
}.map { sps =>
161-
Prod(tp, fun, sps, full)
164+
Prod(tp, fun, sps)
162165
}
163166

164167
case Or(spaces) =>
@@ -184,12 +187,13 @@ trait SpaceLogic {
184187
ss.exists(isSubspace(a, _)) || tryDecompose1(tp1)
185188
case (_, Or(_)) =>
186189
simplify(minus(a, b)) == Empty
187-
case (Prod(tp1, _, _, _), Typ(tp2, _)) =>
190+
case (Prod(tp1, _, _), Typ(tp2, _)) =>
191+
isSubType(tp1, tp2)
192+
case (Typ(tp1, _), Prod(tp2, fun, ss)) =>
188193
isSubType(tp1, tp2)
189-
case (Typ(tp1, _), Prod(tp2, fun, ss, full)) =>
190-
// approximation: a type can never be fully matched by a partial extractor
191-
full && isSubType(tp1, tp2) && isSubspace(Prod(tp2, fun, signature(fun, tp2, ss.length).map(Typ(_, false)), full), b)
192-
case (Prod(_, fun1, ss1, _), Prod(_, fun2, ss2, _)) =>
194+
&& covers(fun, tp1)
195+
&& isSubspace(Prod(tp2, fun, signature(fun, tp2, ss.length).map(Typ(_, false))), b)
196+
case (Prod(_, fun1, ss1), Prod(_, fun2, ss2)) =>
193197
isSameUnapply(fun1, fun2) && ss1.zip(ss2).forall((isSubspace _).tupled)
194198
}
195199
}
@@ -209,28 +213,20 @@ trait SpaceLogic {
209213
else if (canDecompose(tp1)) tryDecompose1(tp1)
210214
else if (canDecompose(tp2)) tryDecompose2(tp2)
211215
else intersectUnrelatedAtomicTypes(tp1, tp2)
212-
case (Typ(tp1, _), Prod(tp2, fun, ss, true)) =>
216+
case (Typ(tp1, _), Prod(tp2, fun, ss)) =>
213217
if (isSubType(tp2, tp1)) b
214-
else if (isSubType(tp1, tp2)) a // problematic corner case: inheriting a case class
215-
else if (canDecompose(tp1)) tryDecompose1(tp1)
216-
else Empty
217-
case (Typ(tp1, _), Prod(tp2, _, _, false)) =>
218-
if (isSubType(tp1, tp2) || isSubType(tp2, tp1)) b // prefer extractor space for better approximation
219218
else if (canDecompose(tp1)) tryDecompose1(tp1)
219+
else if (isSubType(tp1, tp2)) a // problematic corner case: inheriting a case class
220220
else Empty
221-
case (Prod(tp1, fun, ss, true), Typ(tp2, _)) =>
221+
case (Prod(tp1, fun, ss), Typ(tp2, _)) =>
222222
if (isSubType(tp1, tp2)) a
223-
else if (isSubType(tp2, tp1)) a // problematic corner case: inheriting a case class
224-
else if (canDecompose(tp2)) tryDecompose2(tp2)
225-
else Empty
226-
case (Prod(tp1, _, _, false), Typ(tp2, _)) =>
227-
if (isSubType(tp1, tp2) || isSubType(tp2, tp1)) a
228223
else if (canDecompose(tp2)) tryDecompose2(tp2)
224+
else if (isSubType(tp2, tp1)) a // problematic corner case: inheriting a case class
229225
else Empty
230-
case (Prod(tp1, fun1, ss1, full), Prod(tp2, fun2, ss2, _)) =>
226+
case (Prod(tp1, fun1, ss1), Prod(tp2, fun2, ss2)) =>
231227
if (!isSameUnapply(fun1, fun2)) Empty
232228
else if (ss1.zip(ss2).exists(p => simplify(intersect(p._1, p._2)) == Empty)) Empty
233-
else Prod(tp1, fun1, ss1.zip(ss2).map((intersect _).tupled), full)
229+
else Prod(tp1, fun1, ss1.zip(ss2).map((intersect _).tupled))
234230
}
235231
}
236232

@@ -247,27 +243,31 @@ trait SpaceLogic {
247243
else if (canDecompose(tp1)) tryDecompose1(tp1)
248244
else if (canDecompose(tp2)) tryDecompose2(tp2)
249245
else a
250-
case (Typ(tp1, _), Prod(tp2, fun, ss, true)) =>
246+
case (Typ(tp1, _), Prod(tp2, fun, ss)) =>
251247
// rationale: every instance of `tp1` is covered by `tp2(_)`
252-
if (isSubType(tp1, tp2)) minus(Prod(tp1, fun, signature(fun, tp1, ss.length).map(Typ(_, false)), true), b)
253-
else if (canDecompose(tp1)) tryDecompose1(tp1)
254-
else a
248+
if isSubType(tp1, tp2) && covers(fun, tp1) then
249+
minus(Prod(tp1, fun, signature(fun, tp1, ss.length).map(Typ(_, false))), b)
250+
else if canDecompose(tp1) then
251+
tryDecompose1(tp1)
252+
else
253+
a
255254
case (_, Or(ss)) =>
256255
ss.foldLeft(a)(minus)
257256
case (Or(ss), _) =>
258257
Or(ss.map(minus(_, b)))
259-
case (Prod(tp1, fun, ss, true), Typ(tp2, _)) =>
260-
// uncovered corner case: tp2 :< tp1
261-
if (isSubType(tp1, tp2)) Empty
262-
else if (simplify(a) == Empty) Empty
263-
else if (canDecompose(tp2)) tryDecompose2(tp2)
264-
else a
265-
case (Prod(tp1, _, _, false), Typ(tp2, _)) =>
266-
if (isSubType(tp1, tp2)) Empty
267-
else a
268-
case (Typ(tp1, _), Prod(tp2, _, _, false)) =>
269-
a // approximation
270-
case (Prod(tp1, fun1, ss1, full), Prod(tp2, fun2, ss2, _)) =>
258+
case (Prod(tp1, fun, ss), Typ(tp2, _)) =>
259+
// uncovered corner case: tp2 :< tp1, may happen when inheriting case class
260+
if (isSubType(tp1, tp2))
261+
Empty
262+
else if (simplify(a) == Empty)
263+
Empty
264+
else if (canDecompose(tp2))
265+
tryDecompose2(tp2)
266+
else if (isSubType(tp2, tp1) &&covers(fun, tp2))
267+
minus(a, Prod(tp1, fun, signature(fun, tp1, ss.length).map(Typ(_, false))))
268+
else
269+
a
270+
case (Prod(tp1, fun1, ss1), Prod(tp2, fun2, ss2)) =>
271271
if (!isSameUnapply(fun1, fun2)) return a
272272

273273
val range = (0 until ss1.size).toList
@@ -282,40 +282,36 @@ trait SpaceLogic {
282282
else if cache.forall(sub => isSubspace(sub, Empty)) then Empty
283283
else
284284
// `(_, _, _) - (Some, None, _)` becomes `(None, _, _) | (_, Some, _) | (_, _, Empty)`
285-
Or(range.map { i => Prod(tp1, fun1, ss1.updated(i, sub(i)), full) })
285+
Or(range.map { i => Prod(tp1, fun1, ss1.updated(i, sub(i))) })
286286
}
287287
}
288288
}
289289

290290
object SpaceEngine {
291291

292-
/** Is the unapply irrefutable?
292+
/** Is the unapply or unapplySeq irrefutable?
293293
* @param unapp The unapply function reference
294294
*/
295-
def isIrrefutableUnapply(unapp: tpd.Tree, patSize: Int)(using Context): Boolean = {
296-
val unappResult = unapp.tpe.widen.finalResultType
297-
unappResult.isRef(defn.SomeClass) ||
298-
unappResult <:< ConstantType(Constant(true)) ||
299-
(unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) || // scala2 compatibility
300-
(patSize != -1 && productArity(unappResult) == patSize) || {
301-
val isEmptyTp = extractorMemberType(unappResult, nme.isEmpty, unapp.srcPos)
295+
def isIrrefutable(unapp: TermRef)(using Context): Boolean = {
296+
val unappResult = unapp.widen.finalResultType
297+
unappResult.isRef(defn.SomeClass)
298+
|| unappResult <:< ConstantType(Constant(true)) // only for unapply
299+
|| (unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) // scala2 compatibility
300+
|| unapplySeqTypeElemTp(unappResult).exists // only for unapplySeq
301+
|| productArity(unappResult) > 0
302+
|| {
303+
val isEmptyTp = extractorMemberType(unappResult, nme.isEmpty, NoSourcePosition)
302304
isEmptyTp <:< ConstantType(Constant(false))
303305
}
304306
}
305307

306-
/** Is the unapplySeq irrefutable?
307-
* @param unapp The unapplySeq function reference
308+
/** Is the unapply or unapplySeq irrefutable?
309+
* @param unapp The unapply function tree
308310
*/
309-
def isIrrefutableUnapplySeq(unapp: tpd.Tree, patSize: Int)(using Context): Boolean = {
310-
val unappResult = unapp.tpe.widen.finalResultType
311-
unappResult.isRef(defn.SomeClass) ||
312-
(unapp.symbol.is(Synthetic) && unapp.symbol.owner.linkedClass.is(Case)) || // scala2 compatibility
313-
unapplySeqTypeElemTp(unappResult).exists ||
314-
isProductSeqMatch(unappResult, patSize) ||
315-
{
316-
val isEmptyTp = extractorMemberType(unappResult, nme.isEmpty, unapp.srcPos)
317-
isEmptyTp <:< ConstantType(Constant(false))
318-
}
311+
def isIrrefutable(unapp: tpd.Tree)(using Context): Boolean = {
312+
val fun1 = tpd.funPart(unapp)
313+
val funRef = fun1.tpe.asInstanceOf[TermRef]
314+
isIrrefutable(funRef)
319315
}
320316
}
321317

@@ -396,12 +392,12 @@ class SpaceEngine(using Context) extends SpaceLogic {
396392
else {
397393
val (arity, elemTp, resultTp) = unapplySeqInfo(fun.tpe.widen.finalResultType, fun.srcPos)
398394
if (elemTp.exists)
399-
Prod(erase(pat.tpe.stripAnnots), funRef, projectSeq(pats) :: Nil, isIrrefutableUnapplySeq(fun, pats.size))
395+
Prod(erase(pat.tpe.stripAnnots), funRef, projectSeq(pats) :: Nil)
400396
else
401-
Prod(erase(pat.tpe.stripAnnots), funRef, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)), isIrrefutableUnapplySeq(fun, pats.size))
397+
Prod(erase(pat.tpe.stripAnnots), funRef, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)))
402398
}
403399
else
404-
Prod(erase(pat.tpe.stripAnnots), funRef, pats.map(project), isIrrefutableUnapply(fun, pats.length))
400+
Prod(erase(pat.tpe.stripAnnots), funRef, pats.map(project))
405401

406402
case Typed(pat @ UnApply(_, _, _), _) =>
407403
project(pat)
@@ -509,7 +505,7 @@ class SpaceEngine(using Context) extends SpaceLogic {
509505
val unapplyTp = scalaConsType.classSymbol.companionModule.termRef.select(nme.unapply)
510506
items.foldRight[Space](zero) { (pat, acc) =>
511507
val consTp = scalaConsType.appliedTo(pats.head.tpe.widen)
512-
Prod(consTp, unapplyTp, project(pat) :: acc :: Nil, true)
508+
Prod(consTp, unapplyTp, project(pat) :: acc :: Nil)
513509
}
514510
}
515511

@@ -525,7 +521,9 @@ class SpaceEngine(using Context) extends SpaceLogic {
525521
}
526522

527523
def isSameUnapply(tp1: TermRef, tp2: TermRef): Boolean =
528-
tp1.prefix.isStable && tp2.prefix.isStable && tp1 =:= tp2
524+
// always assume two TypeTest[S, T].unapply are the same if they are equal in types
525+
(tp1.prefix.isStable && tp2.prefix.isStable || tp1.symbol == defn.TypeTest_unapply)
526+
&& tp1 =:= tp2
529527

530528
/** Parameter types of the case class type `tp`. Adapted from `unapplyPlan` in patternMatcher */
531529
def signature(unapp: TermRef, scrutineeTp: Type, argLen: Int): List[Type] = {
@@ -558,7 +556,7 @@ class SpaceEngine(using Context) extends SpaceLogic {
558556
// Case unapplySeq:
559557
// 1. return the type `List[T]` where `T` is the element type of the unapplySeq return type `Seq[T]`
560558

561-
val resTp = mt.finalResultType
559+
val resTp = mt.instantiate(scrutineeTp :: Nil).finalResultType
562560

563561
val sig =
564562
if (resTp.isRef(defn.BooleanClass))
@@ -591,6 +589,13 @@ class SpaceEngine(using Context) extends SpaceLogic {
591589
sig.map(_.annotatedToRepeated)
592590
}
593591

592+
/** Whether the extractor covers the given type */
593+
def covers(unapp: TermRef, scrutineeTp: Type): Boolean =
594+
SpaceEngine.isIrrefutable(unapp) || unapp.symbol == defn.TypeTest_unapply && {
595+
val AppliedType(_, _ :: tp :: Nil) = unapp.prefix.widen.dealias
596+
scrutineeTp <:< tp
597+
}
598+
594599
/** Decompose a type into subspaces -- assume the type can be decomposed */
595600
def decompose(tp: Type): List[Space] =
596601
tp.dealias match {
@@ -710,7 +715,7 @@ class SpaceEngine(using Context) extends SpaceLogic {
710715
def impossible: Nothing = throw new AssertionError("`satisfiable` only accepts flattened space.")
711716

712717
def genConstraint(space: Space): List[(Type, Type)] = space match {
713-
case Prod(tp, unappTp, ss, _) =>
718+
case Prod(tp, unappTp, ss) =>
714719
val tps = signature(unappTp, tp, ss.length)
715720
ss.zip(tps).flatMap {
716721
case (sp : Prod, tp) => sp.tp -> tp :: genConstraint(sp)
@@ -772,7 +777,7 @@ class SpaceEngine(using Context) extends SpaceLogic {
772777
showType(tp) + params(tp).map(_ => "_").mkString("(", ", ", ")")
773778
else if (decomposed) "_: " + showType(tp, showTypeArgs = true)
774779
else "_"
775-
case Prod(tp, fun, params, _) =>
780+
case Prod(tp, fun, params) =>
776781
if (ctx.definitions.isTupleType(tp))
777782
"(" + params.map(doShow(_)).mkString(", ") + ")"
778783
else if (tp.isRef(scalaConsType.symbol))

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import NameKinds.DefaultGetterName
3333
import NameOps._
3434
import SymDenotations.{NoCompleter, NoDenotation}
3535
import Applications.unapplyArgs
36-
import transform.patmat.SpaceEngine.isIrrefutableUnapply
36+
import transform.patmat.SpaceEngine.isIrrefutable
3737
import config.Feature._
3838
import config.SourceVersion._
3939

@@ -742,7 +742,7 @@ trait Checking {
742742
recur(pat1, pt)
743743
case UnApply(fn, _, pats) =>
744744
check(pat, pt) &&
745-
(isIrrefutableUnapply(fn, pats.length) || fail(pat, pt)) && {
745+
(isIrrefutable(fn) || fail(pat, pt)) && {
746746
val argPts = unapplyArgs(fn.tpe.widen.finalResultType, fn, pats, pat.srcPos)
747747
pats.corresponds(argPts)(recur)
748748
}

library/src/scala/reflect/TypeTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package scala.reflect
22

33
/** A `TypeTest[S, T] contains the logic needed to know at runtime if a value of
4-
* type `S` can be downcasted to `T`.
4+
* type `S` is an instance of `T`.
55
*
66
* If a pattern match is performed on a term of type `s: S` that is uncheckable with `s.isInstanceOf[T]` and
77
* the pattern are of the form:
@@ -12,7 +12,7 @@ package scala.reflect
1212
@scala.annotation.implicitNotFound(msg = "No TypeTest available for [${S}, ${T}]")
1313
trait TypeTest[-S, T] extends Serializable:
1414

15-
/** A TypeTest[S, T] can serve as an extractor that matches only S of type T.
15+
/** A TypeTest[S, T] can serve as an extractor that matches if and only if S of type T.
1616
*
1717
* The compiler tries to turn unchecked type tests in pattern matches into checked ones
1818
* by wrapping a `(_: T)` type pattern as `tt(_: T)`, where `tt` is the `TypeTest[S, T]` instance.

tests/patmat/i12020.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
19: Pattern Match Exhaustivity: _: TypeDef

tests/patmat/i12020.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import scala.quoted.*
2+
3+
def qwe(using Quotes) = {
4+
import quotes.reflect.*
5+
6+
def ko_1(param: ValDef | TypeDef) =
7+
param match {
8+
case _: ValDef =>
9+
case _: TypeDef =>
10+
}
11+
12+
def ko_2(params: List[ValDef] | List[TypeDef]) =
13+
params.map {
14+
case x: ValDef =>
15+
case y: TypeDef =>
16+
}
17+
18+
def ko_3(param: ValDef | TypeDef) =
19+
param match {
20+
case _: ValDef =>
21+
}
22+
}

tests/patmat/i12026.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def test[A, B](a: A|B)(using reflect.TypeTest[Any, A], reflect.TypeTest[Any, B]) =
2+
a match {
3+
case a: A =>
4+
case b: B =>
5+
}

tests/patmat/i12026b.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
def test[A, B](a: A|B)(tta: reflect.TypeTest[Any, A], ttb: reflect.TypeTest[Any, B]) =
2+
a match {
3+
case tta(a: A) =>
4+
case ttb(b: B) =>
5+
}

tests/patmat/i2363.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
15: Pattern Match Exhaustivity: List(_, _*)
2-
21: Pattern Match Exhaustivity: _: Expr
2+
21: Pattern Match Exhaustivity: _: IntExpr, _: BooleanExpr

tests/patmat/irrefutable.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
22: Pattern Match Exhaustivity: _: Base
2-
65: Pattern Match Exhaustivity: _: M
1+
22: Pattern Match Exhaustivity: _: A, _: B, C(_, _)
2+
65: Pattern Match Exhaustivity: ExM(_, _)

tests/patmat/optionless.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
28: Pattern Match Exhaustivity: _: Tree
1+
28: Pattern Match Exhaustivity: Ident(_)

tests/patmat/patmat-extractor.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
13: Pattern Match Exhaustivity: _: Node
1+
13: Pattern Match Exhaustivity: NodeA(_), NodeB(_), NodeC(_)
22
15: Match case Unreachable

0 commit comments

Comments
 (0)