Skip to content

Commit eb2aab7

Browse files
committed
Fix scala#7554: Implement TypeTest interface
1 parent 87ea7a6 commit eb2aab7

File tree

10 files changed

+382
-9
lines changed

10 files changed

+382
-9
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,10 @@ class Definitions {
627627
@tu lazy val ClassTagModule: Symbol = ClassTagClass.companionModule
628628
@tu lazy val ClassTagModule_apply: Symbol = ClassTagModule.requiredMethod(nme.apply)
629629

630+
@tu lazy val TypeTestClass: ClassSymbol = ctx.requiredClass("scala.tasty.TypeTest")
631+
@tu lazy val TypeTestModule: Symbol = TypeTestClass.companionModule
632+
@tu lazy val TypeTestModule_unapply: Symbol = TypeTestModule.requiredMethod(nme.unapply)
633+
630634
@tu lazy val QuotedExprClass: ClassSymbol = ctx.requiredClass("scala.quoted.Expr")
631635
@tu lazy val QuotedExprModule: Symbol = QuotedExprClass.companionModule
632636
@tu lazy val QuotedExprModule_nullExpr: Symbol = QuotedExprModule.requiredMethod(nme.nullExpr)

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -680,14 +680,15 @@ class Typer extends Namer
680680
*/
681681
def tryWithClassTag(tree: Typed, pt: Type)(implicit ctx: Context): Tree = tree.tpt.tpe.dealias match {
682682
case tref: TypeRef if !tref.symbol.isClass && !ctx.isAfterTyper =>
683-
require(ctx.mode.is(Mode.Pattern))
684-
inferImplicit(defn.ClassTagClass.typeRef.appliedTo(tref),
685-
EmptyTree, tree.tpt.span)(ctx.retractMode(Mode.Pattern)) match {
686-
case SearchSuccess(clsTag, _, _) =>
687-
typed(untpd.Apply(untpd.TypedSplice(clsTag), untpd.TypedSplice(tree.expr)), pt)
688-
case _ =>
689-
tree
683+
def withTag(tpe: Type): Option[Tree] = {
684+
inferImplicit(tpe, EmptyTree, tree.tpt.span)(ctx.retractMode(Mode.Pattern)) match
685+
case SearchSuccess(typeTest, _, _) =>
686+
Some(typed(untpd.Apply(untpd.TypedSplice(typeTest), untpd.TypedSplice(tree.expr)), pt))
687+
case _ =>
688+
None
690689
}
690+
withTag(defn.TypeTestClass.typeRef.appliedTo(pt, tref)).orElse(
691+
withTag(defn.ClassTagClass.typeRef.appliedTo(tref))).getOrElse(tree)
691692
case _ => tree
692693
}
693694

@@ -1467,8 +1468,8 @@ class Typer extends Namer
14671468
val body1 = typed(tree.body, pt1)
14681469
body1 match {
14691470
case UnApply(fn, Nil, arg :: Nil)
1470-
if fn.symbol.exists && fn.symbol.owner == defn.ClassTagClass && !body1.tpe.isError =>
1471-
// A typed pattern `x @ (e: T)` with an implicit `ctag: ClassTag[T]`
1471+
if fn.symbol.exists && (fn.symbol.owner == defn.ClassTagClass || fn.symbol.owner.derivesFrom(defn.TypeTestClass)) && !body1.tpe.isError =>
1472+
// A typed pattern `x @ (e: T)` with an implicit `ctag: ClassTag[T]` or `ctag: TypeTest[T]`
14721473
// was rewritten to `x @ ctag(e)` by `tryWithClassTag`.
14731474
// Rewrite further to `ctag(x @ e)`
14741475
tpd.cpy.UnApply(body1)(fn, Nil,
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
---
2+
layout: doc-page
3+
title: "TypeTest"
4+
---
5+
6+
TypeTest
7+
--------
8+
9+
`TypeTest` provides the a replacement for `ClassTag.unapply` where the type of the argument is generalized.
10+
`TypeTest.unapply` will return `Some(x.asInstanceOf[Y])` if `x` conforms to `Y`, otherwise it returns `None`.
11+
12+
```scala
13+
trait TypeTest[S, T <: S] extends Serializable {
14+
def unapply(s: S): Option[s.type & T]
15+
}
16+
```
17+
18+
Just like `ClassTag` used to to, it can be used to perform type checks in patterns.
19+
20+
```scala
21+
type X
22+
type Y <: X
23+
given TypeTest[X, Y] = ...
24+
(x: X) match {
25+
case y: Y => ... // safe checked downcast
26+
case _ => ...
27+
}
28+
```
29+
30+
31+
Examples
32+
--------
33+
34+
Given the following abstract definition of `Peano` numbers that provides `TypeTest[Nat, Zero]` and `TypeTest[Nat, Succ]`
35+
36+
```scala
37+
trait Peano {
38+
type Nat
39+
type Zero <: Nat
40+
type Succ <: Nat
41+
def safeDiv(m: Nat, n: Succ): (Nat, Nat)
42+
val Zero: Zero
43+
val Succ: SuccExtractor
44+
trait SuccExtractor {
45+
def apply(nat: Nat): Succ
46+
def unapply(nat: Succ): Option[Nat]
47+
}
48+
given TypeTest[Nat, Zero] = typeTestOfZero
49+
protected def typeTestOfZero: TypeTest[Nat, Zero]
50+
given TypeTest[Nat, Succ]
51+
protected def typeTestOfSucc: TypeTest[Nat, Succ]
52+
```
53+
54+
it will be possible to write the following program
55+
56+
```scala
57+
val peano: Peano = ...
58+
import peano.{_, given}
59+
def divOpt(m: Nat, n: Nat): Option[(Nat, Nat)] = {
60+
n match {
61+
case Zero => None
62+
case s @ Succ(_) => Some(safeDiv(m, s))
63+
}
64+
}
65+
val two = Succ(Succ(Zero))
66+
val five = Succ(Succ(Succ(two)))
67+
println(divOpt(five, two))
68+
```
69+
70+
Note that without the `TypeTest[Nat, Succ]` the pattern `Succ.unapply(nat: Succ)` would be unchecked.
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package scala.tasty
2+
3+
/** A `TypeTest[S, T]` (where `T <: S`) contains the logic needed to know at runtime if a value of
4+
* type `S` can be downcased to `T`.
5+
*
6+
* If a pattern match is performed on a term of type `s: S` that is uncheckable with `s.isInstanceOf[T]` and
7+
* the pattern are of the form:
8+
* - `t: T`
9+
* - `t @ X()` where the `X.unapply` has takes an argument of type `T`
10+
* then a given instance of `TypeTest[S, T]` is summoned and used to performed the test.
11+
*
12+
* Note: This is replacemet for `ClassTag.unapply` that can be sound for path dependent types
13+
*/
14+
@scala.annotation.implicitNotFound(msg = "No TypeTest available for [${S}, ${T}]")
15+
trait TypeTest[S, T <: S] extends Serializable {
16+
17+
def isInstance(x: S): TypeTest.Result[x.type & T]
18+
19+
/** A TypeTest[S, T] can serve as an extractor that matches only S of type T.
20+
*
21+
* The compiler tries to turn unchecked type tests in pattern matches into checked ones
22+
* by wrapping a `(_: T)` type pattern as `tt(_: T)`, where `ct` is the `TypeTest[S, T]` instance.
23+
* Type tests necessary before calling other extractors are treated similarly.
24+
* `SomeExtractor(...)` is turned into `tt(SomeExtractor(...))` if `T` in `SomeExtractor.unapply(x: T)`
25+
* is uncheckable, but we have an instance of `TypeTest[S, T]`.
26+
*/
27+
def unapply(x: S): Option[x.type & T] =
28+
if isInstance(x).asInstanceOf[Boolean] then Some(x.asInstanceOf[x.type & T])
29+
else None
30+
31+
}
32+
33+
object TypeTest {
34+
35+
opaque type Result[A] = Boolean
36+
37+
def success[A](x: A): Result[A] = true
38+
39+
def failure[A]: Result[A] = false
40+
41+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import scala.tasty.TypeTest
2+
3+
trait R {
4+
type Nat
5+
type Succ <: Nat
6+
type Idx
7+
given TypeTest[Nat, Succ] = typeTestOfSucc
8+
protected def typeTestOfSucc: TypeTest[Nat, Succ]
9+
def n: Nat
10+
def one: Succ
11+
}
12+
13+
object RI extends R {
14+
type Nat = Int
15+
type Succ = Int
16+
type Idx = Int
17+
protected def typeTestOfSucc: TypeTest[Nat, Succ] = new {
18+
def isInstance(x: Int): TypeTest.Result[x.type & Succ] =
19+
if x > 0 then TypeTest.success(x) else TypeTest.failure
20+
}
21+
def n: Nat = 4
22+
def one: Succ = 1
23+
}
24+
25+
object Test {
26+
val r1: R = RI
27+
import r1.given
28+
29+
val r2: R = RI
30+
import r2.given
31+
32+
r1.n match {
33+
case n: r2.Nat => // error: the type test for Test.r2.Nat cannot be checked at runtime
34+
case n: r1.Idx => // error: the type test for Test.r1.Idx cannot be checked at runtime
35+
case n: r1.Succ => // Ok
36+
case n: r1.Nat => // Ok
37+
}
38+
39+
r1.one match {
40+
case n: r2.Nat => // error: the type test for Test.r2.Nat cannot be checked at runtime
41+
case n: r1.Idx => // error: the type test for Test.r1.Idx cannot be checked at runtime
42+
case n: r1.Nat => // Ok
43+
}
44+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import scala.tasty.TypeTest
2+
3+
object Test {
4+
def main(args: Array[String]): Unit = {
5+
val p1: T = T1
6+
import p1.given
7+
8+
val p2: T = T1
9+
import p2.given
10+
11+
(p1.y: p1.X) match {
12+
case x: p2.Y => // error: unchecked
13+
case x: p1.Y =>
14+
case _ =>
15+
}
16+
}
17+
18+
}
19+
20+
trait T {
21+
type X
22+
type Y <: X
23+
def x: X
24+
def y: Y
25+
given TypeTest[X, Y] = typeTestOfY
26+
protected def typeTestOfY: TypeTest[X, Y]
27+
}
28+
29+
object T1 extends T {
30+
type X = Boolean
31+
type Y = true
32+
def x: X = false
33+
def y: Y = true
34+
protected def typeTestOfY: TypeTest[X, Y] = new {
35+
def isInstance(x: X): TypeTest.Result[x.type & Y] = x match
36+
case x: (true & x.type) => TypeTest.success(x)
37+
case _ => TypeTest.failure
38+
}
39+
40+
}

tests/run/type-test-binding.check

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ok
2+
9

tests/run/type-test-binding.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import scala.tasty.TypeTest
2+
3+
sealed trait Foo {
4+
5+
type X
6+
type Y <: X
7+
8+
def x: X
9+
10+
def f(y: Y) = println("ok")
11+
12+
given TypeTest[X, Y] = new TypeTest {
13+
def isInstance(x: X): TypeTest.Result[x.type & Y] =
14+
TypeTest.success(x.asInstanceOf[x.type & Y])
15+
}
16+
17+
object Z {
18+
def unapply(arg: Y): Option[Int] = Some(9)
19+
}
20+
}
21+
22+
object Test {
23+
def main(args: Array[String]): Unit = {
24+
test(new Foo { type X = Int; type Y = Int; def x: X = 1 })
25+
}
26+
27+
def test(foo: Foo): Unit = {
28+
import foo.given
29+
foo.x match {
30+
case x @ foo.Z(i) => // `x` is refined to type `foo.Y`
31+
foo.f(x)
32+
println(i)
33+
}
34+
}
35+
}

tests/run/type-test-nat.check

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
Some((SuccClass(SuccClass(ZeroObject)),SuccClass(ZeroObject)))
2+
Some((ZeroObject,SuccClass(SuccClass(ZeroObject))))
3+
None
4+
Some((2,1))
5+
Some((0,2))
6+
None

0 commit comments

Comments
 (0)