Skip to content

Avoid incorrect simplifications when updating bounds in the constraint #16410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Constraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ abstract class Constraint extends Showable {
* - Another type, indicating a solution for the parameter
*
* @pre `this contains param`.
* @pre `tp` does not contain top-level references to `param`
* (see `validBoundsFor`)
*/
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This

Expand Down Expand Up @@ -172,6 +174,23 @@ abstract class Constraint extends Showable {
*/
def occursAtToplevel(param: TypeParamRef, tp: Type)(using Context): Boolean

/** Sanitize `bound` to make it either a valid upper or lower bound for
* `param` depending on `isUpper`.
*
* Toplevel references to `param`, are replaced by `Any` if `isUpper` is true
* and `Nothing` otherwise.
*
* @see `occursAtTopLevel` for a definition of "toplevel"
* @see `validBoundsFor` to sanitize both the lower and upper bound at once.
*/
def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type

/** Sanitize `bounds` to make them valid constraints for `param`.
*
* @see `validBoundFor` for details.
*/
def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type

/** A string that shows the reverse dependencies maintained by this constraint
* (coDeps and contraDeps for OrderingConstraints).
*/
Expand Down
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ trait ConstraintHandling {
end LevelAvoidMap

/** Approximate `rawBound` if needed to make it a legal bound of `param` by
* avoiding wildcards and types with a level strictly greater than its
* avoiding cycles, wildcards and types with a level strictly greater than its
* `nestingLevel`.
*
* Note that level-checking must be performed here and cannot be delayed
Expand All @@ -283,7 +283,7 @@ trait ConstraintHandling {
// This is necessary for i8900-unflip.scala to typecheck.
val v = if necessaryConstraintsOnly then -this.variance else this.variance
atVariance(v)(super.legalVar(tp))
approx(rawBound)
constraint.validBoundFor(param, approx(rawBound), isUpper)
end legalBound

protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(using Context): Boolean =
Expand Down Expand Up @@ -413,8 +413,10 @@ trait ConstraintHandling {

constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1)

val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept)
var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept)
val boundKept = constraint.validBoundsFor(pKept,
constraint.nonParamBounds( pKept).substParam(pRemoved, pKept).bounds)
var boundRemoved = constraint.validBoundsFor(pKept,
constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept).bounds)

if level1 != level2 then
boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved)
Expand Down
53 changes: 22 additions & 31 deletions compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -525,20 +525,11 @@ class OrderingConstraint(private val boundsMap: ParamBounds,

// ---------- Updates ------------------------------------------------------------

/** If `inst` is a TypeBounds, make sure it does not contain toplevel
* references to `param` (see `Constraint#occursAtToplevel` for a definition
* of "toplevel").
* Any such references are replaced by `Nothing` in the lower bound and `Any`
* in the upper bound.
* References can be direct or indirect through instantiations of other
* parameters in the constraint.
*/
private def ensureNonCyclic(param: TypeParamRef, inst: Type)(using Context): Type =

def recur(tp: Type, fromBelow: Boolean): Type = tp match
def validBoundFor(param: TypeParamRef, bound: Type, isUpper: Boolean)(using Context): Type =
def recur(tp: Type): Type = tp match
case tp: AndOrType =>
val r1 = recur(tp.tp1, fromBelow)
val r2 = recur(tp.tp2, fromBelow)
val r1 = recur(tp.tp1)
val r2 = recur(tp.tp2)
if (r1 eq tp.tp1) && (r2 eq tp.tp2) then tp
else tp.match
case tp: OrType =>
Expand All @@ -547,35 +538,34 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
r1 & r2
case tp: TypeParamRef =>
if tp eq param then
if fromBelow then defn.NothingType else defn.AnyType
if isUpper then defn.AnyType else defn.NothingType
else entry(tp) match
case NoType => tp
case TypeBounds(lo, hi) => if lo eq hi then recur(lo, fromBelow) else tp
case inst => recur(inst, fromBelow)
case TypeBounds(lo, hi) => if lo eq hi then recur(lo) else tp
case inst => recur(inst)
case tp: TypeVar =>
val underlying1 = recur(tp.underlying, fromBelow)
val underlying1 = recur(tp.underlying)
if underlying1 ne tp.underlying then underlying1 else tp
case CapturingType(parent, refs) =>
val parent1 = recur(parent, fromBelow)
val parent1 = recur(parent)
if parent1 ne parent then tp.derivedCapturingType(parent1, refs) else tp
case tp: AnnotatedType =>
val parent1 = recur(tp.parent, fromBelow)
val parent1 = recur(tp.parent)
if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp
case _ =>
val tp1 = tp.dealiasKeepAnnots
if tp1 ne tp then
val tp2 = recur(tp1, fromBelow)
val tp2 = recur(tp1)
if tp2 ne tp1 then tp2 else tp
else tp

inst match
case bounds: TypeBounds =>
bounds.derivedTypeBounds(
recur(bounds.lo, fromBelow = true),
recur(bounds.hi, fromBelow = false))
case _ =>
inst
end ensureNonCyclic
recur(bound)
end validBoundFor

def validBoundsFor(param: TypeParamRef, bounds: TypeBounds)(using Context): Type =
bounds.derivedTypeBounds(
validBoundFor(param, bounds.lo, isUpper = false),
validBoundFor(param, bounds.hi, isUpper = true))

/** Add the fact `param1 <: param2` to the constraint `current` and propagate
* `<:<` relationships between parameters ("edges") but not bounds.
Expand Down Expand Up @@ -658,9 +648,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
current1
}

/** The public version of `updateEntry`. Guarantees that there are no cycles */
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This =
updateEntry(this, param, ensureNonCyclic(param, tp)).checkWellFormed()
updateEntry(this, param, tp).checkWellFormed()

def addLess(param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection)(using Context): This =
order(this, param1, param2, direction).checkWellFormed()
Expand Down Expand Up @@ -703,7 +692,9 @@ class OrderingConstraint(private val boundsMap: ParamBounds,

def replaceParamIn(other: TypeParamRef) =
val oldEntry = current.entry(other)
val newEntry = current.ensureNonCyclic(other, oldEntry.substParam(param, replacement))
val newEntry = oldEntry.substParam(param, replacement) match
case tp: TypeBounds => validBoundsFor(other, tp)
case tp => tp
current = boundsLens.update(this, current, other, newEntry)
var oldDepEntry = oldEntry
var newDepEntry = newEntry
Expand Down
24 changes: 24 additions & 0 deletions compiler/test/dotty/tools/dotc/core/ConstraintsTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,27 @@ class ConstraintsTest:
i"Merging constraints `?S <: ?T` and `Int <: ?S` should result in `Int <:< ?T`: ${ctx.typerState.constraint}")
}
end mergeBoundsTransitivity

@Test def validBoundsInit: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= defn.StringType, i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}") // used to be Any
}

@Test def validBoundsUnify: Unit = inCompilerContext(
TestConfiguration.basicClasspath,
scalaSources = "trait A { def foo[S >: T <: T | Int, T <: String | Int]: Any }") {
val tvars = constrained(requiredClass("A").typeRef.select("foo".toTermName).info.asInstanceOf[TypeLambda], EmptyTree, alwaysAddTypeVars = true)._2
val List(s, t) = tvars.tpes

s <:< t

val TypeBounds(lo, hi) = ctx.typerState.constraint.entry(t.asInstanceOf[TypeVar].origin): @unchecked
assert(lo =:= defn.NothingType, i"Unexpected lower bound $lo for $t: ${ctx.typerState.constraint}")
assert(hi =:= (defn.StringType | defn.IntType), i"Unexpected upper bound $hi for $t: ${ctx.typerState.constraint}")
}