Skip to content

Commit 1efdbe7

Browse files
committed
Fix scala#4984: support name-based unapplySeq
1 parent 354611b commit 1efdbe7

File tree

5 files changed

+86
-4
lines changed

5 files changed

+86
-4
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,8 @@ class Definitions {
422422
def Seq_drop(implicit ctx: Context) = Seq_dropR.symbol
423423
lazy val Seq_lengthCompareR = SeqClass.requiredMethodRef(nme.lengthCompare)
424424
def Seq_lengthCompare(implicit ctx: Context) = Seq_lengthCompareR.symbol
425+
lazy val Seq_toSeqR = SeqClass.requiredMethodRef(nme.toSeq)
426+
def Seq_toSeq(implicit ctx: Context) = Seq_toSeqR.symbol
425427

426428
lazy val ArrayType: TypeRef = ctx.requiredClassRef("scala.Array")
427429
def ArrayClass(implicit ctx: Context) = ArrayType.symbol.asClass

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,13 @@ object PatternMatcher {
265265
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
266266
case Some(VarArgPattern(arg)) =>
267267
val matchRemaining =
268-
if (args.length == 1)
269-
patternPlan(getResult, arg, onSuccess)
268+
if (args.length == 1) {
269+
val toSeq = ref(getResult)
270+
.select(defn.Seq_toSeq.matchingMember(getResult.info))
271+
letAbstract(toSeq) { toSeqResult =>
272+
patternPlan(toSeqResult, arg, onSuccess)
273+
}
274+
}
270275
else {
271276
val dropped = ref(getResult)
272277
.select(defn.Seq_drop.matchingMember(getResult.info))

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,27 @@ object Applications {
100100
Nil
101101
}
102102

103+
def vaidUnapplySeqType(getTp: Type): Boolean = {
104+
def superType(elemTp: Type) = {
105+
val tps = List(
106+
MethodType(List("len".toTermName))(_ => defn.IntType :: Nil, _ => defn.IntType),
107+
MethodType(List("i".toTermName))(_ => defn.IntType :: Nil, _ => elemTp),
108+
MethodType(List("n".toTermName))(_ => defn.IntType :: Nil, _ => defn.SeqType.appliedTo(elemTp)),
109+
ExprType(defn.SeqType.appliedTo(elemTp)),
110+
)
111+
val names = List(nme.lengthCompare, nme.apply, nme.drop, nme.toSeq)
112+
RefinedType.make(defn.AnyType, names, tps)
113+
}
114+
getTp <:< superType(WildcardType) && {
115+
val seqArg = getTp.member(nme.toSeq).info.elemType.hiBound
116+
getTp <:< superType(seqArg)
117+
}
118+
}
119+
103120
if (unapplyName == nme.unapplySeq) {
104121
if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil
105-
else if (isGetMatch(unapplyResult, pos) && getTp.derivesFrom(defn.SeqClass)) {
106-
val seqArg = getTp.elemType.hiBound
122+
else if (isGetMatch(unapplyResult, pos) && vaidUnapplySeqType(getTp)) {
123+
val seqArg = getTp.member(nme.apply).info.finalResultType
107124
if (seqArg.exists) args.map(Function.const(seqArg))
108125
else fail
109126
}

tests/pos/i4984.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
object Array2 {
2+
def unapplySeq[T](x: Array[T]): UnapplySeqWrapper[T] = new UnapplySeqWrapper(x)
3+
4+
final class UnapplySeqWrapper[T](private val a: Array[T]) extends AnyVal {
5+
def isEmpty: Boolean = false
6+
def get: UnapplySeqWrapper[T] = this
7+
def lengthCompare(len: Int): Int = a.lengthCompare(len)
8+
def apply(i: Int): T = a(i)
9+
def drop(n: Int): scala.Seq[T] = ???
10+
def toSeq: scala.Seq[T] = a.toSeq // clones the array
11+
}
12+
}
13+
14+
class Test {
15+
def test1(xs: Array[Int]): Int = xs match {
16+
case Array2(x, y) => x + y
17+
}
18+
19+
def test2(xs: Array[Int]): Seq[Int] = xs match {
20+
case Array2(x, y, xs:_*) => xs
21+
}
22+
23+
def test3(xs: Array[Int]): Seq[Int] = xs match {
24+
case Array2(xs:_*) => xs
25+
}
26+
}

tests/run/i4984b.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
object Array2 {
2+
def unapplySeq(x: Array[Int]): Data = new Data
3+
4+
final class Data {
5+
def isEmpty: Boolean = false
6+
def get: Data = this
7+
def lengthCompare(len: Int): Int = 0
8+
def apply(i: Int): Int = 3
9+
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
10+
def toSeq: scala.Seq[Int] = Seq(6, 7)
11+
}
12+
}
13+
14+
object Test {
15+
def test1(xs: Array[Int]): Int = xs match {
16+
case Array2(x, y) => x + y
17+
}
18+
19+
def test2(xs: Array[Int]): Seq[Int] = xs match {
20+
case Array2(x, y, xs:_*) => xs
21+
}
22+
23+
def test3(xs: Array[Int]): Seq[Int] = xs match {
24+
case Array2(xs:_*) => xs
25+
}
26+
27+
def main(args: Array[String]): Unit = {
28+
test1(Array(3, 5))
29+
test2(Array(3, 5))
30+
test3(Array(3, 5))
31+
}
32+
}

0 commit comments

Comments
 (0)