Skip to content

Commit 614fe90

Browse files
committed
Track type variable dependencies to guide instantiation decisions
We now keep track of reverse type variable dependencies in constraints. E.g. is a constraint contains a clause like A >: List[B] We associate with `B` info that A depends co-variantly on it. Or, if A <: B => C we associate with `B` that `A` depends co-variantly on it and with `C` that `A` depends contra-variantly on it. These dependencies are then used to guide type variable instantiation. If an eligible type variable does not appear in the type of a typed expression, we interpolate it to one of its bounds. Previously this was done in an ad-hoc manner where we minimized the type variable if it had a lower bound and maximized it otherwise. We now take reverse dependencies into account. If maximizing a type variable would narrow the remaining constraint we minimize, and if minimizing a type variable would narrow the remaining constraint we maximize. Only if the type variable is not referred to from the remaining constraint we resort to the old heuristic based on the lower bound. Fixes scala#15864
1 parent 3d4275d commit 614fe90

File tree

6 files changed

+213
-27
lines changed

6 files changed

+213
-27
lines changed

compiler/src/dotty/tools/dotc/config/Config.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ object Config {
184184
/** If set, prints a trace of all symbol completions */
185185
inline val showCompletions = false
186186

187+
/** If set, show variable/variable reverse dependencoes when printing constraints. */
188+
inline val showConstraintDeps = true
189+
187190
/** If set, method results that are context functions are flattened by adding
188191
* the parameters of the context function results to the methods themselves.
189192
* This is an optimization that reduces closure allocations.

compiler/src/dotty/tools/dotc/core/Constraint.scala

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package core
44

55
import Types._, Contexts._
66
import printing.Showable
7+
import util.SimpleIdentityMap
78

89
/** Constraint over undetermined type parameters. Constraints are built
910
* over values of the following types:
@@ -128,7 +129,7 @@ abstract class Constraint extends Showable {
128129

129130
/** Is `tv` marked as hard in the constraint? */
130131
def isHard(tv: TypeVar): Boolean
131-
132+
132133
/** The same as this constraint, but with `tv` marked as hard. */
133134
def withHard(tv: TypeVar)(using Context): This
134135

@@ -165,6 +166,28 @@ abstract class Constraint extends Showable {
165166
*/
166167
def hasConflictingTypeVarsFor(tl: TypeLambda, that: Constraint): Boolean
167168

169+
/** A map that associates type variables with all other type variables that
170+
* refer to them in their bounds covariantly, such that, if the type variable
171+
* is isntantiated to a larger type, the constraint would be narrowed.
172+
*/
173+
def coDeps: Constraint.TypeVarDeps
174+
175+
/** A map that associates type variables with all other type variables that
176+
* refer to them in their bounds covariantly, such that, if the type variable
177+
* is isntantiated to a smaller type, the constraint would be narrowed.
178+
*/
179+
def contraDeps: Constraint.TypeVarDeps
180+
181+
/** A string showing the `coDeps` and `contraDeps` maps */
182+
def depsToString(using Context): String
183+
184+
/** Does the constraint restricted to variables outside `except` depend on `tv`
185+
* in the given direction `co`?
186+
* @param `co` If true, test whether the constraint would change if the variable is made larger
187+
* otherwise, test whether the constraint would change if the variable is made smaller.
188+
*/
189+
def dependsOn(tv: TypeVar, except: TypeVars, co: Boolean)(using Context): Boolean
190+
168191
/** Check that no constrained parameter contains itself as a bound */
169192
def checkNonCyclic()(using Context): this.type
170193

@@ -183,6 +206,10 @@ abstract class Constraint extends Showable {
183206
def checkConsistentVars()(using Context): Unit
184207
}
185208

209+
object Constraint:
210+
type TypeVarDeps = SimpleIdentityMap[TypeVar, TypeVars]
211+
212+
186213
/** When calling `Constraint#addLess(p1, p2, ...)`, the caller might end up
187214
* unifying one parameter with the other, this enum lets `addLess` know which
188215
* direction the unification will take.

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 139 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import printing.Texts._
1010
import config.Config
1111
import config.Printers.constr
1212
import reflect.ClassTag
13+
import Constraint.TypeVarDeps
14+
import NameKinds.DepParamName
1315
import annotation.tailrec
1416
import annotation.internal.sharable
1517
import cc.{CapturingType, derivedCapturingType}
@@ -148,6 +150,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
148150
else
149151
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap, hardVars)
150152
if ctx.run != null then ctx.run.nn.recordConstraintSize(result, result.boundsMap.size)
153+
result.coDeps = this.coDeps
154+
result.contraDeps = this.contraDeps
151155
result
152156

153157
// ----------- Basic indices --------------------------------------------------
@@ -219,6 +223,127 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
219223
if tvar == null then NoType
220224
else tvar
221225

226+
// ------------- TypeVar dependencies -----------------------------------
227+
228+
var coDeps, contraDeps: TypeVarDeps = SimpleIdentityMap.empty
229+
230+
def dependsOn(tv: TypeVar, except: TypeVars, co: Boolean)(using Context): Boolean =
231+
def test(deps: TypeVarDeps, lens: ConstraintLens[List[TypeParamRef]]) =
232+
val tvdeps = deps(tv)
233+
null != tvdeps && tvdeps.exists(!except.contains(_))
234+
|| lens(this, tv.origin.binder, tv.origin.paramNum).exists(
235+
p => !except.contains(typeVarOfParam(p)))
236+
//.showing(i"outer depends on $tv with ${tvdeps.toList}%, % = $result")
237+
if co then test(coDeps, upperLens) else test(contraDeps, lowerLens)
238+
239+
private class Adjuster(tvar: TypeVar)(using Context) extends TypeTraverser:
240+
var add: Boolean = compiletime.uninitialized
241+
242+
def update(deps: TypeVarDeps, referenced: TypeVar): TypeVarDeps =
243+
val entry = deps(referenced)
244+
val prev = if null == entry then SimpleIdentitySet.empty else entry
245+
val now = if add then prev + tvar else prev - tvar
246+
deps.updated(referenced, now)
247+
248+
def traverse(t: Type) = t match
249+
case tv: TypeVar =>
250+
val inst = tv.instanceOpt
251+
if inst.exists then traverse(inst)
252+
else
253+
if variance >= 0 then coDeps = update(coDeps, tv)
254+
if variance <= 0 then contraDeps = update(contraDeps, tv)
255+
case param: TypeParamRef =>
256+
traverse(typeVarOfParam(param))
257+
case tp: LazyRef if !tp.completed =>
258+
case _ =>
259+
traverseChildren(t)
260+
261+
/** Adjust dependencies to account for the delta of previous entry `prevEntry`
262+
* and new bound `entry` for the type variable `tvar`.
263+
*/
264+
def adjustDeps(entry: Type | Null, prevEntry: Type | Null, tvar: Type | Null)(using Context): this.type =
265+
tvar match
266+
case tvar: TypeVar =>
267+
val adjuster = new Adjuster(tvar)
268+
269+
/** Adjust reverse depemdencies of all type variables referenced by `bound`
270+
* @param isLower `bound` is a lower bound
271+
* @param add if true, add referenced variables to dependencoes, otherwise drop them.
272+
*/
273+
def adjustReferenced(bound: Type, isLower: Boolean, add: Boolean) =
274+
adjuster.variance = if isLower then 1 else -1
275+
adjuster.add = add
276+
adjuster.traverse(bound)
277+
278+
/** Use an optimized strategy to adjust dependencies to account for the delta
279+
* of previous bound `prevBound` and new bound `bound`: If `prevBound` is some
280+
* and/or prefix of `bound`, just add the new parts of `bound`.
281+
* @param isLower `bound` and `prevBound` are lower bounds
282+
*/
283+
def adjustDelta(bound: Type, prevBound: Type, isLower: Boolean): Boolean =
284+
if bound eq prevBound then true
285+
else bound match
286+
case bound: AndOrType =>
287+
adjustDelta(bound.tp1, prevBound, isLower) && {
288+
adjustReferenced(bound.tp2, isLower, add = true)
289+
true
290+
}
291+
case _ => false
292+
293+
/** Adjust dependencies to account for the delta of previous bound `prevBound`
294+
* and new bound `bound`.
295+
* @param isLower `bound` and `prevBound` are lower bounds
296+
*/
297+
def adjustBounds(bound: Type, prevBound: Type, isLower: Boolean) =
298+
if !adjustDelta(bound, prevBound, isLower) then
299+
adjustReferenced(prevBound, isLower, add = false)
300+
adjustReferenced(bound, isLower, add = true)
301+
302+
entry match
303+
case TypeBounds(lo, hi) =>
304+
prevEntry match
305+
case TypeBounds(plo, phi) =>
306+
adjustBounds(lo, plo, isLower = true)
307+
adjustBounds(hi, phi, isLower = false)
308+
case _ =>
309+
adjustReferenced(lo, isLower = true, add = true)
310+
adjustReferenced(hi, isLower = false, add = true)
311+
case _ =>
312+
prevEntry match
313+
case TypeBounds(plo, phi) =>
314+
adjustReferenced(plo, isLower = true, add = false)
315+
adjustReferenced(phi, isLower = false, add = false)
316+
case _ =>
317+
dropDeps(tvar)
318+
case _ =>
319+
this
320+
end adjustDeps
321+
322+
/** Adjust dependencies to account for adding or dropping `entries` to the
323+
* constraint.
324+
* @param add if true, entries is added, otherwise it is dropped
325+
*/
326+
def adjustDeps(entries: Array[Type], add: Boolean)(using Context): this.type =
327+
for n <- 0 until paramCount(entries) do
328+
if add
329+
then adjustDeps(entries(n), NoType, typeVar(entries, n))
330+
else adjustDeps(NoType, entries(n), typeVar(entries, n))
331+
this
332+
333+
/** If `tp` is a type variable, remove all its reverse dependencies */
334+
def dropDeps(tp: Type)(using Context): Unit = tp match
335+
case tv: TypeVar =>
336+
coDeps = coDeps.remove(tv)
337+
contraDeps = contraDeps.remove(tv)
338+
case _ =>
339+
340+
/** A string representing the two depenecy maps */
341+
def depsToString(using Context): String =
342+
def depsStr(deps: SimpleIdentityMap[TypeVar, TypeVars]): String =
343+
def depStr(tv: TypeVar) = i"$tv --> ${deps(tv).nn.toList}%, %"
344+
if deps.isEmpty then "" else i"\n ${deps.toList.map((k, v) => depStr(k))}%\n %"
345+
i"co-deps:${depsStr(coDeps)}\ncontra-deps:${depsStr(contraDeps)}\n"
346+
222347
// ---------- Adding TypeLambdas --------------------------------------------------
223348

224349
/** The bound type `tp` without constrained parameters which are clearly
@@ -286,6 +411,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
286411
tvars.copyToArray(entries1, nparams)
287412
newConstraint(boundsMap = this.boundsMap.updated(poly, entries1))
288413
.init(poly)
414+
.adjustDeps(entries1, add = true)
289415
}
290416

291417
/** Split dependent parameters off the bounds for parameters in `poly`.
@@ -432,6 +558,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
432558
private def updateEntry(current: This, param: TypeParamRef, tp: Type)(using Context): This = {
433559
if Config.checkNoWildcardsInConstraint then assert(!tp.containsWildcardTypes)
434560
var current1 = boundsLens.update(this, current, param, tp)
561+
current1.adjustDeps(tp, current.entry(param), typeVarOfParam(param))
435562
tp match {
436563
case TypeBounds(lo, hi) =>
437564
for p <- dependentParams(lo, isUpper = false) do
@@ -471,10 +598,15 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
471598
current.ensureNonCyclic(atPoly.paramRefs(atIdx), tp.substParam(param, replacement))
472599

473600
current.foreachParam { (p, i) =>
474-
current = boundsLens.map(this, current, p, i, replaceParam(_, p, i))
601+
current = boundsLens.map(this, current, p, i,
602+
entry =>
603+
val newEntry = replaceParam(entry, p, i)
604+
adjustDeps(newEntry, entry, typeVar(this.boundsMap(p).nn, i))
605+
newEntry)
475606
current = lowerLens.map(this, current, p, i, removeParam)
476607
current = upperLens.map(this, current, p, i, removeParam)
477608
}
609+
current.dropDeps(typeVarOfParam(param))
478610
current.checkNonCyclic()
479611
end replace
480612

@@ -489,6 +621,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
489621
val hardVars1 = pt.paramRefs.foldLeft(hardVars)((hvs, param) => hvs - typeVarOfParam(param))
490622
newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap), hardVars1)
491623
.checkNonCyclic()
624+
.adjustDeps(boundsMap(pt).nn, add = false)
492625
}
493626

494627
def isRemovable(pt: TypeLambda): Boolean = {
@@ -666,13 +799,16 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
666799
val constrainedText =
667800
" constrained types = " + domainLambdas.mkString("\n")
668801
val boundsText =
669-
" bounds = " + {
802+
"\n bounds = " + {
670803
val assocs =
671804
for (param <- domainParams)
672805
yield
673806
s"${param.binder.paramNames(param.paramNum)}: ${entryText(entry(param))}"
674807
assocs.mkString("\n")
675808
}
676-
constrainedText + "\n" + boundsText
809+
val depsText =
810+
"\n coDeps = " + coDeps +
811+
"\n contraDeps = " + contraDeps
812+
constrainedText + boundsText + depsText
677813
}
678814
}

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,9 @@ class PlainPrinter(_ctx: Context) extends Printer {
689689
Text(ups.map(toText), ", ")
690690
Text(deps, "\n")
691691
}
692+
val depsText = if Config.showConstraintDeps then c.depsToString else ""
692693
//Printer.debugPrintUnique = false
693-
Text.lines(List(uninstVarsText, constrainedText, boundsText, orderingText))
694+
Text.lines(List(uninstVarsText, constrainedText, boundsText, orderingText, depsText))
694695
finally
695696
ctx.typerState.constraint = savedConstraint
696697

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@ import core._
66
import ast._
77
import Contexts._, Types._, Flags._, Symbols._
88
import ProtoTypes._
9-
import NameKinds.{AvoidNameKind, UniqueName}
9+
import NameKinds.UniqueName
1010
import util.Spans._
11-
import util.{Stats, SimpleIdentityMap, SrcPos}
11+
import util.{Stats, SimpleIdentityMap, SimpleIdentitySet, SrcPos}
1212
import Decorators._
1313
import config.Printers.{gadts, typr}
1414
import annotation.tailrec
1515
import reporting._
1616
import collection.mutable
17-
1817
import scala.annotation.internal.sharable
1918

2019
object Inferencing {
@@ -619,7 +618,7 @@ trait Inferencing { this: Typer =>
619618
if state.reporter.hasUnreportedErrors then return tree
620619

621620
def constraint = state.constraint
622-
type InstantiateQueue = mutable.ListBuffer[(TypeVar, Boolean)]
621+
type InstantiateQueue = mutable.ListBuffer[(TypeVar, Int)]
623622
val toInstantiate = new InstantiateQueue
624623
for tvar <- qualifying do
625624
if !tvar.isInstantiated && constraint.contains(tvar) && tvar.nestingLevel >= ctx.nestingLevel then
@@ -628,24 +627,9 @@ trait Inferencing { this: Typer =>
628627
// instantiated `tvar` through unification.
629628
val v = vs(tvar)
630629
if v == null then
631-
// Even though `tvar` is non-occurring in `v`, the specific
632-
// instantiation we pick still matters because `tvar` might appear
633-
// in the bounds of a non-`qualifying` type variable in the
634-
// constraint.
635-
// In particular, if `tvar` was created as the upper or lower
636-
// bound of an existing variable by `LevelAvoidMap`, we
637-
// instantiate it in the direction corresponding to the
638-
// original variable which might be further constrained later.
639-
// Otherwise, we simply rely on `hasLowerBound`.
640-
val name = tvar.origin.paramName
641-
val fromBelow =
642-
name.is(AvoidNameKind.UpperBound) ||
643-
!name.is(AvoidNameKind.LowerBound) && tvar.hasLowerBound
644-
typr.println(i"interpolate non-occurring $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
645-
toInstantiate += ((tvar, fromBelow))
630+
toInstantiate += ((tvar, 0))
646631
else if v.intValue != 0 then
647-
typr.println(i"interpolate $tvar in $state in $tree: $tp, fromBelow = ${v.intValue == 1}, $constraint")
648-
toInstantiate += ((tvar, v.intValue == 1))
632+
toInstantiate += ((tvar, v.intValue))
649633
else comparing(cmp =>
650634
if !cmp.levelOK(tvar.nestingLevel, ctx.nestingLevel) then
651635
// Invariant: The type of a tree whose enclosing scope is level
@@ -686,10 +670,23 @@ trait Inferencing { this: Typer =>
686670
* V2 := V3, O2 := O3
687671
*/
688672
def doInstantiate(buf: InstantiateQueue): Unit =
673+
val varsToInstantiate = buf.foldLeft(SimpleIdentitySet.empty: TypeVars) {
674+
case (tvs, (tv, _)) => tvs + tv
675+
}
689676
if buf.nonEmpty then
690677
val suspended = new InstantiateQueue
691678
while buf.nonEmpty do
692-
val first @ (tvar, fromBelow) = buf.head
679+
val first @ (tvar, v) = buf.head
680+
val fromBelow =
681+
if v == 0 then
682+
val aboveOK = !constraint.dependsOn(tvar, varsToInstantiate, co = true)
683+
val belowOK = !constraint.dependsOn(tvar, varsToInstantiate, co = false)
684+
if aboveOK == belowOK then tvar.hasLowerBound
685+
else belowOK
686+
else
687+
v == 1
688+
typr.println(
689+
i"interpolate${if v == 0 then "non-occurring" else ""} $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
693690
buf.dropInPlace(1)
694691
if !tvar.isInstantiated then
695692
val suspend = buf.exists{ (following, _) =>

tests/pos/i15864.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
object Test:
2+
def op[O, P](ta: List[O], tb: List[P]): List[P] = ???
3+
4+
class Graph { class Node }
5+
6+
def outsQ(using g: Graph): List[List[g.Node]] = ???
7+
8+
object aGraph extends Graph
9+
given implA: aGraph.type = aGraph
10+
11+
val q1: List[List[aGraph.Node]] = op(outsQ, op(outsQ, outsQ))
12+
implicitly[q1.type <:< List[List[aGraph.Node]]]
13+
14+
val a1 = outsQ
15+
val a2 = op(outsQ, outsQ)
16+
val a3 = op(a1, a2)
17+
18+
val q2 = op(outsQ, op(outsQ, outsQ))
19+
val q3: List[List[aGraph.Node]] = q2
20+
implicitly[q2.type <:< List[List[aGraph.Node]]]
21+
22+

0 commit comments

Comments
 (0)