Skip to content

Commit 8643876

Browse files
authored
Merge pull request #1801 from dotty-staging/fix-#1790
Fix #1790: Change by-name pattern matching.
2 parents a5620ab + b1553cb commit 8643876

File tree

11 files changed

+142
-61
lines changed

11 files changed

+142
-61
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ object desugar {
2424

2525
/** Names of methods that are added unconditionally to case classes */
2626
def isDesugaredCaseClassMethodName(name: Name)(implicit ctx: Context): Boolean =
27-
name == nme.isDefined ||
2827
name == nme.copy ||
2928
name == nme.productArity ||
3029
name.isSelectorName
@@ -343,7 +342,6 @@ object desugar {
343342
if (isCaseClass) {
344343
def syntheticProperty(name: TermName, rhs: Tree) =
345344
DefDef(name, Nil, Nil, TypeTree(), rhs).withMods(synthetic)
346-
val isDefinedMeth = syntheticProperty(nme.isDefined, Literal(Constant(true)))
347345
val caseParams = constrVparamss.head.toArray
348346
val productElemMeths = for (i <- 0 until arity) yield
349347
syntheticProperty(nme.selectorName(i), Select(This(EmptyTypeIdent), caseParams(i).name))
@@ -369,7 +367,7 @@ object desugar {
369367
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
370368
.withMods(synthetic) :: Nil
371369
}
372-
copyMeths ::: isDefinedMeth :: productElemMeths.toList
370+
copyMeths ::: productElemMeths.toList
373371
}
374372
else Nil
375373

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -675,7 +675,9 @@ class Definitions {
675675

676676
private def isVarArityClass(cls: Symbol, prefix: Name) = {
677677
val name = scalaClassName(cls)
678-
name.startsWith(prefix) && name.drop(prefix.length).forall(_.isDigit)
678+
name.startsWith(prefix) &&
679+
name.length > prefix.length &&
680+
name.drop(prefix.length).forall(_.isDigit)
679681
}
680682

681683
def isBottomClass(cls: Symbol) =
@@ -735,6 +737,14 @@ class Definitions {
735737
def isProductSubType(tp: Type)(implicit ctx: Context) =
736738
(tp derivesFrom ProductType.symbol) && tp.baseClasses.exists(isProductClass)
737739

740+
def productArity(tp: Type)(implicit ctx: Context) =
741+
if (tp derivesFrom ProductType.symbol)
742+
tp.baseClasses.find(isProductClass) match {
743+
case Some(prod) => prod.typeParams.length
744+
case None => -1
745+
}
746+
else -1
747+
738748
def isFunctionType(tp: Type)(implicit ctx: Context) =
739749
isFunctionClass(tp.dealias.typeSymbol) && {
740750
val arity = functionArity(tp)

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,6 @@ object StdNames {
424424
val info: N = "info"
425425
val inlinedEquals: N = "inlinedEquals"
426426
val isArray: N = "isArray"
427-
val isDefined: N = "isDefined"
428427
val isDefinedAt: N = "isDefinedAt"
429428
val isDefinedAtImpl: N = "$isDefinedAt"
430429
val isEmpty: N = "isEmpty"

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,21 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {
235235
// next: MatchMonad[U]
236236
// returns MatchMonad[U]
237237
def flatMap(prev: Tree, b: Symbol, next: Tree): Tree = {
238-
239-
val getTp = extractorMemberType(prev.tpe, nme.get)
240-
val isDefined = extractorMemberType(prev.tpe, nme.isDefined)
241-
242-
if ((isDefined isRef defn.BooleanClass) && getTp.exists) {
243-
// isDefined and get may be overloaded
244-
val getDenot = prev.tpe.member(nme.get).suchThat(_.info.isParameterless)
245-
val isDefinedDenot = prev.tpe.member(nme.isDefined).suchThat(_.info.isParameterless)
238+
val resultArity = defn.productArity(b.info)
239+
if (isProductMatch(prev.tpe, resultArity)) {
240+
val nullCheck: Tree = prev.select(defn.Object_ne).appliedTo(Literal(Constant(null)))
241+
ifThenElseZero(
242+
nullCheck,
243+
Block(
244+
List(ValDef(b.asTerm, prev)),
245+
next //Substitution(b, ref(prevSym))(next)
246+
)
247+
)
248+
}
249+
else {
250+
val getDenot = extractorMember(prev.tpe, nme.get)
251+
val isEmptyDenot = extractorMember(prev.tpe, nme.isEmpty)
252+
assert(getDenot.exists && isEmptyDenot.exists, i"${prev.tpe}")
246253

247254
val tmpSym = freshSym(prev.pos, prev.tpe, "o")
248255
val prevValue = ref(tmpSym).select(getDenot.symbol).ensureApplied
@@ -251,20 +258,10 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {
251258
List(ValDef(tmpSym, prev)),
252259
// must be isEmpty and get as we don't control the target of the call (prev is an extractor call)
253260
ifThenElseZero(
254-
ref(tmpSym).select(isDefinedDenot.symbol),
261+
ref(tmpSym).select(isEmptyDenot.symbol).select(defn.Boolean_!),
255262
Block(List(ValDef(b.asTerm, prevValue)), next)
256263
)
257264
)
258-
} else {
259-
assert(defn.isProductSubType(prev.tpe))
260-
val nullCheck: Tree = prev.select(defn.Object_ne).appliedTo(Literal(Constant(null)))
261-
ifThenElseZero(
262-
nullCheck,
263-
Block(
264-
List(ValDef(b.asTerm, prev)),
265-
next //Substitution(b, ref(prevSym))(next)
266-
)
267-
)
268265
}
269266
}
270267

@@ -1431,12 +1428,12 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {
14311428
case _ => or
14321429
}
14331430

1434-
def resultInMonad = if (aligner.isBool) defn.UnitType else {
1435-
val getTp = extractorMemberType(resultType, nme.get)
1436-
if ((extractorMemberType(resultType, nme.isDefined) isRef defn.BooleanClass) && getTp.exists)
1437-
getTp
1431+
def resultInMonad =
1432+
if (aligner.isBool) defn.UnitType
1433+
else if (isProductMatch(resultType, aligner.prodArity)) resultType
1434+
else if (isGetMatch(resultType)) extractorMemberType(resultType, nme.get)
14381435
else resultType
1439-
}
1436+
14401437
def resultType: Type
14411438

14421439
/** Create the TreeMaker that embodies this extractor call
@@ -1632,13 +1629,12 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {
16321629
//val spr = subPatRefs(binder)
16331630
assert(go && go1)
16341631
ref(binder) :: Nil
1635-
} else {
1636-
lazy val getTp = extractorMemberType(binderTypeTested, nme.get)
1637-
if ((aligner.isSingle && aligner.extractor.prodArity == 1) && ((extractorMemberType(binderTypeTested, nme.isDefined) isRef defn.BooleanClass) && getTp.exists))
1638-
List(ref(binder))
1639-
else
1640-
subPatRefs(binder)
16411632
}
1633+
else if ((aligner.isSingle && aligner.extractor.prodArity == 1) &&
1634+
!isProductMatch(binderTypeTested, aligner.prodArity) && isGetMatch(binderTypeTested))
1635+
List(ref(binder))
1636+
else
1637+
subPatRefs(binder)
16421638
}
16431639

16441640
/*protected def spliceApply(binder: Symbol): Tree = {
@@ -1890,9 +1886,8 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {
18901886
else if (result.classSymbol is Flags.CaseClass) result.decls.filter(x => x.is(Flags.CaseAccessor) && x.is(Flags.Method)).map(_.info).toList
18911887
else result.select(nme.get) :: Nil
18921888
)*/
1893-
if ((extractorMemberType(resultType, nme.isDefined) isRef defn.BooleanClass) && resultOfGet.exists)
1894-
getUnapplySelectors(resultOfGet, args)
1895-
else if (defn.isProductSubType(resultType)) productSelectorTypes(resultType)
1889+
if (isProductMatch(resultType, args.length)) productSelectorTypes(resultType)
1890+
else if (isGetMatch(resultType)) getUnapplySelectors(resultOfGet, args)
18961891
else if (resultType isRef defn.BooleanClass) Nil
18971892
else {
18981893
ctx.error(i"invalid return type in Unapply node: $resultType")

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,37 @@ import reporting.diagnostic.Message
3232
object Applications {
3333
import tpd._
3434

35+
def extractorMember(tp: Type, name: Name)(implicit ctx: Context) = {
36+
def isPossibleExtractorType(tp: Type) = tp match {
37+
case _: MethodType | _: PolyType => false
38+
case _ => true
39+
}
40+
tp.member(name).suchThat(d => isPossibleExtractorType(d.info))
41+
}
42+
3543
def extractorMemberType(tp: Type, name: Name, errorPos: Position = NoPosition)(implicit ctx: Context) = {
36-
val ref = tp.member(name).suchThat(_.info.isParameterless)
44+
val ref = extractorMember(tp, name)
3745
if (ref.isOverloaded)
3846
errorType(i"Overloaded reference to $ref is not allowed in extractor", errorPos)
39-
else if (ref.info.isInstanceOf[PolyType])
40-
errorType(i"Reference to polymorphic $ref: ${ref.info} is not allowed in extractor", errorPos)
41-
else
42-
ref.info.widenExpr.dealias
47+
ref.info.widenExpr.dealias
4348
}
4449

50+
/** Does `tp` fit the "product match" conditions as an unapply result type
51+
* for a pattern with `numArgs` subpatterns>
52+
* This is the case of `tp` is a subtype of the Product<numArgs> class.
53+
*/
54+
def isProductMatch(tp: Type, numArgs: Int)(implicit ctx: Context) =
55+
0 <= numArgs && numArgs <= Definitions.MaxTupleArity &&
56+
tp.derivesFrom(defn.ProductNType(numArgs).typeSymbol)
57+
58+
/** Does `tp` fit the "get match" conditions as an unapply result type?
59+
* This is the case of `tp` has a `get` member as well as a
60+
* parameterless `isDefined` member of result type `Boolean`.
61+
*/
62+
def isGetMatch(tp: Type, errorPos: Position = NoPosition)(implicit ctx: Context) =
63+
extractorMemberType(tp, nme.isEmpty, errorPos).isRef(defn.BooleanClass) &&
64+
extractorMemberType(tp, nme.get, errorPos).exists
65+
4566
def productSelectorTypes(tp: Type, errorPos: Position = NoPosition)(implicit ctx: Context): List[Type] = {
4667
val sels = for (n <- Iterator.from(0)) yield extractorMemberType(tp, nme.selectorName(n), errorPos)
4768
sels.takeWhile(_.exists).toList
@@ -61,24 +82,37 @@ object Applications {
6182

6283
def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: Position = NoPosition)(implicit ctx: Context): List[Type] = {
6384

85+
val unapplyName = unapplyFn.symbol.name
6486
def seqSelector = defn.RepeatedParamType.appliedTo(unapplyResult.elemType :: Nil)
6587
def getTp = extractorMemberType(unapplyResult, nme.get, pos)
6688

67-
// println(s"unapply $unapplyResult ${extractorMemberType(unapplyResult, nme.isDefined)}")
68-
if (extractorMemberType(unapplyResult, nme.isDefined, pos) isRef defn.BooleanClass) {
69-
if (getTp.exists)
70-
if (unapplyFn.symbol.name == nme.unapplySeq) {
71-
val seqArg = boundsToHi(getTp.elemType)
72-
if (seqArg.exists) return args map Function.const(seqArg)
73-
}
74-
else return getUnapplySelectors(getTp, args, pos)
75-
else if (defn.isProductSubType(unapplyResult)) return productSelectorTypes(unapplyResult, pos)
89+
def fail = {
90+
ctx.error(i"$unapplyResult is not a valid result type of an $unapplyName method of an extractor", pos)
91+
Nil
92+
}
93+
94+
if (unapplyName == nme.unapplySeq) {
95+
if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil
96+
else if (isGetMatch(unapplyResult, pos)) {
97+
val seqArg = boundsToHi(getTp.elemType)
98+
if (seqArg.exists) args.map(Function.const(seqArg))
99+
else fail
100+
}
101+
else fail
76102
}
77-
if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil
78-
else if (unapplyResult isRef defn.BooleanClass) Nil
79103
else {
80-
ctx.error(i"$unapplyResult is not a valid result type of an unapply method of an extractor", pos)
81-
Nil
104+
assert(unapplyName == nme.unapply)
105+
if (isProductMatch(unapplyResult, args.length))
106+
productSelectorTypes(unapplyResult)
107+
else if (isGetMatch(unapplyResult, pos))
108+
getUnapplySelectors(getTp, args, pos)
109+
else if (unapplyResult isRef defn.BooleanClass)
110+
Nil
111+
else if (defn.isProductSubType(unapplyResult))
112+
productSelectorTypes(unapplyResult)
113+
// this will cause a "wrong number of arguments in pattern" error later on,
114+
// which is better than the message in `fail`.
115+
else fail
82116
}
83117
}
84118

compiler/test/dotc/scala-collections.whitelist

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,5 @@
280280
../scala-scala/src/library/scala/collection/generic/Subtractable.scala
281281
../scala-scala/src/library/scala/collection/generic/TraversableFactory.scala
282282
../scala-scala/src/library/scala/collection/generic/package.scala
283+
284+
../scala-scala/src/library/scala/util/Try.scala

tests/pos/Patterns.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,31 @@ object NestedPattern {
108108
val xss: List[List[String]] = ???
109109
val List(List(x)) = xss
110110
}
111+
112+
// Tricky case (exercised by Scala parser combinators) where we use
113+
// both get/isEmpty and product-based pattern matching in different
114+
// matches on the same types.
115+
object ProductAndGet {
116+
117+
trait Result[+T]
118+
case class Success[+T](in: String, x: T) extends Result[T] {
119+
def isEmpty = false
120+
def get: T = x
121+
}
122+
case class Failure[+T](in: String, msg: String) extends Result[T] {
123+
def isEmpty = false
124+
def get: String = msg
125+
}
126+
127+
val r: Result[Int] = ???
128+
129+
r match {
130+
case Success(in, x) => x
131+
case Failure(in, msg) => -1
132+
}
133+
134+
r match {
135+
case Success(x) => x
136+
case Failure(msg) => -1
137+
}
138+
}

tests/pos/i1540.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
class Casey1(val a: Int) {
2-
def isDefined: Boolean = true
3-
def isDefined(x: Int): Boolean = ???
2+
def isEmpty: Boolean = false
3+
def isEmpty(x: Int): Boolean = ???
44
def get: Int = a
55
def get(x: Int): String = ???
66
}

tests/pos/i1540b.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
class Casey1[T](val a: T) {
2-
def isDefined: Boolean = true
3-
def isDefined(x: T): Boolean = ???
2+
def isEmpty: Boolean = false
3+
def isEmpty(x: T): Boolean = ???
44
def get: T = a
55
def get(x: T): String = ???
66
}

tests/pos/i1790.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import scala.util.control.NonFatal
2+
3+
class Try[+T] {
4+
def transform[U](s: T => Try[U], f: Throwable => Try[U]): Try[U] =
5+
try this match {
6+
case Success(v) => s(v)
7+
case Failure(e) => f(e)
8+
} catch {
9+
case NonFatal(e) => Failure(e)
10+
}
11+
}
12+
final case class Success[+T](value: T) extends Try[T]
13+
final case class Failure[+T](exception: Throwable) extends Try[T] {
14+
def get: T = throw exception
15+
}

tests/pos/pos_valueclasses/optmatch.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ package optmatch
77

88
class NonZeroLong(val value: Long) extends AnyVal {
99
def get: Long = value
10-
def isDefined: Boolean = get != 0l
10+
def isEmpty: Boolean = get == 0l
1111
}
1212
object NonZeroLong {
1313
def unapply(value: Long): NonZeroLong = new NonZeroLong(value)

0 commit comments

Comments
 (0)