Skip to content

Commit 62a11b9

Browse files
committed
Check capture sets for well-formedness
1 parent c5b67b2 commit 62a11b9

File tree

10 files changed

+149
-62
lines changed

10 files changed

+149
-62
lines changed

compiler/src/dotty/tools/dotc/CaptureOps.scala renamed to compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ import config.Printers.capt
1212
object CaptureOps:
1313
import tpd.*
1414

15+
def retainedElems(tree: Tree)(using Context): List[Tree] = tree match
16+
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) => elems
17+
case _ => Nil
18+
1519
extension (cs: CaptureSet)
1620
def toAnnotation(using Context): Annotation =
1721
val refs = cs.elems.toList.map {
@@ -29,10 +33,8 @@ object CaptureOps:
2933
extension (annot: Annotation)
3034
def toCaptureSet(using Context): CaptureSet =
3135
assert(annot.symbol == defn.RetainsAnnot)
32-
annot.tree match
33-
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) =>
34-
CaptureSet(elems.map(_.tpe.asInstanceOf[CaptureRef])*)
35-
.showing(i"toCaptureSet $annot --> $result", capt)
36+
CaptureSet(retainedElems(annot.tree).map(_.tpe.asInstanceOf[CaptureRef])*)
37+
.showing(i"toCaptureSet $annot --> $result", capt)
3638

3739
extension (tp: AnnotatedType)
3840
def toCapturingType(using Context): Type =

compiler/src/dotty/tools/dotc/core/CaptureSet.scala

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ sealed abstract class CaptureSet extends Showable:
3737
*/
3838
def isConst: Boolean
3939

40-
/** Is this capture set (always) empty? For capture veraiables, returns
40+
/** Is this capture set always empty? For capture veraiables, returns
4141
* always false
4242
*/
43-
def isEmpty: Boolean
44-
def nonEmpty: Boolean = !isEmpty
43+
def isAlwaysEmpty: Boolean
44+
45+
/** Is this capture set definitely non-empty? */
46+
final def isNotEmpty: Boolean = !elems.isEmpty
4547

4648
/** Add new elements to this capture set if allowed.
4749
* @pre `newElems` is not empty and does not overlap with `this.elems`.
@@ -74,8 +76,10 @@ sealed abstract class CaptureSet extends Showable:
7476
* as frozen.
7577
*/
7678
def accountsFor(x: CaptureRef)(using Context): Boolean =
77-
elems.contains(x)
78-
|| !x.isRootCapability && (x.captureSetOfInfo frozen_<:< this) == CompareResult.OK
79+
reporting.trace(i"$this accountsFor $x, ${x.captureSetOfInfo}?", show = true) {
80+
elems.contains(x)
81+
|| !x.isRootCapability && (x.captureSetOfInfo frozen_<:< this) == CompareResult.OK
82+
}
7983

8084
/** The subcapturing test */
8185
def <:< (that: CaptureSet)(using Context): CompareResult =
@@ -172,7 +176,7 @@ object CaptureSet:
172176
class Const private[CaptureSet] (val elems: Refs) extends CaptureSet:
173177
assert(elems != null)
174178
def isConst = true
175-
def isEmpty: Boolean = elems.isEmpty
179+
def isAlwaysEmpty = elems.isEmpty
176180

177181
def addNewElems(elems: Refs)(using Context, VarState): CompareResult =
178182
CompareResult.fail(this)
@@ -190,7 +194,7 @@ object CaptureSet:
190194
var elems: Refs = initialElems
191195
var deps: Deps = emptySet
192196
def isConst = false
193-
def isEmpty = false
197+
def isAlwaysEmpty = false
194198

195199
private def recordElemsState()(using VarState): Boolean =
196200
varState.getElems(this) match

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,7 +2041,7 @@ object Types {
20412041
private var mySingletonCaptureSet: CaptureSet = null
20422042

20432043
def canBeTracked(using Context): Boolean
2044-
final def isTracked(using Context): Boolean = canBeTracked && captureSetOfInfo.nonEmpty
2044+
final def isTracked(using Context): Boolean = canBeTracked && !captureSetOfInfo.isAlwaysEmpty
20452045
def isRootCapability(using Context): Boolean = false
20462046
def normalizedRef(using Context): CaptureRef = this
20472047

@@ -2067,7 +2067,7 @@ object Types {
20672067

20682068
override def captureSet(using Context): CaptureSet =
20692069
val cs = captureSetOfInfo
2070-
if canBeTracked && cs.nonEmpty then singletonCaptureSet else cs
2070+
if canBeTracked && !cs.isAlwaysEmpty then singletonCaptureSet else cs
20712071
end CaptureRef
20722072

20732073
/** A trait for types that bind other types that refer to them.
@@ -5224,7 +5224,7 @@ object Types {
52245224

52255225
object CapturingType:
52265226
def apply(parent: Type, refs: CaptureSet)(using Context): Type =
5227-
if refs.isEmpty then parent
5227+
if refs.isAlwaysEmpty then parent
52285228
else unique(CachedCapturingType(parent, refs))
52295229
end CapturingType
52305230

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 74 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,39 @@ import CaptureSet.CompareResult
3131
import cc.CaptureOps.*
3232

3333
object CheckCaptures:
34+
import ast.tpd.*
35+
3436
case class Env(owner: Symbol, captured: CaptureSet, isBoxed: Boolean, outer: Env):
35-
def isOpen = !captured.isEmpty && !isBoxed
37+
def isOpen = !captured.isAlwaysEmpty && !isBoxed
3638

3739
/** Attachment key for printing trees with rechecked types */
3840
val RecheckedType = Property.Key[Type]
3941

42+
/** Check that a @retains annotation only mentions references that can be tracked
43+
* This check is performed at Typer.
44+
*/
45+
def checkWellformed(ann: Tree)(using Context): Unit =
46+
for elem <- retainedElems(ann) do
47+
elem.tpe match
48+
case ref: CaptureRef =>
49+
if !ref.canBeTracked then
50+
report.error(em"$elem cannot be tracked since it is not a parameter or a local variable", elem.srcPos)
51+
case tpe =>
52+
report.error(em"$tpe is not a legal type for a capture set", elem.srcPos)
53+
54+
/** If `tp` is a capturing type, check that all references it mentions have non-empty
55+
* capture sets.
56+
* This check is performed after capture sets are computed in phase cc.
57+
*/
58+
def checkWellformedPost(tp: Type, pos: SrcPos)(using Context): Unit = tp match
59+
case tp: CapturingType =>
60+
for ref <- tp.refs.elems do
61+
if (ref.captureSet frozen_<:< CaptureSet.empty) == CompareResult.OK then
62+
report.error(em"$ref cannot be tracked since its capture set is empty", pos)
63+
case _ =>
64+
65+
private inline val disallowGlobal = true
66+
4067
class CheckCaptures extends Recheck:
4168
thisPhase =>
4269

@@ -71,6 +98,7 @@ class CheckCaptures extends Recheck:
7198
parent
7299
case _ =>
73100
mapOver(t)
101+
74102
def addVars(tp: Type): Type =
75103
if tp.canHaveInferredCapture then
76104
val tp1 = tp match
@@ -80,6 +108,7 @@ class CheckCaptures extends Recheck:
80108
tp
81109
CapturingType(tp1, CaptureSet.Var())
82110
else tp
111+
83112
tp match
84113
case tp: MethodOrPoly =>
85114
tp.derivedLambdaType(resType = reinfer(tp.resType))
@@ -88,6 +117,7 @@ class CheckCaptures extends Recheck:
88117
case _ =>
89118
val tp1 = cleanType(tp)
90119
addVars(tp1)
120+
end reinfer
91121

92122
val tp1 = if inferred then reinfer(tp) else tp
93123
mapType(tp1)
@@ -115,7 +145,7 @@ class CheckCaptures extends Recheck:
115145
if curEnv.isOpen then
116146
val ownEnclosure = ctx.owner.enclosingMethodOrClass
117147
var targetSet = capturedVars(sym)
118-
if !targetSet.isEmpty && sym.enclosure == ownEnclosure then
148+
if !targetSet.isAlwaysEmpty && sym.enclosure == ownEnclosure then
119149
targetSet = targetSet.filter {
120150
case ref: TermRef => ref.symbol.enclosure != ownEnclosure
121151
case _ => true
@@ -166,14 +196,14 @@ class CheckCaptures extends Recheck:
166196
remember(tree, sym.localReturnType)
167197
val saved = curEnv
168198
val localSet = capturedVars(sym)
169-
if !localSet.isEmpty then curEnv = Env(sym, localSet, false, curEnv)
199+
if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, false, curEnv)
170200
try super.recheckDefDef(tree, sym)
171201
finally curEnv = saved
172202

173203
override def recheckClassDef(tree: TypeDef, impl: Template, sym: ClassSymbol)(using Context): Type =
174204
val saved = curEnv
175205
val localSet = capturedVars(sym)
176-
if !localSet.isEmpty then curEnv = Env(sym, localSet, false, curEnv)
206+
if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, false, curEnv)
177207
try super.recheckClassDef(tree, impl, sym)
178208
finally curEnv = saved
179209

@@ -198,44 +228,47 @@ class CheckCaptures extends Recheck:
198228
super.checkUnit(unit)
199229
PostRefinerCheck.traverse(unit.tpdTree)
200230

201-
end CaptureChecker
231+
def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit =
232+
if disallowGlobal then
233+
tree match
234+
case LambdaTypeTree(_, restpt) =>
235+
checkNotGlobal(restpt, allArgs*)
236+
case _ =>
237+
for ref <- tree.tpe.captureSet.elems do
238+
val isGlobal = ref match
239+
case ref: TermRef =>
240+
ref.isRootCapability || ref.prefix != NoPrefix && ref.symbol.hasAnnotation(defn.AbilityAnnot)
241+
case _ => false
242+
val what = if ref.isRootCapability then "universal" else "global"
243+
if isGlobal then
244+
val notAllowed = i" is not allowed to capture the $what capability $ref"
245+
def msg = tree match
246+
case tree: InferredTypeTree =>
247+
i"""inferred type argument ${tree.tpe}$notAllowed
248+
|
249+
|The inferred arguments are: [$allArgs%, %]"""
250+
case _ => s"type argument$notAllowed"
251+
report.error(msg, tree.srcPos)
252+
253+
object PostRefinerCheck extends TreeTraverser:
254+
def traverse(tree: Tree)(using Context) =
255+
tree match
256+
case _: InferredTypeTree =>
257+
case tree: TypeTree =>
258+
transformType(tree.tpe, inferred = false).foreachPart(
259+
checkWellformedPost(_, tree.srcPos))
260+
261+
case tree1 @ TypeApply(fn, args) if disallowGlobal =>
262+
for arg <- args do
263+
//println(i"checking $arg in $tree: ${arg.tpe.captureSet}")
264+
checkNotGlobal(arg, args*)
265+
case _ =>
266+
traverseChildren(tree)
202267

203-
inline val disallowGlobal = true
204-
205-
def checkNotGlobal(tree: Tree, allArgs: Tree*)(using Context): Unit =
206-
if disallowGlobal then
207-
tree match
208-
case LambdaTypeTree(_, restpt) =>
209-
checkNotGlobal(restpt, allArgs*)
210-
case _ =>
211-
for ref <- tree.tpe.captureSet.elems do
212-
val isGlobal = ref match
213-
case ref: TermRef =>
214-
ref.isRootCapability || ref.prefix != NoPrefix && ref.symbol.hasAnnotation(defn.AbilityAnnot)
215-
case _ => false
216-
val what = if ref.isRootCapability then "universal" else "global"
217-
if isGlobal then
218-
val notAllowed = i" is not allowed to capture the $what capability $ref"
219-
def msg = tree match
220-
case tree: InferredTypeTree =>
221-
i"""inferred type argument ${tree.tpe}$notAllowed
222-
|
223-
|The inferred arguments are: [$allArgs%, %]"""
224-
case _ => s"type argument$notAllowed"
225-
report.error(msg, tree.srcPos)
226-
227-
object PostRefinerCheck extends TreeTraverser:
228-
def traverse(tree: Tree)(using Context) =
229-
tree match
230-
case tree1 @ TypeApply(fn, args) if disallowGlobal =>
231-
for arg <- args do
232-
//println(i"checking $arg in $tree: ${arg.tpe.captureSet}")
233-
checkNotGlobal(arg, args*)
234-
case _ =>
235-
traverseChildren(tree)
236-
237-
def postRefinerCheck(tree: tpd.Tree)(using Context): Unit =
238-
PostRefinerCheck.traverse(tree)
268+
def postRefinerCheck(tree: tpd.Tree)(using Context): Unit =
269+
PostRefinerCheck.traverse(tree)
270+
271+
end CaptureChecker
239272

240273
override def show(tree: untpd.Tree)(using Context): String =
241274
val addRecheckedTypes = new TreeMap:

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,6 +2490,8 @@ class Typer extends Namer
24902490
val annot1 = typedExpr(tree.annot, defn.AnnotationClass.typeRef)
24912491
val arg1 = typed(tree.arg, pt)
24922492
if (ctx.mode is Mode.Type) {
2493+
if annot1.symbol.maybeOwner == defn.RetainsAnnot then
2494+
CheckCaptures.checkWellformed(annot1)
24932495
if arg1.isType then
24942496
assignType(cpy.Annotated(tree)(arg1, annot1), arg1, annot1)
24952497
else
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class C
2+
type Cap = {*} C
3+
4+
object foo
5+
6+
def test(c: Cap, other: String): Unit =
7+
val x7: {c} String = ??? // OK
8+
val x8: String @retains(x7 + x7) = ??? // error
9+
val x9: String @retains(foo) = ??? // error
10+
()
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
class C
2+
type Cap = {*} C
3+
4+
object foo
5+
6+
def test(c: Cap, other: String): Unit =
7+
val x1: {*} C = ??? // OK
8+
val x2: {other} C = ??? // error: cs is empty
9+
val s1 = () => "abc"
10+
val x3: {s1} C = ??? // error: cs is empty
11+
val x3a: () => String = s1
12+
val s2 = () => if x1 == null then "" else "abc"
13+
val x4: {s2} C = ??? // OK
14+
val x5: {c, c} C = ??? // warning: redundant
15+
val x6: {c} {c} C = ??? // warning: redundant
16+
val x7: {c} {*} C = ??? // warning: redundant
17+
val x8: {*} {c} C = ??? // OK
18+
19+
def even(n: Int): Boolean = if n == 0 then true else odd(n - 1)
20+
def odd(n: Int): Boolean = if n == 1 then true else even(n - 1)
21+
val e1 = even
22+
val o1 = odd
23+
24+
val y1: {e1} String = ??? // error cs is empty
25+
val y2: {o1} String = ??? // error cs is empty
26+
27+
lazy val ev: (Int => Boolean) = (n: Int) =>
28+
lazy val od: (Int => Boolean) = (n: Int) =>
29+
if n == 1 then true else ev(n - 1)
30+
if n == 0 then true else od(n - 1)
31+
val y3: {ev} String = ??? // error cs is empty
32+
33+
()

tests/new/test.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
11
object Test:
22

3-
val x = ""
4-
5-
class C extends AnyRef @scala.retains(x)
6-
73
def test = ???
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
object Test:
22

33
def test() =
4-
val x = "abc"
4+
val x: {*} Any = "abc"
55
val y: Object @scala.retains(x) = ???
66
val z: Object @scala.retains(x, *) = y: Object @scala.retains(x)
77

tests/pos-custom-args/captures/capt2.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,18 @@ class C
33
type Cap = C @retains(*)
44

55
def test1() =
6-
val y = ""
6+
val y: {*} String = ""
77
def x: Object @retains(y) = y
88

99
def test2() =
1010
val x: Cap = C()
1111
val y = () => { x; () }
1212
def z: (() => Unit) @retains(x) = y
13-
z: (() => Unit) @retains(x) // TODO: replace x with y
13+
z: (() => Unit) @retains(x)
14+
def z2: (() => Unit) @retains(y) = y
15+
z2: (() => Unit) @retains(y)
16+
val p: {*} () => String = () => "abc"
17+
val q: {p} C = ???
18+
p: ({p} () => String)
19+
20+

0 commit comments

Comments
 (0)