Skip to content

Commit b1553cb

Browse files
committed
Implement new rules for name-based pattern matching
This implements the rules laid down in scala#1805.
1 parent 9bf5809 commit b1553cb

File tree

6 files changed

+70
-26
lines changed

6 files changed

+70
-26
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ object desugar {
2424

2525
/** Names of methods that are added unconditionally to case classes */
2626
def isDesugaredCaseClassMethodName(name: Name)(implicit ctx: Context): Boolean =
27-
name == nme.isDefined ||
2827
name == nme.copy ||
2928
name == nme.productArity ||
3029
name.isSelectorName
@@ -343,7 +342,6 @@ object desugar {
343342
if (isCaseClass) {
344343
def syntheticProperty(name: TermName, rhs: Tree) =
345344
DefDef(name, Nil, Nil, TypeTree(), rhs).withMods(synthetic)
346-
val isDefinedMeth = syntheticProperty(nme.isDefined, Literal(Constant(true)))
347345
val caseParams = constrVparamss.head.toArray
348346
val productElemMeths = for (i <- 0 until arity) yield
349347
syntheticProperty(nme.selectorName(i), Select(This(EmptyTypeIdent), caseParams(i).name))
@@ -369,7 +367,7 @@ object desugar {
369367
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
370368
.withMods(synthetic) :: Nil
371369
}
372-
copyMeths ::: isDefinedMeth :: productElemMeths.toList
370+
copyMeths ::: productElemMeths.toList
373371
}
374372
else Nil
375373

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,7 @@ class Definitions {
675675

676676
private def isVarArityClass(cls: Symbol, prefix: Name) = {
677677
val name = scalaClassName(cls)
678-
name.startsWith(prefix) &&
678+
name.startsWith(prefix) &&
679679
name.length > prefix.length &&
680680
name.drop(prefix.length).forall(_.isDigit)
681681
}
@@ -737,6 +737,14 @@ class Definitions {
737737
def isProductSubType(tp: Type)(implicit ctx: Context) =
738738
(tp derivesFrom ProductType.symbol) && tp.baseClasses.exists(isProductClass)
739739

740+
def productArity(tp: Type)(implicit ctx: Context) =
741+
if (tp derivesFrom ProductType.symbol)
742+
tp.baseClasses.find(isProductClass) match {
743+
case Some(prod) => prod.typeParams.length
744+
case None => -1
745+
}
746+
else -1
747+
740748
def isFunctionType(tp: Type)(implicit ctx: Context) =
741749
isFunctionClass(tp.dealias.typeSymbol) && {
742750
val arity = functionArity(tp)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ object StdNames {
424424
val info: N = "info"
425425
val inlinedEquals: N = "inlinedEquals"
426426
val isArray: N = "isArray"
427-
val isDefined: N = "isDefined"
428427
val isDefinedAt: N = "isDefinedAt"
429428
val isDefinedAtImpl: N = "$isDefinedAt"
430429
val isEmpty: N = "isEmpty"

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {
235235
// next: MatchMonad[U]
236236
// returns MatchMonad[U]
237237
def flatMap(prev: Tree, b: Symbol, next: Tree): Tree = {
238-
if (isProductMatch(prev.tpe)) {
238+
val resultArity = defn.productArity(b.info)
239+
if (isProductMatch(prev.tpe, resultArity)) {
239240
val nullCheck: Tree = prev.select(defn.Object_ne).appliedTo(Literal(Constant(null)))
240241
ifThenElseZero(
241242
nullCheck,
@@ -1429,7 +1430,7 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {
14291430

14301431
def resultInMonad =
14311432
if (aligner.isBool) defn.UnitType
1432-
else if (isProductMatch(resultType)) resultType
1433+
else if (isProductMatch(resultType, aligner.prodArity)) resultType
14331434
else if (isGetMatch(resultType)) extractorMemberType(resultType, nme.get)
14341435
else resultType
14351436

@@ -1630,7 +1631,7 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {
16301631
ref(binder) :: Nil
16311632
}
16321633
else if ((aligner.isSingle && aligner.extractor.prodArity == 1) &&
1633-
!isProductMatch(binderTypeTested) && isGetMatch(binderTypeTested))
1634+
!isProductMatch(binderTypeTested, aligner.prodArity) && isGetMatch(binderTypeTested))
16341635
List(ref(binder))
16351636
else
16361637
subPatRefs(binder)
@@ -1885,7 +1886,7 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {
18851886
else if (result.classSymbol is Flags.CaseClass) result.decls.filter(x => x.is(Flags.CaseAccessor) && x.is(Flags.Method)).map(_.info).toList
18861887
else result.select(nme.get) :: Nil
18871888
)*/
1888-
if (isProductMatch(resultType)) productSelectorTypes(resultType)
1889+
if (isProductMatch(resultType, args.length)) productSelectorTypes(resultType)
18891890
else if (isGetMatch(resultType)) getUnapplySelectors(resultOfGet, args)
18901891
else if (resultType isRef defn.BooleanClass) Nil
18911892
else {

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ object Applications {
4747
ref.info.widenExpr.dealias
4848
}
4949

50-
/** Does `tp` fit the "product match" conditions as an unapply result type?
51-
* This is the case of `tp` is a subtype of a ProductN class and `tp` has a
52-
* parameterless `isDefined` member of result type `Boolean`.
50+
/** Does `tp` fit the "product match" conditions as an unapply result type
51+
* for a pattern with `numArgs` subpatterns>
52+
* This is the case of `tp` is a subtype of the Product<numArgs> class.
5353
*/
54-
def isProductMatch(tp: Type, errorPos: Position = NoPosition)(implicit ctx: Context) =
55-
extractorMemberType(tp, nme.isDefined, errorPos).isRef(defn.BooleanClass) &&
56-
defn.isProductSubType(tp)
54+
def isProductMatch(tp: Type, numArgs: Int)(implicit ctx: Context) =
55+
0 <= numArgs && numArgs <= Definitions.MaxTupleArity &&
56+
tp.derivesFrom(defn.ProductNType(numArgs).typeSymbol)
5757

5858
/** Does `tp` fit the "get match" conditions as an unapply result type?
5959
* This is the case of `tp` has a `get` member as well as a
@@ -82,28 +82,38 @@ object Applications {
8282

8383
def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: Position = NoPosition)(implicit ctx: Context): List[Type] = {
8484

85+
val unapplyName = unapplyFn.symbol.name
8586
def seqSelector = defn.RepeatedParamType.appliedTo(unapplyResult.elemType :: Nil)
87+
def getTp = extractorMemberType(unapplyResult, nme.get, pos)
8688

8789
def fail = {
88-
ctx.error(i"$unapplyResult is not a valid result type of an unapply method of an extractor", pos)
90+
ctx.error(i"$unapplyResult is not a valid result type of an $unapplyName method of an extractor", pos)
8991
Nil
9092
}
9193

92-
// println(s"unapply $unapplyResult ${extractorMemberType(unapplyResult, nme.isDefined)}")
93-
if (isProductMatch(unapplyResult))
94-
productSelectorTypes(unapplyResult)
95-
else if (isGetMatch(unapplyResult)) {
96-
val getTp = extractorMemberType(unapplyResult, nme.get, pos)
97-
if (unapplyFn.symbol.name == nme.unapplySeq) {
94+
if (unapplyName == nme.unapplySeq) {
95+
if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil
96+
else if (isGetMatch(unapplyResult, pos)) {
9897
val seqArg = boundsToHi(getTp.elemType)
9998
if (seqArg.exists) args.map(Function.const(seqArg))
10099
else fail
101100
}
102-
else getUnapplySelectors(getTp, args, pos)
101+
else fail
102+
}
103+
else {
104+
assert(unapplyName == nme.unapply)
105+
if (isProductMatch(unapplyResult, args.length))
106+
productSelectorTypes(unapplyResult)
107+
else if (isGetMatch(unapplyResult, pos))
108+
getUnapplySelectors(getTp, args, pos)
109+
else if (unapplyResult isRef defn.BooleanClass)
110+
Nil
111+
else if (defn.isProductSubType(unapplyResult))
112+
productSelectorTypes(unapplyResult)
113+
// this will cause a "wrong number of arguments in pattern" error later on,
114+
// which is better than the message in `fail`.
115+
else fail
103116
}
104-
else if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil
105-
else if (unapplyResult isRef defn.BooleanClass) Nil
106-
else fail
107117
}
108118

109119
def wrapDefs(defs: mutable.ListBuffer[Tree], tree: Tree)(implicit ctx: Context): Tree =

tests/pos/Patterns.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,31 @@ object NestedPattern {
108108
val xss: List[List[String]] = ???
109109
val List(List(x)) = xss
110110
}
111+
112+
// Tricky case (exercised by Scala parser combinators) where we use
113+
// both get/isEmpty and product-based pattern matching in different
114+
// matches on the same types.
115+
object ProductAndGet {
116+
117+
trait Result[+T]
118+
case class Success[+T](in: String, x: T) extends Result[T] {
119+
def isEmpty = false
120+
def get: T = x
121+
}
122+
case class Failure[+T](in: String, msg: String) extends Result[T] {
123+
def isEmpty = false
124+
def get: String = msg
125+
}
126+
127+
val r: Result[Int] = ???
128+
129+
r match {
130+
case Success(in, x) => x
131+
case Failure(in, msg) => -1
132+
}
133+
134+
r match {
135+
case Success(x) => x
136+
case Failure(msg) => -1
137+
}
138+
}

0 commit comments

Comments
 (0)