Skip to content

Commit bbd8084

Browse files
committed
Fix #9740: harden type checking for pattern match on objects
We enforce that when pattern match on an object, the object should be a subtype of the scrutinee type. Reasons for doing so: - such code patterns usually implies hidden errors in the code - it's always safe/sound to reject the code We could check whether `equals` is overridden in the object, but - it complicates the protocol - overriding `equals` of object is also a bad practice - there is no sign that the slightly improved completeness is useful
1 parent 53a2dc5 commit bbd8084

File tree

2 files changed

+30
-5
lines changed

2 files changed

+30
-5
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3780,11 +3780,20 @@ class Typer extends Namer
37803780
withMode(Mode.GadtConstraintInference) {
37813781
TypeComparer.constrainPatternType(tree.tpe, pt)
37823782
}
3783-
val cmp =
3784-
untpd.Apply(
3785-
untpd.Select(untpd.TypedSplice(tree), nme.EQ),
3786-
untpd.TypedSplice(dummyTreeOfType(pt)))
3787-
typedExpr(cmp, defn.BooleanType)
3783+
3784+
if tree.symbol.is(Module) && !(tree.tpe <:< pt) then
3785+
// We could check whether `equals` is overriden.
3786+
// Reasons for not doing so:
3787+
// - it complicates the protocol
3788+
// - such code patterns usually implies hidden errors in the code
3789+
// - it's safe/sound to reject the code
3790+
report.error(TypeMismatch(tree.tpe, pt, "\npattern type is incompatible with expected type"), tree.srcPos)
3791+
else
3792+
val cmp =
3793+
untpd.Apply(
3794+
untpd.Select(untpd.TypedSplice(tree), nme.EQ),
3795+
untpd.TypedSplice(dummyTreeOfType(pt)))
3796+
typedExpr(cmp, defn.BooleanType)
37883797
case _ =>
37893798

37903799
private def checkStatementPurity(tree: tpd.Tree)(original: untpd.Tree, exprOwner: Symbol)(using Context): Unit =

tests/neg/i9740.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
abstract class RecoveryCompleted
2+
object RecoveryCompleted extends RecoveryCompleted
3+
4+
abstract class TypedRecoveryCompleted
5+
object TypedRecoveryCompleted extends TypedRecoveryCompleted
6+
7+
class Test {
8+
TypedRecoveryCompleted match {
9+
case RecoveryCompleted => println("Recovery completed") // error
10+
case TypedRecoveryCompleted => println("Typed recovery completed")
11+
}
12+
13+
def foo(x: TypedRecoveryCompleted) = x match
14+
case RecoveryCompleted => // error
15+
case TypedRecoveryCompleted =>
16+
}

0 commit comments

Comments
 (0)