Skip to content

Commit 050c9af

Browse files
committed
Special case for pattern matching tagged abstract types.
Add special case when pattern matching against an abstract type that comes with a class tag
1 parent 14096e3 commit 050c9af

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,25 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
427427
ifExpr = seqToRepeated(typedExpr(tree.expr, defn.SeqType)),
428428
wildName = nme.WILDCARD_STAR)
429429
else {
430-
def tpt1 = checkSimpleKinded(typedType(tree.tpt))
430+
def typedTpt = checkSimpleKinded(typedType(tree.tpt))
431+
def handlePattern: Tree = {
432+
val tpt1 = typedTpt
433+
// special case for an abstract type that comes with a class tag
434+
tpt1.tpe.dealias match {
435+
case tref: TypeRef if !tref.symbol.isClass =>
436+
inferImplicit(defn.ClassTagType.appliedTo(tref),
437+
EmptyTree, tpt1.pos)(ctx.retractMode(Mode.Pattern)) match {
438+
case SearchSuccess(arg, _, _) =>
439+
return typed(untpd.Apply(untpd.TypedSplice(arg), tree.expr), pt)
440+
case _ =>
441+
}
442+
case _ =>
443+
}
444+
ascription(tpt1, isWildcard = true)
445+
}
431446
cases(
432-
ifPat = ascription(tpt1, isWildcard = true),
433-
ifExpr = ascription(tpt1, isWildcard = false),
447+
ifPat = handlePattern,
448+
ifExpr = ascription(typedTpt, isWildcard = false),
434449
wildName = nme.WILDCARD)
435450
}
436451
}

tests/run/i1099.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import scala.reflect.ClassTag
2+
object Test {
3+
def foo[T: ClassTag](x: Any) =
4+
x match {
5+
case t: T => true
6+
case _ => false
7+
}
8+
// This is what `foo` expands to
9+
def foo2[T](x: Any)(implicit ev: ClassTag[T]) =
10+
x match {
11+
case t @ ev(_) => true
12+
case _ => false
13+
}
14+
def main(args: Array[String]): Unit = {
15+
assert(foo[String]("a"))
16+
assert(!foo[String](new Integer(1)))
17+
assert(foo[Int](1))
18+
assert(!foo[Int](true))
19+
20+
assert(foo2[String]("a"))
21+
assert(!foo2[String](new Integer(1)))
22+
assert(foo2[Int](1))
23+
assert(!foo2[Int](true))
24+
}
25+
}

0 commit comments

Comments
 (0)