Skip to content

Commit 60bae2f

Browse files
authored
Inline GADT state restoring in TypeComparer (#16564)
2 parents 00a6c4a + a560dc9 commit 60bae2f

File tree

3 files changed

+46
-22
lines changed

3 files changed

+46
-22
lines changed

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ sealed trait GadtConstraint (
2626

2727
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
2828

29+
private[core] def getConstraint: Constraint = constraint
30+
private[core] def getMapping: SimpleIdentityMap[Symbol, TypeVar] = mapping
31+
private[core] def getReverseMapping: SimpleIdentityMap[TypeParamRef, Symbol] = reverseMapping
32+
private[core] def getWasConstrained: Boolean = wasConstrained
33+
2934
/** Exposes ConstraintHandling.subsumes */
3035
def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = {
3136
def extractConstraint(g: GadtConstraint) = g.constraint
@@ -198,6 +203,25 @@ sealed trait GadtConstraint (
198203
this.reverseMapping = other.reverseMapping
199204
this.wasConstrained = other.wasConstrained
200205

206+
def restore(constr: Constraint, mapping: SimpleIdentityMap[Symbol, TypeVar], revMapping: SimpleIdentityMap[TypeParamRef, Symbol], wasConstrained: Boolean): Unit =
207+
this.myConstraint = constr
208+
this.mapping = mapping
209+
this.reverseMapping = revMapping
210+
this.wasConstrained = wasConstrained
211+
212+
def rollbackGadtUnless(op: => Boolean): Boolean =
213+
val savedConstr = myConstraint
214+
val savedMapping = mapping
215+
val savedReverseMapping = reverseMapping
216+
val savedWasConstrained = wasConstrained
217+
var result = false
218+
try
219+
result = op
220+
finally
221+
if !result then
222+
restore(savedConstr, savedMapping, savedReverseMapping, savedWasConstrained)
223+
result
224+
201225
// ---- Protected/internal -----------------------------------------------
202226

203227
override protected def constraint = myConstraint

compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -265,26 +265,26 @@ trait PatternTypeConstrainer { self: TypeComparer =>
265265
(tp, pt) match {
266266
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
267267
val saved = state.nn.constraint
268-
val savedGadt = ctx.gadt.fresh
269268
val result =
270-
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
271-
val variance = param.paramVarianceSign
272-
if variance == 0 || assumeInvariantRefinement ||
273-
// As a special case, when pattern and scrutinee types have the same type constructor,
274-
// we infer better bounds for pattern-bound abstract types.
275-
argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
276-
then
277-
val TypeBounds(loS, hiS) = argS.bounds
278-
val TypeBounds(loP, hiP) = argP.bounds
279-
var res = true
280-
if variance < 1 then res &&= isSubType(loS, hiP)
281-
if variance > -1 then res &&= isSubType(loP, hiS)
282-
res
283-
else true
269+
ctx.gadt.rollbackGadtUnless {
270+
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
271+
val variance = param.paramVarianceSign
272+
if variance == 0 || assumeInvariantRefinement ||
273+
// As a special case, when pattern and scrutinee types have the same type constructor,
274+
// we infer better bounds for pattern-bound abstract types.
275+
argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
276+
then
277+
val TypeBounds(loS, hiS) = argS.bounds
278+
val TypeBounds(loP, hiP) = argP.bounds
279+
var res = true
280+
if variance < 1 then res &&= isSubType(loS, hiP)
281+
if variance > -1 then res &&= isSubType(loP, hiS)
282+
res
283+
else true
284+
}
284285
}
285286
if !result then
286287
constraint = saved
287-
ctx.gadt.restore(savedGadt)
288288
result
289289
case _ =>
290290
// Give up if we don't get AppliedType, e.g. if we upcasted to Any.

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,10 +1443,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
14431443
else if tp1 eq tp2 then true
14441444
else
14451445
val saved = constraint
1446-
val savedGadt = ctx.gadt.fresh
1446+
val savedGadtConstr = ctx.gadt.getConstraint
1447+
val savedMapping = ctx.gadt.getMapping
1448+
val savedReverseMapping = ctx.gadt.getReverseMapping
1449+
val savedWasConstrained = ctx.gadt.getWasConstrained
14471450
inline def restore() =
14481451
state.constraint = saved
1449-
ctx.gadt.restore(savedGadt)
1452+
ctx.gadt.restore(savedGadtConstr, savedMapping, savedReverseMapping, savedWasConstrained)
14501453
val savedSuccessCount = successCount
14511454
try
14521455
recCount += 1
@@ -2050,10 +2053,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
20502053
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
20512054
if (bound.isRef(tparam)) false
20522055
else
2053-
val savedGadt = ctx.gadt.fresh
2054-
val success = gadtAddBound(tparam, bound, isUpper)
2055-
if !success then ctx.gadt.restore(savedGadt)
2056-
success
2056+
ctx.gadt.rollbackGadtUnless(gadtAddBound(tparam, bound, isUpper))
20572057
}
20582058
}
20592059

0 commit comments

Comments
 (0)