Skip to content

Inline GADT state restoring in TypeComparer #16564

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ sealed trait GadtConstraint (

import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}

private[core] def getConstraint: Constraint = constraint
private[core] def getMapping: SimpleIdentityMap[Symbol, TypeVar] = mapping
private[core] def getReverseMapping: SimpleIdentityMap[TypeParamRef, Symbol] = reverseMapping
private[core] def getWasConstrained: Boolean = wasConstrained

/** Exposes ConstraintHandling.subsumes */
def subsumes(left: GadtConstraint, right: GadtConstraint, pre: GadtConstraint)(using Context): Boolean = {
def extractConstraint(g: GadtConstraint) = g.constraint
Expand Down Expand Up @@ -198,6 +203,25 @@ sealed trait GadtConstraint (
this.reverseMapping = other.reverseMapping
this.wasConstrained = other.wasConstrained

def restore(constr: Constraint, mapping: SimpleIdentityMap[Symbol, TypeVar], revMapping: SimpleIdentityMap[TypeParamRef, Symbol], wasConstrained: Boolean): Unit =
this.myConstraint = constr
this.mapping = mapping
this.reverseMapping = revMapping
this.wasConstrained = wasConstrained

def rollbackGadtUnless(op: => Boolean): Boolean =
val savedConstr = myConstraint
val savedMapping = mapping
val savedReverseMapping = reverseMapping
val savedWasConstrained = wasConstrained
var result = false
try
result = op
finally
if !result then
restore(savedConstr, savedMapping, savedReverseMapping, savedWasConstrained)
result

// ---- Protected/internal -----------------------------------------------

override protected def constraint = myConstraint
Expand Down
32 changes: 16 additions & 16 deletions compiler/src/dotty/tools/dotc/core/PatternTypeConstrainer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -265,26 +265,26 @@ trait PatternTypeConstrainer { self: TypeComparer =>
(tp, pt) match {
case (AppliedType(tyconS, argsS), AppliedType(tyconP, argsP)) =>
val saved = state.nn.constraint
val savedGadt = ctx.gadt.fresh
val result =
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
val variance = param.paramVarianceSign
if variance == 0 || assumeInvariantRefinement ||
// As a special case, when pattern and scrutinee types have the same type constructor,
// we infer better bounds for pattern-bound abstract types.
argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
then
val TypeBounds(loS, hiS) = argS.bounds
val TypeBounds(loP, hiP) = argP.bounds
var res = true
if variance < 1 then res &&= isSubType(loS, hiP)
if variance > -1 then res &&= isSubType(loP, hiS)
res
else true
ctx.gadt.rollbackGadtUnless {
tyconS.typeParams.lazyZip(argsS).lazyZip(argsP).forall { (param, argS, argP) =>
val variance = param.paramVarianceSign
if variance == 0 || assumeInvariantRefinement ||
// As a special case, when pattern and scrutinee types have the same type constructor,
// we infer better bounds for pattern-bound abstract types.
argP.typeSymbol.isPatternBound && patternTp.classSymbol == scrutineeTp.classSymbol
then
val TypeBounds(loS, hiS) = argS.bounds
val TypeBounds(loP, hiP) = argP.bounds
var res = true
if variance < 1 then res &&= isSubType(loS, hiP)
if variance > -1 then res &&= isSubType(loP, hiS)
res
else true
}
}
if !result then
constraint = saved
ctx.gadt.restore(savedGadt)
result
case _ =>
// Give up if we don't get AppliedType, e.g. if we upcasted to Any.
Expand Down
12 changes: 6 additions & 6 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1443,10 +1443,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
else if tp1 eq tp2 then true
else
val saved = constraint
val savedGadt = ctx.gadt.fresh
val savedGadtConstr = ctx.gadt.getConstraint
val savedMapping = ctx.gadt.getMapping
val savedReverseMapping = ctx.gadt.getReverseMapping
val savedWasConstrained = ctx.gadt.getWasConstrained
inline def restore() =
state.constraint = saved
ctx.gadt.restore(savedGadt)
ctx.gadt.restore(savedGadtConstr, savedMapping, savedReverseMapping, savedWasConstrained)
val savedSuccessCount = successCount
try
recCount += 1
Expand Down Expand Up @@ -2050,10 +2053,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
gadts.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.toString} ${bound.isRef(tparam)}")
if (bound.isRef(tparam)) false
else
val savedGadt = ctx.gadt.fresh
val success = gadtAddBound(tparam, bound, isUpper)
if !success then ctx.gadt.restore(savedGadt)
success
ctx.gadt.rollbackGadtUnless(gadtAddBound(tparam, bound, isUpper))
}
}

Expand Down