Skip to content

Commit e90c953

Browse files
authored
Merge pull request scala#11327 from dotty-staging/fix-9740b
Fix scala#9740: harden type checking for pattern match on objects
2 parents abbbcef + b04150c commit e90c953

File tree

9 files changed

+109
-15
lines changed

9 files changed

+109
-15
lines changed

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3784,11 +3784,32 @@ class Typer extends Namer
37843784
withMode(Mode.GadtConstraintInference) {
37853785
TypeComparer.constrainPatternType(tree.tpe, pt)
37863786
}
3787-
val cmp =
3788-
untpd.Apply(
3789-
untpd.Select(untpd.TypedSplice(tree), nme.EQ),
3790-
untpd.TypedSplice(dummyTreeOfType(pt)))
3791-
typedExpr(cmp, defn.BooleanType)
3787+
3788+
// approximate type params with bounds
3789+
def approx = new ApproximatingTypeMap {
3790+
def apply(tp: Type) = tp.dealias match
3791+
case tp: TypeRef if !tp.symbol.isClass =>
3792+
expandBounds(tp.info.bounds)
3793+
case _ =>
3794+
mapOver(tp)
3795+
}
3796+
3797+
if tree.symbol.is(Module)
3798+
&& !(tree.tpe frozen_<:< pt) // fast track
3799+
&& !(tree.tpe frozen_<:< approx(pt))
3800+
then
3801+
// We could check whether `equals` is overriden.
3802+
// Reasons for not doing so:
3803+
// - it complicates the protocol
3804+
// - such code patterns usually implies hidden errors in the code
3805+
// - it's safe/sound to reject the code
3806+
report.error(TypeMismatch(tree.tpe, pt, "\npattern type is incompatible with expected type"), tree.srcPos)
3807+
else
3808+
val cmp =
3809+
untpd.Apply(
3810+
untpd.Select(untpd.TypedSplice(tree), nme.EQ),
3811+
untpd.TypedSplice(dummyTreeOfType(pt)))
3812+
typedExpr(cmp, defn.BooleanType)
37923813
case _ =>
37933814

37943815
private def checkStatementPurity(tree: tpd.Tree)(original: untpd.Tree, exprOwner: Symbol)(using Context): Unit =

tests/neg/i5077.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
trait Is[A]
2+
case object IsInt extends Is[Int]
3+
case object IsString extends Is[String]
4+
case class C[A](is: Is[A], value: A)
5+
6+
@main
7+
def Test = {
8+
val c_string: C[String] = C(IsString, "name")
9+
val c_any: C[_] = c_string
10+
val any: Any = c_string
11+
12+
// Case 1: error
13+
c_string match {
14+
case C(IsInt, _) => println(s"An Int") // error
15+
case C(IsString, s) => println(s"A String with length ${s.length}")
16+
case _ => println("No match")
17+
}
18+
19+
// Case 2: Should match the second case and print the length of the string
20+
c_any match {
21+
case C(IsInt, i) if i < 10 => println(s"An Int less than 10")
22+
case C(IsString, s) => println(s"A String with length ${s.length}")
23+
case _ => println("No match")
24+
}
25+
26+
// Case 3: Same as above; should match the second case and print the length of the string
27+
any match {
28+
case C(IsInt, i) if i < 10 => println(s"An Int less than 10")
29+
case C(IsString, s) => println(s"A String with length ${s.length}")
30+
case _ => println("No match")
31+
}
32+
}

tests/neg/i9740.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
abstract class RecoveryCompleted
2+
object RecoveryCompleted extends RecoveryCompleted
3+
4+
abstract class TypedRecoveryCompleted
5+
object TypedRecoveryCompleted extends TypedRecoveryCompleted
6+
7+
class Test {
8+
TypedRecoveryCompleted match {
9+
case RecoveryCompleted => println("Recovery completed") // error
10+
case TypedRecoveryCompleted => println("Typed recovery completed")
11+
}
12+
13+
def foo(x: TypedRecoveryCompleted) = x match
14+
case RecoveryCompleted => // error
15+
case TypedRecoveryCompleted =>
16+
}

tests/pos/autoTuplingTest.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
object autoTupling {
22

3-
val x = Some(1, 2)
3+
val x: Option[(Int, Int)] = Some(1, 2)
44

55
x match {
66
case Some(a, b) => a + b

tests/pos/i7516.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
val foo: Int => Int = Some(7) match
1+
val foo: Int => Int = Option(7) match
22
case Some(y) => x => y
33
case None => identity[Int]

tests/pos/i9740b.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
sealed trait Exp[T]
2+
case class IntExp(x: Int) extends Exp[Int]
3+
case class StrExp(x: String) extends Exp[String]
4+
object UnitExp extends Exp[Unit]
5+
6+
class Foo[U <: Int, T <: U] {
7+
def bar[A <: T](x: Exp[A]): Unit = x match
8+
case IntExp(x) =>
9+
case StrExp(x) =>
10+
case UnitExp =>
11+
}

tests/pos/i9740c.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
sealed trait Exp[T]
2+
case class IntExp(x: Int) extends Exp[Int]
3+
case class StrExp(x: String) extends Exp[String]
4+
object UnitExp extends Exp[Unit]
5+
6+
trait Txn[T <: Txn[T]]
7+
case class Obj(o: AnyRef) extends Txn[Obj] with Exp[AnyRef]
8+
9+
10+
class Foo {
11+
def bar[A <: Txn[A]](x: Exp[A]): Unit = x match
12+
case IntExp(x) =>
13+
case StrExp(x) =>
14+
case UnitExp =>
15+
case Obj(o) =>
16+
}

tests/run/i5077.check

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
A String with length 4
22
A String with length 4
3-
A String with length 4

tests/run/i5077.scala

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@ def Test = {
99
val c_any: C[_] = c_string
1010
val any: Any = c_string
1111

12-
// Case 1: no error
13-
// `IsInt.equals` might be overridden to match a value of `C[String]`
14-
c_string match {
15-
case C(IsInt, _) => println(s"An Int") // Can't possibly happen!
16-
case C(IsString, s) => println(s"A String with length ${s.length}")
17-
case _ => println("No match")
18-
}
12+
// Case 1: error, tested in tests/neg/i5077.scala
13+
// c_string match {
14+
// case C(IsInt, _) => println(s"An Int") // Can't possibly happen!
15+
// case C(IsString, s) => println(s"A String with length ${s.length}")
16+
// case _ => println("No match")
17+
// }
1918

2019
// Case 2: Should match the second case and print the length of the string
2120
c_any match {

0 commit comments

Comments
 (0)