Skip to content

Commit 7f9d8dd

Browse files
committed
Fix scala#7554: Implement TypeTest interface
Using tests from: https://gist.github.com/Blaisorblade/a0eebb6a4f35344e48c4c60dc2a14ce6
1 parent 904f407 commit 7f9d8dd

File tree

16 files changed

+523
-10
lines changed

16 files changed

+523
-10
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,10 @@ class Definitions {
785785
@tu lazy val ClassTagModule_apply: Symbol = ClassTagModule.requiredMethod(nme.apply)
786786

787787

788+
@tu lazy val TypeTestClass: ClassSymbol = requiredClass("scala.reflect.TypeTest")
789+
@tu lazy val TypeTestModule: Symbol = TypeTestClass.companionModule
790+
@tu lazy val TypeTestModule_identity: Symbol = TypeTestModule.requiredMethod(nme.identity)
791+
788792
@tu lazy val QuotedExprClass: ClassSymbol = requiredClass("scala.quoted.Expr")
789793
@tu lazy val QuotedExprModule: Symbol = QuotedExprClass.companionModule
790794

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,37 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
4545
case _ => EmptyTree
4646
end synthesizedClassTag
4747

48+
val synthesizedTypeTest: SpecialHandler =
49+
(formal, span) => formal.argInfos match {
50+
case arg1 :: arg2 :: Nil if !defn.isBottomClass(arg2.typeSymbol) =>
51+
val tp1 = fullyDefinedType(arg1, "TypeTest argument", span)
52+
val tp2 = fullyDefinedType(arg2, "TypeTest argument", span)
53+
val sym2 = tp2.typeSymbol
54+
if tp1 <:< tp2 then
55+
ref(defn.TypeTestModule_identity).appliedToType(tp2).withSpan(span)
56+
else if sym2 == defn.AnyValClass || sym2 == defn.AnyRefAlias || sym2 == defn.ObjectClass then
57+
EmptyTree
58+
else
59+
// Generate SAM: (s: <tp1>) => if s.isInstanceOf[s.type & <tp2>] then Some(s.asInstanceOf[s.type & <tp2>]) else None
60+
def body(args: List[Tree]): Tree = {
61+
val arg :: Nil = args
62+
val t = arg.tpe & tp2
63+
If(
64+
arg.select(defn.Any_isInstanceOf).appliedToType(t),
65+
ref(defn.SomeClass.companionModule.termRef).select(nme.apply)
66+
.appliedToType(t)
67+
.appliedTo(arg.select(nme.asInstanceOf_).appliedToType(t)),
68+
ref(defn.NoneModule))
69+
}
70+
val tpe = MethodType(List(nme.s))(_ => List(tp1), mth => defn.OptionClass.typeRef.appliedTo(mth.newParamRef(0) & tp2))
71+
val meth = newSymbol(ctx.owner, nme.ANON_FUN, Synthetic | Method, tpe, coord = span)
72+
val typeTestType = defn.TypeTestClass.typeRef.appliedTo(List(tp1, tp2))
73+
Closure(meth, tss => body(tss.head).changeOwner(ctx.owner, meth), targetType = typeTestType).withSpan(span)
74+
case _ =>
75+
EmptyTree
76+
}
77+
end synthesizedTypeTest
78+
4879
val synthesizedTupleFunction: SpecialHandler = (formal, span) =>
4980
formal match
5081
case AppliedType(_, funArgs @ fun :: tupled :: Nil) =>
@@ -374,6 +405,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
374405

375406
val specialHandlers = List(
376407
defn.ClassTagClass -> synthesizedClassTag,
408+
defn.TypeTestClass -> synthesizedTypeTest,
377409
defn.EqlClass -> synthesizedEql,
378410
defn.TupledFunctionClass -> synthesizedTupleFunction,
379411
defn.ValueOfClass -> synthesizedValueOf,

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -787,15 +787,18 @@ class Typer extends Namer
787787
*/
788788
def tryWithClassTag(tree: Typed, pt: Type)(using Context): Tree = tree.tpt.tpe.dealias match {
789789
case tref: TypeRef if !tref.symbol.isClass && !ctx.isAfterTyper && !(tref =:= pt) =>
790-
require(ctx.mode.is(Mode.Pattern))
791-
withoutMode(Mode.Pattern)(
792-
inferImplicit(defn.ClassTagClass.typeRef.appliedTo(tref), EmptyTree, tree.tpt.span)
793-
) match {
794-
case SearchSuccess(clsTag, _, _) =>
795-
typed(untpd.Apply(untpd.TypedSplice(clsTag), untpd.TypedSplice(tree.expr)), pt)
796-
case _ =>
797-
tree
790+
def withTag(tpe: Type): Option[Tree] = {
791+
require(ctx.mode.is(Mode.Pattern))
792+
withoutMode(Mode.Pattern)(
793+
inferImplicit(tpe, EmptyTree, tree.tpt.span)
794+
) match
795+
case SearchSuccess(clsTag, _, _) =>
796+
Some(typed(untpd.Apply(untpd.TypedSplice(clsTag), untpd.TypedSplice(tree.expr)), pt))
797+
case _ =>
798+
None
798799
}
800+
withTag(defn.TypeTestClass.typeRef.appliedTo(pt, tref)).orElse(
801+
withTag(defn.ClassTagClass.typeRef.appliedTo(tref))).getOrElse(tree)
799802
case _ => tree
800803
}
801804

@@ -1838,8 +1841,8 @@ class Typer extends Namer
18381841
val body1 = typed(tree.body, pt)
18391842
body1 match {
18401843
case UnApply(fn, Nil, arg :: Nil)
1841-
if fn.symbol.exists && fn.symbol.owner == defn.ClassTagClass && !body1.tpe.isError =>
1842-
// A typed pattern `x @ (e: T)` with an implicit `ctag: ClassTag[T]`
1844+
if fn.symbol.exists && (fn.symbol.owner == defn.ClassTagClass || fn.symbol.owner.derivesFrom(defn.TypeTestClass)) && !body1.tpe.isError =>
1845+
// A typed pattern `x @ (e: T)` with an implicit `ctag: ClassTag[T]` or `ctag: TypeTest[T]`
18431846
// was rewritten to `x @ ctag(e)` by `tryWithClassTag`.
18441847
// Rewrite further to `ctag(x @ e)`
18451848
tpd.cpy.UnApply(body1)(fn, Nil,
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
---
2+
layout: doc-page
3+
title: "TypeTest"
4+
---
5+
6+
TypeTest
7+
--------
8+
9+
`TypeTest` provides a replacement for `ClassTag.unapply` where the type of the argument is generalized.
10+
`TypeTest.unapply` will return `Some(x.asInstanceOf[Y])` if `x` conforms to `Y`, otherwise it returns `None`.
11+
12+
```scala
13+
trait TypeTest[-S, T] extends Serializable {
14+
def unapply(s: S): Option[s.type & T]
15+
}
16+
```
17+
18+
Just like `ClassTag` used to do, it can be used to perform type checks in patterns.
19+
20+
```scala
21+
type X
22+
type Y <: X
23+
given TypeTest[X, Y] = ...
24+
(x: X) match {
25+
case y: Y => ... // safe checked downcast
26+
case _ => ...
27+
}
28+
```
29+
30+
31+
Examples
32+
--------
33+
34+
Given the following abstract definition of `Peano` numbers that provides `TypeTest[Nat, Zero]` and `TypeTest[Nat, Succ]`
35+
36+
```scala
37+
trait Peano {
38+
type Nat
39+
type Zero <: Nat
40+
type Succ <: Nat
41+
def safeDiv(m: Nat, n: Succ): (Nat, Nat)
42+
val Zero: Zero
43+
val Succ: SuccExtractor
44+
trait SuccExtractor {
45+
def apply(nat: Nat): Succ
46+
def unapply(nat: Succ): Option[Nat]
47+
}
48+
given TypeTest[Nat, Zero] = typeTestOfZero
49+
protected def typeTestOfZero: TypeTest[Nat, Zero]
50+
given TypeTest[Nat, Succ]
51+
protected def typeTestOfSucc: TypeTest[Nat, Succ]
52+
```
53+
54+
it will be possible to write the following program
55+
56+
```scala
57+
val peano: Peano = ...
58+
import peano.{_, given}
59+
def divOpt(m: Nat, n: Nat): Option[(Nat, Nat)] = {
60+
n match {
61+
case Zero => None
62+
case s @ Succ(_) => Some(safeDiv(m, s))
63+
}
64+
}
65+
val two = Succ(Succ(Zero))
66+
val five = Succ(Succ(Succ(two)))
67+
println(divOpt(five, two))
68+
```
69+
70+
Note that without the `TypeTest[Nat, Succ]` the pattern `Succ.unapply(nat: Succ)` would be unchecked.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package scala.reflect
2+
3+
/** A `TypeTest[S, T] contains the logic needed to know at runtime if a value of
4+
* type `S` can be downcasted to `T`.
5+
*
6+
* If a pattern match is performed on a term of type `s: S` that is uncheckable with `s.isInstanceOf[T]` and
7+
* the pattern are of the form:
8+
* - `t: T`
9+
* - `t @ X()` where the `X.unapply` has takes an argument of type `T`
10+
* then a given instance of `TypeTest[S, T]` is summoned and used to perform the test.
11+
*/
12+
@scala.annotation.implicitNotFound(msg = "No TypeTest available for [${S}, ${T}]")
13+
trait TypeTest[-S, T] extends Serializable {
14+
15+
/** A TypeTest[S, T] can serve as an extractor that matches only S of type T.
16+
*
17+
* The compiler tries to turn unchecked type tests in pattern matches into checked ones
18+
* by wrapping a `(_: T)` type pattern as `tt(_: T)`, where `tt` is the `TypeTest[S, T]` instance.
19+
* Type tests necessary before calling other extractors are treated similarly.
20+
* `SomeExtractor(...)` is turned into `tt(SomeExtractor(...))` if `T` in `SomeExtractor.unapply(x: T)`
21+
* is uncheckable, but we have an instance of `TypeTest[S, T]`.
22+
*/
23+
def unapply(x: S): Option[x.type & T]
24+
25+
}
26+
27+
object TypeTest {
28+
29+
def identity[T]: TypeTest[T, T] = Some(_)
30+
31+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import scala.reflect.ClassTag
2+
3+
object IsInstanceOfClassTag {
4+
def safeCast[T: ClassTag](x: Any): Option[T] = {
5+
x match {
6+
case x: T => Some(x) // TODO error: deprecation waring
7+
case _ => None
8+
}
9+
}
10+
11+
def main(args: Array[String]): Unit = {
12+
safeCast[List[String]](List[Int](1)) match {
13+
case None =>
14+
case Some(xs) =>
15+
xs.head.substring(0)
16+
}
17+
18+
safeCast[List[_]](List[Int](1)) match {
19+
case None =>
20+
case Some(xs) =>
21+
xs.head.substring(0) // error
22+
}
23+
}
24+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import scala.reflect.TypeTest
2+
3+
object IsInstanceOfClassTag {
4+
def safeCast[T](x: Any)(using TypeTest[Any, T]): Option[T] = {
5+
x match {
6+
case x: T => Some(x)
7+
case _ => None
8+
}
9+
}
10+
11+
def main(args: Array[String]): Unit = {
12+
safeCast[List[String]](List[Int](1)) match { // error
13+
case None =>
14+
case Some(xs) =>
15+
}
16+
17+
safeCast[List[_]](List[Int](1)) match {
18+
case None =>
19+
case Some(xs) =>
20+
}
21+
}
22+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import scala.reflect.TypeTest
2+
3+
trait R {
4+
type Nat
5+
type Succ <: Nat
6+
type Idx
7+
given TypeTest[Nat, Succ] = typeTestOfSucc
8+
protected def typeTestOfSucc: TypeTest[Nat, Succ]
9+
def n: Nat
10+
def one: Succ
11+
}
12+
13+
object RI extends R {
14+
type Nat = Int
15+
type Succ = Int
16+
type Idx = Int
17+
protected def typeTestOfSucc: TypeTest[Nat, Succ] = new {
18+
def unapply(x: Int): Option[x.type & Succ] =
19+
if x > 0 then Some(x) else None
20+
}
21+
def n: Nat = 4
22+
def one: Succ = 1
23+
}
24+
25+
object Test {
26+
val r1: R = RI
27+
val r2: R = RI
28+
29+
r1.n match {
30+
case n: r2.Nat => // error: the type test for Test.r2.Nat cannot be checked at runtime
31+
case n: r1.Idx => // error: the type test for Test.r1.Idx cannot be checked at runtime
32+
case n: r1.Succ => // Ok
33+
case n: r1.Nat => // Ok
34+
}
35+
36+
r1.one match {
37+
case n: r2.Nat => // error: the type test for Test.r2.Nat cannot be checked at runtime
38+
case n: r1.Idx => // error: the type test for Test.r1.Idx cannot be checked at runtime
39+
case n: r1.Nat => // Ok
40+
}
41+
}
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import scala.reflect.TypeTest
2+
3+
object Test {
4+
def main(args: Array[String]): Unit = {
5+
val p1: T = T1
6+
val p2: T = T1
7+
8+
(p1.y: p1.X) match {
9+
case x: p2.Y => // error: unchecked
10+
case x: p1.Y =>
11+
case _ =>
12+
}
13+
}
14+
15+
}
16+
17+
trait T {
18+
type X
19+
type Y <: X
20+
def x: X
21+
def y: Y
22+
given TypeTest[X, Y] = typeTestOfY
23+
protected def typeTestOfY: TypeTest[X, Y]
24+
}
25+
26+
object T1 extends T {
27+
type X = Boolean
28+
type Y = true
29+
def x: X = false
30+
def y: Y = true
31+
protected def typeTestOfY: TypeTest[X, Y] = new {
32+
def unapply(x: X): Option[x.type & Y] = x match
33+
case x: (true & x.type) => Some(x)
34+
case _ => None
35+
}
36+
37+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import scala.reflect.TypeTest
2+
3+
object Test {
4+
def test[S, T](using TypeTest[S, T]): Unit = ()
5+
val a: A = ???
6+
7+
test[Any, Any]
8+
test[Int, Int]
9+
10+
test[Int, Any]
11+
test[String, Any]
12+
test[String, AnyRef]
13+
14+
test[Any, Int]
15+
test[Any, String]
16+
test[Any, Some[_]]
17+
test[Any, Array[Int]]
18+
test[Seq[Int], List[Int]]
19+
20+
test[Any, Some[Int]] // error
21+
test[Any, a.X] // error
22+
test[a.X, a.Y] // error
23+
24+
}
25+
26+
class A {
27+
type X
28+
type Y <: X
29+
}

tests/neg/type-test-syntesize.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import scala.reflect.TypeTest
2+
3+
object Test {
4+
def test[S, T](using x: TypeTest[S, T]): Unit = ()
5+
6+
test[Any, AnyRef] // error
7+
test[Any, AnyVal] // error
8+
test[Any, Object] // error
9+
10+
test[Any, Nothing] // error
11+
test[AnyRef, Nothing] // error
12+
test[AnyVal, Nothing] // error
13+
test[Null, Nothing] // error
14+
test[Unit, Nothing] // error
15+
test[Int, Nothing] // error
16+
test[8, Nothing] // error
17+
test[List[_], Nothing] // error
18+
test[Nothing, Nothing] // error
19+
}

0 commit comments

Comments
 (0)