diff --git a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala index 53fc58595472..37c984b86934 100644 --- a/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala +++ b/compiler/src/dotty/tools/dotc/core/GadtConstraint.scala @@ -26,6 +26,11 @@ sealed trait GadtConstraint ( import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr} + private[core] def getConstraint: Constraint = constraint + private[core] def getMapping: SimpleIdentityMap[Symbol, TypeVar] = mapping + private[core] def getReverseMapping: SimpleIdentityMap[TypeParamRef, Symbol] = reverseMapping + private[core] def getWasConstrained: Boolean = wasConstrained + /** Exposes ConstraintHandling.subsumes */ def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = { def extractConstraint(g: GadtConstraint) = g.constraint @@ -198,6 +203,25 @@ sealed trait GadtConstraint ( this.reverseMapping = other.reverseMapping this.wasConstrained = other.wasConstrained + def restore(constr: Constraint, mapping: SimpleIdentityMap[Symbol, TypeVar], revMapping: SimpleIdentityMap[TypeParamRef, Symbol], wasConstrained: Boolean): Unit = + this.myConstraint = constr + this.mapping = mapping + this.reverseMapping = revMapping + this.wasConstrained = wasConstrained + + def rollbackGadtUnless(op: => Boolean): Boolean = + val savedConstr = myConstraint + val savedMapping = mapping + val savedReverseMapping = reverseMapping + val savedWasConstrained = wasConstrained + var result = false + try + result = op + finally + if !result then + restore(savedConstr, savedMapping, savedReverseMapping, savedWasConstrained) + result + // ---- Protected/internal ----------------------------------------------- override protected def constraint = myConstraint diff --git a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala index ff9a5cd0aed7..e7f54d088c09 100644 --- a/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala +++ b/compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala @@ -265,26 +265,26 @@ trait PatternTypeConstrainer { self: TypeComparer => (tp, pt) match { case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) => val saved = state.nn.constraint - val savedGadt = ctx.gadt.fresh val result = - tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) => - val variance = param.paramVarianceSign - if variance == 0 || assumeInvariantRefinement || - // As a special case, when pattern and scrutinee types have the same type constructor, - // we infer better bounds for pattern-bound abstract types. - argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol - then - val TypeBounds(loS, hiS) = argS.bounds - val TypeBounds(loP, hiP) = argP.bounds - var res = true - if variance < 1 then res &&= isSubType(loS, hiP) - if variance > -1 then res &&= isSubType(loP, hiS) - res - else true + ctx.gadt.rollbackGadtUnless { + tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) => + val variance = param.paramVarianceSign + if variance == 0 || assumeInvariantRefinement || + // As a special case, when pattern and scrutinee types have the same type constructor, + // we infer better bounds for pattern-bound abstract types. + argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol + then + val TypeBounds(loS, hiS) = argS.bounds + val TypeBounds(loP, hiP) = argP.bounds + var res = true + if variance < 1 then res &&= isSubType(loS, hiP) + if variance > -1 then res &&= isSubType(loP, hiS) + res + else true + } } if !result then constraint = saved - ctx.gadt.restore(savedGadt) result case _ => // Give up if we don't get AppliedType, e.g. if we upcasted to Any. diff --git a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala index a4c476568818..a0eb5139eb07 100644 --- a/compiler/src/dotty/tools/dotc/core/TypeComparer.scala +++ b/compiler/src/dotty/tools/dotc/core/TypeComparer.scala @@ -1443,10 +1443,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling else if tp1 eq tp2 then true else val saved = constraint - val savedGadt = ctx.gadt.fresh + val savedGadtConstr = ctx.gadt.getConstraint + val savedMapping = ctx.gadt.getMapping + val savedReverseMapping = ctx.gadt.getReverseMapping + val savedWasConstrained = ctx.gadt.getWasConstrained inline def restore() = state.constraint = saved - ctx.gadt.restore(savedGadt) + ctx.gadt.restore(savedGadtConstr, savedMapping, savedReverseMapping, savedWasConstrained) val savedSuccessCount = successCount try recCount += 1 @@ -2050,10 +2053,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}") if (bound.isRef(tparam)) false else - val savedGadt = ctx.gadt.fresh - val success = gadtAddBound(tparam, bound, isUpper) - if !success then ctx.gadt.restore(savedGadt) - success + ctx.gadt.rollbackGadtUnless(gadtAddBound(tparam, bound, isUpper)) } }