Skip to content

Fix #9740: harden type checking for pattern match on objects #11327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 26 additions & 5 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3780,11 +3780,32 @@ class Typer extends Namer
withMode(Mode.GadtConstraintInference) {
TypeComparer.constrainPatternType(tree.tpe, pt)
}
val cmp =
untpd.Apply(
untpd.Select(untpd.TypedSplice(tree), nme.EQ),
untpd.TypedSplice(dummyTreeOfType(pt)))
typedExpr(cmp, defn.BooleanType)

// approximate type params with bounds
def approx = new ApproximatingTypeMap {
def apply(tp: Type) = tp.dealias match
case tp: TypeRef if !tp.symbol.isClass =>
expandBounds(tp.info.bounds)
case _ =>
mapOver(tp)
}

if tree.symbol.is(Module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to check this also for other values that are not modules?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We cannot check non-module values due to aliasing.

&& !(tree.tpe frozen_<:< pt) // fast track
&& !(tree.tpe frozen_<:< approx(pt))
then
// We could check whether `equals` is overriden.
// Reasons for not doing so:
// - it complicates the protocol
// - such code patterns usually implies hidden errors in the code
// - it's safe/sound to reject the code
report.error(TypeMismatch(tree.tpe, pt, "\npattern type is incompatible with expected type"), tree.srcPos)
else
val cmp =
untpd.Apply(
untpd.Select(untpd.TypedSplice(tree), nme.EQ),
untpd.TypedSplice(dummyTreeOfType(pt)))
typedExpr(cmp, defn.BooleanType)
case _ =>

private def checkStatementPurity(tree: tpd.Tree)(original: untpd.Tree, exprOwner: Symbol)(using Context): Unit =
Expand Down
32 changes: 32 additions & 0 deletions tests/neg/i5077.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
trait Is[A]
case object IsInt extends Is[Int]
case object IsString extends Is[String]
case class C[A](is: Is[A], value: A)

@main
def Test = {
val c_string: C[String] = C(IsString, "name")
val c_any: C[_] = c_string
val any: Any = c_string

// Case 1: error
c_string match {
case C(IsInt, _) => println(s"An Int") // error
case C(IsString, s) => println(s"A String with length ${s.length}")
case _ => println("No match")
}

// Case 2: Should match the second case and print the length of the string
c_any match {
case C(IsInt, i) if i < 10 => println(s"An Int less than 10")
case C(IsString, s) => println(s"A String with length ${s.length}")
case _ => println("No match")
}

// Case 3: Same as above; should match the second case and print the length of the string
any match {
case C(IsInt, i) if i < 10 => println(s"An Int less than 10")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no checkfile for neg tests?? Sounds very unstable to not verify the failure reasons...

I wanted to see on what grounds this now fails to compile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We tend to avoid check files as they incur maintenance overhead. In contrast, we check the line numbers of reported errors agree with the comments // error in the source code.

case C(IsString, s) => println(s"A String with length ${s.length}")
case _ => println("No match")
}
}
16 changes: 16 additions & 0 deletions tests/neg/i9740.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
abstract class RecoveryCompleted
object RecoveryCompleted extends RecoveryCompleted

abstract class TypedRecoveryCompleted
object TypedRecoveryCompleted extends TypedRecoveryCompleted

class Test {
TypedRecoveryCompleted match {
case RecoveryCompleted => println("Recovery completed") // error
case TypedRecoveryCompleted => println("Typed recovery completed")
}

def foo(x: TypedRecoveryCompleted) = x match
case RecoveryCompleted => // error
case TypedRecoveryCompleted =>
}
2 changes: 1 addition & 1 deletion tests/pos/autoTuplingTest.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
object autoTupling {

val x = Some(1, 2)
val x: Option[(Int, Int)] = Some(1, 2)

x match {
case Some(a, b) => a + b
Expand Down
2 changes: 1 addition & 1 deletion tests/pos/i7516.scala
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
val foo: Int => Int = Some(7) match
val foo: Int => Int = Option(7) match
case Some(y) => x => y
case None => identity[Int]
11 changes: 11 additions & 0 deletions tests/pos/i9740b.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
sealed trait Exp[T]
case class IntExp(x: Int) extends Exp[Int]
case class StrExp(x: String) extends Exp[String]
object UnitExp extends Exp[Unit]

class Foo[U <: Int, T <: U] {
def bar[A <: T](x: Exp[A]): Unit = x match
case IntExp(x) =>
case StrExp(x) =>
case UnitExp =>
}
16 changes: 16 additions & 0 deletions tests/pos/i9740c.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
sealed trait Exp[T]
case class IntExp(x: Int) extends Exp[Int]
case class StrExp(x: String) extends Exp[String]
object UnitExp extends Exp[Unit]

trait Txn[T <: Txn[T]]
case class Obj(o: AnyRef) extends Txn[Obj] with Exp[AnyRef]


class Foo {
def bar[A <: Txn[A]](x: Exp[A]): Unit = x match
case IntExp(x) =>
case StrExp(x) =>
case UnitExp =>
case Obj(o) =>
}
1 change: 0 additions & 1 deletion tests/run/i5077.check
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
A String with length 4
A String with length 4
A String with length 4
13 changes: 6 additions & 7 deletions tests/run/i5077.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@ def Test = {
val c_any: C[_] = c_string
val any: Any = c_string

// Case 1: no error
// `IsInt.equals` might be overridden to match a value of `C[String]`
c_string match {
case C(IsInt, _) => println(s"An Int") // Can't possibly happen!
case C(IsString, s) => println(s"A String with length ${s.length}")
case _ => println("No match")
}
// Case 1: error, tested in tests/neg/i5077.scala
// c_string match {
// case C(IsInt, _) => println(s"An Int") // Can't possibly happen!
// case C(IsString, s) => println(s"A String with length ${s.length}")
// case _ => println("No match")
// }

// Case 2: Should match the second case and print the length of the string
c_any match {
Expand Down