Skip to content

Commit 10230b8

Browse files
committed
Fix #i8690: the signature for product should come from expected type
1 parent 2352d90 commit 10230b8

File tree

2 files changed

+67
-42
lines changed

2 files changed

+67
-42
lines changed

compiler/src/dotty/tools/dotc/transform/patmat/Space.scala

Lines changed: 60 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ case object Empty extends Space
6969
case class Typ(tp: Type, decomposed: Boolean = true) extends Space
7070

7171
/** Space representing an extractor pattern */
72-
case class Prod(tp: Type, unappTp: Type, unappSym: Symbol, params: List[Space], full: Boolean) extends Space
72+
case class Prod(tp: Type, unappTp: TermRef, params: List[Space], full: Boolean) extends Space
7373

7474
/** Union of spaces */
7575
case class Or(spaces: List[Space]) extends Space
@@ -97,7 +97,7 @@ trait SpaceLogic {
9797
def canDecompose(tp: Type): Boolean
9898

9999
/** Return term parameter types of the extractor `unapp` */
100-
def signature(unapp: Type, unappSym: Symbol, argLen: Int): List[Type]
100+
def signature(unapp: TermRef, scrutineeTp: Type, argLen: Int): List[Type]
101101

102102
/** Get components of decomposable types */
103103
def decompose(tp: Type): List[Space]
@@ -113,8 +113,8 @@ trait SpaceLogic {
113113
* This reduces noise in counterexamples.
114114
*/
115115
def simplify(space: Space, aggressive: Boolean = false)(implicit ctx: Context): Space = trace(s"simplify ${show(space)}, aggressive = $aggressive --> ", debug, x => show(x.asInstanceOf[Space]))(space match {
116-
case Prod(tp, fun, sym, spaces, full) =>
117-
val sp = Prod(tp, fun, sym, spaces.map(simplify(_)), full)
116+
case Prod(tp, fun, spaces, full) =>
117+
val sp = Prod(tp, fun, spaces.map(simplify(_)), full)
118118
if (sp.params.contains(Empty)) Empty
119119
else if (canDecompose(tp) && decompose(tp).isEmpty) Empty
120120
else sp
@@ -143,13 +143,13 @@ trait SpaceLogic {
143143

144144
/** Flatten space to get rid of `Or` for pretty print */
145145
def flatten(space: Space)(implicit ctx: Context): List[Space] = space match {
146-
case Prod(tp, fun, sym, spaces, full) =>
146+
case Prod(tp, fun, spaces, full) =>
147147
spaces.map(flatten) match {
148-
case Nil => Prod(tp, fun, sym, Nil, full) :: Nil
148+
case Nil => Prod(tp, fun, Nil, full) :: Nil
149149
case ss =>
150150
ss.foldLeft(List[Prod]()) { (acc, flat) =>
151-
if (acc.isEmpty) flat.map(s => Prod(tp, fun, sym, s :: Nil, full))
152-
else for (Prod(tp, fun, sym, ss, full) <- acc; s <- flat) yield Prod(tp, fun, sym, ss :+ s, full)
151+
if (acc.isEmpty) flat.map(s => Prod(tp, fun, s :: Nil, full))
152+
else for (Prod(tp, fun, ss, full) <- acc; s <- flat) yield Prod(tp, fun, ss :+ s, full)
153153
}
154154
}
155155
case Or(spaces) =>
@@ -173,13 +173,13 @@ trait SpaceLogic {
173173
ss.exists(isSubspace(a, _)) || tryDecompose1(tp1)
174174
case (_, Or(_)) =>
175175
simplify(minus(a, b)) == Empty
176-
case (Prod(tp1, _, _, _, _), Typ(tp2, _)) =>
176+
case (Prod(tp1, _, _, _), Typ(tp2, _)) =>
177177
isSubType(tp1, tp2)
178-
case (Typ(tp1, _), Prod(tp2, fun, sym, ss, full)) =>
178+
case (Typ(tp1, _), Prod(tp2, fun, ss, full)) =>
179179
// approximation: a type can never be fully matched by a partial extractor
180-
full && isSubType(tp1, tp2) && isSubspace(Prod(tp2, fun, sym, signature(fun, sym, ss.length).map(Typ(_, false)), full), b)
181-
case (Prod(_, fun1, sym1, ss1, _), Prod(_, fun2, sym2, ss2, _)) =>
182-
sym1 == sym2 && isEqualType(fun1, fun2) && ss1.zip(ss2).forall((isSubspace _).tupled)
180+
full && isSubType(tp1, tp2) && isSubspace(Prod(tp2, fun, signature(fun, tp2, ss.length).map(Typ(_, false)), full), b)
181+
case (Prod(_, fun1, ss1, _), Prod(_, fun2, ss2, _)) =>
182+
isEqualType(fun1, fun2) && ss1.zip(ss2).forall((isSubspace _).tupled)
183183
}
184184
}
185185

@@ -198,28 +198,28 @@ trait SpaceLogic {
198198
else if (canDecompose(tp1)) tryDecompose1(tp1)
199199
else if (canDecompose(tp2)) tryDecompose2(tp2)
200200
else intersectUnrelatedAtomicTypes(tp1, tp2)
201-
case (Typ(tp1, _), Prod(tp2, fun, _, ss, true)) =>
201+
case (Typ(tp1, _), Prod(tp2, fun, ss, true)) =>
202202
if (isSubType(tp2, tp1)) b
203203
else if (isSubType(tp1, tp2)) a // problematic corner case: inheriting a case class
204204
else if (canDecompose(tp1)) tryDecompose1(tp1)
205205
else Empty
206-
case (Typ(tp1, _), Prod(tp2, _, _, _, false)) =>
206+
case (Typ(tp1, _), Prod(tp2, _, _, false)) =>
207207
if (isSubType(tp1, tp2) || isSubType(tp2, tp1)) b // prefer extractor space for better approximation
208208
else if (canDecompose(tp1)) tryDecompose1(tp1)
209209
else Empty
210-
case (Prod(tp1, fun, _, ss, true), Typ(tp2, _)) =>
210+
case (Prod(tp1, fun, ss, true), Typ(tp2, _)) =>
211211
if (isSubType(tp1, tp2)) a
212212
else if (isSubType(tp2, tp1)) a // problematic corner case: inheriting a case class
213213
else if (canDecompose(tp2)) tryDecompose2(tp2)
214214
else Empty
215-
case (Prod(tp1, _, _, _, false), Typ(tp2, _)) =>
215+
case (Prod(tp1, _, _, false), Typ(tp2, _)) =>
216216
if (isSubType(tp1, tp2) || isSubType(tp2, tp1)) a
217217
else if (canDecompose(tp2)) tryDecompose2(tp2)
218218
else Empty
219-
case (Prod(tp1, fun1, sym1, ss1, full), Prod(tp2, fun2, sym2, ss2, _)) =>
220-
if (sym1 != sym2 || !isEqualType(fun1, fun2)) Empty
219+
case (Prod(tp1, fun1, ss1, full), Prod(tp2, fun2, ss2, _)) =>
220+
if (!isEqualType(fun1, fun2)) Empty
221221
else if (ss1.zip(ss2).exists(p => simplify(intersect(p._1, p._2)) == Empty)) Empty
222-
else Prod(tp1, fun1, sym1, ss1.zip(ss2).map((intersect _).tupled), full)
222+
else Prod(tp1, fun1, ss1.zip(ss2).map((intersect _).tupled), full)
223223
}
224224
}
225225

@@ -236,34 +236,34 @@ trait SpaceLogic {
236236
else if (canDecompose(tp1)) tryDecompose1(tp1)
237237
else if (canDecompose(tp2)) tryDecompose2(tp2)
238238
else a
239-
case (Typ(tp1, _), Prod(tp2, fun, sym, ss, true)) =>
239+
case (Typ(tp1, _), Prod(tp2, fun, ss, true)) =>
240240
// rationale: every instance of `tp1` is covered by `tp2(_)`
241-
if (isSubType(tp1, tp2)) minus(Prod(tp1, fun, sym, signature(fun, sym, ss.length).map(Typ(_, false)), true), b)
241+
if (isSubType(tp1, tp2)) minus(Prod(tp1, fun, signature(fun, tp1, ss.length).map(Typ(_, false)), true), b)
242242
else if (canDecompose(tp1)) tryDecompose1(tp1)
243243
else a
244244
case (_, Or(ss)) =>
245245
ss.foldLeft(a)(minus)
246246
case (Or(ss), _) =>
247247
Or(ss.map(minus(_, b)))
248-
case (Prod(tp1, fun, _, ss, true), Typ(tp2, _)) =>
248+
case (Prod(tp1, fun, ss, true), Typ(tp2, _)) =>
249249
// uncovered corner case: tp2 :< tp1
250250
if (isSubType(tp1, tp2)) Empty
251251
else if (simplify(a) == Empty) Empty
252252
else if (canDecompose(tp2)) tryDecompose2(tp2)
253253
else a
254-
case (Prod(tp1, _, _, _, false), Typ(tp2, _)) =>
254+
case (Prod(tp1, _, _, false), Typ(tp2, _)) =>
255255
if (isSubType(tp1, tp2)) Empty
256256
else a
257-
case (Typ(tp1, _), Prod(tp2, _, _, _, false)) =>
257+
case (Typ(tp1, _), Prod(tp2, _, _, false)) =>
258258
a // approximation
259-
case (Prod(tp1, fun1, sym1, ss1, full), Prod(tp2, fun2, sym2, ss2, _)) =>
260-
if (sym1 != sym2 || !isEqualType(fun1, fun2)) a
259+
case (Prod(tp1, fun1, ss1, full), Prod(tp2, fun2, ss2, _)) =>
260+
if (!isEqualType(fun1, fun2)) a
261261
else if (ss1.zip(ss2).exists(p => simplify(intersect(p._1, p._2)) == Empty)) a
262262
else if (ss1.zip(ss2).forall((isSubspace _).tupled)) Empty
263263
else
264264
// `(_, _, _) - (Some, None, _)` becomes `(None, _, _) | (_, Some, _) | (_, _, Empty)`
265265
Or(ss1.zip(ss2).map((minus _).tupled).zip(0 to ss2.length - 1).map {
266-
case (ri, i) => Prod(tp1, fun1, sym1, ss1.updated(i, ri), full)
266+
case (ri, i) => Prod(tp1, fun1, ss1.updated(i, ri), full)
267267
})
268268

269269
}
@@ -360,18 +360,20 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
360360
case Bind(_, pat) => project(pat)
361361
case SeqLiteral(pats, _) => projectSeq(pats)
362362
case UnApply(fun, _, pats) =>
363+
val (fun1, _, _) = decomposeCall(fun)
364+
val funRef = fun1.tpe.asInstanceOf[TermRef]
363365
if (fun.symbol.name == nme.unapplySeq)
364366
if (fun.symbol.owner == scalaSeqFactoryClass)
365367
projectSeq(pats)
366368
else {
367369
val (arity, elemTp, resultTp) = unapplySeqInfo(fun.tpe.widen.finalResultType, fun.sourcePos)
368370
if (elemTp.exists)
369-
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, projectSeq(pats) :: Nil, isIrrefutableUnapplySeq(fun, pats.size))
371+
Prod(erase(pat.tpe.stripAnnots), funRef, projectSeq(pats) :: Nil, isIrrefutableUnapplySeq(fun, pats.size))
370372
else
371-
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)), isIrrefutableUnapplySeq(fun, pats.size))
373+
Prod(erase(pat.tpe.stripAnnots), funRef, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)), isIrrefutableUnapplySeq(fun, pats.size))
372374
}
373375
else
374-
Prod(erase(pat.tpe.stripAnnots), erase(fun.tpe), fun.symbol, pats.map(project), isIrrefutableUnapply(fun, pats.length))
376+
Prod(erase(pat.tpe.stripAnnots), funRef, pats.map(project), isIrrefutableUnapply(fun, pats.length))
375377
case Typed(pat @ UnApply(_, _, _), _) => project(pat)
376378
case Typed(expr, _) =>
377379
Typ(erase(expr.tpe.stripAnnots), true)
@@ -386,6 +388,11 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
386388
Typ(pat.tpe.narrow, false)
387389
}
388390

391+
private def project(tp: Type): Space = tp match {
392+
case OrType(tp1, tp2) => Or(project(tp1) :: project(tp2) :: Nil)
393+
case tp => Typ(tp, decomposed = true)
394+
}
395+
389396
private def unapplySeqInfo(resTp: Type, pos: SourcePosition)(implicit ctx: Context): (Int, Type, Type) = {
390397
var resultTp = resTp
391398
var elemTp = unapplySeqTypeElemTp(resultTp)
@@ -458,11 +465,10 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
458465
else
459466
(pats, Typ(scalaNilType, false))
460467

468+
val unapplyTp = scalaConsType.classSymbol.companionModule.termRef.select(nme.unapply)
461469
items.foldRight[Space](zero) { (pat, acc) =>
462470
val consTp = scalaConsType.appliedTo(pats.head.tpe.widen)
463-
val unapplySym = consTp.classSymbol.linkedClass.info.member(nme.unapply).symbol
464-
val unapplyTp = unapplySym.info.appliedTo(pats.head.tpe.widen)
465-
Prod(consTp, unapplyTp, unapplySym, project(pat) :: acc :: Nil, true)
471+
Prod(consTp, unapplyTp, project(pat) :: acc :: Nil, true)
466472
}
467473
}
468474

@@ -480,15 +486,26 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
480486
def isEqualType(tp1: Type, tp2: Type): Boolean = tp1 =:= tp2
481487

482488
/** Parameter types of the case class type `tp`. Adapted from `unapplyPlan` in patternMatcher */
483-
def signature(unapp: Type, unappSym: Symbol, argLen: Int): List[Type] = {
489+
def signature(unapp: TermRef, scrutineeTp: Type, argLen: Int): List[Type] = {
490+
val unappSym = unapp.symbol
484491
def caseClass = unappSym.owner.linkedClass
485492

486493
lazy val caseAccessors = caseClass.caseAccessors.filter(_.is(Method))
487494

488495
def isSyntheticScala2Unapply(sym: Symbol) =
489496
sym.isAllOf(SyntheticCase) && sym.owner.is(Scala2x)
490497

491-
val mt @ MethodType(_) = unapp.widen
498+
val mt: MethodType = unapp.widen match {
499+
case mt: MethodType => mt
500+
case pt: PolyType =>
501+
inContext(ctx.fresh.setNewTyperState()) {
502+
val tvars = pt.paramInfos.map(newTypeVar)
503+
val mt = pt.instantiate(tvars).asInstanceOf[MethodType]
504+
scrutineeTp <:< mt.paramInfos(0)
505+
instantiateSelected(mt, tvars)
506+
mt
507+
}
508+
}
492509

493510
// Case unapply:
494511
// 1. return types of constructor fields if the extractor is synthesized for Scala2 case classes & length match
@@ -657,8 +674,8 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
657674
def impossible: Nothing = throw new AssertionError("`satisfiable` only accepts flattened space.")
658675

659676
def genConstraint(space: Space): List[(Type, Type)] = space match {
660-
case Prod(tp, unappTp, unappSym, ss, _) =>
661-
val tps = signature(unappTp, unappSym, ss.length)
677+
case Prod(tp, unappTp, ss, _) =>
678+
val tps = signature(unappTp, tp, ss.length)
662679
ss.zip(tps).flatMap {
663680
case (sp : Prod, tp) => sp.tp -> tp :: genConstraint(sp)
664681
case (Typ(tp1, _), tp2) => tp1 -> tp2 :: Nil
@@ -717,13 +734,14 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
717734
showType(tp) + params(tp).map(_ => "_").mkString("(", ", ", ")")
718735
else if (decomposed) "_: " + showType(tp, showTypeArgs = true)
719736
else "_"
720-
case Prod(tp, fun, sym, params, _) =>
737+
case Prod(tp, fun, params, _) =>
721738
if (ctx.definitions.isTupleType(tp))
722739
"(" + params.map(doShow(_)).mkString(", ") + ")"
723740
else if (tp.isRef(scalaConsType.symbol))
724741
if (flattenList) params.map(doShow(_, flattenList)).mkString(", ")
725742
else params.map(doShow(_, flattenList = true)).filter(!_.isEmpty).mkString("List(", ", ", ")")
726743
else {
744+
val sym = fun.symbol
727745
val isUnapplySeq = sym.name.eq(nme.unapplySeq)
728746
val paramsStr = params.map(doShow(_, flattenList = isUnapplySeq)).mkString("(", ", ", ")")
729747
showType(sym.owner.typeRef) + paramsStr
@@ -781,7 +799,7 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
781799
val checkGADTSAT = shouldCheckExamples(selTyp)
782800

783801
val uncovered =
784-
flatten(simplify(minus(Typ(selTyp, true), patternSpace), aggressive = true)).filter { s =>
802+
flatten(simplify(minus(project(selTyp), patternSpace), aggressive = true)).filter { s =>
785803
s != Empty && (!checkGADTSAT || satisfiable(s))
786804
}
787805

@@ -805,9 +823,9 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
805823

806824
val targetSpace =
807825
if (ctx.explicitNulls || selTyp.classSymbol.isPrimitiveValueClass)
808-
Typ(selTyp, true)
826+
project(selTyp)
809827
else
810-
Or(Typ(selTyp, true) :: constantNullSpace :: Nil)
828+
project(OrType(selTyp, constantNullType))
811829

812830
// in redundancy check, take guard as false in order to soundly approximate
813831
def projectPrevCases(cases: List[CaseDef]): Space =

tests/patmat/i8690.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class A
2+
class B
3+
4+
def test(x: (A, B) | (B, A)) = x match {
5+
case (u: A, v) => (u, v)
6+
case (u: B, v) => (v, u)
7+
}

0 commit comments

Comments
 (0)