diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 15cb0b665a5d..11f8b81eb26c 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -24,7 +24,6 @@ object desugar { /** Names of methods that are added unconditionally to case classes */ def isDesugaredCaseClassMethodName(name: Name)(implicit ctx: Context): Boolean = - name == nme.isDefined || name == nme.copy || name == nme.productArity || name.isSelectorName @@ -343,7 +342,6 @@ object desugar { if (isCaseClass) { def syntheticProperty(name: TermName, rhs: Tree) = DefDef(name, Nil, Nil, TypeTree(), rhs).withMods(synthetic) - val isDefinedMeth = syntheticProperty(nme.isDefined, Literal(Constant(true))) val caseParams = constrVparamss.head.toArray val productElemMeths = for (i <- 0 until arity) yield syntheticProperty(nme.selectorName(i), Select(This(EmptyTypeIdent), caseParams(i).name)) @@ -369,7 +367,7 @@ object desugar { DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr) .withMods(synthetic) :: Nil } - copyMeths ::: isDefinedMeth :: productElemMeths.toList + copyMeths ::: productElemMeths.toList } else Nil diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index e4e5761b29b6..9759e39fcc5b 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -675,7 +675,9 @@ class Definitions { private def isVarArityClass(cls: Symbol, prefix: Name) = { val name = scalaClassName(cls) - name.startsWith(prefix) && name.drop(prefix.length).forall(_.isDigit) + name.startsWith(prefix) && + name.length > prefix.length && + name.drop(prefix.length).forall(_.isDigit) } def isBottomClass(cls: Symbol) = @@ -735,6 +737,14 @@ class Definitions { def isProductSubType(tp: Type)(implicit ctx: Context) = (tp derivesFrom ProductType.symbol) && tp.baseClasses.exists(isProductClass) + def productArity(tp: Type)(implicit ctx: Context) = + if (tp derivesFrom ProductType.symbol) + tp.baseClasses.find(isProductClass) match { + case Some(prod) => prod.typeParams.length + case None => -1 + } + else -1 + def isFunctionType(tp: Type)(implicit ctx: Context) = isFunctionClass(tp.dealias.typeSymbol) && { val arity = functionArity(tp) diff --git a/compiler/src/dotty/tools/dotc/core/StdNames.scala b/compiler/src/dotty/tools/dotc/core/StdNames.scala index 741ff8b1fef2..e71893c1ee2c 100644 --- a/compiler/src/dotty/tools/dotc/core/StdNames.scala +++ b/compiler/src/dotty/tools/dotc/core/StdNames.scala @@ -424,7 +424,6 @@ object StdNames { val info: N = "info" val inlinedEquals: N = "inlinedEquals" val isArray: N = "isArray" - val isDefined: N = "isDefined" val isDefinedAt: N = "isDefinedAt" val isDefinedAtImpl: N = "$isDefinedAt" val isEmpty: N = "isEmpty" diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index 3e25cf82eb44..181dfccd9c72 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -235,14 +235,21 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer { // next: MatchMonad[U] // returns MatchMonad[U] def flatMap(prev: Tree, b: Symbol, next: Tree): Tree = { - - val getTp = extractorMemberType(prev.tpe, nme.get) - val isDefined = extractorMemberType(prev.tpe, nme.isDefined) - - if ((isDefined isRef defn.BooleanClass) && getTp.exists) { - // isDefined and get may be overloaded - val getDenot = prev.tpe.member(nme.get).suchThat(_.info.isParameterless) - val isDefinedDenot = prev.tpe.member(nme.isDefined).suchThat(_.info.isParameterless) + val resultArity = defn.productArity(b.info) + if (isProductMatch(prev.tpe, resultArity)) { + val nullCheck: Tree = prev.select(defn.Object_ne).appliedTo(Literal(Constant(null))) + ifThenElseZero( + nullCheck, + Block( + List(ValDef(b.asTerm, prev)), + next //Substitution(b, ref(prevSym))(next) + ) + ) + } + else { + val getDenot = extractorMember(prev.tpe, nme.get) + val isEmptyDenot = extractorMember(prev.tpe, nme.isEmpty) + assert(getDenot.exists && isEmptyDenot.exists, i"${prev.tpe}") val tmpSym = freshSym(prev.pos, prev.tpe, "o") val prevValue = ref(tmpSym).select(getDenot.symbol).ensureApplied @@ -251,20 +258,10 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer { List(ValDef(tmpSym, prev)), // must be isEmpty and get as we don't control the target of the call (prev is an extractor call) ifThenElseZero( - ref(tmpSym).select(isDefinedDenot.symbol), + ref(tmpSym).select(isEmptyDenot.symbol).select(defn.Boolean_!), Block(List(ValDef(b.asTerm, prevValue)), next) ) ) - } else { - assert(defn.isProductSubType(prev.tpe)) - val nullCheck: Tree = prev.select(defn.Object_ne).appliedTo(Literal(Constant(null))) - ifThenElseZero( - nullCheck, - Block( - List(ValDef(b.asTerm, prev)), - next //Substitution(b, ref(prevSym))(next) - ) - ) } } @@ -1431,12 +1428,12 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer { case _ => or } - def resultInMonad = if (aligner.isBool) defn.UnitType else { - val getTp = extractorMemberType(resultType, nme.get) - if ((extractorMemberType(resultType, nme.isDefined) isRef defn.BooleanClass) && getTp.exists) - getTp + def resultInMonad = + if (aligner.isBool) defn.UnitType + else if (isProductMatch(resultType, aligner.prodArity)) resultType + else if (isGetMatch(resultType)) extractorMemberType(resultType, nme.get) else resultType - } + def resultType: Type /** Create the TreeMaker that embodies this extractor call @@ -1632,13 +1629,12 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer { //val spr = subPatRefs(binder) assert(go && go1) ref(binder) :: Nil - } else { - lazy val getTp = extractorMemberType(binderTypeTested, nme.get) - if ((aligner.isSingle && aligner.extractor.prodArity == 1) && ((extractorMemberType(binderTypeTested, nme.isDefined) isRef defn.BooleanClass) && getTp.exists)) - List(ref(binder)) - else - subPatRefs(binder) } + else if ((aligner.isSingle && aligner.extractor.prodArity == 1) && + !isProductMatch(binderTypeTested, aligner.prodArity) && isGetMatch(binderTypeTested)) + List(ref(binder)) + else + subPatRefs(binder) } /*protected def spliceApply(binder: Symbol): Tree = { @@ -1890,9 +1886,8 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer { else if (result.classSymbol is Flags.CaseClass) result.decls.filter(x => x.is(Flags.CaseAccessor) && x.is(Flags.Method)).map(_.info).toList else result.select(nme.get) :: Nil )*/ - if ((extractorMemberType(resultType, nme.isDefined) isRef defn.BooleanClass) && resultOfGet.exists) - getUnapplySelectors(resultOfGet, args) - else if (defn.isProductSubType(resultType)) productSelectorTypes(resultType) + if (isProductMatch(resultType, args.length)) productSelectorTypes(resultType) + else if (isGetMatch(resultType)) getUnapplySelectors(resultOfGet, args) else if (resultType isRef defn.BooleanClass) Nil else { ctx.error(i"invalid return type in Unapply node: $resultType") diff --git a/compiler/src/dotty/tools/dotc/typer/Applications.scala b/compiler/src/dotty/tools/dotc/typer/Applications.scala index 11121e1f33d6..d34804865f3e 100644 --- a/compiler/src/dotty/tools/dotc/typer/Applications.scala +++ b/compiler/src/dotty/tools/dotc/typer/Applications.scala @@ -32,16 +32,37 @@ import reporting.diagnostic.Message object Applications { import tpd._ + def extractorMember(tp: Type, name: Name)(implicit ctx: Context) = { + def isPossibleExtractorType(tp: Type) = tp match { + case _: MethodType | _: PolyType => false + case _ => true + } + tp.member(name).suchThat(d => isPossibleExtractorType(d.info)) + } + def extractorMemberType(tp: Type, name: Name, errorPos: Position = NoPosition)(implicit ctx: Context) = { - val ref = tp.member(name).suchThat(_.info.isParameterless) + val ref = extractorMember(tp, name) if (ref.isOverloaded) errorType(i"Overloaded reference to $ref is not allowed in extractor", errorPos) - else if (ref.info.isInstanceOf[PolyType]) - errorType(i"Reference to polymorphic $ref: ${ref.info} is not allowed in extractor", errorPos) - else - ref.info.widenExpr.dealias + ref.info.widenExpr.dealias } + /** Does `tp` fit the "product match" conditions as an unapply result type + * for a pattern with `numArgs` subpatterns> + * This is the case of `tp` is a subtype of the Product class. + */ + def isProductMatch(tp: Type, numArgs: Int)(implicit ctx: Context) = + 0 <= numArgs && numArgs <= Definitions.MaxTupleArity && + tp.derivesFrom(defn.ProductNType(numArgs).typeSymbol) + + /** Does `tp` fit the "get match" conditions as an unapply result type? + * This is the case of `tp` has a `get` member as well as a + * parameterless `isDefined` member of result type `Boolean`. + */ + def isGetMatch(tp: Type, errorPos: Position = NoPosition)(implicit ctx: Context) = + extractorMemberType(tp, nme.isEmpty, errorPos).isRef(defn.BooleanClass) && + extractorMemberType(tp, nme.get, errorPos).exists + def productSelectorTypes(tp: Type, errorPos: Position = NoPosition)(implicit ctx: Context): List[Type] = { val sels = for (n <- Iterator.from(0)) yield extractorMemberType(tp, nme.selectorName(n), errorPos) sels.takeWhile(_.exists).toList @@ -61,24 +82,37 @@ object Applications { def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: Position = NoPosition)(implicit ctx: Context): List[Type] = { + val unapplyName = unapplyFn.symbol.name def seqSelector = defn.RepeatedParamType.appliedTo(unapplyResult.elemType :: Nil) def getTp = extractorMemberType(unapplyResult, nme.get, pos) - // println(s"unapply $unapplyResult ${extractorMemberType(unapplyResult, nme.isDefined)}") - if (extractorMemberType(unapplyResult, nme.isDefined, pos) isRef defn.BooleanClass) { - if (getTp.exists) - if (unapplyFn.symbol.name == nme.unapplySeq) { - val seqArg = boundsToHi(getTp.elemType) - if (seqArg.exists) return args map Function.const(seqArg) - } - else return getUnapplySelectors(getTp, args, pos) - else if (defn.isProductSubType(unapplyResult)) return productSelectorTypes(unapplyResult, pos) + def fail = { + ctx.error(i"$unapplyResult is not a valid result type of an $unapplyName method of an extractor", pos) + Nil + } + + if (unapplyName == nme.unapplySeq) { + if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil + else if (isGetMatch(unapplyResult, pos)) { + val seqArg = boundsToHi(getTp.elemType) + if (seqArg.exists) args.map(Function.const(seqArg)) + else fail + } + else fail } - if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil - else if (unapplyResult isRef defn.BooleanClass) Nil else { - ctx.error(i"$unapplyResult is not a valid result type of an unapply method of an extractor", pos) - Nil + assert(unapplyName == nme.unapply) + if (isProductMatch(unapplyResult, args.length)) + productSelectorTypes(unapplyResult) + else if (isGetMatch(unapplyResult, pos)) + getUnapplySelectors(getTp, args, pos) + else if (unapplyResult isRef defn.BooleanClass) + Nil + else if (defn.isProductSubType(unapplyResult)) + productSelectorTypes(unapplyResult) + // this will cause a "wrong number of arguments in pattern" error later on, + // which is better than the message in `fail`. + else fail } } diff --git a/compiler/test/dotc/scala-collections.whitelist b/compiler/test/dotc/scala-collections.whitelist index bb62b260a4c5..e984af6c622c 100644 --- a/compiler/test/dotc/scala-collections.whitelist +++ b/compiler/test/dotc/scala-collections.whitelist @@ -280,3 +280,5 @@ ../scala-scala/src/library/scala/collection/generic/Subtractable.scala ../scala-scala/src/library/scala/collection/generic/TraversableFactory.scala ../scala-scala/src/library/scala/collection/generic/package.scala + +../scala-scala/src/library/scala/util/Try.scala \ No newline at end of file diff --git a/tests/pos/Patterns.scala b/tests/pos/Patterns.scala index aa369a77bce2..fd0d7e97ace4 100644 --- a/tests/pos/Patterns.scala +++ b/tests/pos/Patterns.scala @@ -108,3 +108,31 @@ object NestedPattern { val xss: List[List[String]] = ??? val List(List(x)) = xss } + +// Tricky case (exercised by Scala parser combinators) where we use +// both get/isEmpty and product-based pattern matching in different +// matches on the same types. +object ProductAndGet { + + trait Result[+T] + case class Success[+T](in: String, x: T) extends Result[T] { + def isEmpty = false + def get: T = x + } + case class Failure[+T](in: String, msg: String) extends Result[T] { + def isEmpty = false + def get: String = msg + } + + val r: Result[Int] = ??? + + r match { + case Success(in, x) => x + case Failure(in, msg) => -1 + } + + r match { + case Success(x) => x + case Failure(msg) => -1 + } +} diff --git a/tests/pos/i1540.scala b/tests/pos/i1540.scala index 7aa24f459947..0fdfea23555d 100644 --- a/tests/pos/i1540.scala +++ b/tests/pos/i1540.scala @@ -1,6 +1,6 @@ class Casey1(val a: Int) { - def isDefined: Boolean = true - def isDefined(x: Int): Boolean = ??? + def isEmpty: Boolean = false + def isEmpty(x: Int): Boolean = ??? def get: Int = a def get(x: Int): String = ??? } diff --git a/tests/pos/i1540b.scala b/tests/pos/i1540b.scala index 2b4c5408ea49..f4408b0c7aab 100644 --- a/tests/pos/i1540b.scala +++ b/tests/pos/i1540b.scala @@ -1,6 +1,6 @@ class Casey1[T](val a: T) { - def isDefined: Boolean = true - def isDefined(x: T): Boolean = ??? + def isEmpty: Boolean = false + def isEmpty(x: T): Boolean = ??? def get: T = a def get(x: T): String = ??? } diff --git a/tests/pos/i1790.scala b/tests/pos/i1790.scala new file mode 100644 index 000000000000..7535255f9edc --- /dev/null +++ b/tests/pos/i1790.scala @@ -0,0 +1,15 @@ +import scala.util.control.NonFatal + +class Try[+T] { + def transform[U](s: T => Try[U], f: Throwable => Try[U]): Try[U] = + try this match { + case Success(v) => s(v) + case Failure(e) => f(e) + } catch { + case NonFatal(e) => Failure(e) + } +} +final case class Success[+T](value: T) extends Try[T] +final case class Failure[+T](exception: Throwable) extends Try[T] { + def get: T = throw exception +} diff --git a/tests/pos/pos_valueclasses/optmatch.scala b/tests/pos/pos_valueclasses/optmatch.scala index a7995a455f17..ff1a17906ddf 100644 --- a/tests/pos/pos_valueclasses/optmatch.scala +++ b/tests/pos/pos_valueclasses/optmatch.scala @@ -7,7 +7,7 @@ package optmatch class NonZeroLong(val value: Long) extends AnyVal { def get: Long = value - def isDefined: Boolean = get != 0l + def isEmpty: Boolean = get == 0l } object NonZeroLong { def unapply(value: Long): NonZeroLong = new NonZeroLong(value)