@@ -10,6 +10,8 @@ import printing.Texts._
10
10
import config .Config
11
11
import config .Printers .constr
12
12
import reflect .ClassTag
13
+ import Constraint .TypeVarDeps
14
+ import NameKinds .DepParamName
13
15
import annotation .tailrec
14
16
import annotation .internal .sharable
15
17
import cc .{CapturingType , derivedCapturingType }
@@ -148,6 +150,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
148
150
else
149
151
val result = new OrderingConstraint (boundsMap, lowerMap, upperMap, hardVars)
150
152
if ctx.run != null then ctx.run.nn.recordConstraintSize(result, result.boundsMap.size)
153
+ result.coDeps = this .coDeps
154
+ result.contraDeps = this .contraDeps
151
155
result
152
156
153
157
// ----------- Basic indices --------------------------------------------------
@@ -219,6 +223,124 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
219
223
if tvar == null then NoType
220
224
else tvar
221
225
226
+ // ------------- TypeVar dependencies -----------------------------------
227
+
228
+ var coDeps, contraDeps : TypeVarDeps = SimpleIdentityMap .empty
229
+
230
+ private def outerDependsOn (tv : TypeVar , nestingLevel : Int , deps : TypeVarDeps , lens : ConstraintLens [List [TypeParamRef ]], seen : TypeVars )(using Context ): Boolean =
231
+
232
+ def qualifies (tp : Type ): Boolean = tp match
233
+ case tv1 : TypeVar =>
234
+ tv1.nestingLevel < tv.nestingLevel
235
+ || tv1.nestingLevel == tv.nestingLevel && tv.origin.paramName.is(DepParamName )
236
+ // dependent parameters get created at the same nesting level as the current constraint,
237
+ // but are logically nested inside the variables of the expected result type.
238
+ || outerDependsOn(tv1, nestingLevel, deps, lens, seen + tv)
239
+ case tp : TypeParamRef =>
240
+ qualifies(typeVarOfParam(tp))
241
+ case _ =>
242
+ false
243
+
244
+ val tvdeps = deps(tv)
245
+ null != tvdeps
246
+ && ! seen.contains(tv)
247
+ && (tvdeps.exists(qualifies)
248
+ || lens(this , tv.origin.binder, tv.origin.paramNum).exists(qualifies))
249
+ // .showing(i"outer depends on $tv ${tv.initNestingLevel} with ${tvdeps.toList.map(_.initNestingLevel)}%, % = $result")
250
+ end outerDependsOn
251
+
252
+ def outerDependsCovariantlyOn (tv : TypeVar )(using Context ) =
253
+ outerDependsOn(tv, tv.nestingLevel, coDeps, upperLens, SimpleIdentitySet .empty)
254
+
255
+ def outerDependsContravariantlyOn (tv : TypeVar )(using Context ) =
256
+ outerDependsOn(tv, tv.nestingLevel, contraDeps, lowerLens, SimpleIdentitySet .empty)
257
+
258
+ private class Adjuster (tvar : TypeVar )(using Context ) extends TypeTraverser :
259
+ var add : Boolean = compiletime.uninitialized
260
+
261
+ def update (deps : TypeVarDeps , referenced : TypeVar ): TypeVarDeps =
262
+ val entry = deps(referenced)
263
+ val prev = if null == entry then SimpleIdentitySet .empty else entry
264
+ val now = if add then prev + tvar else prev - tvar
265
+ deps.updated(referenced, now)
266
+
267
+ def traverse (t : Type ) = t match
268
+ case tv : TypeVar =>
269
+ val inst = tv.instanceOpt
270
+ if inst.exists then traverse(inst)
271
+ else
272
+ if variance >= 0 then coDeps = update(coDeps, tv)
273
+ if variance <= 0 then contraDeps = update(contraDeps, tv)
274
+ case param : TypeParamRef =>
275
+ traverse(typeVarOfParam(param))
276
+ case tp : LazyRef if ! tp.completed =>
277
+ case _ =>
278
+ traverseChildren(t)
279
+
280
+ def adjustDeps (entry : Type | Null , prevEntry : Type | Null , tvar : Type | Null )(using Context ): this .type =
281
+ tvar match
282
+ case tvar : TypeVar =>
283
+ val adjuster = new Adjuster (tvar)
284
+
285
+ def adjustReferenced (bound : Type , isLower : Boolean , add : Boolean ) =
286
+ adjuster.variance = if isLower then 1 else - 1
287
+ adjuster.add = add
288
+ adjuster.traverse(bound)
289
+
290
+ def adjustDelta (bound : Type , prevBound : Type , isLower : Boolean ): Boolean =
291
+ if bound eq prevBound then true
292
+ else bound match
293
+ case bound : AndOrType =>
294
+ adjustDelta(bound.tp1, prevBound, isLower) && {
295
+ adjustReferenced(bound.tp2, isLower, add = true )
296
+ true
297
+ }
298
+ case _ => false
299
+
300
+ def adjustBounds (bound : Type , prevBound : Type , isLower : Boolean ) =
301
+ if ! adjustDelta(bound, prevBound, isLower) then
302
+ adjustReferenced(prevBound, isLower, add = false )
303
+ adjustReferenced(bound, isLower, add = true )
304
+
305
+ entry match
306
+ case TypeBounds (lo, hi) =>
307
+ prevEntry match
308
+ case TypeBounds (plo, phi) =>
309
+ adjustBounds(lo, plo, isLower = true )
310
+ adjustBounds(hi, phi, isLower = false )
311
+ case _ =>
312
+ adjustReferenced(lo, isLower = true , add = true )
313
+ adjustReferenced(hi, isLower = false , add = true )
314
+ case _ =>
315
+ prevEntry match
316
+ case TypeBounds (plo, phi) =>
317
+ adjustReferenced(plo, isLower = true , add = false )
318
+ adjustReferenced(phi, isLower = false , add = false )
319
+ case _ =>
320
+ dropDeps(tvar)
321
+ case _ =>
322
+ this
323
+ end adjustDeps
324
+
325
+ def adjustDeps (entries : Array [Type ], add : Boolean )(using Context ): this .type =
326
+ for n <- 0 until paramCount(entries) do
327
+ if add
328
+ then adjustDeps(entries(n), NoType , typeVar(entries, n))
329
+ else adjustDeps(NoType , entries(n), typeVar(entries, n))
330
+ this
331
+
332
+ def dropDeps (tp : Type )(using Context ): Unit = tp match
333
+ case tv : TypeVar =>
334
+ coDeps = coDeps.remove(tv)
335
+ contraDeps = contraDeps.remove(tv)
336
+ case _ =>
337
+
338
+ def depsToString (using Context ): String =
339
+ def depsStr (deps : SimpleIdentityMap [TypeVar , TypeVars ]): String =
340
+ def depStr (tv : TypeVar ) = i " $tv --> ${deps(tv).nn.toList}%, % "
341
+ if deps.isEmpty then " " else i " \n ${deps.toList.map((k, v) => depStr(k))}% \n % "
342
+ i " co-deps: ${depsStr(coDeps)}\n contra-deps: ${depsStr(contraDeps)}\n "
343
+
222
344
// ---------- Adding TypeLambdas --------------------------------------------------
223
345
224
346
/** The bound type `tp` without constrained parameters which are clearly
@@ -286,6 +408,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
286
408
tvars.copyToArray(entries1, nparams)
287
409
newConstraint(boundsMap = this .boundsMap.updated(poly, entries1))
288
410
.init(poly)
411
+ .adjustDeps(entries1, add = true )
289
412
}
290
413
291
414
/** Split dependent parameters off the bounds for parameters in `poly`.
@@ -432,6 +555,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
432
555
private def updateEntry (current : This , param : TypeParamRef , tp : Type )(using Context ): This = {
433
556
if Config .checkNoWildcardsInConstraint then assert(! tp.containsWildcardTypes)
434
557
var current1 = boundsLens.update(this , current, param, tp)
558
+ current1.adjustDeps(tp, current.entry(param), typeVarOfParam(param))
435
559
tp match {
436
560
case TypeBounds (lo, hi) =>
437
561
for p <- dependentParams(lo, isUpper = false ) do
@@ -471,10 +595,15 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
471
595
current.ensureNonCyclic(atPoly.paramRefs(atIdx), tp.substParam(param, replacement))
472
596
473
597
current.foreachParam { (p, i) =>
474
- current = boundsLens.map(this , current, p, i, replaceParam(_, p, i))
598
+ current = boundsLens.map(this , current, p, i,
599
+ entry =>
600
+ val newEntry = replaceParam(entry, p, i)
601
+ adjustDeps(newEntry, entry, typeVar(this .boundsMap(p).nn, i))
602
+ newEntry)
475
603
current = lowerLens.map(this , current, p, i, removeParam)
476
604
current = upperLens.map(this , current, p, i, removeParam)
477
605
}
606
+ current.dropDeps(typeVarOfParam(param))
478
607
current.checkNonCyclic()
479
608
end replace
480
609
@@ -489,6 +618,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
489
618
val hardVars1 = pt.paramRefs.foldLeft(hardVars)((hvs, param) => hvs - typeVarOfParam(param))
490
619
newConstraint(boundsMap.remove(pt), removeFromOrdering(lowerMap), removeFromOrdering(upperMap), hardVars1)
491
620
.checkNonCyclic()
621
+ .adjustDeps(boundsMap(pt).nn, add = false )
492
622
}
493
623
494
624
def isRemovable (pt : TypeLambda ): Boolean = {
@@ -666,13 +796,16 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
666
796
val constrainedText =
667
797
" constrained types = " + domainLambdas.mkString(" \n " )
668
798
val boundsText =
669
- " bounds = " + {
799
+ " \n bounds = " + {
670
800
val assocs =
671
801
for (param <- domainParams)
672
802
yield
673
803
s " ${param.binder.paramNames(param.paramNum)}: ${entryText(entry(param))}"
674
804
assocs.mkString(" \n " )
675
805
}
676
- constrainedText + " \n " + boundsText
806
+ val depsText =
807
+ " \n coDeps = " + coDeps +
808
+ " \n contraDeps = " + contraDeps
809
+ constrainedText + boundsText + depsText
677
810
}
678
811
}
0 commit comments