Skip to content

Commit 72a20bd

Browse files
committed
Improve Contains handling
Make use of enclosing Contains assumptions to improve the subsumes logic.
1 parent a8cc133 commit 72a20bd

File tree

6 files changed

+66
-23
lines changed

6 files changed

+66
-23
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,3 +713,21 @@ extension (self: Type)
713713
case _ =>
714714
self
715715

716+
/** An extractor for a contains argument */
717+
object ContainsImpl:
718+
def unapply(tree: TypeApply)(using Context): Option[(Tree, Tree)] =
719+
tree.fun.tpe.widen match
720+
case fntpe: PolyType if tree.fun.symbol == defn.Caps_containsImpl =>
721+
tree.args match
722+
case csArg :: refArg :: Nil => Some((csArg, refArg))
723+
case _ => None
724+
case _ => None
725+
726+
/** An extractor for a contains parameter */
727+
object ContainsParam:
728+
def unapply(sym: Symbol)(using Context): Option[(TypeRef, CaptureRef)] =
729+
sym.info.dealias match
730+
case AppliedType(tycon, (cs: TypeRef) :: (ref: CaptureRef) :: Nil)
731+
if tycon.typeSymbol == defn.Caps_ContainsTrait
732+
&& cs.typeSymbol.isAbstractOrParamType => Some((cs, ref))
733+
case _ => None

compiler/src/dotty/tools/dotc/cc/CaptureRef.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,12 @@ trait CaptureRef extends TypeProxy, ValueType:
116116
case x1: SingletonCaptureRef => x1.subsumes(y)
117117
case _ => false
118118
case x: TermParamRef => subsumesExistentially(x, y)
119+
case x: TypeRef => assumedContainsOf(x).contains(y)
119120
case _ => false
120121

122+
def assumedContainsOf(x: TypeRef)(using Context): SimpleIdentitySet[CaptureRef] =
123+
CaptureSet.assumedContains.getOrElse(x, SimpleIdentitySet.empty)
124+
121125
end CaptureRef
122126

123127
trait SingletonCaptureRef extends SingletonType, CaptureRef

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ import util.{SimpleIdentitySet, Property}
1616
import typer.ErrorReporting.Addenda
1717
import TypeComparer.subsumesExistentially
1818
import util.common.alwaysTrue
19-
import scala.collection.mutable
19+
import scala.collection.{mutable, immutable}
2020
import CCState.*
2121

2222
/** A class for capture sets. Capture sets can be constants or variables.
@@ -1125,6 +1125,12 @@ object CaptureSet:
11251125
foldOver(cs, t)
11261126
collect(CaptureSet.empty, tp)
11271127

1128+
type AssumedContains = immutable.Map[TypeRef, SimpleIdentitySet[CaptureRef]]
1129+
val AssumedContains: Property.Key[AssumedContains] = Property.Key()
1130+
1131+
def assumedContains(using Context): AssumedContains =
1132+
ctx.property(AssumedContains).getOrElse(immutable.Map.empty)
1133+
11281134
private val ShownVars: Property.Key[mutable.Set[Var]] = Property.Key()
11291135

11301136
/** Perform `op`. Under -Ycc-debug, collect and print info about all variables reachable

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -676,29 +676,24 @@ class CheckCaptures extends Recheck, SymTransformer:
676676
i"Sealed type variable $pname", "be instantiated to",
677677
i"This is often caused by a local capability$where\nleaking as part of its result.",
678678
tree.srcPos)
679-
val res = handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
680-
if meth == defn.Caps_containsImpl then checkContains(tree)
681-
res
679+
try handleCall(meth, tree, () => Existential.toCap(super.recheckTypeApply(tree, pt)))
680+
finally checkContains(tree)
682681
end recheckTypeApply
683682

684683
/** Faced with a tree of form `caps.contansImpl[CS, r.type]`, check that `R` is a tracked
685684
* capability and assert that `{r} <:CS`.
686685
*/
687-
def checkContains(tree: TypeApply)(using Context): Unit =
688-
tree.fun.knownType.widen match
689-
case fntpe: PolyType =>
690-
tree.args match
691-
case csArg :: refArg :: Nil =>
692-
val cs = csArg.knownType.captureSet
693-
val ref = refArg.knownType
694-
capt.println(i"check contains $cs , $ref")
695-
ref match
696-
case ref: CaptureRef if ref.isTracked =>
697-
checkElem(ref, cs, tree.srcPos)
698-
case _ =>
699-
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
700-
case _ =>
701-
case _ =>
686+
def checkContains(tree: TypeApply)(using Context): Unit = tree match
687+
case ContainsImpl(csArg, refArg) =>
688+
val cs = csArg.knownType.captureSet
689+
val ref = refArg.knownType
690+
capt.println(i"check contains $cs , $ref")
691+
ref match
692+
case ref: CaptureRef if ref.isTracked =>
693+
checkElem(ref, cs, tree.srcPos)
694+
case _ =>
695+
report.error(em"$refArg is not a tracked capability", refArg.srcPos)
696+
case _ =>
702697

703698
override def recheckBlock(tree: Block, pt: Type)(using Context): Type =
704699
inNestedLevel(super.recheckBlock(tree, pt))
@@ -814,15 +809,26 @@ class CheckCaptures extends Recheck, SymTransformer:
814809
val localSet = capturedVars(sym)
815810
if !localSet.isAlwaysEmpty then
816811
curEnv = Env(sym, EnvKind.Regular, localSet, curEnv)
812+
813+
// ctx with AssumedContains entries for each Contains parameter
814+
val bodyCtx =
815+
var ac = CaptureSet.assumedContains
816+
for paramSyms <- sym.paramSymss do
817+
for case ContainsParam(cs, ref) <- paramSyms do
818+
ac = ac.updated(cs, ac.getOrElse(cs, SimpleIdentitySet.empty) + ref)
819+
if ac.isEmpty then ctx
820+
else ctx.withProperty(CaptureSet.AssumedContains, Some(ac))
821+
817822
inNestedLevel: // TODO: needed here?
818-
try checkInferredResult(super.recheckDefDef(tree, sym), tree)
823+
try checkInferredResult(super.recheckDefDef(tree, sym)(using bodyCtx), tree)
819824
finally
820825
if !sym.isAnonymousFunction then
821826
// Anonymous functions propagate their type to the enclosing environment
822827
// so it is not in general sound to interpolate their types.
823828
interpolateVarsIn(tree.tpt)
824829
curEnv = saved
825-
830+
end recheckDefDef
831+
826832
/** If val or def definition with inferred (result) type is visible
827833
* in other compilation units, check that the actual inferred type
828834
* conforms to the expected type where all inferred capture sets are dropped.

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,7 @@ class Definitions {
10021002
@tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox")
10031003
@tu lazy val Caps_unsafeUnbox: Symbol = CapsUnsafeModule.requiredMethod("unsafeUnbox")
10041004
@tu lazy val Caps_unsafeBoxFunArg: Symbol = CapsUnsafeModule.requiredMethod("unsafeBoxFunArg")
1005-
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Capability")
1005+
@tu lazy val Caps_ContainsTrait: TypeSymbol = CapsModule.requiredType("Contains")
10061006
@tu lazy val Caps_containsImpl: TermSymbol = CapsModule.requiredMethod("containsImpl")
10071007

10081008
@tu lazy val PureClass: Symbol = requiredClass("scala.Pure")

tests/pos-custom-args/captures/i21313.scala

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
import caps.CapSet
22

33
trait Async:
4-
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T
4+
def await[T, Cap^](using caps.Contains[Cap, this.type])(src: Source[T, Cap]^): T =
5+
val x: Async^{this} = ???
6+
val y: Async^{Cap^} = x
7+
val ac: Async^ = ???
8+
def f(using caps.Contains[Cap, ac.type]) =
9+
val x2: Async^{this} = ???
10+
val y2: Async^{Cap^} = x2
11+
val x3: Async^{ac} = ???
12+
val y3: Async^{Cap^} = x3
13+
???
514

615
trait Source[+T, Cap^]:
716
final def await(using ac: Async^{Cap^}) = ac.await[T, Cap](this) // Contains[Cap, ac] is assured because {ac} <: Cap.

0 commit comments

Comments
 (0)