Skip to content

Commit 4e62ca9

Browse files
Merge pull request #5989 from dotty-staging/fix-3248
Fix #3248: support product-seq pattern
2 parents 0496200 + ec5a391 commit 4e62ca9

15 files changed

+398
-139
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,12 @@ object desugar {
436436
appliedTypeTree(tycon, targs)
437437
}
438438

439+
def isRepeated(tree: Tree): Boolean = tree match {
440+
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
441+
case ByNameTypeTree(tree1) => isRepeated(tree1)
442+
case _ => false
443+
}
444+
439445
// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
440446
val classTypeRef = appliedRef(classTycon)
441447

@@ -486,11 +492,6 @@ object desugar {
486492
}
487493
def enumTagMeths = if (isEnumCase) enumTagMeth(CaseKind.Class)._1 :: Nil else Nil
488494
def copyMeths = {
489-
def isRepeated(tree: Tree): Boolean = tree match {
490-
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
491-
case ByNameTypeTree(tree1) => isRepeated(tree1)
492-
case _ => false
493-
}
494495
val hasRepeatedParam = constrVparamss.exists(_.exists {
495496
case ValDef(_, tpt, _) => isRepeated(tpt)
496497
})
@@ -564,7 +565,8 @@ object desugar {
564565
// companion definitions include:
565566
// 1. If class is a case class case class C[Ts](p1: T1, ..., pN: TN)(moreParams):
566567
// def apply[Ts](p1: T1, ..., pN: TN)(moreParams) = new C[Ts](p1, ..., pN)(moreParams) (unless C is abstract)
567-
// def unapply[Ts]($1: C[Ts]) = $1
568+
// def unapply[Ts]($1: C[Ts]) = $1 // if not repeated
569+
// def unapplySeq[Ts]($1: C[Ts]) = $1 // if repeated
568570
// 2. The default getters of the constructor
569571
// The parent of the companion object of a non-parameterized case class
570572
// (T11, ..., T1N) => ... => (TM1, ..., TMN) => C
@@ -613,9 +615,13 @@ object desugar {
613615
app :: widenDefs
614616
}
615617
val unapplyMeth = {
618+
val hasRepeatedParam = constrVparamss.head.exists {
619+
case ValDef(_, tpt, _) => isRepeated(tpt)
620+
}
621+
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
616622
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
617623
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
618-
DefDef(nme.unapply, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
624+
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
619625
.withMods(synthetic)
620626
}
621627
companionDefs(companionParent, applyMeths ::: unapplyMeth :: companionMembers)

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import collection.mutable
88
import Symbols._, Contexts._, Types._, StdNames._, NameOps._
99
import ast.Trees._
1010
import util.Spans._
11-
import typer.Applications.{isProductMatch, isGetMatch, productSelectors}
11+
import typer.Applications.{isProductMatch, isGetMatch, isProductSeqMatch, productSelectors, productArity}
1212
import SymUtils._
1313
import Flags._, Constants._
1414
import Decorators._
@@ -286,6 +286,21 @@ object PatternMatcher {
286286
matchElemsPlan(getResult, args, exact = true, onSuccess)
287287
}
288288

289+
/** Plan for matching the sequence in `getResult`
290+
*
291+
* `getResult` is a product, where the last element is a sequence of elements.
292+
*/
293+
def unapplyProductSeqPlan(getResult: Symbol, args: List[Tree], arity: Int): Plan = {
294+
assert(arity <= args.size + 1)
295+
val selectors = productSelectors(getResult.info).map(ref(getResult).select(_))
296+
297+
val matchSeq =
298+
letAbstract(selectors.last) { seqResult =>
299+
unapplySeqPlan(seqResult, args.drop(arity - 1))
300+
}
301+
matchArgsPlan(selectors.take(arity - 1), args.take(arity - 1), matchSeq)
302+
}
303+
289304
/** Plan for matching the result of an unapply against argument patterns `args` */
290305
def unapplyPlan(unapp: Tree, args: List[Tree]): Plan = {
291306
def caseClass = unapp.symbol.owner.linkedClass
@@ -306,12 +321,20 @@ object PatternMatcher {
306321
.map(ref(unappResult).select(_))
307322
matchArgsPlan(selectors, args, onSuccess)
308323
}
324+
else if (isProductSeqMatch(unapp.tpe.widen, args.length, unapp.sourcePos) && isUnapplySeq) {
325+
val arity = productArity(unapp.tpe.widen, unapp.sourcePos)
326+
unapplyProductSeqPlan(unappResult, args, arity)
327+
}
309328
else {
310329
assert(isGetMatch(unapp.tpe))
311330
val argsPlan = {
312331
val get = ref(unappResult).select(nme.get, _.info.isParameterless)
332+
val arity = productArity(get.tpe, unapp.sourcePos)
313333
if (isUnapplySeq)
314-
letAbstract(get)(unapplySeqPlan(_, args))
334+
letAbstract(get) { getResult =>
335+
if (arity > 0) unapplyProductSeqPlan(getResult, args, arity)
336+
else unapplySeqPlan(getResult, args)
337+
}
315338
else
316339
letAbstract(get) { getResult =>
317340
val selectors =

compiler/src/dotty/tools/dotc/transform/patmat/Space.scala

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import ProtoTypes._
2020
import transform.SymUtils._
2121
import reporting.diagnostic.messages._
2222
import config.Printers.{exhaustivity => debug}
23+
import util.SourcePosition
2324

2425
/** Space logic for checking exhaustivity and unreachability of pattern matching
2526
*
@@ -336,8 +337,13 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
336337
if (fun.symbol.name == nme.unapplySeq)
337338
if (fun.symbol.owner == scalaSeqFactoryClass)
338339
projectSeq(pats)
339-
else
340-
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, projectSeq(pats) :: Nil, irrefutable(fun))
340+
else {
341+
val (arity, elemTp, resultTp) = unapplySeqInfo(fun.tpe.widen.finalResultType, fun.sourcePos)
342+
if (elemTp.exists)
343+
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, projectSeq(pats) :: Nil, irrefutable(fun))
344+
else
345+
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)), irrefutable(fun))
346+
}
341347
else
342348
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.map(project), irrefutable(fun))
343349
case Typed(pat @ UnApply(_, _, _), _) => project(pat)
@@ -352,6 +358,18 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
352358
Empty
353359
}
354360

361+
private def unapplySeqInfo(resTp: Type, pos: SourcePosition)(implicit ctx: Context): (Int, Type, Type) = {
362+
var resultTp = resTp
363+
var elemTp = unapplySeqTypeElemTp(resultTp)
364+
var arity = productArity(resultTp, pos)
365+
if (!elemTp.exists && arity <= 0) {
366+
resultTp = resTp.select(nme.get).finalResultType
367+
elemTp = unapplySeqTypeElemTp(resultTp.widen)
368+
arity = productSelectorTypes(resultTp, pos).size
369+
}
370+
(arity, elemTp, resultTp)
371+
}
372+
355373
/* Erase pattern bound types with WildcardType */
356374
def erase(tp: Type): Type = {
357375
def isPatternTypeSymbol(sym: Symbol) = !sym.isClass && sym.is(Case)
@@ -422,17 +440,26 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
422440
List()
423441
else {
424442
val isUnapplySeq = unappSym.name == nme.unapplySeq
425-
if (isProductMatch(mt.finalResultType, argLen) && !isUnapplySeq) {
426-
productSelectors(mt.finalResultType).take(argLen)
427-
.map(_.info.asSeenFrom(mt.finalResultType, mt.resultType.classSymbol).widenExpr)
443+
444+
if (isUnapplySeq) {
445+
val (arity, elemTp, resultTp) = unapplySeqInfo(mt.finalResultType, unappSym.sourcePos)
446+
if (elemTp.exists) scalaListType.appliedTo(elemTp) :: Nil
447+
else {
448+
val sels = productSeqSelectors(resultTp, arity, unappSym.sourcePos)
449+
sels.init :+ scalaListType.appliedTo(sels.last)
450+
}
428451
}
429452
else {
430-
val resTp = mt.finalResultType.select(nme.get).finalResultType.widen
431-
if (isUnapplySeq) scalaListType.appliedTo(resTp.argTypes.head) :: Nil
432-
else if (argLen == 0) Nil
433-
else if (isProductMatch(resTp, argLen))
434-
productSelectors(resTp).map(_.info.asSeenFrom(resTp, resTp.classSymbol).widenExpr)
435-
else resTp :: Nil
453+
val arity = productArity(mt.finalResultType, unappSym.sourcePos)
454+
if (arity > 0)
455+
productSelectors(mt.finalResultType)
456+
.map(_.info.asSeenFrom(mt.finalResultType, mt.resultType.classSymbol).widenExpr)
457+
else {
458+
val resTp = mt.finalResultType.select(nme.get).finalResultType.widen
459+
val arity = productArity(resTp, unappSym.sourcePos)
460+
if (argLen == 1) resTp :: Nil
461+
else productSelectors(resTp).map(_.info.asSeenFrom(resTp, resTp.classSymbol).widenExpr)
462+
}
436463
}
437464
}
438465

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 65 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,22 @@ object Applications {
4747

4848
/** Does `tp` fit the "product match" conditions as an unapply result type
4949
* for a pattern with `numArgs` subpatterns?
50-
* This is the case of `tp` has members `_1` to `_N` where `N == numArgs`.
50+
* This is the case if `tp` has members `_1` to `_N` where `N == numArgs`.
5151
*/
5252
def isProductMatch(tp: Type, numArgs: Int, errorPos: SourcePosition = NoSourcePosition)(implicit ctx: Context): Boolean =
5353
numArgs > 0 && productArity(tp, errorPos) == numArgs
5454

55+
/** Does `tp` fit the "product-seq match" conditions as an unapply result type
56+
* for a pattern with `numArgs` subpatterns?
57+
* This is the case if (1) `tp` has members `_1` to `_N` where `N <= numArgs + 1`.
58+
* (2) `tp._N` conforms to Seq match
59+
*/
60+
def isProductSeqMatch(tp: Type, numArgs: Int, errorPos: SourcePosition = NoSourcePosition)(implicit ctx: Context): Boolean = {
61+
val arity = productArity(tp, errorPos)
62+
arity > 0 && arity <= numArgs + 1 &&
63+
unapplySeqTypeElemTp(productSelectorTypes(tp, errorPos).last).exists
64+
}
65+
5566
/** Does `tp` fit the "get match" conditions as an unapply result type?
5667
* This is the case of `tp` has a `get` member as well as a
5768
* parameterless `isEmpty` member of result type `Boolean`.
@@ -60,6 +71,39 @@ object Applications {
6071
extractorMemberType(tp, nme.isEmpty, errorPos).isRef(defn.BooleanClass) &&
6172
extractorMemberType(tp, nme.get, errorPos).exists
6273

74+
/** If `getType` is of the form:
75+
* ```
76+
* {
77+
* def lengthCompare(len: Int): Int // or, def length: Int
78+
* def apply(i: Int): T = a(i)
79+
* def drop(n: Int): scala.Seq[T]
80+
* def toSeq: scala.Seq[T]
81+
* }
82+
* ```
83+
* returns `T`, otherwise NoType.
84+
*/
85+
def unapplySeqTypeElemTp(getTp: Type)(implicit ctx: Context): Type = {
86+
def lengthTp = ExprType(defn.IntType)
87+
def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType)
88+
def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
89+
def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp))
90+
def toSeqTp(elemTp: Type) = ExprType(defn.SeqType.appliedTo(elemTp))
91+
92+
// the result type of `def apply(i: Int): T`
93+
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType
94+
95+
def hasMethod(name: Name, tp: Type) =
96+
getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists
97+
98+
val isValid =
99+
elemTp.exists &&
100+
(hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) &&
101+
hasMethod(nme.drop, dropTp(elemTp)) &&
102+
hasMethod(nme.toSeq, toSeqTp(elemTp))
103+
104+
if (isValid) elemTp else NoType
105+
}
106+
63107
def productSelectorTypes(tp: Type, errorPos: SourcePosition)(implicit ctx: Context): List[Type] = {
64108
def tupleSelectors(n: Int, tp: Type): List[Type] = {
65109
val sel = extractorMemberType(tp, nme.selectorName(n), errorPos)
@@ -92,57 +136,35 @@ object Applications {
92136
else tp :: Nil
93137
} else tp :: Nil
94138

139+
def productSeqSelectors(tp: Type, argsNum: Int, pos: SourcePosition)(implicit ctx: Context): List[Type] = {
140+
val selTps = productSelectorTypes(tp, pos)
141+
val arity = selTps.length
142+
val elemTp = unapplySeqTypeElemTp(selTps.last)
143+
(0 until argsNum).map(i => if (i < arity - 1) selTps(i) else elemTp).toList
144+
}
145+
95146
def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: SourcePosition)(implicit ctx: Context): List[Type] = {
96147

97148
val unapplyName = unapplyFn.symbol.name
98-
def seqSelector = defn.RepeatedParamType.appliedTo(unapplyResult.elemType :: Nil)
99149
def getTp = extractorMemberType(unapplyResult, nme.get, pos)
100150

101151
def fail = {
102152
ctx.error(UnapplyInvalidReturnType(unapplyResult, unapplyName), pos)
103153
Nil
104154
}
105155

106-
/** If `getType` is of the form:
107-
* ```
108-
* {
109-
* def lengthCompare(len: Int): Int // or, def length: Int
110-
* def apply(i: Int): T = a(i)
111-
* def drop(n: Int): scala.Seq[T]
112-
* def toSeq: scala.Seq[T]
113-
* }
114-
* ```
115-
* returns `T`, otherwise NoType.
116-
*/
117-
def unapplySeqTypeElemTp(getTp: Type): Type = {
118-
def lengthTp = ExprType(defn.IntType)
119-
def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType)
120-
def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
121-
def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp))
122-
def toSeqTp(elemTp: Type) = defn.SeqType.appliedTo(elemTp)
123-
124-
// the result type of `def apply(i: Int): T`
125-
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType
126-
127-
def hasMethod(name: Name, tp: Type) =
128-
getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists
129-
130-
val isValid =
131-
elemTp.exists &&
132-
(hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) &&
133-
hasMethod(nme.drop, dropTp(elemTp)) &&
134-
hasMethod(nme.toSeq, toSeqTp(elemTp))
135-
136-
if (isValid) elemTp else NoType
156+
def unapplySeq(tp: Type)(fallback: => List[Type]): List[Type] = {
157+
val elemTp = unapplySeqTypeElemTp(tp)
158+
if (elemTp.exists) args.map(Function.const(elemTp))
159+
else if (isProductSeqMatch(tp, args.length, pos)) productSeqSelectors(tp, args.length, pos)
160+
else fallback
137161
}
138162

139163
if (unapplyName == nme.unapplySeq) {
140-
if (isGetMatch(unapplyResult, pos)) {
141-
val elemTp = unapplySeqTypeElemTp(getTp)
142-
if (elemTp.exists) args.map(Function.const(elemTp))
164+
unapplySeq(unapplyResult) {
165+
if (isGetMatch(unapplyResult, pos)) unapplySeq(getTp)(fail)
143166
else fail
144167
}
145-
else fail
146168
}
147169
else {
148170
assert(unapplyName == nme.unapply)
@@ -1106,19 +1128,12 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>
11061128

11071129
var argTypes = unapplyArgs(unapplyApp.tpe, unapplyFn, args, tree.sourcePos)
11081130
for (argType <- argTypes) assert(!isBounds(argType), unapplyApp.tpe.show)
1109-
val bunchedArgs =
1110-
if (argTypes.nonEmpty && argTypes.last.isRepeatedParam)
1111-
args.lastOption match {
1112-
case Some(arg @ Typed(argSeq, _)) if untpd.isWildcardStarArg(arg) =>
1113-
args.init :+ argSeq
1114-
case _ =>
1115-
val (regularArgs, varArgs) = args.splitAt(argTypes.length - 1)
1116-
regularArgs :+ untpd.SeqLiteral(varArgs, untpd.TypeTree()).withSpan(tree.span)
1117-
}
1118-
else if (argTypes.lengthCompare(1) == 0 && args.lengthCompare(1) > 0 && ctx.canAutoTuple)
1119-
untpd.Tuple(args) :: Nil
1120-
else
1121-
args
1131+
val bunchedArgs = argTypes match {
1132+
case argType :: Nil =>
1133+
if (args.lengthCompare(1) > 0 && ctx.canAutoTuple) untpd.Tuple(args) :: Nil
1134+
else args
1135+
case _ => args
1136+
}
11221137
if (argTypes.length != bunchedArgs.length) {
11231138
ctx.error(UnapplyInvalidNumberOfArguments(qual, argTypes), tree.sourcePos)
11241139
argTypes = argTypes.take(args.length) ++

0 commit comments

Comments
 (0)