Skip to content

Commit 43f70cf

Browse files
committed
Track type variable dependencies to guide instantiation decisions
We now keep track of reverse type variable dependencies in constraints. E.g. is a constraint contains a clause like A >: List[B] We associate with `B` info that A depends co-variantly on it. Or, if A <: B => C we associate with `B` that `A` depends co-variantly on it and with `C` that `A` depends contra-variantly on it. These dependencies are then used to guide type variable instantiation. If an eligible type variable does not appear in the type of a typed expression, we interpolate it to one of its bounds. Previously this was done in an ad-hoc manner where we minimized the type variable if it had a lower bound and maximized it otherwise. We now take reverse dependencies into account. If maximizing a type variable would narrow the remaining constraint we minimize, and if minimizing a type variable would narrow the remaining constraint we maximize. Only if the type variable is not referred to from the remaining constraint we resort to the old heuristic based on the lower bound. Fixes scala#15864 Todo: This could be generalized in several directions: - We could base the dependency tracking on type param refs instead of type variables. That could make the `replace` operation in a constraint more efficient. - We could base more interpolation decisions on dependencies. E.g. we could interpolate a type variable only if both the type an expression and the enclosing constraint agree in which direction this should be done.
1 parent cecaffe commit 43f70cf

File tree

6 files changed

+199
-7
lines changed

6 files changed

+199
-7
lines changed

compiler/src/dotty/tools/dotc/config/Config.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,9 @@ object Config {
184184
/** If set, prints a trace of all symbol completions */
185185
inline val showCompletions = false
186186

187+
/** If set, show variable/variable reverse dependencoes when printing constraints. */
188+
inline val showConstraintDeps = true
189+
187190
/** If set, method results that are context functions are flattened by adding
188191
* the parameters of the context function results to the methods themselves.
189192
* This is an optimization that reduces closure allocations.

compiler/src/dotty/tools/dotc/core/Constraint.scala

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ package core
44

55
import Types._, Contexts._
66
import printing.Showable
7+
import util.SimpleIdentityMap
78

89
/** Constraint over undetermined type parameters. Constraints are built
910
* over values of the following types:
@@ -128,7 +129,7 @@ abstract class Constraint extends Showable {
128129

129130
/** Is `tv` marked as hard in the constraint? */
130131
def isHard(tv: TypeVar): Boolean
131-
132+
132133
/** The same as this constraint, but with `tv` marked as hard. */
133134
def withHard(tv: TypeVar)(using Context): This
134135

@@ -165,6 +166,31 @@ abstract class Constraint extends Showable {
165166
*/
166167
def hasConflictingTypeVarsFor(tl: TypeLambda, that: Constraint): Boolean
167168

169+
/** A map that associates type variables with all other type variables that
170+
* refer to them in their bounds covariantly, such that, if the type variable
171+
* is isntantiated to a larger type, the constraint would be narrowed.
172+
*/
173+
def coDeps: Constraint.TypeVarDeps
174+
175+
/** A map that associates type variables with all other type variables that
176+
* refer to them in their bounds covariantly, such that, if the type variable
177+
* is isntantiated to a smaller type, the constraint would be narrowed.
178+
*/
179+
def contraDeps: Constraint.TypeVarDeps
180+
181+
/** A string showing the `coDeps` and `contraDeps` maps */
182+
def depsToString(using Context): String
183+
184+
/** If the type variable `tv` would be instantiated to a larger type, the
185+
* constraint over outer type variables would be narrowed.
186+
*/
187+
def outerDependsCovariantlyOn(tv: TypeVar)(using Context): Boolean
188+
189+
/** If the type variable `tv` would be instantiated to a smaller type, the
190+
* constraint over outer type variables would be narrowed.
191+
*/
192+
def outerDependsContravariantlyOn(tv: TypeVar)(using Context): Boolean
193+
168194
/** Check that no constrained parameter contains itself as a bound */
169195
def checkNonCyclic()(using Context): this.type
170196

@@ -183,6 +209,10 @@ abstract class Constraint extends Showable {
183209
def checkConsistentVars()(using Context): Unit
184210
}
185211

212+
object Constraint:
213+
type TypeVarDeps = SimpleIdentityMap[TypeVar, TypeVars]
214+
215+
186216
/** When calling `Constraint#addLess(p1, p2, ...)`, the caller might end up
187217
* unifying one parameter with the other, this enum lets `addLess` know which
188218
* direction the unification will take.

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 136 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ import printing.Texts._
1010
import config.Config
1111
import config.Printers.constr
1212
import reflect.ClassTag
13+
import Constraint.TypeVarDeps
14+
import NameKinds.DepParamName
1315
import annotation.tailrec
1416
import annotation.internal.sharable
1517
import cc.{CapturingType, derivedCapturingType}
@@ -148,6 +150,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
148150
else
149151
val result = new OrderingConstraint(boundsMap, lowerMap, upperMap, hardVars)
150152
if ctx.run != null then ctx.run.nn.recordConstraintSize(result, result.boundsMap.size)
153+
result.coDeps = this.coDeps
154+
result.contraDeps = this.contraDeps
151155
result
152156

153157
// ----------- Basic indices --------------------------------------------------
@@ -219,6 +223,124 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
219223
if tvar == null then NoType
220224
else tvar
221225

226+
// ------------- TypeVar dependencies -----------------------------------
227+
228+
var coDeps, contraDeps: TypeVarDeps = SimpleIdentityMap.empty
229+
230+
private def outerDependsOn(tv: TypeVar, nestingLevel: Int, deps: TypeVarDeps, lens: ConstraintLens[List[TypeParamRef]], seen: TypeVars)(using Context): Boolean =
231+
232+
def qualifies(tp: Type): Boolean = tp match
233+
case tv1: TypeVar =>
234+
tv1.nestingLevel < tv.nestingLevel
235+
|| tv1.nestingLevel == tv.nestingLevel && tv.origin.paramName.is(DepParamName)
236+
// dependent parameters get created at the same nesting level as the current constraint,
237+
// but are logically nested inside the variables of the expected result type.
238+
|| outerDependsOn(tv1, nestingLevel, deps, lens, seen + tv)
239+
case tp: TypeParamRef =>
240+
qualifies(typeVarOfParam(tp))
241+
case _ =>
242+
false
243+
244+
val tvdeps = deps(tv)
245+
null != tvdeps
246+
&& !seen.contains(tv)
247+
&& (tvdeps.exists(qualifies)
248+
|| lens(this, tv.origin.binder, tv.origin.paramNum).exists(qualifies))
249+
//.showing(i"outer depends on $tv ${tv.initNestingLevel} with ${tvdeps.toList.map(_.initNestingLevel)}%, % = $result")
250+
end outerDependsOn
251+
252+
def outerDependsCovariantlyOn(tv: TypeVar)(using Context) =
253+
outerDependsOn(tv, tv.nestingLevel, coDeps, upperLens, SimpleIdentitySet.empty)
254+
255+
def outerDependsContravariantlyOn(tv: TypeVar)(using Context) =
256+
outerDependsOn(tv, tv.nestingLevel, contraDeps, lowerLens, SimpleIdentitySet.empty)
257+
258+
private class Adjuster(tvar: TypeVar)(using Context) extends TypeTraverser:
259+
var add: Boolean = compiletime.uninitialized
260+
261+
def update(deps: TypeVarDeps, referenced: TypeVar): TypeVarDeps =
262+
val entry = deps(referenced)
263+
val prev = if null == entry then SimpleIdentitySet.empty else entry
264+
val now = if add then prev + tvar else prev - tvar
265+
deps.updated(referenced, now)
266+
267+
def traverse(t: Type) = t match
268+
case tv: TypeVar =>
269+
val inst = tv.instanceOpt
270+
if inst.exists then traverse(inst)
271+
else
272+
if variance >= 0 then coDeps = update(coDeps, tv)
273+
if variance <= 0 then contraDeps = update(contraDeps, tv)
274+
case param: TypeParamRef =>
275+
traverse(typeVarOfParam(param))
276+
case tp: LazyRef if !tp.completed =>
277+
case _ =>
278+
traverseChildren(t)
279+
280+
def adjustDeps(entry: Type | Null, prevEntry: Type | Null, tvar: Type | Null)(using Context): this.type =
281+
tvar match
282+
case tvar: TypeVar =>
283+
val adjuster = new Adjuster(tvar)
284+
285+
def adjustReferenced(bound: Type, isLower: Boolean, add: Boolean) =
286+
adjuster.variance = if isLower then 1 else -1
287+
adjuster.add = add
288+
adjuster.traverse(bound)
289+
290+
def adjustDelta(bound: Type, prevBound: Type, isLower: Boolean): Boolean =
291+
if bound eq prevBound then true
292+
else bound match
293+
case bound: AndOrType =>
294+
adjustDelta(bound.tp1, prevBound, isLower) && {
295+
adjustReferenced(bound.tp2, isLower, add = true)
296+
true
297+
}
298+
case _ => false
299+
300+
def adjustBounds(bound: Type, prevBound: Type, isLower: Boolean) =
301+
if !adjustDelta(bound, prevBound, isLower) then
302+
adjustReferenced(prevBound, isLower, add = false)
303+
adjustReferenced(bound, isLower, add = true)
304+
305+
entry match
306+
case TypeBounds(lo, hi) =>
307+
prevEntry match
308+
case TypeBounds(plo, phi) =>
309+
adjustBounds(lo, plo, isLower = true)
310+
adjustBounds(hi, phi, isLower = false)
311+
case _ =>
312+
adjustReferenced(lo, isLower = true, add = true)
313+
adjustReferenced(hi, isLower = false, add = true)
314+
case _ =>
315+
prevEntry match
316+
case TypeBounds(plo, phi) =>
317+
adjustReferenced(plo, isLower = true, add = false)
318+
adjustReferenced(phi, isLower = false, add = false)
319+
case _ =>
320+
dropDeps(tvar)
321+
case _ =>
322+
this
323+
end adjustDeps
324+
325+
def adjustDeps(entries: Array[Type], add: Boolean)(using Context): this.type =
326+
for n <- 0 until paramCount(entries) do
327+
if add
328+
then adjustDeps(entries(n), NoType, typeVar(entries, n))
329+
else adjustDeps(NoType, entries(n), typeVar(entries, n))
330+
this
331+
332+
def dropDeps(tp: Type)(using Context): Unit = tp match
333+
case tv: TypeVar =>
334+
coDeps = coDeps.remove(tv)
335+
contraDeps = contraDeps.remove(tv)
336+
case _ =>
337+
338+
def depsToString(using Context): String =
339+
def depsStr(deps: SimpleIdentityMap[TypeVar, TypeVars]): String =
340+
def depStr(tv: TypeVar) = i"$tv --> ${deps(tv).nn.toList}%, %"
341+
if deps.isEmpty then "" else i"\n ${deps.toList.map((k, v) => depStr(k))}%\n %"
342+
i"co-deps:${depsStr(coDeps)}\ncontra-deps:${depsStr(contraDeps)}\n"
343+
222344
// ---------- Adding TypeLambdas --------------------------------------------------
223345

224346
/** The bound type `tp` without constrained parameters which are clearly
@@ -286,6 +408,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
286408
tvars.copyToArray(entries1, nparams)
287409
newConstraint(boundsMap = this.boundsMap.updated(poly, entries1))
288410
.init(poly)
411+
.adjustDeps(entries1, add = true)
289412
}
290413

291414
/** Split dependent parameters off the bounds for parameters in `poly`.
@@ -432,6 +555,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
432555
private def updateEntry(current: This, param: TypeParamRef, tp: Type)(using Context): This = {
433556
if Config.checkNoWildcardsInConstraint then assert(!tp.containsWildcardTypes)
434557
var current1 = boundsLens.update(this, current, param, tp)
558+
current1.adjustDeps(tp, current.entry(param), typeVarOfParam(param))
435559
tp match {
436560
case TypeBounds(lo, hi) =>
437561
for p <- dependentParams(lo, isUpper = false) do
@@ -471,10 +595,15 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
471595
current.ensureNonCyclic(atPoly.paramRefs(atIdx), tp.substParam(param, replacement))
472596

473597
current.foreachParam { (p, i) =>
474-
current = boundsLens.map(this, current, p, i, replaceParam(_, p, i))
598+
current = boundsLens.map(this, current, p, i,
599+
entry =>
600+
val newEntry = replaceParam(entry, p, i)
601+
adjustDeps(newEntry, entry, typeVar(this.boundsMap(p).nn, i))
602+
newEntry)
475603
current = lowerLens.map(this, current, p, i, removeParam)
476604
current = upperLens.map(this, current, p, i, removeParam)
477605
}
606+
current.dropDeps(typeVarOfParam(param))
478607
current.checkNonCyclic()
479608
end replace
480609

@@ -489,6 +618,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
489618
val hardVars1 = pt.paramRefs.foldLeft(hardVars)((hvs, param) => hvs - typeVarOfParam(param))
490619
newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap), hardVars1)
491620
.checkNonCyclic()
621+
.adjustDeps(boundsMap(pt).nn, add = false)
492622
}
493623

494624
def isRemovable(pt: TypeLambda): Boolean = {
@@ -666,13 +796,16 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
666796
val constrainedText =
667797
" constrained types = " + domainLambdas.mkString("\n")
668798
val boundsText =
669-
" bounds = " + {
799+
"\n bounds = " + {
670800
val assocs =
671801
for (param <- domainParams)
672802
yield
673803
s"${param.binder.paramNames(param.paramNum)}: ${entryText(entry(param))}"
674804
assocs.mkString("\n")
675805
}
676-
constrainedText + "\n" + boundsText
806+
val depsText =
807+
"\n coDeps = " + coDeps +
808+
"\n contraDeps = " + contraDeps
809+
constrainedText + boundsText + depsText
677810
}
678811
}

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -688,8 +688,9 @@ class PlainPrinter(_ctx: Context) extends Printer {
688688
Text(ups.map(toText), ", ")
689689
Text(deps, "\n")
690690
}
691+
val depsText = if Config.showConstraintDeps then c.depsToString else ""
691692
//Printer.debugPrintUnique = false
692-
Text.lines(List(uninstVarsText, constrainedText, boundsText, orderingText))
693+
Text.lines(List(uninstVarsText, constrainedText, boundsText, orderingText, depsText))
693694
finally
694695
ctx.typerState.constraint = savedConstraint
695696

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,8 +639,11 @@ trait Inferencing { this: Typer =>
639639
// Otherwise, we simply rely on `hasLowerBound`.
640640
val name = tvar.origin.paramName
641641
val fromBelow =
642-
name.is(AvoidNameKind.UpperBound) ||
643-
!name.is(AvoidNameKind.LowerBound) && tvar.hasLowerBound
642+
if name.is(AvoidNameKind.UpperBound) then true
643+
else if name.is(AvoidNameKind.LowerBound) then false
644+
else if constraint.outerDependsCovariantlyOn(tvar) then true
645+
else if constraint.outerDependsContravariantlyOn(tvar) then false
646+
else tvar.hasLowerBound
644647
typr.println(i"interpolate non-occurring $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
645648
toInstantiate += ((tvar, fromBelow))
646649
else if v.intValue != 0 then

tests/pos/i15864.scala

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
object Test:
2+
def op[O, P](ta: List[O], tb: List[P]): List[P] = ???
3+
4+
class Graph { class Node }
5+
6+
def outsQ(using g: Graph): List[List[g.Node]] = ???
7+
8+
object aGraph extends Graph
9+
given implA: aGraph.type = aGraph
10+
11+
val q1: List[List[aGraph.Node]] = op(outsQ, op(outsQ, outsQ))
12+
implicitly[q1.type <:< List[List[aGraph.Node]]]
13+
14+
val a1 = outsQ
15+
val a2 = op(outsQ, outsQ)
16+
val a3 = op(a1, a2)
17+
18+
val q2 = op(outsQ, op(outsQ, outsQ))
19+
val q3: List[List[aGraph.Node]] = q2
20+
implicitly[q2.type <:< List[List[aGraph.Node]]]
21+
22+

0 commit comments

Comments
 (0)