Skip to content

Commit 91da9fb

Browse files
committed
Be more careful computing underlying types of reach capabilities
We can use the dcs only if there are no type variables.
1 parent c4c69b0 commit 91da9fb

File tree

11 files changed

+165
-46
lines changed

11 files changed

+165
-46
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -220,8 +220,8 @@ extension (tp: Type)
220220
* a singleton capability `x` or a reach capability `x*`, the deep capture
221221
* set can be narrowed to`{x*}`.
222222
*/
223-
def deepCaptureSet(using Context): CaptureSet =
224-
val dcs = CaptureSet.ofTypeDeeply(tp.widen.stripCapturing)
223+
def deepCaptureSet(includeTypevars: Boolean)(using Context): CaptureSet =
224+
val dcs = CaptureSet.ofTypeDeeply(tp.widen.stripCapturing, includeTypevars)
225225
if dcs.isAlwaysEmpty then tp.captureSet
226226
else tp match
227227
case tp @ ReachCapability(_) =>
@@ -231,6 +231,9 @@ extension (tp: Type)
231231
case _ =>
232232
tp.captureSet ++ dcs
233233

234+
def deepCaptureSet(using Context): CaptureSet =
235+
deepCaptureSet(includeTypevars = false)
236+
234237
/** A type capturing `ref` */
235238
def capturing(ref: CaptureRef)(using Context): Type =
236239
if tp.captureSet.accountsFor(ref) then tp
@@ -593,16 +596,26 @@ extension (sym: Symbol)
593596
def isRefiningParamAccessor(using Context): Boolean =
594597
sym.is(ParamAccessor)
595598
&& {
596-
val param = sym.owner.primaryConstructor.paramSymss
597-
.nestedFind(_.name == sym.name)
598-
.getOrElse(NoSymbol)
599+
val param = sym.owner.primaryConstructor.paramNamed(sym.name)
599600
!param.hasAnnotation(defn.ConstructorOnlyAnnot)
600601
&& !param.hasAnnotation(defn.UntrackedCapturesAnnot)
601602
}
602603

603604
def hasTrackedParts(using Context): Boolean =
604605
!CaptureSet.ofTypeDeeply(sym.info).isAlwaysEmpty
605606

607+
/** `sym` is annotated @use or it is a type parameter with a matching
608+
* @use-annotated term parameter that contains `sym` in its deep capture set.
609+
*/
610+
def isUseParam(using Context): Boolean =
611+
sym.hasAnnotation(defn.UseAnnot)
612+
|| sym.is(TypeParam)
613+
&& sym.owner.rawParamss.nestedExists: param =>
614+
param.is(TermParam) && param.hasAnnotation(defn.UseAnnot)
615+
&& param.info.deepCaptureSet.elems.exists:
616+
case c: TypeRef => c.symbol == sym
617+
case _ => false
618+
606619
extension (tp: AnnotatedType)
607620
/** Is this a boxed capturing type? */
608621
def isBoxed(using Context): Boolean = tp.annot match

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,8 +1064,9 @@ object CaptureSet:
10641064
case ref: (TermRef | TermParamRef) if ref.isMaxCapability =>
10651065
if ref.isTrackableRef then ref.singletonCaptureSet
10661066
else CaptureSet.universal
1067-
case ReachCapability(ref1) => ref1.widen.deepCaptureSet
1068-
.showing(i"Deep capture set of $ref: ${ref1.widen} = $result", capt)
1067+
case ReachCapability(ref1) =>
1068+
ref1.widen.deepCaptureSet(includeTypevars = true)
1069+
.showing(i"Deep capture set of $ref: ${ref1.widen} = ${result}", capt)
10691070
case _ => ofType(ref.underlying, followResult = true)
10701071

10711072
/** Capture set of a type */
@@ -1120,7 +1121,7 @@ object CaptureSet:
11201121
* arguments. This have to be included to be conservative in dcs but must be
11211122
* excluded in narrowCaps.
11221123
*/
1123-
def ofTypeDeeply(tp: Type)(using Context): CaptureSet =
1124+
def ofTypeDeeply(tp: Type, includeTypevars: Boolean = false)(using Context): CaptureSet =
11241125
val collect = new TypeAccumulator[CaptureSet]:
11251126
val seen = util.HashSet[Symbol]()
11261127
def apply(cs: CaptureSet, t: Type) =
@@ -1132,7 +1133,9 @@ object CaptureSet:
11321133
this(cs, parent)
11331134
case t: TypeRef if t.symbol.isAbstractOrParamType && !seen.contains(t.symbol) =>
11341135
seen += t.symbol
1135-
this(cs, t.info.bounds.hi)
1136+
val upper = t.info.bounds.hi
1137+
if includeTypevars && upper.isExactlyAny then CaptureSet.universal
1138+
else this(cs, t.info.bounds.hi)
11361139
case t @ FunctionOrMethod(args, res @ Existential(_, _))
11371140
if args.forall(_.isAlwaysPure) =>
11381141
this(cs, Existential.toCap(res))

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -185,34 +185,6 @@ object CheckCaptures:
185185
if ccConfig.useSealed then check.traverse(tp)
186186
end disallowRootCapabilitiesIn
187187

188-
/** Under the sealed policy, disallow the root capability in type arguments.
189-
* Type arguments come either from a TypeApply node or from an AppliedType
190-
* which represents a trait parent in a template.
191-
* @param fn the type application, of type TypeApply or TypeTree
192-
* @param sym the constructor symbol (could be a method or a val or a class)
193-
* @param args the type arguments
194-
*/
195-
private def disallowCapInTypeArgs(fn: Tree, sym: Symbol, args: List[Tree], thisPhase: Phase)(using Context): Unit =
196-
def isExempt = sym.isTypeTestOrCast || sym == defn.Compiletime_erasedValue
197-
if ccConfig.useSealed && !isExempt then
198-
val paramNames = atPhase(thisPhase.prev):
199-
fn.tpe.widenDealias match
200-
case tl: TypeLambda => tl.paramNames
201-
case ref: AppliedType if ref.typeSymbol.isClass => ref.typeSymbol.typeParams.map(_.name)
202-
case t =>
203-
println(i"parent type: $t")
204-
args.map(_ => EmptyTypeName)
205-
for case (arg: TypeTree, pname) <- args.lazyZip(paramNames) do
206-
def where = if sym.exists then i" in an argument of $sym" else ""
207-
val (addendum, pos) =
208-
if arg.isInferred
209-
then ("\nThis is often caused by a local capability$where\nleaking as part of its result.", fn.srcPos)
210-
else if arg.span.exists then ("", arg.srcPos)
211-
else ("", fn.srcPos)
212-
disallowRootCapabilitiesIn(arg.knownType, NoSymbol,
213-
i"Type variable $pname of $sym", "be instantiated to", addendum, pos)
214-
end disallowCapInTypeArgs
215-
216188
/** If we are not under the sealed policy, and a tree is an application that unboxes
217189
* its result or is a try, check that the tree's type does not have covariant universal
218190
* capabilities.
@@ -409,14 +381,14 @@ class CheckCaptures extends Recheck, SymTransformer:
409381
if lastEnv != null && env.nestedClosure.exists && env.nestedClosure == lastEnv.owner then
410382
() // access is from a nested closure, so it's OK
411383
else c.pathRoot match
412-
case ref: NamedType if !ref.symbol.hasAnnotation(defn.UseAnnot) =>
384+
case ref: NamedType if !ref.symbol.isUseParam =>
413385
val what = if ref.isType then "Capture set parameter" else "Local reach capability"
414386
report.error(
415387
em"""$what $c leaks into capture scope of ${env.ownerString}.
416388
|To allow this, the ${ref.symbol} should be declared with a @use annotation""", pos)
417389
case _ =>
418390

419-
def recur(cs: CaptureSet, env: Env, lastEnv: Env | Null)(using Context): Unit =
391+
def recur(cs: CaptureSet, env: Env, lastEnv: Env | Null): Unit =
420392
if env.isOpen && !env.owner.isStaticOwner && !cs.isAlwaysEmpty then
421393
// Only captured references that are visible from the environment
422394
// should be included.
@@ -480,6 +452,40 @@ class CheckCaptures extends Recheck, SymTransformer:
480452
case _ =>
481453
if sym.exists && curEnv.isOpen then markFree(capturedVars(sym), pos)
482454

455+
/** Under the sealed policy, disallow the root capability in type arguments.
456+
* Type arguments come either from a TypeApply node or from an AppliedType
457+
* which represents a trait parent in a template. Also, if a corresponding
458+
* formal type parameter is declared or implied @use, charge the deep capture
459+
* set of the argument to the environent.
460+
* @param fn the type application, of type TypeApply or TypeTree
461+
* @param sym the constructor symbol (could be a method or a val or a class)
462+
* @param args the type arguments
463+
*/
464+
def disallowCapInTypeArgs(fn: Tree, sym: Symbol, args: List[Tree])(using Context): Unit =
465+
def isExempt = sym.isTypeTestOrCast || sym == defn.Compiletime_erasedValue
466+
if ccConfig.useSealed && !isExempt then
467+
val paramNames = atPhase(thisPhase.prev):
468+
fn.tpe.widenDealias match
469+
case tl: TypeLambda => tl.paramNames
470+
case ref: AppliedType if ref.typeSymbol.isClass => ref.typeSymbol.typeParams.map(_.name)
471+
case t =>
472+
println(i"parent type: $t")
473+
args.map(_ => EmptyTypeName)
474+
475+
for case (arg: TypeTree, pname) <- args.lazyZip(paramNames) do
476+
def where = if sym.exists then i" in an argument of $sym" else ""
477+
val (addendum, pos) =
478+
if arg.isInferred
479+
then ("\nThis is often caused by a local capability$where\nleaking as part of its result.", fn.srcPos)
480+
else if arg.span.exists then ("", arg.srcPos)
481+
else ("", fn.srcPos)
482+
disallowRootCapabilitiesIn(arg.knownType, NoSymbol,
483+
i"Type variable $pname of $sym", "be instantiated to", addendum, pos)
484+
485+
val param = fn.symbol.paramNamed(pname)
486+
if param.isUseParam then markFree(arg.knownType.deepCaptureSet, pos)
487+
end disallowCapInTypeArgs
488+
483489
override def recheckIdent(tree: Ident, pt: Type)(using Context): Type =
484490
val sym = tree.symbol
485491
if sym.is(Method) then
@@ -558,8 +564,8 @@ class CheckCaptures extends Recheck, SymTransformer:
558564
*/
559565
override def prepareFunction(funtpe: MethodType, meth: Symbol)(using Context): MethodType =
560566
val paramInfosWithUses = funtpe.paramInfos.zipWithConserve(funtpe.paramNames): (formal, pname) =>
561-
val paramOpt = meth.rawParamss.nestedFind(_.name == pname)
562-
paramOpt.flatMap(_.getAnnotation(defn.UseAnnot)) match
567+
val param = meth.paramNamed(pname)
568+
param.getAnnotation(defn.UseAnnot) match
563569
case Some(ann) => AnnotatedType(formal, ann)
564570
case _ => formal
565571
funtpe.derivedLambdaType(paramInfos = paramInfosWithUses)
@@ -725,7 +731,7 @@ class CheckCaptures extends Recheck, SymTransformer:
725731
val meth = tree.fun match
726732
case fun @ Select(qual, nme.apply) => qual.symbol.orElse(fun.symbol)
727733
case fun => fun.symbol
728-
disallowCapInTypeArgs(tree.fun, meth, tree.args, thisPhase)
734+
disallowCapInTypeArgs(tree.fun, meth, tree.args)
729735
val res = Existential.toCap(super.recheckTypeApply(tree, pt))
730736
includeCallCaptures(tree.symbol, res, tree.srcPos)
731737
checkContains(tree)
@@ -956,7 +962,7 @@ class CheckCaptures extends Recheck, SymTransformer:
956962
for case tpt: TypeTree <- impl.parents do
957963
tpt.tpe match
958964
case AppliedType(fn, args) =>
959-
disallowCapInTypeArgs(tpt, fn.typeSymbol, args.map(TypeTree(_)), thisPhase)
965+
disallowCapInTypeArgs(tpt, fn.typeSymbol, args.map(TypeTree(_)))
960966
case _ =>
961967
inNestedLevelUnless(cls.is(Module)):
962968
super.recheckClassDef(tree, impl, cls)

compiler/src/dotty/tools/dotc/core/SymUtils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,9 @@ class SymUtils:
271271
self.owner.info.decl(fieldName).suchThat(!_.is(Method)).symbol
272272
}
273273

274+
def paramNamed(name: Name)(using Context): Symbol =
275+
self.rawParamss.nestedFind(_.name == name).getOrElse(NoSymbol)
276+
274277
/** Is this symbol a constant expression final val?
275278
*
276279
* This is the case if all of the following are true:

tests/pos/gears-probem-1.scala renamed to tests/neg-custom-args/captures/gears-problem-1.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,4 @@ extension [T](@use fs: Seq[Future[T]^])
2222
val collector//: Collector[T]{val futures: Seq[Future[T]^{fs*}]}
2323
= Collector(fs)
2424
// val ch = collector.results // also errors
25-
val fut: Future[T]^{fs*} = collector.results.read().get // found ...^{caps.cap}
25+
val fut: Future[T]^{fs*} = collector.results.read().get // error
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/gears-problem.scala:19:62 --------------------------------
2+
19 | val fut: Future[T]^{fs*} = collector.results.read().right.get // error
3+
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
4+
| Found: Future[T]^{collector.futures*}
5+
| Required: Future[T]^{fs*}
6+
|
7+
| longer explanation available when compiling with `-explain`
8+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/gears-problem.scala:24:34 --------------------------------
9+
24 | val fut2: Future[T]^{fs*} = r.get // error
10+
| ^^^^^
11+
| Found: Future[box T^?]^{collector.futures*}
12+
| Required: Future[T]^{fs*}
13+
|
14+
| longer explanation available when compiling with `-explain`
15+
there were 4 deprecation warnings; re-run with -deprecation for details

tests/pos-custom-args/captures/gears-problem.scala renamed to tests/neg-custom-args/captures/gears-problem.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ extension [T](@use fs: Seq[Future[T]^])
1616
val collector: Collector[T]{val futures: Seq[Future[T]^{fs*}]}
1717
= Collector(fs)
1818
// val ch = collector.results // also errors
19-
val fut: Future[T]^{fs*} = collector.results.read().right.get // found ...^{caps.cap}
19+
val fut: Future[T]^{fs*} = collector.results.read().right.get // error
2020

2121
val ch = collector.results
2222
val item = ch.read()
2323
val r = item.right
24-
val fut2: Future[T]^{fs*} = r.get
24+
val fut2: Future[T]^{fs*} = r.get // error
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import language.experimental.captureChecking
2+
import caps.{cap, use}
3+
4+
trait IO
5+
trait Async
6+
7+
def main(io: IO^, async: Async^) =
8+
def bad[X](ops: List[(X, () ->{io} Unit)])(f: () ->{ops*} Unit): () ->{io} Unit = f // error
9+
def runOps(@use ops: List[(() => Unit, () => Unit)]): () ->{ops*} Unit =
10+
() => ops.foreach((f1, f2) => { f1(); f2() })
11+
def delayOps(@use ops: List[(() ->{async} Unit, () ->{io} Unit)]): () ->{io} Unit =
12+
val runner: () ->{ops*} Unit = runOps(ops)
13+
val badRunner: () ->{io} Unit = bad[() ->{async} Unit](ops)(runner)
14+
// it uses both async and io, but we losed track of async.
15+
badRunner
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
-- Error: tests/neg-custom-args/captures/use-capset.scala:7:50 ---------------------------------------------------------
2+
7 |private def g[C^] = (xs: List[Object^{C^}]) => xs.head // error
3+
| ^^^^^^^
4+
| Capture set parameter C leaks into capture scope of method g.
5+
| To allow this, the type C should be declared with a @use annotation
6+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/use-capset.scala:13:22 -----------------------------------
7+
13 | val _: () -> Unit = h // error: should be ->{io}
8+
| ^
9+
| Found: (h : () ->{io} Unit)
10+
| Required: () -> Unit
11+
|
12+
| longer explanation available when compiling with `-explain`
13+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/use-capset.scala:15:50 -----------------------------------
14+
15 | val _: () -> List[Object^{io}] -> Object^{io} = h2 // error, should be ->{io}
15+
| ^^
16+
| Found: () ->? (x$0: List[box Object^{io}]^{}) ->{io} (ex$13: caps.Exists) -> Object^{io}
17+
| Required: () -> List[box Object^{io}] -> Object^{io}
18+
|
19+
| longer explanation available when compiling with `-explain`
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import caps.{use, CapSet}
2+
3+
4+
5+
def f[C^](@use xs: List[Object^{C^}]): Unit = ???
6+
7+
private def g[C^] = (xs: List[Object^{C^}]) => xs.head // error
8+
9+
private def g2[@use C^] = (xs: List[Object^{C^}]) => xs.head // ok
10+
11+
def test(io: Object^)(@use xs: List[Object^{io}]): Unit =
12+
val h = () => f(xs)
13+
val _: () -> Unit = h // error: should be ->{io}
14+
val h2 = () => g[CapSet^{io}]
15+
val _: () -> List[Object^{io}] -> Object^{io} = h2 // error, should be ->{io}
16+
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import language.experimental.captureChecking
2+
import caps.{use, CapSet}
3+
4+
trait Future[+T]:
5+
def await: T
6+
7+
trait Channel[+T]:
8+
def read(): Ok[T]
9+
10+
class Collector[T, C^](val futures: Seq[Future[T]^{C^}]):
11+
val results: Channel[Future[T]^{C^}] = ???
12+
end Collector
13+
14+
class Result[+T, +E]:
15+
def get: T = ???
16+
17+
case class Err[+E](e: E) extends Result[Nothing, E]
18+
case class Ok[+T](x: T) extends Result[T, Nothing]
19+
20+
extension [T, C^](@use fs: Seq[Future[T]^{C^}])
21+
def awaitAllPoly =
22+
val collector = Collector(fs)
23+
val fut: Future[T]^{C^} = collector.results.read().get
24+
25+
extension [T](@use fs: Seq[Future[T]^])
26+
def awaitAll = fs.awaitAllPoly
27+
28+
def awaitExplicit[T](@use fs: Seq[Future[T]^]): Unit =
29+
awaitAllPoly[T, CapSet^{fs*}](fs)

0 commit comments

Comments
 (0)