Skip to content

Commit 4250b26

Browse files
committed
Special case for pattern matching tagged abstract types.
Add special case when pattern matching against an abstract type that comes with a class tag
1 parent b7ba84d commit 4250b26

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,25 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
426426
ifExpr = seqToRepeated(typedExpr(tree.expr, defn.SeqType)),
427427
wildName = nme.WILDCARD_STAR)
428428
else {
429-
def tpt1 = checkSimpleKinded(typedType(tree.tpt))
429+
def typedTpt = checkSimpleKinded(typedType(tree.tpt))
430+
def handlePattern: Tree = {
431+
val tpt1 = typedTpt
432+
// special case for an abstract type that comes with a class tag
433+
tpt1.tpe.dealias match {
434+
case tref: TypeRef if !tref.symbol.isClass =>
435+
inferImplicit(defn.ClassTagType.appliedTo(tref),
436+
EmptyTree, tpt1.pos)(ctx.retractMode(Mode.Pattern)) match {
437+
case SearchSuccess(arg, _, _) =>
438+
return typed(untpd.Apply(untpd.TypedSplice(arg), tree.expr), pt)
439+
case _ =>
440+
}
441+
case _ =>
442+
}
443+
ascription(tpt1, isWildcard = true)
444+
}
430445
cases(
431-
ifPat = ascription(tpt1, isWildcard = true),
432-
ifExpr = ascription(tpt1, isWildcard = false),
446+
ifPat = handlePattern,
447+
ifExpr = ascription(typedTpt, isWildcard = false),
433448
wildName = nme.WILDCARD)
434449
}
435450
}

tests/run/i1099.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import scala.reflect.ClassTag
2+
object Test {
3+
def foo[T: ClassTag](x: Any) =
4+
x match {
5+
case t: T => true
6+
case _ => false
7+
}
8+
// This is what `foo` expands to
9+
def foo2[T](x: Any)(implicit ev: ClassTag[T]) =
10+
x match {
11+
case t @ ev(_) => true
12+
case _ => false
13+
}
14+
def main(args: Array[String]): Unit = {
15+
assert(foo[String]("a"))
16+
assert(!foo[String](new Integer(1)))
17+
assert(foo[Int](1))
18+
assert(!foo[Int](true))
19+
20+
assert(foo2[String]("a"))
21+
assert(!foo2[String](new Integer(1)))
22+
assert(foo2[Int](1))
23+
assert(!foo2[Int](true))
24+
}
25+
}

0 commit comments

Comments
 (0)