Skip to content

Commit ff2922c

Browse files
committed
Refactor path-dependent structural GADT reasoning
1 parent cd368fa commit ff2922c

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ sealed abstract class GadtConstraint extends Showable {
6262
*/
6363
def addAllPDTsFrom(path: Type)(using Context): List[TypeRef]
6464

65+
def replacePath(from: Type, to: Type)(using Context): Unit
66+
6567
/** Supplies the singleton type of the scrutinee when typechecking pattern-matches.
6668
*/
6769
def withScrutinee[T](path: TermRef)(body: T): T
@@ -218,6 +220,17 @@ final class ProperGadtConstraint private(
218220
m.values.toList map { tv => externalize(tv.origin).asInstanceOf[TypeRef] }
219221
}
220222

223+
override def replacePath(from: Type, to: Type)(using Context): Unit =
224+
val originalPairs = mapping.toList
225+
226+
originalPairs foreach { (tpr, tvar) =>
227+
if tpr.prefix eq from then
228+
val extType = TypeRef(to, tpr.symbol)
229+
mapping = mapping.updated(extType, tvar)
230+
mapping = mapping.remove(tpr)
231+
reverseMapping = reverseMapping.updated(tvar.origin, extType)
232+
}
233+
221234
override def withScrutinee[T](path: TermRef)(body: T): T =
222235
val saved = this.scrutinee
223236
this.scrutinee = path
@@ -554,6 +567,7 @@ final class ProperGadtConstraint private(
554567
override def isConstrainablePDT(tp: Type)(using Context): Boolean = false
555568
override def addPDT(tp: Type)(using Context): Boolean = false
556569
override def addAllPDTsFrom(path: Type)(using Context): List[TypeRef] = null
570+
override def replacePath(from: Type, to: Type)(using Context): Unit = ()
557571
override def withScrutinee[T](path: TermRef)(body: T): T = body
558572

559573
override def addToConstraint(params: List[Symbol])(using Context): Boolean = unsupported("EmptyGadtConstraint.addToConstraint")

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ trait PatternTypeConstrainer { self: TypeComparer =>
284284
}
285285

286286
/** Derive GADT bounds on type members of the scrutinee and the pattern. */
287-
def constrainTypeMembers(scrut: Type, pat: Type) = trace.force(i"constraining type members $scrut >:< $pat", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
287+
def constrainTypeMembers(scrut: Type, pat: Type, realScrutPath: TermRef, realPatPath: TermRef) = trace(i"constraining type members $scrut >:< $pat", gadts, res => s"$res\n${ctx.gadt.debugBoundsDescription}") {
288288
val saved = state.constraint
289289
val savedGadt = ctx.gadt.fresh
290290

@@ -294,10 +294,6 @@ trait PatternTypeConstrainer { self: TypeComparer =>
294294
val scrutPDTs = ctx.gadt.addAllPDTsFrom(scrutPath)
295295
val patPDTs = ctx.gadt.addAllPDTsFrom(patPath)
296296

297-
println(i"scrut pdts: $scrutPDTs")
298-
println(i"pat pdts: $patPDTs")
299-
println(i"after adding PDTs: gadt = ${ctx.gadt.debugBoundsDescription}")
300-
301297
val scrutSyms = Map.from {
302298
scrutPDTs map { pdt => pdt.symbol.name -> pdt }
303299
}
@@ -316,6 +312,9 @@ trait PatternTypeConstrainer { self: TypeComparer =>
316312
if !result then
317313
constraint = saved
318314
ctx.gadt.restore(savedGadt)
315+
else
316+
if realScrutPath ne null then ctx.gadt.replacePath(scrutPath, realScrutPath)
317+
if realPatPath ne null then ctx.gadt.replacePath(patPath, realPatPath)
319318

320319
result
321320
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
241241
* code would have two extra parameters for each of the many calls that go from
242242
* one sub-part of isSubType to another.
243243
*/
244-
protected def recur(tp1: Type, tp2: Type): Boolean = trace.force(s"isSubType ${traceInfo(tp1, tp2)}${approx.show}", subtyping) {
244+
protected def recur(tp1: Type, tp2: Type): Boolean = trace(s"isSubType ${traceInfo(tp1, tp2)}${approx.show}", subtyping) {
245245

246246
def monitoredIsSubType = {
247247
if (pendingSubTypes == null) {
@@ -2831,8 +2831,8 @@ object TypeComparer {
28312831
def constrainPatternType(pat: Type, scrut: Type, forceInvariantRefinement: Boolean = false)(using Context): Boolean =
28322832
comparing(_.constrainPatternType(pat, scrut, forceInvariantRefinement))
28332833

2834-
def constrainTypeMembers(scrut: Type, pat: Type)(using Context): Boolean =
2835-
comparing(_.constrainTypeMembers(scrut, pat))
2834+
def constrainTypeMembers(scrut: Type, pat: Type, scrutPath: TermRef, patPath: TermRef)(using Context): Boolean =
2835+
comparing(_.constrainTypeMembers(scrut, pat, scrutPath, patPath))
28362836

28372837
def explained[T](op: ExplainingTypeComparer => T, header: String = "Subtype trace:")(using Context): String =
28382838
comparing(_.explained(op, header))

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,17 +1648,20 @@ class Typer extends Namer
16481648

16491649
val pat1 = typedPattern(tree.pat, wideSelType)(using gadtCtx)
16501650

1651-
println(i"*** typed a match case ***")
16521651
val scrutType = sel.tpe.widen
16531652
val patType = pat1.tpe.widen match
16541653
case AndType(tp1, tp2) => tp2
16551654
case tp => tp
1656-
println(i"scrutinee: ${scrutType} ${scrutType.toString}")
1657-
println(i"pat: ${patType} ${patType.toString}")
1658-
1659-
withMode(Mode.GadtConstraintInference) {
1660-
TypeComparer.constrainTypeMembers(scrutType, patType)(using gadtCtx)
1661-
}
1655+
val scrutPath = sel.tpe match
1656+
case tp: TermRef => tp
1657+
case _ => null
1658+
val patPath = pat1.tpe match
1659+
case tp: TermRef => tp
1660+
case _ => null
1661+
1662+
withMode(Mode.GadtConstraintInference)
1663+
(TypeComparer.constrainTypeMembers(scrutType, patType, scrutPath, patPath))
1664+
(using gadtCtx)
16621665

16631666
caseRest(pat1)(
16641667
using Nullables.caseContext(sel, pat1)(

0 commit comments

Comments
 (0)