Skip to content

Commit 94b2717

Browse files
committed
Add capture root levels -- first draft
1 parent ac9a42d commit 94b2717

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+588
-200
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,18 @@ private[cc] def retainedElems(tree: Tree)(using Context): List[Tree] = tree matc
2626
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems
2727
case _ => Nil
2828

29+
/** Map @retain'ed arguments with a Symbol->Symbol function, keeping the same trees
30+
* but with mapped symbols.
31+
*/
32+
private[cc] def mapRetainedElems(tree: Tree)(f: Symbol => Symbol)(using Context): Tree = tree match
33+
case Apply(fn, (typd @ Typed(seqlit @ SeqLiteral(elems, tpt1), tpt2)) :: Nil) =>
34+
val elems1 = elems.mapConserve: elem =>
35+
val sym1 = f(elem.symbol)
36+
if sym1 ne elem.symbol then elem.withType(sym1.namedType) else elem
37+
if elems1 ne elems then
38+
cpy.Apply(tree)(fn, cpy.Typed(typd)(cpy.SeqLiteral(seqlit)(elems1, tpt1), tpt2) :: Nil)
39+
else tree
40+
2941
def allowUniversalInBoxed(using Context) =
3042
Feature.sourceVersion.isAtLeast(SourceVersion.`3.3`)
3143

@@ -253,6 +265,10 @@ extension (sym: Symbol)
253265
&& sym != defn.Caps_unsafeBox
254266
&& sym != defn.Caps_unsafeUnbox
255267

268+
/** The nesting level of `sym` according to capture checking */
269+
def ccNestingLevel(using Context): Int =
270+
Setup.ccNestingLevel(sym)
271+
256272
extension (tp: AnnotatedType)
257273
/** Is this a boxed capturing type? */
258274
def isBoxed(using Context): Boolean = tp.annot match

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import util.{SimpleIdentitySet, Property}
1616
import util.common.alwaysTrue
1717
import scala.collection.mutable
1818
import config.Config.ccAllowUnsoundMaps
19+
import config.Printers.{captLevels, noPrinter}
1920
import transform.SymUtils.{maxNested, minNested}
2021

2122
/** A class for capture sets. Capture sets can be constants or variables.
@@ -74,11 +75,10 @@ sealed abstract class CaptureSet extends Showable:
7475
asInstanceOf[Var]
7576

7677
/** Does this capture set contain the root reference `cap` as element? */
77-
final def isUniversal(using Context) =
78-
elems.exists {
79-
case ref: TermRef => ref.symbol == defn.captureRoot
80-
case _ => false
81-
}
78+
final def isUniversal(using Context) = elems.exists(_.isGlobalRootCapability)
79+
80+
/** Does this capture set contain a global or local root references as elements? */
81+
final def isRoot(using Context) = elems.exists(_.isRootCapability)
8282

8383
/** Add new elements to this capture set if allowed.
8484
* @pre `newElems` is not empty and does not overlap with `this.elems`.
@@ -120,17 +120,31 @@ sealed abstract class CaptureSet extends Showable:
120120
private def subsumes(y: CaptureRef) =
121121
(x eq y)
122122
|| y.match
123-
case y: TermRef => y.prefix eq x
123+
case y: TermRef => (y.prefix eq x) || x.isRootIncluding(y)
124124
case _ => false
125125

126+
private def isRootIncluding(y: CaptureRef) =
127+
x.isRootCapability && y.isRootCapability && {
128+
val xsym = x.termSymbol
129+
val ysym = y.termSymbol
130+
xsym == defn.captureRoot // roots without level are compatible with all other roots
131+
|| ysym == defn.captureRoot
132+
|| xsym.ccNestingLevel >= ysym.ccNestingLevel
133+
}
134+
end extension
135+
126136
/** {x} <:< this where <:< is subcapturing, but treating all variables
127137
* as frozen.
128138
*/
129139
def accountsFor(x: CaptureRef)(using Context): Boolean =
130-
reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true) {
131-
elems.exists(_.subsumes(x))
132-
|| !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK
133-
}
140+
if comparer.isInstanceOf[ExplainingTypeComparer] then // !!! DEBUG
141+
reporting.trace.force(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true):
142+
elems.exists(_.subsumes(x))
143+
|| !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK
144+
else
145+
reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true):
146+
elems.exists(_.subsumes(x))
147+
|| !x.isRootCapability && x.captureSetOfInfo.subCaptures(this, frozen = true).isOK
134148

135149
/** A more optimistic version of accountsFor, which does not take variable supersets
136150
* of the `x` reference into account. A set might account for `x` if it accounts
@@ -278,7 +292,9 @@ sealed abstract class CaptureSet extends Showable:
278292

279293
/** Invoke handler if this set has (or later aquires) the root capability `cap` */
280294
def disallowRootCapability(handler: () => Context ?=> Unit)(using Context): this.type =
281-
if isUniversal then handler()
295+
if
296+
if allowUniversalInBoxed then isUniversal else isRoot
297+
then handler()
282298
this
283299

284300
/** Invoke handler on the elements to ensure wellformedness of the capture set.
@@ -433,10 +449,32 @@ object CaptureSet:
433449
def resetDeps()(using state: VarState): Unit =
434450
deps = state.deps(this)
435451

452+
/** If `root` is a local root, check that its level is not smaller than
453+
* the set's owner's nesting level.
454+
* @param ref The actually checked reference to be use in the error message.
455+
* Currently always the same as `root`.
456+
*/
457+
private def checkLevel(ref: CaptureRef, root: CaptureRef)(using Context): Unit =
458+
if root.isLocalRootCapability then
459+
if root.termSymbol.ccNestingLevel > owner.ccNestingLevel then
460+
throw LevelException(owner, ref)
461+
436462
def addNewElems(newElems: Refs, origin: CaptureSet)(using Context, VarState): CompareResult =
463+
//println(i"ADD ${newElems.toList} to $this in $owner")
437464
if !isConst && recordElemsState() then
438465
elems ++= newElems
439-
if isUniversal then rootAddedHandler()
466+
for elem <- newElems do
467+
if elem.isLocalRootCapability then
468+
checkLevel(elem, elem)
469+
else if elem.isGlobalRootCapability then
470+
rootAddedHandler()
471+
else if false then
472+
// Disabled since it gives spurious errors for references that end up
473+
// being avoided through healing. TODO: enable it when level-checking
474+
// and avoidance is better integrated.
475+
for superElem <- elem.captureSetOfInfo.elems.toList do
476+
checkLevel(elem, superElem)
477+
440478
newElemAddedHandler(newElems.toList)
441479
// assert(id != 5 || elems.size != 3, this)
442480
(CompareResult.OK /: deps) { (r, dep) =>
@@ -950,4 +988,8 @@ object CaptureSet:
950988
println(i" ${cv.show.padTo(20, ' ')} :: ${cv.deps.toList}%, %")
951989
}
952990
else op
991+
992+
class LevelException(val owner: Symbol, val ref: CaptureRef) extends Exception {
993+
if captLevels ne noPrinter then printStackTrace()
994+
}
953995
end CaptureSet

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 58 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@ import transform.SymUtils.*
1717
import transform.{Recheck, PreRecheck}
1818
import Recheck.*
1919
import scala.collection.mutable
20-
import CaptureSet.{withCaptureSetsExplained, IdempotentCaptRefMap}
20+
import CaptureSet.{withCaptureSetsExplained, IdempotentCaptRefMap, LevelException}
2121
import StdNames.nme
2222
import NameKinds.DefaultGetterName
23-
import reporting.{trace, ClosureParameterMismatch}
23+
import reporting.*
2424

2525
/** The capture checker */
2626
object CheckCaptures:
@@ -157,6 +157,10 @@ object CheckCaptures:
157157
/** Attachment key for bodies of closures, provided they are values */
158158
val ClosureBodyValue = Property.Key[Unit]
159159

160+
/** Report level violation */
161+
def levelError[T](ex: LevelException, pos: SrcPos)(using Context): Unit =
162+
report.error(em"local reference ${ex.ref} escapes into outer capture set of ${ex.owner}", pos)
163+
160164
class CheckCaptures extends Recheck, SymTransformer:
161165
thisPhase =>
162166

@@ -203,6 +207,15 @@ class CheckCaptures extends Recheck, SymTransformer:
203207
case _ =>
204208
traverseChildren(t)
205209

210+
/** Map all occurrences of `cap` in a type to the currently valid local capture root. */
211+
private def mapRootsIn(using Context) = new IdempotentCaptRefMap:
212+
def apply(t: Type) = t.dealiasKeepAnnots match
213+
case CapturingType(parent, refs) if refs.isUniversal =>
214+
val parent1 = apply(parent)
215+
CapturingType(parent1, CaptureSet(setup.localRoot(setup.currentScopeOwner)))
216+
case _ =>
217+
mapOver(t)
218+
206219
/** If `tpt` is an inferred type, interpolate capture set variables appearing contra-
207220
* variantly in it.
208221
*/
@@ -504,11 +517,16 @@ class CheckCaptures extends Recheck, SymTransformer:
504517
mdef.rhs.putAttachment(ClosureBodyValue, ())
505518
case _ =>
506519

520+
val mapRoots = mapRootsIn(using ctx.withOwner(mdef.symbol))
521+
522+
def checkCompatible(lower: Type, upper: Type, tree: Tree, msgFn: (Type, Type) => TypeMismatchMsg): Unit =
523+
withMode(Mode.CCIgnoreBoxing):
524+
if !isCompatible(lower, upper, tree) then
525+
report.error(msgFn(lower, upper), tree.srcPos)
526+
507527
def constrainParams(ptformals: List[Type], params: ParamClause) =
508-
for (param, ptformal) <- params.lazyZip(ptformals) do
509-
withMode(Mode.CCIgnoreBoxing):
510-
if !isCompatible(ptformal, param.symbol.info, param) then
511-
report.error(ClosureParameterMismatch(ptformal, param.symbol.info), param.srcPos)
528+
for (param, ptformal) <- params.lazyZip(ptformals.mapConserve(mapRoots)) do
529+
checkCompatible(ptformal, param.symbol.info, param, ClosureParameterMismatch(_, _))
512530

513531
def constrain(pt: Type, paramss: List[ParamClause]): Unit = (pt: @unchecked) match
514532
case RefinedType(_, nme.apply, rinfo) =>
@@ -521,10 +539,11 @@ class CheckCaptures extends Recheck, SymTransformer:
521539
constrainParams(ptformals, paramss.head)
522540
case _ =>
523541

524-
//println(i"recheck closure $mdef with $pt")
525542
constrain(pt.dealias.stripCapturing, mdef.paramss)
526543
recheckDef(mdef, mdef.symbol)
527-
recheckClosure(expr, pt)
544+
val clt = recheckClosure(expr, pt)
545+
inContext(ctx.withOwner(mdef.symbol)):
546+
mapRootsIn(clt)
528547
end recheckClosureBlock
529548

530549
override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit =
@@ -621,19 +640,23 @@ class CheckCaptures extends Recheck, SymTransformer:
621640
* adding all references in the boxed capture set to the current environment.
622641
*/
623642
override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
624-
val saved = curEnv
625-
tree match
626-
case _: RefTree | closureDef(_) if pt.isBoxedCapturing =>
627-
curEnv = Env(curEnv.owner, EnvKind.Boxed, CaptureSet.Var(curEnv.owner), curEnv)
628-
case _ if tree.hasAttachment(ClosureBodyValue) =>
629-
curEnv = Env(curEnv.owner, EnvKind.ClosureResult, CaptureSet.Var(curEnv.owner), curEnv)
630-
case _ =>
631-
val res =
632-
try super.recheck(tree, pt)
633-
finally curEnv = saved
634-
if tree.isTerm && !pt.isBoxedCapturing then
635-
markFree(res.boxedCaptureSet, tree.srcPos)
636-
res
643+
try
644+
val saved = curEnv
645+
tree match
646+
case _: RefTree | closureDef(_) if pt.isBoxedCapturing =>
647+
curEnv = Env(curEnv.owner, EnvKind.Boxed, CaptureSet.Var(curEnv.owner), curEnv)
648+
case _ if tree.hasAttachment(ClosureBodyValue) =>
649+
curEnv = Env(curEnv.owner, EnvKind.ClosureResult, CaptureSet.Var(curEnv.owner), curEnv)
650+
case _ =>
651+
val res =
652+
try super.recheck(tree, pt)
653+
finally curEnv = saved
654+
if tree.isTerm && !pt.isBoxedCapturing then
655+
markFree(res.boxedCaptureSet, tree.srcPos)
656+
res
657+
catch case ex: LevelException =>
658+
levelError(ex, tree.srcPos)
659+
UnspecifiedErrorType
637660

638661
/** If `tree` is a reference or an application where the result type refers
639662
* to an enclosing class or method parameter of the reference, check that the result type
@@ -941,7 +964,8 @@ class CheckCaptures extends Recheck, SymTransformer:
941964
def traverse(t: Tree)(using Context) =
942965
t match
943966
case t: Template =>
944-
checkAllOverrides(ctx.owner.asClass, OverridingPairsCheckerCC(_, _, t))
967+
try checkAllOverrides(ctx.owner.asClass, OverridingPairsCheckerCC(_, _, t))
968+
catch case ex: LevelException => levelError(ex, t.srcPos)
945969
case _ =>
946970
traverseChildren(t)
947971

@@ -1014,7 +1038,7 @@ class CheckCaptures extends Recheck, SymTransformer:
10141038
capt.println(i"checked $root with $selfType")
10151039
end checkSelfTypes
10161040

1017-
/** Heal ill-formed capture sets in the type parameter.
1041+
/** Heal ill-formed capture sets in the type parameter of function `meth`
10181042
*
10191043
* We can push parameter refs into a capture set in type parameters
10201044
* that this type parameter can't see.
@@ -1032,28 +1056,32 @@ class CheckCaptures extends Recheck, SymTransformer:
10321056
* compensate this by pushing the widened capture set of `f` into ?1.
10331057
* This solves the soundness issue caused by the ill-formness of ?1.
10341058
*/
1035-
private def healTypeParam(tree: Tree)(using Context): Unit =
1059+
private def healTypeParam(meth: Symbol, tree: Tree)(using Context): Unit =
10361060
val checker = new TypeTraverser:
10371061
private var allowed: SimpleIdentitySet[TermParamRef] = SimpleIdentitySet.empty
10381062

10391063
private def isAllowed(ref: CaptureRef): Boolean = ref match
10401064
case ref: TermParamRef => allowed.contains(ref)
1065+
case ref: TermRef => !(ref.isLocalRootCapability && ref.symbol.owner == meth)
10411066
case _ => true
10421067

10431068
private def healCaptureSet(cs: CaptureSet): Unit =
10441069
cs.ensureWellformed: elems =>
10451070
ctx ?=>
10461071
var seen = new util.HashSet[CaptureRef]
1047-
def recur(elems: List[CaptureRef]): Unit =
1072+
def recur(elems: List[CaptureRef], prev: Option[CaptureRef]): Unit =
10481073
for ref <- elems do
10491074
if !isAllowed(ref) && !seen.contains(ref) then
1075+
if ref.isLocalRootCapability
1076+
&& ref.termSymbol.nestingLevel > cs.owner.nestingLevel then
1077+
throw LevelException(cs.owner, prev.getOrElse(ref))
10501078
seen += ref
10511079
val widened = ref.captureSetOfInfo
10521080
val added = widened.filter(isAllowed(_))
10531081
capt.println(i"heal $ref in $cs by widening to $added")
10541082
checkSubset(added, cs, tree.srcPos)
1055-
recur(widened.elems.toList)
1056-
recur(elems)
1083+
recur(widened.elems.toList, Some(ref))
1084+
recur(elems, None)
10571085

10581086
def traverse(tp: Type) =
10591087
tp match
@@ -1144,7 +1172,9 @@ class CheckCaptures extends Recheck, SymTransformer:
11441172
checkBounds(normArgs, tl)
11451173
case _ =>
11461174

1147-
args.foreach(healTypeParam(_))
1175+
args.foreach: arg =>
1176+
try healTypeParam(fun.symbol, arg)
1177+
catch case ex: LevelException => levelError(ex, arg.srcPos)
11481178
case _ =>
11491179
end check
11501180
end checker

0 commit comments

Comments
 (0)