Skip to content

Simplify ConstraintHandling #9400

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 29 additions & 33 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,12 @@ import reporting.trace
* By comparison: Constraint handlers are parts of type comparers and can use their functionality.
* Constraint handlers update the current constraint as a side effect.
*/
trait ConstraintHandling[AbstractContext] {
trait ConstraintHandling {

def constr: config.Printers.Printer = config.Printers.constr

def comparerCtx(using AbstractContext): Context

given (using AbstractContext) as Context = comparerCtx

protected def isSubType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean
protected def isSameType(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean
protected def isSub(tp1: Type, tp2: Type)(using Context): Boolean
protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean

protected def constraint: Constraint
protected def constraint_=(c: Constraint): Unit
Expand Down Expand Up @@ -71,23 +67,23 @@ trait ConstraintHandling[AbstractContext] {
case tp => tp
}

def nonParamBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = constraint.nonParamBounds(param)
def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = constraint.nonParamBounds(param)

def fullLowerBound(param: TypeParamRef)(implicit actx: AbstractContext): Type =
def fullLowerBound(param: TypeParamRef)(using Context): Type =
constraint.minLower(param).foldLeft(nonParamBounds(param).lo)(_ | _)

def fullUpperBound(param: TypeParamRef)(implicit actx: AbstractContext): Type =
def fullUpperBound(param: TypeParamRef)(using Context): Type =
constraint.minUpper(param).foldLeft(nonParamBounds(param).hi)(_ & _)

/** Full bounds of `param`, including other lower/upper params.
*
* Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds`
* of some param when comparing types might lead to infinite recursion. Consider `bounds` instead.
*/
def fullBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds =
def fullBounds(param: TypeParamRef)(using Context): TypeBounds =
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))

protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using AbstractContext): Boolean =
protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Boolean =
if !constraint.contains(param) then true
else if !isUpper && param.occursIn(bound)
// We don't allow recursive lower bounds when defining a type,
Expand Down Expand Up @@ -121,11 +117,11 @@ trait ConstraintHandling[AbstractContext] {
|| {
constraint = c1
val TypeBounds(lo, hi) = constraint.entry(param)
isSubType(lo, hi)
isSub(lo, hi)
}
end addOneBound

protected def addBoundTransitively(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean =
protected def addBoundTransitively(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =

/** Adjust the bound `tp` in the following ways:
*
Expand Down Expand Up @@ -172,7 +168,7 @@ trait ConstraintHandling[AbstractContext] {
.reporting(i"added $description = $result$location", constr)
end addBoundTransitively

protected def addLess(p1: TypeParamRef, p2: TypeParamRef)(implicit actx: AbstractContext): Boolean = {
protected def addLess(p1: TypeParamRef, p2: TypeParamRef)(using Context): Boolean = {
def description = i"ordering $p1 <: $p2 to\n$constraint"
val res =
if (constraint.isLess(p2, p1)) unify(p2, p1)
Expand All @@ -195,7 +191,7 @@ trait ConstraintHandling[AbstractContext] {
/** Make p2 = p1, transfer all bounds of p2 to p1
* @pre less(p1)(p2)
*/
private def unify(p1: TypeParamRef, p2: TypeParamRef)(implicit actx: AbstractContext): Boolean = {
private def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): Boolean = {
constr.println(s"unifying $p1 $p2")
assert(constraint.isLess(p1, p2))
val down = constraint.exclusiveLower(p2, p1)
Expand All @@ -204,16 +200,16 @@ trait ConstraintHandling[AbstractContext] {
val bounds = constraint.nonParamBounds(p1)
val lo = bounds.lo
val hi = bounds.hi
isSubType(lo, hi) &&
isSub(lo, hi) &&
down.forall(addOneBound(_, hi, isUpper = true)) &&
up.forall(addOneBound(_, lo, isUpper = false))
}

protected def isSubType(tp1: Type, tp2: Type, whenFrozen: Boolean)(implicit actx: AbstractContext): Boolean =
protected def isSubType(tp1: Type, tp2: Type, whenFrozen: Boolean)(using Context): Boolean =
if (whenFrozen)
isSubTypeWhenFrozen(tp1, tp2)
else
isSubType(tp1, tp2)
isSub(tp1, tp2)

inline final def inFrozenConstraint[T](op: => T): T = {
val savedFrozen = frozenConstraint
Expand All @@ -227,16 +223,16 @@ trait ConstraintHandling[AbstractContext] {
}
}

final def isSubTypeWhenFrozen(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean = inFrozenConstraint(isSubType(tp1, tp2))
final def isSameTypeWhenFrozen(tp1: Type, tp2: Type)(implicit actx: AbstractContext): Boolean = inFrozenConstraint(isSameType(tp1, tp2))
final def isSubTypeWhenFrozen(tp1: Type, tp2: Type)(using Context): Boolean = inFrozenConstraint(isSub(tp1, tp2))
final def isSameTypeWhenFrozen(tp1: Type, tp2: Type)(using Context): Boolean = inFrozenConstraint(isSame(tp1, tp2))

/** Test whether the lower bounds of all parameters in this
* constraint are a solution to the constraint.
*/
protected final def isSatisfiable(implicit actx: AbstractContext): Boolean =
protected final def isSatisfiable(using Context): Boolean =
constraint.forallParams { param =>
val TypeBounds(lo, hi) = constraint.entry(param)
isSubType(lo, hi) || {
isSub(lo, hi) || {
report.log(i"sub fail $lo <:< $hi")
false
}
Expand All @@ -253,7 +249,7 @@ trait ConstraintHandling[AbstractContext] {
* @return the instantiating type
* @pre `param` is in the constraint's domain.
*/
final def approximation(param: TypeParamRef, fromBelow: Boolean)(implicit actx: AbstractContext): Type = {
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
val replaceWildcards = new TypeMap {
override def stopAtStatic = true
def apply(tp: Type) = mapOver {
Expand Down Expand Up @@ -317,7 +313,7 @@ trait ConstraintHandling[AbstractContext] {
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
* as those could leak the annotation to users (see run/inferred-repeated-result).
*/
def widenInferred(inst: Type, bound: Type)(implicit actx: AbstractContext): Type =
def widenInferred(inst: Type, bound: Type)(using Context): Type =

def dropSuperTraits(tp: Type): Type =
var kept: Set[Type] = Set() // types to keep since otherwise bound would not fit
Expand Down Expand Up @@ -380,7 +376,7 @@ trait ConstraintHandling[AbstractContext] {
* a lower bound instantiation can be a singleton type only if the upper bound
* is also a singleton type.
*/
def instanceType(param: TypeParamRef, fromBelow: Boolean)(implicit actx: AbstractContext): Type = {
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
val approx = approximation(param, fromBelow).simplified
if (fromBelow)
val widened = widenInferred(approx, param)
Expand Down Expand Up @@ -408,7 +404,7 @@ trait ConstraintHandling[AbstractContext] {
* Both `c1` and `c2` are required to derive from constraint `pre`, without adding
* any new type variables but possibly narrowing already registered ones with further bounds.
*/
protected final def subsumes(c1: Constraint, c2: Constraint, pre: Constraint)(implicit actx: AbstractContext): Boolean =
protected final def subsumes(c1: Constraint, c2: Constraint, pre: Constraint)(using Context): Boolean =
if (c2 eq pre) true
else if (c1 eq pre) false
else {
Expand All @@ -427,7 +423,7 @@ trait ConstraintHandling[AbstractContext] {
}

/** The current bounds of type parameter `param` */
def bounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds = {
def bounds(param: TypeParamRef)(using Context): TypeBounds = {
val e = constraint.entry(param)
if (e.exists) e.bounds
else {
Expand All @@ -441,7 +437,7 @@ trait ConstraintHandling[AbstractContext] {
* and propagate all bounds.
* @param tvars See Constraint#add
*/
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(implicit actx: AbstractContext): Boolean =
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean =
checkPropagated(i"initialized $tl") {
constraint = constraint.add(tl, tvars)
tl.paramRefs.forall { param =>
Expand Down Expand Up @@ -470,7 +466,7 @@ trait ConstraintHandling[AbstractContext] {
* This holds if `TypeVarsMissContext` is set unless `param` is a part
* of a MatchType that is currently normalized.
*/
final def assumedTrue(param: TypeParamRef)(implicit actx: AbstractContext): Boolean =
final def assumedTrue(param: TypeParamRef)(using Context): Boolean =
ctx.mode.is(Mode.TypevarsMissContext) && (caseLambda `ne` param.binder)

/** Add constraint `param <: bound` if `fromBelow` is false, `param >: bound` otherwise.
Expand All @@ -480,7 +476,7 @@ trait ConstraintHandling[AbstractContext] {
* not be AndTypes and lower bounds may not be OrTypes. This is assured by the
* way isSubType is organized.
*/
protected def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(implicit actx: AbstractContext): Boolean =
protected def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(using Context): Boolean =

/** When comparing lambdas we might get constraints such as
* `A <: X0` or `A = List[X0]` where `A` is a constrained parameter
Expand Down Expand Up @@ -514,7 +510,7 @@ trait ConstraintHandling[AbstractContext] {
case _: TypeBounds =>
if (fromBelow) addLess(bound, param) else addLess(param, bound)
case tp =>
if (fromBelow) isSubType(bound, tp) else isSubType(tp, bound)
if (fromBelow) isSub(bound, tp) else isSub(tp, bound)
}

def kindCompatible(tp1: Type, tp2: Type): Boolean =
Expand All @@ -541,7 +537,7 @@ trait ConstraintHandling[AbstractContext] {
end addConstraint

/** Check that constraint is fully propagated. See comment in Config.checkConstraintsPropagated */
def checkPropagated(msg: => String)(result: Boolean)(implicit actx: AbstractContext): Boolean = {
def checkPropagated(msg: => String)(result: Boolean)(using Context): Boolean = {
if (Config.checkConstraintsPropagated && result && addConstraintInvocations == 0)
inFrozenConstraint {
for (p <- constraint.domainParams) {
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ object Contexts {
store = initialStore
.updated(settingsStateLoc, settingsGroup.defaultState)
.updated(notNullInfosLoc, Nil)
typeComparer = new TypeComparer(this)
typeComparer = new TypeComparer(using this)
searchHistory = new SearchRoot
gadt = EmptyGadtConstraint
}
Expand Down
10 changes: 4 additions & 6 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ final class ProperGadtConstraint private(
private var myConstraint: Constraint,
private var mapping: SimpleIdentityMap[Symbol, TypeVar],
private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
) extends GadtConstraint with ConstraintHandling[Context] {
) extends GadtConstraint with ConstraintHandling {
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}

def this() = this(
Expand Down Expand Up @@ -140,7 +140,7 @@ final class ProperGadtConstraint private(
case tv: TypeVar => tv
case inst =>
gadts.println(i"instantiated: $sym -> $inst")
return if (isUpper) isSubType(inst , bound) else isSubType(bound, inst)
return if (isUpper) isSub(inst, bound) else isSub(bound, inst)
}

val internalizedBound = bound match {
Expand Down Expand Up @@ -217,13 +217,11 @@ final class ProperGadtConstraint private(

// ---- Protected/internal -----------------------------------------------

override def comparerCtx(using Context): Context = ctx

override protected def constraint = myConstraint
override protected def constraint_=(c: Constraint) = myConstraint = c

override def isSubType(tp1: Type, tp2: Type)(using Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
override def isSameType(tp1: Type, tp2: Type)(using Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2)
override protected def isSub(tp1: Type, tp2: Type)(using Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
override protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2)

override def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds =
constraint.nonParamBounds(param) match {
Expand Down
32 changes: 15 additions & 17 deletions compiler/src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,10 @@ import typer.Applications.productSelectorTypes
import reporting.trace
import NullOpsDecorator.NullOps

final class AbsentContext
object AbsentContext {
implicit val absentContext: AbsentContext = new AbsentContext
}

/** Provides methods to compare types.
*/
class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] with PatternTypeConstrainer {
class TypeComparer(using val comparerCtx: Context) extends ConstraintHandling with PatternTypeConstrainer {
import TypeComparer._
def comparerCtx(using AbsentContext): Context = initctx

val state = ctx.typerState
def constraint: Constraint = state.constraint
Expand Down Expand Up @@ -175,7 +169,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
}
}

def isSubType(tp1: Type, tp2: Type)(implicit nc: AbsentContext): Boolean = isSubType(tp1, tp2, FreshApprox)
def isSubType(tp1: Type, tp2: Type): Boolean = isSubType(tp1, tp2, FreshApprox)

override protected def isSub(tp1: Type, tp2: Type)(using Context): Boolean = isSubType(tp1, tp2)

/** The inner loop of the isSubType comparison.
* Recursive calls from recur should go to recur directly if the two types
Expand Down Expand Up @@ -1769,11 +1765,13 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
// Type equality =:=

/** Two types are the same if are mutual subtypes of each other */
def isSameType(tp1: Type, tp2: Type)(implicit nc: AbsentContext): Boolean =
def isSameType(tp1: Type, tp2: Type): Boolean =
if (tp1 eq NoType) false
else if (tp1 eq tp2) true
else isSubType(tp1, tp2) && isSubType(tp2, tp1)

override protected def isSame(tp1: Type, tp2: Type)(using Context): Boolean = isSameType(tp1, tp2)

/** Same as `isSameType` but also can be applied to overloaded TermRefs, where
* two overloaded refs are the same if they have pairwise equal alternatives
*/
Expand Down Expand Up @@ -2215,7 +2213,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] w
}

/** A new type comparer of the same type as this one, using the given context. */
def copyIn(ctx: Context): TypeComparer = new TypeComparer(ctx)
def copyIn(ctx: Context): TypeComparer = new TypeComparer(using ctx)

// ----------- Diagnostics --------------------------------------------------

Expand Down Expand Up @@ -2469,7 +2467,7 @@ object TypeComparer {

/** Show trace of comparison operations when performing `op` */
def explaining[T](say: String => Unit)(op: Context ?=> T)(using Context): T = {
val nestedCtx = ctx.fresh.setTypeComparerFn(new ExplainingTypeComparer(_))
val nestedCtx = ctx.fresh.setTypeComparerFn(new ExplainingTypeComparer(using _))
val res = try { op(using nestedCtx) } finally { say(nestedCtx.typeComparer.lastTrace()) }
res
}
Expand All @@ -2482,17 +2480,17 @@ object TypeComparer {
}
}

class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
class TrackingTypeComparer(using Context) extends TypeComparer {
import state.constraint

val footprint: mutable.Set[Type] = mutable.Set[Type]()

override def bounds(param: TypeParamRef)(implicit nc: AbsentContext): TypeBounds = {
override def bounds(param: TypeParamRef)(using Context): TypeBounds = {
if (param.binder `ne` caseLambda) footprint += param
super.bounds(param)
}

override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit nc: AbsentContext): Boolean = {
override def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Boolean = {
if (param.binder `ne` caseLambda) footprint += param
super.addOneBound(param, bound, isUpper)
}
Expand Down Expand Up @@ -2630,7 +2628,7 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
}

/** A type comparer that can record traces of subtype operations */
class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
class ExplainingTypeComparer(using Context) extends TypeComparer {
import TypeComparer._

private var indent = 0
Expand Down Expand Up @@ -2678,12 +2676,12 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
super.glb(tp1, tp2)
}

override def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(implicit nc: AbsentContext): Boolean =
override def addConstraint(param: TypeParamRef, bound: Type, fromBelow: Boolean)(using Context): Boolean =
traceIndented(i"add constraint $param ${if (fromBelow) ">:" else "<:"} $bound $frozenConstraint, constraint = ${ctx.typerState.constraint}") {
super.addConstraint(param, bound, fromBelow)
}

override def copyIn(ctx: Context): ExplainingTypeComparer = new ExplainingTypeComparer(ctx)
override def copyIn(using Context): ExplainingTypeComparer = new ExplainingTypeComparer

override def lastTrace(): String = "Subtype trace:" + { try b.toString finally b.clear() }
}
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4283,7 +4283,7 @@ object Types {
override def tryNormalize(using Context): Type = reduced.normalized

def reduced(using Context): Type = {
val trackingCtx = ctx.fresh.setTypeComparerFn(new TrackingTypeComparer(_))
val trackingCtx = ctx.fresh.setTypeComparerFn(new TrackingTypeComparer(using _))
val typeComparer = trackingCtx.typeComparer.asInstanceOf[TrackingTypeComparer]

def contextInfo(tp: Type): Type = tp match {
Expand Down