Skip to content

Commit 0e4adb8

Browse files
committed
Simplify constraint handling
1 parent f4de03f commit 0e4adb8

File tree

2 files changed

+59
-76
lines changed

2 files changed

+59
-76
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 57 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import Symbols._
88
import Decorators._
99
import Flags._
1010
import config.Config
11-
import config.Printers.{constr, typr}
11+
import config.Printers.typr
1212
import dotty.tools.dotc.reporting.trace
1313

1414
/** Methods for adding constraints and solving them.
@@ -24,8 +24,7 @@ import dotty.tools.dotc.reporting.trace
2424
*/
2525
trait ConstraintHandling[AbstractContext] {
2626

27-
def constr_println(msg: => String): Unit = constr.println(msg)
28-
def typr_println(msg: => String): Unit = typr.println(msg)
27+
def constr: config.Printers.Printer = config.Printers.constr
2928

3029
implicit def ctx(implicit ac: AbstractContext): Context
3130

@@ -86,7 +85,41 @@ trait ConstraintHandling[AbstractContext] {
8685
def fullBounds(param: TypeParamRef)(implicit actx: AbstractContext): TypeBounds =
8786
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))
8887

89-
protected def addOneBound(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean =
88+
protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(using AbstractContext): Boolean =
89+
if !constraint.contains(param) then true
90+
else
91+
val oldBounds @ TypeBounds(lo, hi) = constraint.nonParamBounds(param)
92+
val equalBounds = (if isUpper then lo else hi) eq bound
93+
if equalBounds
94+
&& !bound.existsPart(bp => bp.isInstanceOf[WildcardType] || (bp eq param))
95+
then
96+
// The narrowed bounds are equal and do not contain wildcards,
97+
// so we can remove `param` from the constraint.
98+
// (Handling wildcards requires choosing a bound, but we don't know which
99+
// bound to choose here, this is handled in `ConstraintHandling#approximation`)
100+
constraint = constraint.replace(param, bound)
101+
true
102+
else
103+
// Narrow one of the bounds of type parameter `param`
104+
// If `isUpper` is true, ensure that `param <: `bound`, otherwise ensure
105+
// that `param >: bound`.
106+
val narrowedBounds =
107+
val saved = homogenizeArgs
108+
homogenizeArgs = Config.alignArgsInAnd
109+
try
110+
if isUpper then oldBounds.derivedTypeBounds(lo, hi & bound)
111+
else oldBounds.derivedTypeBounds(lo | bound, hi)
112+
finally homogenizeArgs = saved
113+
val c1 = constraint.updateEntry(param, narrowedBounds)
114+
(c1 eq constraint)
115+
|| {
116+
constraint = c1
117+
val TypeBounds(lo, hi) = constraint.entry(param)
118+
isSubType(lo, hi)
119+
}
120+
end addOneBound
121+
122+
protected def addBoundTransitively(param: TypeParamRef, rawBound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean =
90123

91124
/** Adjust the bound `tp` in the following ways:
92125
*
@@ -119,69 +152,19 @@ trait ConstraintHandling[AbstractContext] {
119152
case _ =>
120153
tp
121154

122-
if !constraint.contains(param) then true
123-
else
124-
val bound = adjust(rawBound)
125-
126-
val oldBounds @ TypeBounds(lo, hi) = constraint.nonParamBounds(param)
127-
val equalBounds = isUpper && (lo eq bound) || !isUpper && (bound eq hi)
128-
if !bound.exists then false
129-
else if equalBounds
130-
&& !bound.existsPart(bp => bp.isInstanceOf[WildcardType] || (bp eq param))
131-
then
132-
// The narrowed bounds are equal and do not contain wildcards,
133-
// so we can remove `param` from the constraint.
134-
// (Handling wildcards requires choosing a bound, but we don't know which
135-
// bound to choose here, this is handled in `ConstraintHandling#approximation`)
136-
constraint = constraint.replace(param, bound)
137-
true
138-
else
139-
// Narrow one of the bounds of type parameter `param`
140-
// If `isUpper` is true, ensure that `param <: `bound`, otherwise ensure
141-
// that `param >: bound`.
142-
val narrowedBounds =
143-
val saved = homogenizeArgs
144-
homogenizeArgs = Config.alignArgsInAnd
145-
try
146-
if (isUpper) oldBounds.derivedTypeBounds(lo, hi & bound)
147-
else oldBounds.derivedTypeBounds(lo | bound, hi)
148-
finally homogenizeArgs = saved
149-
val c1 = constraint.updateEntry(param, narrowedBounds)
150-
(c1 eq constraint) || {
151-
constraint = c1
152-
val TypeBounds(lo, hi) = constraint.entry(param)
153-
isSubType(lo, hi)
154-
}
155-
end addOneBound
156-
157-
private def location(implicit ctx: Context) = "" // i"in ${ctx.typerState.stateChainStr}" // use for debugging
158-
159-
protected def addUpperBound(param: TypeParamRef, bound: Type)(implicit actx: AbstractContext): Boolean = {
160-
def description = i"constraint $param <: $bound to\n$constraint"
161-
if (bound.isRef(defn.NothingClass) && ctx.typerState.isGlobalCommittable) {
162-
def msg = s"!!! instantiated to Nothing: $param, constraint = ${constraint.show}"
163-
if (Config.failOnInstantiationToNothing) assert(false, msg)
155+
def description = i"constraint $param ${if isUpper then "<:" else ":>"} $rawBound to\n$constraint"
156+
constr.println(i"adding $description$location")
157+
if isUpper && rawBound.isRef(defn.NothingClass) && ctx.typerState.isGlobalCommittable then
158+
def msg = i"!!! instantiated to Nothing: $param, constraint = $constraint"
159+
if Config.failOnInstantiationToNothing
160+
then assert(false, msg)
164161
else ctx.log(msg)
165-
}
166-
constr_println(i"adding $description$location")
167-
val lower = constraint.lower(param)
168-
val res =
169-
addOneBound(param, bound, isUpper = true) &&
170-
lower.forall(addOneBound(_, bound, isUpper = true))
171-
constr_println(i"added $description = $res$location")
172-
res
173-
}
174-
175-
protected def addLowerBound(param: TypeParamRef, bound: Type)(implicit actx: AbstractContext): Boolean = {
176-
def description = i"constraint $param >: $bound to\n$constraint"
177-
constr_println(i"adding $description")
178-
val upper = constraint.upper(param)
179-
val res =
180-
addOneBound(param, bound, isUpper = false) &&
181-
upper.forall(addOneBound(_, bound, isUpper = false))
182-
constr_println(i"added $description = $res$location")
183-
res
184-
}
162+
def others = if isUpper then constraint.lower(param) else constraint.upper(param)
163+
val bound = adjust(rawBound)
164+
bound.exists
165+
&& addOneBound(param, bound, isUpper) && others.forall(addOneBound(_, bound, isUpper))
166+
.reporting(i"added $description = $result$location", constr)
167+
end addBoundTransitively
185168

186169
protected def addLess(p1: TypeParamRef, p2: TypeParamRef)(implicit actx: AbstractContext): Boolean = {
187170
def description = i"ordering $p1 <: $p2 to\n$constraint"
@@ -192,20 +175,22 @@ trait ConstraintHandling[AbstractContext] {
192175
val up2 = p2 :: constraint.exclusiveUpper(p2, p1)
193176
val lo1 = constraint.nonParamBounds(p1).lo
194177
val hi2 = constraint.nonParamBounds(p2).hi
195-
constr_println(i"adding $description down1 = $down1, up2 = $up2$location")
178+
constr.println(i"adding $description down1 = $down1, up2 = $up2$location")
196179
constraint = constraint.addLess(p1, p2)
197180
down1.forall(addOneBound(_, hi2, isUpper = true)) &&
198181
up2.forall(addOneBound(_, lo1, isUpper = false))
199182
}
200-
constr_println(i"added $description = $res$location")
183+
constr.println(i"added $description = $res$location")
201184
res
202185
}
203186

187+
def location(implicit ctx: Context) = "" // i"in ${ctx.typerState.stateChainStr}" // use for debugging
188+
204189
/** Make p2 = p1, transfer all bounds of p2 to p1
205190
* @pre less(p1)(p2)
206191
*/
207192
private def unify(p1: TypeParamRef, p2: TypeParamRef)(implicit actx: AbstractContext): Boolean = {
208-
constr_println(s"unifying $p1 $p2")
193+
constr.println(s"unifying $p1 $p2")
209194
assert(constraint.isLess(p1, p2))
210195
val down = constraint.exclusiveLower(p2, p1)
211196
val up = constraint.exclusiveUpper(p1, p2)
@@ -301,7 +286,7 @@ trait ConstraintHandling[AbstractContext] {
301286
case _: TypeBounds =>
302287
val bound = if (fromBelow) fullLowerBound(param) else fullUpperBound(param)
303288
val inst = avoidParam(bound)
304-
typr_println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}")
289+
typr.println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}")
305290
inst
306291
case inst =>
307292
assert(inst.exists, i"param = $param\nconstraint = $constraint")
@@ -407,7 +392,7 @@ trait ConstraintHandling[AbstractContext] {
407392
val upper = constraint.upper(param)
408393
if lower.nonEmpty && !bounds.lo.isRef(defn.NothingClass)
409394
|| upper.nonEmpty && !bounds.hi.isAny
410-
then constr_println(i"INIT*** $tl")
395+
then constr.println(i"INIT*** $tl")
411396
lower.forall(addOneBound(_, bounds.hi, isUpper = true)) &&
412397
upper.forall(addOneBound(_, bounds.lo, isUpper = false))
413398
case _ =>
@@ -491,8 +476,7 @@ trait ConstraintHandling[AbstractContext] {
491476
addParamBound(bound)
492477
case _ =>
493478
val pbound = avoidLambdaParams(bound)
494-
kindCompatible(param, pbound)
495-
&& (if fromBelow then addLowerBound(param, pbound) else addUpperBound(param, pbound))
479+
kindCompatible(param, pbound) && addBoundTransitively(param, pbound, !fromBelow)
496480
finally addConstraintInvocations -= 1
497481
}
498482
end addConstraint

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,7 @@ final class ProperGadtConstraint private(
155155
else if (isUpper) addLess(symTvar.origin, boundTvar.origin)
156156
else addLess(boundTvar.origin, symTvar.origin)
157157
case bound =>
158-
if (isUpper) addUpperBound(symTvar.origin, bound)
159-
else addLowerBound(symTvar.origin, bound)
158+
addBoundTransitively(symTvar.origin, bound, isUpper)
160159
}
161160
).reporting({
162161
val descr = if (isUpper) "upper" else "lower"
@@ -271,7 +270,7 @@ final class ProperGadtConstraint private(
271270

272271
// ---- Debug ------------------------------------------------------------
273272

274-
override def constr_println(msg: => String): Unit = gadtsConstr.println(msg)
273+
override def constr = gadtsConstr
275274

276275
override def toText(printer: Printer): Texts.Text = constraint.toText(printer)
277276

0 commit comments

Comments
 (0)