@@ -10,6 +10,8 @@ import printing.Texts._
10
10
import config .Config
11
11
import config .Printers .constr
12
12
import reflect .ClassTag
13
+ import Constraint .TypeVarDeps
14
+ import NameKinds .DepParamName
13
15
import annotation .tailrec
14
16
import annotation .internal .sharable
15
17
import cc .{CapturingType , derivedCapturingType }
@@ -148,6 +150,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
148
150
else
149
151
val result = new OrderingConstraint (boundsMap, lowerMap, upperMap, hardVars)
150
152
if ctx.run != null then ctx.run.nn.recordConstraintSize(result, result.boundsMap.size)
153
+ result.coDeps = this .coDeps
154
+ result.contraDeps = this .contraDeps
151
155
result
152
156
153
157
// ----------- Basic indices --------------------------------------------------
@@ -219,6 +223,127 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
219
223
if tvar == null then NoType
220
224
else tvar
221
225
226
+ // ------------- TypeVar dependencies -----------------------------------
227
+
228
+ var coDeps, contraDeps : TypeVarDeps = SimpleIdentityMap .empty
229
+
230
+ def dependsOn (tv : TypeVar , except : TypeVars , co : Boolean )(using Context ): Boolean =
231
+ def test (deps : TypeVarDeps , lens : ConstraintLens [List [TypeParamRef ]]) =
232
+ val tvdeps = deps(tv)
233
+ null != tvdeps && tvdeps.exists(! except.contains(_))
234
+ || lens(this , tv.origin.binder, tv.origin.paramNum).exists(
235
+ p => ! except.contains(typeVarOfParam(p)))
236
+ // .showing(i"outer depends on $tv with ${tvdeps.toList}%, % = $result")
237
+ if co then test(coDeps, upperLens) else test(contraDeps, lowerLens)
238
+
239
+ private class Adjuster (tvar : TypeVar )(using Context ) extends TypeTraverser :
240
+ var add : Boolean = compiletime.uninitialized
241
+
242
+ def update (deps : TypeVarDeps , referenced : TypeVar ): TypeVarDeps =
243
+ val entry = deps(referenced)
244
+ val prev = if null == entry then SimpleIdentitySet .empty else entry
245
+ val now = if add then prev + tvar else prev - tvar
246
+ deps.updated(referenced, now)
247
+
248
+ def traverse (t : Type ) = t match
249
+ case tv : TypeVar =>
250
+ val inst = tv.instanceOpt
251
+ if inst.exists then traverse(inst)
252
+ else
253
+ if variance >= 0 then coDeps = update(coDeps, tv)
254
+ if variance <= 0 then contraDeps = update(contraDeps, tv)
255
+ case param : TypeParamRef =>
256
+ traverse(typeVarOfParam(param))
257
+ case tp : LazyRef if ! tp.completed =>
258
+ case _ =>
259
+ traverseChildren(t)
260
+
261
+ /** Adjust dependencies to account for the delta of previous entry `prevEntry`
262
+ * and new bound `entry` for the type variable `tvar`.
263
+ */
264
+ def adjustDeps (entry : Type | Null , prevEntry : Type | Null , tvar : Type | Null )(using Context ): this .type =
265
+ tvar match
266
+ case tvar : TypeVar =>
267
+ val adjuster = new Adjuster (tvar)
268
+
269
+ /** Adjust reverse depemdencies of all type variables referenced by `bound`
270
+ * @param isLower `bound` is a lower bound
271
+ * @param add if true, add referenced variables to dependencoes, otherwise drop them.
272
+ */
273
+ def adjustReferenced (bound : Type , isLower : Boolean , add : Boolean ) =
274
+ adjuster.variance = if isLower then 1 else - 1
275
+ adjuster.add = add
276
+ adjuster.traverse(bound)
277
+
278
+ /** Use an optimized strategy to adjust dependencies to account for the delta
279
+ * of previous bound `prevBound` and new bound `bound`: If `prevBound` is some
280
+ * and/or prefix of `bound`, just add the new parts of `bound`.
281
+ * @param isLower `bound` and `prevBound` are lower bounds
282
+ */
283
+ def adjustDelta (bound : Type , prevBound : Type , isLower : Boolean ): Boolean =
284
+ if bound eq prevBound then true
285
+ else bound match
286
+ case bound : AndOrType =>
287
+ adjustDelta(bound.tp1, prevBound, isLower) && {
288
+ adjustReferenced(bound.tp2, isLower, add = true )
289
+ true
290
+ }
291
+ case _ => false
292
+
293
+ /** Adjust dependencies to account for the delta of previous bound `prevBound`
294
+ * and new bound `bound`.
295
+ * @param isLower `bound` and `prevBound` are lower bounds
296
+ */
297
+ def adjustBounds (bound : Type , prevBound : Type , isLower : Boolean ) =
298
+ if ! adjustDelta(bound, prevBound, isLower) then
299
+ adjustReferenced(prevBound, isLower, add = false )
300
+ adjustReferenced(bound, isLower, add = true )
301
+
302
+ entry match
303
+ case TypeBounds (lo, hi) =>
304
+ prevEntry match
305
+ case TypeBounds (plo, phi) =>
306
+ adjustBounds(lo, plo, isLower = true )
307
+ adjustBounds(hi, phi, isLower = false )
308
+ case _ =>
309
+ adjustReferenced(lo, isLower = true , add = true )
310
+ adjustReferenced(hi, isLower = false , add = true )
311
+ case _ =>
312
+ prevEntry match
313
+ case TypeBounds (plo, phi) =>
314
+ adjustReferenced(plo, isLower = true , add = false )
315
+ adjustReferenced(phi, isLower = false , add = false )
316
+ case _ =>
317
+ dropDeps(tvar)
318
+ case _ =>
319
+ this
320
+ end adjustDeps
321
+
322
+ /** Adjust dependencies to account for adding or dropping `entries` to the
323
+ * constraint.
324
+ * @param add if true, entries is added, otherwise it is dropped
325
+ */
326
+ def adjustDeps (entries : Array [Type ], add : Boolean )(using Context ): this .type =
327
+ for n <- 0 until paramCount(entries) do
328
+ if add
329
+ then adjustDeps(entries(n), NoType , typeVar(entries, n))
330
+ else adjustDeps(NoType , entries(n), typeVar(entries, n))
331
+ this
332
+
333
+ /** If `tp` is a type variable, remove all its reverse dependencies */
334
+ def dropDeps (tp : Type )(using Context ): Unit = tp match
335
+ case tv : TypeVar =>
336
+ coDeps = coDeps.remove(tv)
337
+ contraDeps = contraDeps.remove(tv)
338
+ case _ =>
339
+
340
+ /** A string representing the two depenecy maps */
341
+ def depsToString (using Context ): String =
342
+ def depsStr (deps : SimpleIdentityMap [TypeVar , TypeVars ]): String =
343
+ def depStr (tv : TypeVar ) = i " $tv --> ${deps(tv).nn.toList}%, % "
344
+ if deps.isEmpty then " " else i " \n ${deps.toList.map((k, v) => depStr(k))}% \n % "
345
+ i " co-deps: ${depsStr(coDeps)}\n contra-deps: ${depsStr(contraDeps)}\n "
346
+
222
347
// ---------- Adding TypeLambdas --------------------------------------------------
223
348
224
349
/** The bound type `tp` without constrained parameters which are clearly
@@ -286,6 +411,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
286
411
tvars.copyToArray(entries1, nparams)
287
412
newConstraint(boundsMap = this .boundsMap.updated(poly, entries1))
288
413
.init(poly)
414
+ .adjustDeps(entries1, add = true )
289
415
}
290
416
291
417
/** Split dependent parameters off the bounds for parameters in `poly`.
@@ -432,6 +558,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
432
558
private def updateEntry (current : This , param : TypeParamRef , tp : Type )(using Context ): This = {
433
559
if Config .checkNoWildcardsInConstraint then assert(! tp.containsWildcardTypes)
434
560
var current1 = boundsLens.update(this , current, param, tp)
561
+ current1.adjustDeps(tp, current.entry(param), typeVarOfParam(param))
435
562
tp match {
436
563
case TypeBounds (lo, hi) =>
437
564
for p <- dependentParams(lo, isUpper = false ) do
@@ -471,10 +598,15 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
471
598
current.ensureNonCyclic(atPoly.paramRefs(atIdx), tp.substParam(param, replacement))
472
599
473
600
current.foreachParam { (p, i) =>
474
- current = boundsLens.map(this , current, p, i, replaceParam(_, p, i))
601
+ current = boundsLens.map(this , current, p, i,
602
+ entry =>
603
+ val newEntry = replaceParam(entry, p, i)
604
+ adjustDeps(newEntry, entry, typeVar(this .boundsMap(p).nn, i))
605
+ newEntry)
475
606
current = lowerLens.map(this , current, p, i, removeParam)
476
607
current = upperLens.map(this , current, p, i, removeParam)
477
608
}
609
+ current.dropDeps(typeVarOfParam(param))
478
610
current.checkNonCyclic()
479
611
end replace
480
612
@@ -489,6 +621,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
489
621
val hardVars1 = pt.paramRefs.foldLeft(hardVars)((hvs, param) => hvs - typeVarOfParam(param))
490
622
newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap), hardVars1)
491
623
.checkNonCyclic()
624
+ .adjustDeps(boundsMap(pt).nn, add = false )
492
625
}
493
626
494
627
def isRemovable (pt : TypeLambda ): Boolean = {
@@ -666,13 +799,16 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
666
799
val constrainedText =
667
800
" constrained types = " + domainLambdas.mkString(" \n " )
668
801
val boundsText =
669
- " bounds = " + {
802
+ " \n bounds = " + {
670
803
val assocs =
671
804
for (param <- domainParams)
672
805
yield
673
806
s " ${param.binder.paramNames(param.paramNum)}: ${entryText(entry(param))}"
674
807
assocs.mkString(" \n " )
675
808
}
676
- constrainedText + " \n " + boundsText
809
+ val depsText =
810
+ " \n coDeps = " + coDeps +
811
+ " \n contraDeps = " + contraDeps
812
+ constrainedText + boundsText + depsText
677
813
}
678
814
}
0 commit comments