Skip to content

Commit b58205b

Browse files
committed
Add capture set constrains
- Introduce monadic capture set constraint handling - Change CapturingType to take a capture set instead of a single capture refernce
1 parent f66e902 commit b58205b

14 files changed

+213
-109
lines changed

compiler/src/dotty/tools/dotc/core/CaptureSet.scala

Lines changed: 137 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,74 +11,167 @@ import reporting.trace
1111
import printing.{Showable, Printer}
1212
import printing.Texts.*
1313

14-
case class CaptureSet private (elems: CaptureSet.Refs) extends Showable:
14+
/** A class for capture sets. Capture sets can be constants or variables.
15+
*/
16+
sealed abstract class CaptureSet extends Showable:
1517
import CaptureSet.*
1618

17-
def isEmpty: Boolean = elems.isEmpty
19+
/** The elements of this capture set. For capture variables,
20+
* the elements known so far.
21+
*/
22+
def elems: Refs
23+
24+
/** Is this capture set constant (i.e. not a capture variable)?
25+
*/
26+
def isConst: Boolean
27+
28+
/** Is this capture set (always) empty? For capture veraiables, returns
29+
* always false
30+
*/
31+
def isEmpty: Boolean
1832
def nonEmpty: Boolean = !isEmpty
1933

20-
private var myClosure: Refs | Null = null
21-
22-
def closure(using Context): Refs =
23-
if myClosure == null then
24-
var cl = elems
25-
var seen: Refs = SimpleIdentitySet.empty
26-
while
27-
val prev = cl
28-
for ref <- cl do
29-
if !seen.contains(ref) then
30-
seen += ref
31-
cl = cl ++ ref.captureSetOfInfo.elems
32-
prev ne cl
33-
do ()
34-
myClosure = cl
35-
myClosure
36-
37-
def ++ (that: CaptureSet): CaptureSet =
38-
if this.isEmpty then that
39-
else if that.isEmpty then this
40-
else CaptureSet(elems ++ that.elems)
41-
42-
def + (ref: CaptureRef) =
43-
if elems.contains(ref) then this
44-
else CaptureSet(elems + ref)
45-
46-
def intersect (that: CaptureSet): CaptureSet =
47-
CaptureSet(this.elems.intersect(that.elems))
34+
/** Add new elements to this capture set if allowed.
35+
* @pre `newElems` is not empty and does not overlap with `this.elems`.
36+
* Constant capture sets never allow to add new elements.
37+
* Variables allow it if and only if the new elements can be included
38+
* in all their supersets.
39+
* @return true iff elements were added
40+
*/
41+
protected def addNewElems(newElems: Refs)(using Context): Boolean
42+
43+
/** If this is a variable, add `cs` as a super set */
44+
protected def addSuper(cs: CaptureSet): this.type
45+
46+
/** If `cs` is a variable, add this capture set as one of its super sets */
47+
protected def addSub(cs: CaptureSet): this.type =
48+
cs.addSuper(this)
49+
this
50+
51+
/** Try to include all references of `elems` that are not yet accounted by this
52+
* capture set. Inclusion is via `addElems`.
53+
* @return true iff elements were added
54+
*/
55+
protected def tryInclude(elems: Refs)(using Context): Boolean =
56+
val unaccounted = elems.filter(!accountsFor(_))
57+
unaccounted.isEmpty || addNewElems(unaccounted)
4858

4959
/** {x} <:< this where <:< is subcapturing */
5060
def accountsFor(x: CaptureRef)(using Context) =
5161
elems.contains(x) || !x.isRootCapability && x.captureSetOfInfo <:< this
5262

5363
/** The subcapturing test */
5464
def <:< (that: CaptureSet)(using Context): Boolean =
55-
elems.isEmpty || elems.forall(that.accountsFor)
56-
65+
that.tryInclude(elems) && { addSuper(that); true }
66+
67+
/** The smallest capture set (via <:<) that is a superset of both
68+
* `this` and `that`
69+
*/
70+
def ++ (that: CaptureSet)(using Context): CaptureSet =
71+
if this.isConst && this.elems.forall(that.accountsFor) then that
72+
else if that.isConst && that.elems.forall(this.accountsFor) then this
73+
else if this.isConst && that.isConst then Const(this.elems ++ that.elems)
74+
else Var(this.elems ++ that.elems).addSub(this).addSub(that)
75+
76+
/** The smallest superset (via <:<) of this capture set that also contains `ref`.
77+
*/
78+
def + (ref: CaptureRef)(using Context) = ++ (ref.singletonCaptureSet)
79+
80+
/** The largest capture set (via <:<) that is a subset of both `this` and `that`
81+
*/
82+
def intersect(that: CaptureSet)(using Context): CaptureSet =
83+
if this.isConst && this.elems.forall(that.accountsFor) then this
84+
else if that.isConst && that.elems.forall(this.accountsFor) then that
85+
else if this.isConst && that.isConst then Const(this.elems.intersect(that.elems))
86+
else Var(this.elems.intersect(that.elems)).addSuper(this).addSuper(that)
87+
88+
/** capture set obtained by applying `f` to all elements of the current capture set
89+
* and joining the results. If the current capture set is a variable, the same
90+
* transformation is applied to all future additions of new elements.
91+
*/
5792
def flatMap(f: CaptureRef => CaptureSet)(using Context): CaptureSet =
58-
(empty /: elems)((cs, ref) => cs ++ f(ref))
93+
mapRefs(elems, f) match
94+
case cs: Const => cs
95+
case cs: Var => Mapped(cs, f)
5996

6097
def substParams(tl: BindingType, to: List[Type])(using Context) =
6198
flatMap {
6299
case ref: ParamRef if ref.binder eq tl => to(ref.paramNum).captureSet
63100
case ref => ref.singletonCaptureSet
64101
}
65102

66-
override def toString = elems.toString
103+
def toRetainsTypeArg(using Context): Type =
104+
assert(isConst)
105+
((NoType: Type) /: elems) ((tp, ref) =>
106+
if tp.exists then OrType(tp, ref, soft = false) else ref)
67107

68108
override def toText(printer: Printer): Text =
69109
Str("{") ~ Text(elems.toList.map(printer.toTextCaptureRef), ", ") ~ Str("}")
70110

71111
object CaptureSet:
72112
type Refs = SimpleIdentitySet[CaptureRef]
113+
type Vars = SimpleIdentitySet[Var]
114+
type Deps = SimpleIdentitySet[CaptureSet]
115+
116+
private val emptySet = SimpleIdentitySet.empty
117+
@sharable private var varId = 0
118+
119+
val empty: CaptureSet = Const(emptySet)
73120

74-
@sharable val empty: CaptureSet = CaptureSet(SimpleIdentitySet.empty)
121+
/** The universal capture set `{*}` */
122+
def universal(using Context): CaptureSet =
123+
defn.captureRootType.typeRef.singletonCaptureSet
75124

76125
/** Used as a recursion brake */
77-
@sharable private[core] val Pending = CaptureSet(SimpleIdentitySet.empty)
126+
@sharable private[core] val Pending = Const(SimpleIdentitySet.empty)
78127

79128
def apply(elems: CaptureRef*)(using Context): CaptureSet =
80129
if elems.isEmpty then empty
81-
else CaptureSet(SimpleIdentitySet(elems.map(_.normalizedRef)*))
130+
else Const(SimpleIdentitySet(elems.map(_.normalizedRef)*))
131+
132+
class Const private[CaptureSet] (val elems: Refs) extends CaptureSet:
133+
assert(elems != null)
134+
def isConst = true
135+
def isEmpty: Boolean = elems.isEmpty
136+
137+
def addNewElems(elems: Refs)(using Context): Boolean = false
138+
def addSuper(cs: CaptureSet) = this
139+
140+
override def toString = elems.toString
141+
end Const
142+
143+
class Var private[CaptureSet] (initialElems: Refs) extends CaptureSet:
144+
val id =
145+
varId += 1
146+
varId
147+
148+
var elems: Refs = initialElems
149+
var deps: Deps = emptySet
150+
def isConst = false
151+
def isEmpty = false
152+
153+
def addNewElems(newElems: Refs)(using Context): Boolean =
154+
deps.forall(_.tryInclude(newElems)) && { elems ++= newElems; true }
155+
156+
def addSuper(cs: CaptureSet) = { deps += cs; this }
157+
158+
override def toString = s"Var$id$elems"
159+
end Var
160+
161+
class Mapped private[CaptureSet] (cv: Var, f: CaptureRef => CaptureSet) extends Var(cv.elems):
162+
addSub(cv)
163+
164+
override def accountsFor(x: CaptureRef)(using Context): Boolean =
165+
f(x).elems.forall(super.accountsFor)
166+
167+
override def addNewElems(newElems: Refs)(using Context): Boolean =
168+
super.addNewElems(mapRefs(newElems, f).elems)
169+
170+
override def toString = s"Mapped$id$elems"
171+
end Mapped
172+
173+
def mapRefs(xs: Refs, f: CaptureRef => CaptureSet)(using Context): CaptureSet =
174+
(empty /: xs)((cs, x) => cs ++ f(x))
82175

83176
def ofClass(cinfo: ClassInfo, argTypes: List[Type])(using Context): CaptureSet =
84177
def captureSetOf(tp: Type): CaptureSet = tp match
@@ -105,8 +198,8 @@ object CaptureSet:
105198
tp.captureSet
106199
case tp: ParamRef =>
107200
tp.captureSet
108-
case CapturingType(parent, ref) =>
109-
recur(parent) + ref
201+
case CapturingType(parent, refs) =>
202+
recur(parent) ++ refs
110203
case AppliedType(tycon, args) =>
111204
val cs = recur(tycon)
112205
tycon.typeParams match
@@ -125,3 +218,8 @@ object CaptureSet:
125218
recur(tp)
126219
.showing(i"capture set of $tp = $result", capt)
127220

221+
def fromRetainsTypeArg(tp: Type)(using Context): CaptureSet = tp match
222+
case tp: CaptureRef => tp.singletonCaptureSet
223+
case OrType(tp1, tp2) => fromRetainsTypeArg(tp1) ++ fromRetainsTypeArg(tp2)
224+
225+
end CaptureSet

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ class Definitions {
264264
*/
265265
@tu lazy val AnyClass: ClassSymbol = completeClass(enterCompleteClassSymbol(ScalaPackageClass, tpnme.Any, Abstract, Nil), ensureCtor = false)
266266
def AnyType: TypeRef = AnyClass.typeRef
267-
@tu lazy val TopType: Type = CapturingType(AnyType, captureRootType.typeRef)
267+
@tu lazy val TopType: Type = CapturingType(AnyType, CaptureSet.universal)
268268
@tu lazy val MatchableClass: ClassSymbol = completeClass(enterCompleteClassSymbol(ScalaPackageClass, tpnme.Matchable, Trait, AnyType :: Nil), ensureCtor = false)
269269
def MatchableType: TypeRef = MatchableClass.typeRef
270270
@tu lazy val AnyValClass: ClassSymbol =

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
331331
if parent1 ne tp.parent then tp.derivedAnnotatedType(parent1, tp.annot) else tp
332332
case tp: CapturingType =>
333333
val parent1 = recur(tp.parent, fromBelow)
334-
if parent1 ne tp.parent then tp.derivedCapturingType(parent1, tp.ref) else tp
334+
if parent1 ne tp.parent then tp.derivedCapturingType(parent1, tp.refs) else tp
335335
case _ =>
336336
val tp1 = tp.dealiasKeepAnnots
337337
if tp1 ne tp then

compiler/src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2151,7 +2151,7 @@ object SymDenotations {
21512151
recur(TypeComparer.bounds(tp).hi)
21522152

21532153
case tp: CapturingType =>
2154-
tp.derivedCapturingType(recur(tp.parent), tp.ref)
2154+
tp.derivedCapturingType(recur(tp.parent), tp.refs)
21552155

21562156
case tp: TypeProxy =>
21572157
def computeTypeProxy = {

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
490490
// under -Ycheck. Test case is i7965.scala.
491491

492492
case tp1: CapturingType =>
493-
if tp2.captureSet.accountsFor(tp1.ref) then recur(tp1.parent, tp2)
493+
if tp1.refs <:< tp2.captureSet then recur(tp1.parent, tp2)
494494
else thirdTry
495495
case tp1: MatchType =>
496496
val reduced = tp1.reduced
@@ -818,7 +818,7 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
818818
// ---------------------------
819819
// E |- x: {x} T
820820
//
821-
CapturingType(tp2, defn.captureRootType.typeRef)
821+
CapturingType(tp2, CaptureSet.universal)
822822
case _ => tp2
823823
isSubType(tp1.underlying.widenExpr, tp2n, approx.addLow)
824824
}
@@ -2361,8 +2361,9 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
23612361
tp1.underlying & tp2
23622362
case tp1: AnnotatedType if !tp1.isRefining =>
23632363
tp1.underlying & tp2
2364-
case tp1: CapturingType if !tp2.captureSet.accountsFor(tp1.ref) =>
2365-
tp1.parent & tp2
2364+
case tp1: CapturingType =>
2365+
if tp2.captureSet <:< tp1.refs then tp1.parent & tp2
2366+
else tp1.derivedCapturingType(tp1.parent & tp2, tp1.refs)
23662367
case _ =>
23672368
NoType
23682369
}

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ object TypeOps:
169169
val normed = tp.tryNormalize
170170
if (normed.exists) normed else mapOver
171171
case tp: CapturingType
172-
if !ctx.mode.is(Mode.Type) && tp.parent.captureSet.accountsFor(tp.ref) =>
172+
if !ctx.mode.is(Mode.Type) && tp.refs <:< tp.parent.captureSet =>
173173
simplify(tp.parent, theMap)
174174
case tp: MethodicType =>
175175
tp // See documentation of `Types#simplified`
@@ -275,7 +275,7 @@ object TypeOps:
275275
case tp1: RecType =>
276276
return tp1.rebind(approximateOr(tp1.parent, tp2))
277277
case tp1: CapturingType =>
278-
return tp1.derivedCapturingType(approximateOr(tp1.parent, tp2), tp1.ref)
278+
return tp1.derivedCapturingType(approximateOr(tp1.parent, tp2), tp1.refs)
279279
case err: ErrorType =>
280280
return err
281281
case _ =>
@@ -284,7 +284,7 @@ object TypeOps:
284284
case tp2: RecType =>
285285
return tp2.rebind(approximateOr(tp1, tp2.parent))
286286
case tp2: CapturingType =>
287-
return tp2.derivedCapturingType(approximateOr(tp1, tp2.parent), tp2.ref)
287+
return tp2.derivedCapturingType(approximateOr(tp1, tp2.parent), tp2.refs)
288288
case err: ErrorType =>
289289
return err
290290
case _ =>

0 commit comments

Comments
 (0)