Skip to content

Commit ee23239

Browse files
committed
Check that only @consume parameters flow to @consume parameters
1 parent 3adf4b0 commit ee23239

17 files changed

+121
-58
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -427,10 +427,6 @@ extension (tp: Type)
427427
mapOver(t)
428428
tm(tp)
429429

430-
def hasUseAnnot(using Context): Boolean = tp match
431-
case AnnotatedType(_, ann) => ann.symbol == defn.UseAnnot
432-
case _ => false
433-
434430
/** If `x` is a capture ref, its maybe capability `x?`, represented internally
435431
* as `x @maybeCapability`. `x?` stands for a capability `x` that might or might
436432
* not be part of a capture set. We have `{} <: {x?} <: {x}`. Maybe capabilities

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -655,11 +655,13 @@ class CheckCaptures extends Recheck, SymTransformer:
655655
* on method parameter symbols to the corresponding paramInfo types.
656656
*/
657657
override def prepareFunction(funtpe: MethodType, meth: Symbol)(using Context): MethodType =
658-
val paramInfosWithUses = funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) =>
659-
val param = meth.paramNamed(pname)
660-
param.getAnnotation(defn.UseAnnot) match
661-
case Some(ann) => AnnotatedType(formal, ann)
662-
case _ => formal
658+
val paramInfosWithUses =
659+
funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) =>
660+
val param = meth.paramNamed(pname)
661+
def copyAnnot(tp: Type, cls: ClassSymbol) = param.getAnnotation(cls) match
662+
case Some(ann) => AnnotatedType(tp, ann)
663+
case _ => tp
664+
copyAnnot(copyAnnot(formal, defn.UseAnnot), defn.ConsumeAnnot)
663665
funtpe.derivedLambdaType(paramInfos = paramInfosWithUses)
664666

665667
/** Recheck applications, with special handling of unsafeAssumePure.
@@ -687,7 +689,7 @@ class CheckCaptures extends Recheck, SymTransformer:
687689
val freshenedFormal = Fresh.fromCap(formal)
688690
val argType = recheck(arg, freshenedFormal)
689691
.showing(i"recheck arg $arg vs $freshenedFormal", capt)
690-
if formal.hasUseAnnot then
692+
if formal.hasAnnotation(defn.UseAnnot) then
691693
// The @use annotation is added to `formal` by `prepareFunction`
692694
capt.println(i"charging deep capture set of $arg: ${argType} = ${argType.deepCaptureSet}")
693695
markFree(argType.deepCaptureSet, arg)
@@ -722,7 +724,7 @@ class CheckCaptures extends Recheck, SymTransformer:
722724
val qualCaptures = qualType.captureSet
723725
val argCaptures =
724726
for (argType, formal) <- argTypes.lazyZip(funType.paramInfos) yield
725-
if formal.hasUseAnnot then argType.deepCaptureSet else argType.captureSet
727+
if formal.hasAnnotation(defn.UseAnnot) then argType.deepCaptureSet else argType.captureSet
726728
appType match
727729
case appType @ CapturingType(appType1, refs)
728730
if qualType.exists
@@ -1569,10 +1571,11 @@ class CheckCaptures extends Recheck, SymTransformer:
15691571
(params1, params2) <- member.rawParamss.lazyZip(other.rawParamss)
15701572
(param1, param2) <- params1.lazyZip(params2)
15711573
do
1572-
if param1.hasAnnotation(defn.UseAnnot) != param2.hasAnnotation(defn.UseAnnot) then
1573-
fail(i"has a parameter ${param1.name} with different @use status than the corresponding parameter in the overridden definition")
1574-
if param1.hasAnnotation(defn.ConsumeAnnot) != param2.hasAnnotation(defn.ConsumeAnnot) then
1575-
fail(i"has a parameter ${param1.name} with different @consume status than the corresponding parameter in the overridden definition")
1574+
def checkAnnot(cls: ClassSymbol) =
1575+
if param1.hasAnnotation(cls) != param2.hasAnnotation(cls) then
1576+
fail(i"has a parameter ${param1.name} with different @${cls.name} status than the corresponding parameter in the overridden definition")
1577+
checkAnnot(defn.UseAnnot)
1578+
checkAnnot(defn.ConsumeAnnot)
15761579
end OverridingPairsCheckerCC
15771580

15781581
def traverse(t: Tree)(using Context) =

compiler/src/dotty/tools/dotc/cc/SepCheck.scala

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import CaptureSet.{Refs, emptySet, HiddenSet}
1111
import config.Printers.capt
1212
import StdNames.nme
1313
import util.{SimpleIdentitySet, EqHashMap, SrcPos}
14+
import tpd.*
1415

1516
object SepChecker:
1617

@@ -31,15 +32,14 @@ object SepChecker:
3132
/** The kind of checked type, used for composing error messages */
3233
enum TypeKind:
3334
case Result(sym: Symbol, inferred: Boolean)
34-
case Argument
35+
case Argument(arg: Tree)
3536

3637
def dclSym = this match
3738
case Result(sym, _) => sym
3839
case _ => NoSymbol
3940
end TypeKind
4041

4142
class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
42-
import tpd.*
4343
import checker.*
4444
import SepChecker.*
4545

@@ -214,7 +214,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
214214
for (arg, idx) <- indexedArgs do
215215
if arg.needsSepCheck then
216216
val ac = formalCaptures(arg)
217-
checkType(arg.formalType, arg.srcPos, TypeKind.Argument)
217+
checkType(arg.formalType, arg.srcPos, TypeKind.Argument(arg))
218218
val hiddenInArg = ac.hidden.footprint
219219
//println(i"check sep $arg: $ac, footprint so far = $footprint, hidden = $hiddenInArg")
220220
val overlap = hiddenInArg.overlapWith(footprint).deductCapturesOf(deps(arg))
@@ -252,9 +252,9 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
252252
case TypeKind.Result(sym, inferred) =>
253253
def inferredStr = if inferred then " inferred" else ""
254254
def resultStr = if sym.info.isInstanceOf[MethodicType] then " result" else ""
255-
i" $sym's$inferredStr$resultStr"
256-
case TypeKind.Argument =>
257-
" the argument's adapted type"
255+
i"$sym's$inferredStr$resultStr"
256+
case TypeKind.Argument(_) =>
257+
"the argument's adapted"
258258

259259
def explicitRefs(tp: Type): Refs = tp match
260260
case tp: (TermRef | ThisType) => SimpleIdentitySet(tp)
@@ -292,7 +292,7 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
292292
.nextOption
293293
.getOrElse(("", current, globalOverlap))
294294
report.error(
295-
em"""Separation failure in$typeDescr type $tpe.
295+
em"""Separation failure in $typeDescr type $tpe.
296296
|One part, $part , $nextRel ${CaptureSet(next)}.
297297
|A previous part$prevStr $prevRel ${CaptureSet(prevRefs)}.
298298
|The two sets overlap at ${CaptureSet(overlap)}.""",
@@ -346,10 +346,10 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
346346
case t =>
347347
foldOver(c, t)
348348

349-
def checkParameters() =
349+
def checkParams(refsToCheck: Refs, descr: => String) =
350350
val badParams = mutable.ListBuffer[Symbol]()
351351
def currentOwner = kind.dclSym.orElse(ctx.owner)
352-
for hiddenRef <- prune(tpe.deepCaptureSet.elems.hidden.footprint) do
352+
for hiddenRef <- prune(refsToCheck.footprint) do
353353
val refSym = hiddenRef.termSymbol
354354
if refSym.is(TermParam)
355355
&& !refSym.hasAnnotation(defn.ConsumeAnnot)
@@ -364,25 +364,29 @@ class SepChecker(checker: CheckCaptures.CheckerAPI) extends tpd.TreeTraverser:
364364
case p :: ps => i"${p.name}, ${paramsStr(ps)}"
365365
val (pluralS, singleS) = if badParams.tail.isEmpty then ("", "s") else ("s", "")
366366
report.error(
367-
em"""Separation failure:$typeDescr type $tpe hides parameter$pluralS ${paramsStr(badParams.toList)}
367+
em"""Separation failure: $descr parameter$pluralS ${paramsStr(badParams.toList)}.
368368
|The parameter$pluralS need$singleS to be annotated with @consume to allow this.""",
369369
pos)
370370

371-
def flagHiddenParams =
372-
kind match
373-
case TypeKind.Result(sym, _) =>
374-
!sym.isAnonymousFunction // we don't check return types of anonymous functions
375-
&& !sym.is(Case) // We don't check so far binders in patterns since they
376-
// have inferred universal types. TODO come back to this;
377-
// either infer more precise types for such binders or
378-
// "see through them" when we look at hidden sets.
379-
case TypeKind.Argument =>
380-
false
371+
def checkParameters() = kind match
372+
case TypeKind.Result(sym, _) =>
373+
if !sym.isAnonymousFunction // we don't check return types of anonymous functions
374+
&& !sym.is(Case) // We don't check so far binders in patterns since they
375+
// have inferred universal types. TODO come back to this;
376+
// either infer more precise types for such binders or
377+
// "see through them" when we look at hidden sets.
378+
then checkParams(tpe.deepCaptureSet.elems.hidden, i"$typeDescr type $tpe hides")
379+
case TypeKind.Argument(arg) =>
380+
if tpe.hasAnnotation(defn.ConsumeAnnot) then
381+
val capts = captures(arg)
382+
def descr(verb: String) = i"argument to @consume parameter with type ${arg.nuType} $verb"
383+
checkParams(capts, descr("refers to"))
384+
checkParams(capts.hidden, descr("hides"))
381385

382386
if !tpe.hasAnnotation(defn.UntrackedCapturesAnnot) then
383387
traverse(Captures.None, tpe)
384388
traverse.toCheck.foreach(checkParts)
385-
if flagHiddenParams then checkParameters()
389+
checkParameters()
386390
end checkType
387391

388392
private def collectMethodTypes(tp: Type): List[TermLambda] = tp match

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -528,7 +528,7 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
528528
case _ =>
529529
traverseChildren(tree)
530530
postProcess(tree)
531-
checkProperUse(tree)
531+
checkProperUseOrConsume(tree)
532532
end traverse
533533

534534
/** Processing done on node `tree` after its children are traversed */
@@ -682,16 +682,22 @@ class Setup extends PreRecheck, SymTransformer, SetupAPI:
682682
case _ =>
683683
end postProcess
684684

685-
/** Check that @use annotations only appear on parameters and not on anonymous function parameters */
686-
def checkProperUse(tree: Tree)(using Context): Unit = tree match
685+
/** Check that @use and @consume annotations only appear on parameters and not on
686+
* anonymous function parameters
687+
*/
688+
def checkProperUseOrConsume(tree: Tree)(using Context): Unit = tree match
687689
case tree: MemberDef =>
688-
def useAllowed(sym: Symbol) =
689-
(sym.is(Param) || sym.is(ParamAccessor)) && !sym.owner.isAnonymousFunction
690690
for ann <- tree.symbol.annotations do
691-
if ann.symbol == defn.UseAnnot && !useAllowed(tree.symbol) then
692-
report.error(i"Only parameters of methods can have @use annotations", tree.srcPos)
691+
def isAllowedFor(sym: Symbol) =
692+
(sym.is(Param) || sym.is(ParamAccessor))
693+
&& (ann.symbol != defn.ConsumeAnnot || sym.isTerm)
694+
&& !sym.owner.isAnonymousFunction
695+
def termStr =
696+
if ann.symbol == defn.ConsumeAnnot then " term" else ""
697+
if defn.ccParamOnlyAnnotations.contains(ann.symbol) && !isAllowedFor(tree.symbol) then
698+
report.error(i"Only$termStr parameters of methods can have @${ann.symbol.name} annotations", tree.srcPos)
693699
case _ =>
694-
end checkProperUse
700+
end checkProperUseOrConsume
695701
end setupTraverser
696702

697703
// --------------- Adding capture set variables ----------------------------------

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,8 @@ class Definitions {
11181118
@tu lazy val SilentAnnots: Set[Symbol] =
11191119
Set(InlineParamAnnot, ErasedParamAnnot, RefineOverrideAnnot)
11201120

1121+
@tu lazy val ccParamOnlyAnnotations: Set[Symbol] = Set(UseAnnot, ConsumeAnnot)
1122+
11211123
// A list of annotations that are commonly used to indicate that a field/method argument or return
11221124
// type is not null. These annotations are used by the nullification logic in JavaNullInterop to
11231125
// improve the precision of type nullification.
Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1-
import caps.use
2-
class Test:
1+
import caps.{use, consume}
2+
class TestUse:
33
@use def F = ??? // error
44
@use val x = ??? // error
55
@use type T // error
66
def foo[@use T](@use c: T): Unit = ??? // OK
77

8+
class TestConsume:
9+
@consume def F = ??? // error
10+
@consume val x = ??? // error
11+
@consume type T // error
12+
def foo[@consume T](@use c: T): Unit = ??? // error
13+

tests/neg-custom-args/captures/capt-depfun.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
-- Error: tests/neg-custom-args/captures/capt-depfun.scala:11:24 -------------------------------------------------------
99
11 | val dc: ((Str^{y, z}) => Str^{y, z}) = ac(g()) // error // error: separatioon
1010
| ^^^^^^^^^^^^^^^^^^^^^^^^^^
11-
| Separation failure: value dc's type Str^{y, z} => Str^{y, z} hides parameters y and z
11+
| Separation failure: value dc's type Str^{y, z} => Str^{y, z} hides parameters y and z.
1212
| The parameters need to be annotated with @consume to allow this.

tests/neg-custom-args/captures/depfun-reach.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@
1515
-- Error: tests/neg-custom-args/captures/depfun-reach.scala:12:17 ------------------------------------------------------
1616
12 | : (xs: List[(X, () ->{io} Unit)]) => List[() ->{} Unit] = // error
1717
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
18-
|Separation failure: method foo's result type (xs: List[(X, box () ->{io} Unit)]) => List[() -> Unit] hides parameter op
18+
|Separation failure: method foo's result type (xs: List[(X, box () ->{io} Unit)]) => List[() -> Unit] hides parameter op.
1919
|The parameter needs to be annotated with @consume to allow this.

tests/neg-custom-args/captures/i15772.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,5 +42,5 @@
4242
-- Error: tests/neg-custom-args/captures/i15772.scala:34:10 ------------------------------------------------------------
4343
34 | def c : C^ = new C(x) // error separation
4444
| ^^
45-
| Separation failure: method c's result type C^ hides parameter x
45+
| Separation failure: method c's result type C^ hides parameter x.
4646
| The parameter needs to be annotated with @consume to allow this.

tests/neg-custom-args/captures/i19330.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@
1313
-- Error: tests/neg-custom-args/captures/i19330.scala:16:14 ------------------------------------------------------------
1414
16 | val t: () => Logger^ = () => l // error
1515
| ^^^^^^^^^^^^^
16-
| Separation failure: value t's type () => (ex$5: caps.Exists) -> Logger^{ex$5} hides parameter l
16+
| Separation failure: value t's type () => (ex$5: caps.Exists) -> Logger^{ex$5} hides parameter l.
1717
| The parameter needs to be annotated with @consume to allow this.

tests/neg-custom-args/captures/i21442.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@
1010
-- Error: tests/neg-custom-args/captures/i21442.scala:17:10 ------------------------------------------------------------
1111
17 | val x1: Boxed[IO^] = x // error
1212
| ^^^^^^^^^^
13-
| Separation failure: value x1's type Boxed[box IO^] hides parameter x
13+
| Separation failure: value x1's type Boxed[box IO^] hides parameter x.
1414
| The parameter needs to be annotated with @consume to allow this.

tests/neg-custom-args/captures/sepchecks2.check

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@
3636
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:27:6 ---------------------------------------------------------
3737
27 | bar((c, c)) // error
3838
| ^^^^^^
39-
| Separation failure in the argument's adapted type type (box Object^, box Object^).
39+
| Separation failure in the argument's adapted type (box Object^, box Object^).
4040
| One part, box Object^ , hides {c}.
4141
| A previous part, box Object^ , also hides {c}.
4242
| The two sets overlap at {c}.
4343
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:30:9 ---------------------------------------------------------
4444
30 | val x: (Object^, Object^{c}) = (d, c) // error
4545
| ^^^^^^^^^^^^^^^^^^^^^
46-
| Separation failure: value x's type (box Object^, box Object^{c}) hides parameter d
46+
| Separation failure: value x's type (box Object^, box Object^{c}) hides parameter d.
4747
| The parameter needs to be annotated with @consume to allow this.
4848
-- Error: tests/neg-custom-args/captures/sepchecks2.scala:33:9 ---------------------------------------------------------
4949
33 | val x: (Object^, Object^) = (c, d) // error
5050
| ^^^^^^^^^^^^^^^^^^
51-
| Separation failure: value x's type (box Object^, box Object^) hides parameters c and d
51+
| Separation failure: value x's type (box Object^, box Object^) hides parameters c and d.
5252
| The parameters need to be annotated with @consume to allow this.

tests/neg-custom-args/captures/sepchecks4.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
-- Error: tests/neg-custom-args/captures/sepchecks4.scala:8:12 ---------------------------------------------------------
22
8 | val x: () => Unit = () => println(io) // error
33
| ^^^^^^^^^^
4-
| Separation failure: value x's type () => Unit hides parameter io
4+
| Separation failure: value x's type () => Unit hides parameter io.
55
| The parameter needs to be annotated with @consume to allow this.
66
-- Error: tests/neg-custom-args/captures/sepchecks4.scala:7:25 ---------------------------------------------------------
77
7 |def bad(io: Object^): () => Unit = // error
88
| ^^^^^^^^^^
9-
| Separation failure: method bad's result type () => Unit hides parameter io
9+
| Separation failure: method bad's result type () => Unit hides parameter io.
1010
| The parameter needs to be annotated with @consume to allow this.
1111
-- Error: tests/neg-custom-args/captures/sepchecks4.scala:12:25 --------------------------------------------------------
1212
12 | par(() => println(io))(() => println(io)) // error // (1)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
-- Error: tests/neg-custom-args/captures/sepchecks5.scala:12:37 --------------------------------------------------------
2+
12 |def bad(io: Object^): () => Unit = f(io) // error
3+
| ^^
4+
| Separation failure: argument to @consume parameter with type (io : Object^) refers to parameter io.
5+
| The parameter needs to be annotated with @consume to allow this.
6+
-- Error: tests/neg-custom-args/captures/sepchecks5.scala:19:13 --------------------------------------------------------
7+
19 | val f2 = g(io) // error
8+
| ^^
9+
| Separation failure: argument to @consume parameter with type (io : Object^) refers to parameter io.
10+
| The parameter needs to be annotated with @consume to allow this.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import caps.{cap, consume}
2+
import language.future
3+
import language.experimental.captureChecking
4+
5+
def par(op1: () => Unit)(op2: () => Unit): Unit = ()
6+
7+
def f(@consume io: Object^): () => Unit =
8+
() => println(io)
9+
10+
def g(@consume io: Object^): () => Unit = f(io) // ok
11+
12+
def bad(io: Object^): () => Unit = f(io) // error
13+
14+
def test(io: Object^): Unit =
15+
16+
val f1 = bad(io)
17+
par(f1)(() => println(io)) // !!! separation failure
18+
19+
val f2 = g(io) // error
20+
par(f2)(() => println(io)) // !!! separation failure
21+
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- Error: tests/neg-custom-args/captures/unsound-reach-6.scala:7:13 ----------------------------------------------------
2+
7 | println(xs.head) // error
3+
| ^^^^^^^
4+
| Local reach capability xs* leaks into capture scope of method f.
5+
| To allow this, the parameter xs should be declared with a @use annotation
6+
-- Error: tests/neg-custom-args/captures/unsound-reach-6.scala:11:14 ---------------------------------------------------
7+
11 | val z = f(ys) // error @consume failure
8+
| ^^
9+
|Separation failure: argument to @consume parameter with type (ys : List[box () ->{io} Unit]) refers to parameters ys and io.
10+
|The parameters need to be annotated with @consume to allow this.
11+
-- Error: tests/neg-custom-args/captures/unsound-reach-6.scala:19:14 ---------------------------------------------------
12+
19 | val z = f(ys) // error @consume failure
13+
| ^^
14+
|Separation failure: argument to @consume parameter with type (ys : -> List[box () ->{io} Unit]) refers to parameter io.
15+
|The parameter needs to be annotated with @consume to allow this.

tests/neg-custom-args/captures/unsound-reach-6.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ def f(@consume xs: List[() => Unit]): () => Unit = () =>
88

99
def test(io: IO^)(ys: List[() ->{io} Unit]) =
1010
val x = () =>
11-
val z = f(ys)
11+
val z = f(ys) // error @consume failure
1212
z()
1313
val _: () -> Unit = x // !!! ys* gets lost
1414
()
1515

1616
def test(io: IO^) =
1717
def ys: List[() ->{io} Unit] = ???
1818
val x = () =>
19-
val z = f(ys)
19+
val z = f(ys) // error @consume failure
2020
z()
2121
val _: () -> Unit = x // !!! io gets lost
2222
()

0 commit comments

Comments
 (0)