Skip to content

Commit d37bc85

Browse files
committed
Create fresh type variables to keep constraints level-correct
This completes the implementation of `LevelAvoidMap` from the previous commit: we now make sure that the nonParamBounds of a type variable does not refer to variables of a higher level by creating fresh variables of the appropriate level if necessary. Each fresh variable will be upper- or lower-bounded by the existing variable it is substituted for depending on variance (an idea that I got from [1]), in the invariant case the existing variable will be instantiated to the fresh one (unlike [2], we can't simply mutate the nestingLevel of the existing variable after running avoidance on its bounds because the constraint containing these bounds might later be retracted). Additionally: - When unifying two type variables, keep the one with the lowest level in the constraint set and make sure the bounds transferred from the other one are level-correct. This required some changes in `Constraint#addLess` which previously assumed that `unify` would always keep the second parameter. - When instantiating a type variable to its full lower- or upper-bound, we also need to avoid any type variable of a higher level among its param bound. This commit is necessary to avoid leaking local types in i8900a2.scala and i8900a3.scala, these kind of leaks will become compile-time error in the next commit. This commit required making a type parameter explicit both in SnippetChecker.scala and i13809/Macros_1.scala, in both situations the problem is that the lambda passed to `map` can only be typed if the type argument of `map` contains a wildcard, but LevelAvoidMap instead creates a fresh type variable of a lower level at a point where we don't know yet that this cannot work. Since this situation seems very rare in practice, I believe this is an acceptable trade-off for soundness. [1]: Lionel Parreaux. "The simple essence of algebraic subtyping: principal type inference with subtyping made easy (functional pearl)." https://dl.acm.org/doi/abs/10.1145/3409006 [2]: Oleg Kiselyov. "How OCaml type checker works -- or what polymorphism and garbage collection have in common" https://okmij.org/ftp/ML/generalization.html
1 parent 544cda6 commit d37bc85

21 files changed

+382
-133
lines changed

compiler/src/dotty/tools/dotc/core/Constraint.scala

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,15 @@ abstract class Constraint extends Showable {
9393
/** A constraint that includes the relationship `p1 <: p2`.
9494
* `<:` relationships between parameters ("edges") are propagated, but
9595
* non-parameter bounds are left alone.
96+
*
97+
* @param direction Must be set to `KeepParam1` or `KeepParam2` when
98+
* `p2 <: p1` is already true depending on which parameter
99+
* the caller intends to keep. This will avoid propagating
100+
* bounds that will be redundant after `p1` and `p2` are
101+
* unified.
96102
*/
97-
def addLess(p1: TypeParamRef, p2: TypeParamRef)(using Context): This
98-
99-
/** A constraint resulting from adding p2 = p1 to this constraint, and at the same
100-
* time transferring all bounds of p2 to p1
101-
*/
102-
def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): This
103+
def addLess(p1: TypeParamRef, p2: TypeParamRef,
104+
direction: UnificationDirection = UnificationDirection.NoUnification)(using Context): This
103105

104106
/** A new constraint which is derived from this constraint by removing
105107
* the type parameter `param` from the domain and replacing all top-level occurrences
@@ -174,3 +176,15 @@ abstract class Constraint extends Showable {
174176
*/
175177
def checkConsistentVars()(using Context): Unit
176178
}
179+
180+
/** When calling `Constraint#addLess(p1, p2, ...)`, the caller might end up
181+
* unifying one parameter with the other, this enum lets `addLess` know which
182+
* direction the unification will take.
183+
*/
184+
enum UnificationDirection:
185+
/** Neither p1 nor p2 will be instantiated. */
186+
case NoUnification
187+
/** `p2 := p1`, p1 left uninstantiated. */
188+
case KeepParam1
189+
/** `p1 := p2`, p2 left uninstantiated. */
190+
case KeepParam2

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 134 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ import Flags._
1010
import config.Config
1111
import config.Printers.typr
1212
import reporting.trace
13-
import typer.ProtoTypes.newTypeVar
13+
import typer.ProtoTypes.{newTypeVar, representedParamRef}
1414
import StdNames.tpnme
15+
import UnificationDirection.*
16+
import NameKinds.AvoidNameKind
1517

1618
/** Methods for adding constraints and solving them.
1719
*
@@ -85,13 +87,30 @@ trait ConstraintHandling {
8587
case tv: TypeVar => tv.nestingLevel
8688
case _ => Int.MaxValue
8789

90+
/** If `param` is nested deeper than `maxLevel`, try to instantiate it to a
91+
* fresh type variable of level `maxLevel` and return the new variable.
92+
* If this isn't possible, throw a TypeError.
93+
*/
94+
def atLevel(maxLevel: Int, param: TypeParamRef)(using Context): TypeParamRef =
95+
if nestingLevel(param) <= maxLevel then return param
96+
LevelAvoidMap(0, maxLevel)(param) match
97+
case freshVar: TypeVar => freshVar.origin
98+
case _ => throw new TypeError(
99+
i"Could not decrease the nesting level of ${param} from ${nestingLevel(param)} to $maxLevel in $constraint")
100+
88101
def nonParamBounds(param: TypeParamRef)(using Context): TypeBounds = constraint.nonParamBounds(param)
89102

90-
def fullLowerBound(param: TypeParamRef)(using Context): Type =
91-
constraint.minLower(param).foldLeft(nonParamBounds(param).lo)(_ | _)
103+
def fullLowerBound(param: TypeParamRef, maxLevel: Int = Int.MaxValue)(using Context): Type =
104+
var loParams = constraint.minLower(param)
105+
if maxLevel != Int.MaxValue then
106+
loParams = loParams.mapConserve(atLevel(maxLevel, _))
107+
loParams.foldLeft(nonParamBounds(param).lo)(_ | _)
92108

93-
def fullUpperBound(param: TypeParamRef)(using Context): Type =
94-
constraint.minUpper(param).foldLeft(nonParamBounds(param).hi)(_ & _)
109+
def fullUpperBound(param: TypeParamRef, maxLevel: Int = Int.MaxValue)(using Context): Type =
110+
var hiParams = constraint.minUpper(param)
111+
if maxLevel != Int.MaxValue then
112+
hiParams = hiParams.mapConserve(atLevel(maxLevel, _))
113+
hiParams.foldLeft(nonParamBounds(param).hi)(_ & _)
95114

96115
/** Full bounds of `param`, including other lower/upper params.
97116
*
@@ -116,10 +135,64 @@ trait ConstraintHandling {
116135
def toAvoid(tp: NamedType): Boolean =
117136
tp.prefix == NoPrefix && !tp.symbol.isStatic && !levelOK(tp.symbol.nestingLevel)
118137

138+
/** Return a (possibly fresh) type variable of a level no greater than `maxLevel` which is:
139+
* - lower-bounded by `tp` if variance >= 0
140+
* - upper-bounded by `tp` if variance <= 0
141+
* If this isn't possible, return the empty range.
142+
*/
143+
def legalVar(tp: TypeVar): Type =
144+
val oldParam = tp.origin
145+
val nameKind =
146+
if variance > 0 then AvoidNameKind.UpperBound
147+
else if variance < 0 then AvoidNameKind.LowerBound
148+
else AvoidNameKind.BothBounds
149+
150+
/** If it exists, return the first param in the list created in a previous call to `legalVar(tp)`
151+
* with the appropriate level and variance.
152+
*/
153+
def findParam(params: List[TypeParamRef]): Option[TypeParamRef] =
154+
params.find(p =>
155+
nestingLevel(p) <= maxLevel && representedParamRef(p) == oldParam &&
156+
(p.paramName.is(AvoidNameKind.BothBounds) ||
157+
variance != 0 && p.paramName.is(nameKind)))
158+
159+
// First, check if we can reuse an existing parameter, this is more than an optimization
160+
// since it avoids an infinite loop in tests/pos/i8900-cycle.scala
161+
findParam(constraint.lower(oldParam)).orElse(findParam(constraint.upper(oldParam))) match
162+
case Some(param) =>
163+
constraint.typeVarOfParam(param)
164+
case _ =>
165+
// Otherwise, try to return a fresh type variable at `maxLevel` with
166+
// the appropriate constraints.
167+
val name = nameKind(oldParam.paramName.toTermName).toTypeName
168+
val freshVar = newTypeVar(TypeBounds.upper(tp.topType), name,
169+
nestingLevel = maxLevel, represents = oldParam)
170+
val ok =
171+
if variance < 0 then
172+
addLess(freshVar.origin, oldParam)
173+
else if variance > 0 then
174+
addLess(oldParam, freshVar.origin)
175+
else
176+
unify(freshVar.origin, oldParam)
177+
if ok then freshVar else emptyRange
178+
end legalVar
179+
180+
override def apply(tp: Type): Type = tp match
181+
case tp: TypeVar if !tp.isInstantiated && !levelOK(tp.nestingLevel) =>
182+
legalVar(tp)
183+
// TypeParamRef can occur in tl bounds
184+
case tp: TypeParamRef =>
185+
constraint.typeVarOfParam(tp) match
186+
case tvar: TypeVar =>
187+
apply(tvar)
188+
case _ => super.apply(tp)
189+
case _ =>
190+
super.apply(tp)
191+
119192
override def mapWild(t: WildcardType) =
120193
if ctx.mode.is(Mode.TypevarsMissContext) then super.mapWild(t)
121194
else
122-
val tvar = newTypeVar(apply(t.effectiveBounds).toBounds)
195+
val tvar = newTypeVar(apply(t.effectiveBounds).toBounds, nestingLevel = maxLevel)
123196
tvar
124197
end LevelAvoidMap
125198

@@ -151,7 +224,16 @@ trait ConstraintHandling {
151224
// flip the variance to under-approximate.
152225
if necessaryConstraintsOnly then variance = -variance
153226

154-
val approx = LevelAvoidMap(variance, level)
227+
val approx = new LevelAvoidMap(variance, nestingLevel(param)):
228+
override def legalVar(tp: TypeVar): Type =
229+
// `legalVar` will create a type variable whose bounds depend on
230+
// `variance`, but whether the variance is positive or negative,
231+
// we can still infer necessary constraints since just creating a
232+
// type variable doesn't reduce the set of possible solutions.
233+
// Therefore, we can safely "unflip" the variance flipped above.
234+
// This is necessary for i8900-unflip.scala to typecheck.
235+
val v = if necessaryConstraintsOnly then -this.variance else this.variance
236+
atVariance(v)(super.legalVar(tp))
155237
approx(rawBound)
156238

157239
val oldBounds @ TypeBounds(lo, hi) = constraint.nonParamBounds(param)
@@ -248,19 +330,50 @@ trait ConstraintHandling {
248330

249331
def location(using Context) = "" // i"in ${ctx.typerState.stateChainStr}" // use for debugging
250332

251-
/** Make p2 = p1, transfer all bounds of p2 to p1
252-
* @pre less(p1)(p2)
333+
/** Unify p1 with p2: one parameter will be kept in the constraint, the
334+
* other will be removed and its bounds transferred to the remaining one.
335+
*
336+
* If p1 and p2 have different `nestingLevel`, the parameter with the lowest
337+
* level will be kept and the transferred bounds from the other parameter
338+
* will be adjusted for level-correctness.
253339
*/
254340
private def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): Boolean = {
255341
constr.println(s"unifying $p1 $p2")
256-
assert(constraint.isLess(p1, p2))
257-
constraint = constraint.addLess(p2, p1)
342+
if !constraint.isLess(p1, p2) then
343+
constraint = constraint.addLess(p1, p2)
344+
345+
val level1 = nestingLevel(p1)
346+
val level2 = nestingLevel(p2)
347+
val pKept = if level1 <= level2 then p1 else p2
348+
val pRemoved = if level1 <= level2 then p2 else p1
349+
350+
constraint = constraint.addLess(p2, p1, direction = if pKept eq p1 then KeepParam2 else KeepParam1)
351+
352+
val boundKept = constraint.nonParamBounds(pKept).substParam(pRemoved, pKept)
353+
var boundRemoved = constraint.nonParamBounds(pRemoved).substParam(pRemoved, pKept)
354+
355+
if level1 != level2 then
356+
boundRemoved = LevelAvoidMap(-1, math.min(level1, level2))(boundRemoved)
357+
val TypeBounds(lo, hi) = boundRemoved
358+
// After avoidance, the interval might be empty, e.g. in
359+
// tests/pos/i8900-promote.scala:
360+
// >: x.type <: Singleton
361+
// becomes:
362+
// >: Int <: Singleton
363+
// In that case, we can still get a legal constraint
364+
// by replacing the lower-bound to get:
365+
// >: Int & Singleton <: Singleton
366+
if !isSub(lo, hi) then
367+
boundRemoved = TypeBounds(lo & hi, hi)
368+
258369
val down = constraint.exclusiveLower(p2, p1)
259370
val up = constraint.exclusiveUpper(p1, p2)
260-
constraint = constraint.unify(p1, p2)
261-
val bounds = constraint.nonParamBounds(p1)
262-
val lo = bounds.lo
263-
val hi = bounds.hi
371+
372+
val newBounds = (boundKept & boundRemoved).bounds
373+
constraint = constraint.updateEntry(pKept, newBounds).replace(pRemoved, pKept)
374+
375+
val lo = newBounds.lo
376+
val hi = newBounds.hi
264377
isSub(lo, hi) &&
265378
down.forall(addOneBound(_, hi, isUpper = true)) &&
266379
up.forall(addOneBound(_, lo, isUpper = false))
@@ -313,8 +426,13 @@ trait ConstraintHandling {
313426
final def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
314427
constraint.entry(param) match
315428
case entry: TypeBounds =>
429+
val maxLevel = nestingLevel(param)
316430
val useLowerBound = fromBelow || param.occursIn(entry.hi)
317-
val inst = if useLowerBound then fullLowerBound(param) else fullUpperBound(param)
431+
val inst =
432+
if useLowerBound then
433+
fullLowerBound(param, maxLevel)
434+
else
435+
fullUpperBound(param, maxLevel)
318436
typr.println(s"approx ${param.show}, from below = $fromBelow, inst = ${inst.show}")
319437
inst
320438
case inst =>

compiler/src/dotty/tools/dotc/core/GadtConstraint.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import collection.mutable
1111
import printing._
1212

1313
import scala.annotation.internal.sharable
14+
import scala.annotation.unused
1415

1516
/** Represents GADT constraints currently in scope */
1617
sealed abstract class GadtConstraint extends Showable {
@@ -232,12 +233,14 @@ final class ProperGadtConstraint private(
232233
}
233234
externalizeMap(constraint.nonParamBounds(param)).bounds
234235

235-
override def fullLowerBound(param: TypeParamRef)(using Context): Type =
236+
override def fullLowerBound(param: TypeParamRef,
237+
@unused("no level checking needed in GADT constraints") maxLevel: Int)(using Context): Type =
236238
constraint.minLower(param).foldLeft(nonParamBounds(param).lo) {
237239
(t, u) => t | externalize(u)
238240
}
239241

240-
override def fullUpperBound(param: TypeParamRef)(using Context): Type =
242+
override def fullUpperBound(param: TypeParamRef,
243+
@unused("no level checking needed in GADT constraints") maxLevel: Int)(using Context): Type =
241244
constraint.minUpper(param).foldLeft(nonParamBounds(param).hi) { (t, u) =>
242245
val eu = externalize(u)
243246
// Any as the upper bound means "no bound", but if F is higher-kinded,

compiler/src/dotty/tools/dotc/core/NameKinds.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,13 @@ object NameKinds {
358358
val ProtectedAccessorName: PrefixNameKind = new PrefixNameKind(PROTECTEDACCESSOR, "protected$")
359359
val InlineAccessorName: PrefixNameKind = new PrefixNameKind(INLINEACCESSOR, "inline$")
360360

361+
/** See `ConstraintHandling#LevelAvoidMap`. */
362+
enum AvoidNameKind(tag: Int, prefix: String) extends PrefixNameKind(tag, prefix):
363+
override def definesNewName = true
364+
case UpperBound extends AvoidNameKind(AVOIDUPPER, "(upper)")
365+
case LowerBound extends AvoidNameKind(AVOIDLOWER, "(lower)")
366+
case BothBounds extends AvoidNameKind(AVOIDBOTH, "(avoid)")
367+
361368
val BodyRetainerName: SuffixNameKind = new SuffixNameKind(BODYRETAINER, "$retainedBody")
362369
val FieldName: SuffixNameKind = new SuffixNameKind(FIELD, "$$local") {
363370
override def mkString(underlying: TermName, info: ThisInfo) = underlying.toString

compiler/src/dotty/tools/dotc/core/NameTags.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,11 @@ object NameTags extends TastyFormat.NameTags {
3232

3333
inline val SETTER = 34 // A synthesized += suffix.
3434

35+
// Name of type variables created by `ConstraintHandling#LevelAvoidMap`.
36+
final val AVOIDUPPER = 35
37+
final val AVOIDLOWER = 36
38+
final val AVOIDBOTH = 37
39+
3540
def nameTagToString(tag: Int): String = tag match {
3641
case UTF8 => "UTF8"
3742
case QUALIFIED => "QUALIFIED"

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
134134
private val lowerMap : ParamOrdering,
135135
private val upperMap : ParamOrdering) extends Constraint {
136136

137+
import UnificationDirection.*
138+
137139
type This = OrderingConstraint
138140

139141
// ----------- Basic indices --------------------------------------------------
@@ -350,29 +352,37 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
350352
/** Add the fact `param1 <: param2` to the constraint `current` and propagate
351353
* `<:<` relationships between parameters ("edges") but not bounds.
352354
*/
353-
private def order(current: This, param1: TypeParamRef, param2: TypeParamRef)(using Context): This =
355+
def order(current: This, param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection = NoUnification)(using Context): This =
354356
if (param1 == param2 || current.isLess(param1, param2)) this
355357
else {
356358
assert(contains(param1), i"$param1")
357359
assert(contains(param2), i"$param2")
358-
// Is `order` called during parameter unification?
359-
val unifying = isLess(param2, param1)
360+
val unifying = direction != NoUnification
360361
val newUpper = {
361362
val up = exclusiveUpper(param2, param1)
362363
if unifying then
363364
// Since param2 <:< param1 already holds now, filter out param1 to avoid adding
364365
// duplicated orderings.
365-
param2 :: up.filterNot(_ eq param1)
366+
val filtered = up.filterNot(_ eq param1)
367+
// Only add bounds for param2 if it will be kept in the constraint after unification.
368+
if direction == KeepParam2 then
369+
param2 :: filtered
370+
else
371+
filtered
366372
else
367373
param2 :: up
368374
}
369375
val newLower = {
370376
val lower = exclusiveLower(param1, param2)
371377
if unifying then
372-
// Do not add bounds for param1 since it will be unified to param2 soon.
373-
// And, similarly filter out param2 from lowerly-ordered parameters
378+
// Similarly, filter out param2 from lowerly-ordered parameters
374379
// to avoid duplicated orderings.
375-
lower.filterNot(_ eq param2)
380+
val filtered = lower.filterNot(_ eq param2)
381+
// Only add bounds for param1 if it will be kept in the constraint after unification.
382+
if direction == KeepParam1 then
383+
param1 :: filtered
384+
else
385+
filtered
376386
else
377387
param1 :: lower
378388
}
@@ -416,14 +426,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
416426
def updateEntry(param: TypeParamRef, tp: Type)(using Context): This =
417427
updateEntry(this, param, ensureNonCyclic(param, tp)).checkNonCyclic()
418428

419-
def addLess(param1: TypeParamRef, param2: TypeParamRef)(using Context): This =
420-
order(this, param1, param2).checkNonCyclic()
421-
422-
def unify(p1: TypeParamRef, p2: TypeParamRef)(using Context): This =
423-
val bound1 = nonParamBounds(p1).substParam(p2, p1)
424-
val bound2 = nonParamBounds(p2).substParam(p2, p1)
425-
val p1Bounds = bound1 & bound2
426-
updateEntry(p1, p1Bounds).replace(p2, p1)
429+
def addLess(param1: TypeParamRef, param2: TypeParamRef, direction: UnificationDirection)(using Context): This =
430+
order(this, param1, param2, direction).checkNonCyclic()
427431

428432
// ---------- Replacements and Removals -------------------------------------
429433

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,16 @@ class TypeApplications(val self: Type) extends AnyVal {
231231
(alias ne self) && alias.hasSimpleKind
232232
}
233233

234+
/** The top type with the same kind as `self`. */
235+
def topType(using Context): Type =
236+
if self.hasSimpleKind then
237+
defn.AnyType
238+
else EtaExpand(self.typeParams) match
239+
case tp: HKTypeLambda =>
240+
tp.derivedLambdaType(resType = tp.resultType.topType)
241+
case _ =>
242+
defn.AnyKindType
243+
234244
/** If self type is higher-kinded, its result type, otherwise NoType.
235245
* Note: The hkResult of an any-kinded type is again AnyKind.
236246
*/

0 commit comments

Comments
 (0)