Skip to content

Introduce Maybe Capabilities #19500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
* otherwise specified).
*/
def Closure(meth: TermSymbol, rhsFn: List[List[Tree]] => Tree, targs: List[Tree] = Nil, targetType: Type = NoType)(using Context): Block = {
val targetTpt = if (targetType.exists) TypeTree(targetType) else EmptyTree
val targetTpt = if (targetType.exists) TypeTree(targetType, inferred = true) else EmptyTree
val call =
if (targs.isEmpty) Ident(TermRef(NoPrefix, meth))
else TypeApply(Ident(TermRef(NoPrefix, meth)), targs)
Block(
DefDef(meth, rhsFn) :: Nil,
Closure(Nil, call, targetTpt))
var mdef0 = DefDef(meth, rhsFn)
val mdef = cpy.DefDef(mdef0)(tpt = TypeTree(mdef0.tpt.tpe, inferred = true))
Block(mdef :: Nil, Closure(Nil, call, targetTpt))
}

/** A closure whose anonymous function has the given method type */
Expand Down
47 changes: 39 additions & 8 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,31 @@ extension (tp: Type)
* type of `x`. If `x` and `y` are different variables then `{x*}` and `{y*}`
* are unrelated.
*/
def reach(using Context): CaptureRef =
assert(tp.isTrackableRef)
AnnotatedType(tp, Annotation(defn.ReachCapabilityAnnot, util.Spans.NoSpan))
def reach(using Context): CaptureRef = tp match
case tp: CaptureRef if tp.isTrackableRef =>
if tp.isReach then tp else ReachCapability(tp)

/** If `x` is a capture ref, its maybe capability `x?`, represented internally
* as `x @maybeCapability`. `x?` stands for a capability `x` that might or might
* not be part of a capture set. We have `{} <: {x?} <: {x}`. Maybe capabilities
* cannot be propagated between sets. If `a <: b` and `a` acquires `x?` then
* `x` is propagated to `b` as a conservative approximation.
*
* Maybe capabilities should only arise for capture sets that appear in invariant
* position in their surrounding type. They are similar to TypeBunds types, but
* restricted to capture sets. For instance,
*
* Array[C^{x?}]
*
* should be morally equivalent to
*
* Array[_ >: C^{} <: C^{x}]
*
* but it has fewer issues with type inference.
*/
def maybe(using Context): CaptureRef = tp match
case tp: CaptureRef if tp.isTrackableRef =>
if tp.isMaybe then tp else MaybeCapability(tp)

/** If `ref` is a trackable capture ref, and `tp` has only covariant occurrences of a
* universal capture set, replace all these occurrences by `{ref*}`. This implements
Expand Down Expand Up @@ -419,12 +441,21 @@ object ReachCapabilityApply:
case Apply(reach, arg :: Nil) if reach.symbol == defn.Caps_reachCapability => Some(arg)
case _ => None

class AnnotatedCapability(annot: Context ?=> ClassSymbol):
def apply(tp: Type)(using Context) =
AnnotatedType(tp, Annotation(annot, util.Spans.NoSpan))
def unapply(tree: AnnotatedType)(using Context): Option[SingletonCaptureRef] = tree match
case AnnotatedType(parent: SingletonCaptureRef, ann) if ann.symbol == annot => Some(parent)
case _ => None

/** An extractor for `ref @annotation.internal.reachCapability`, which is used to express
* the reach capability `ref*` as a type.
*/
object ReachCapability:
def unapply(tree: AnnotatedType)(using Context): Option[SingletonCaptureRef] = tree match
case AnnotatedType(parent: SingletonCaptureRef, ann)
if ann.symbol == defn.ReachCapabilityAnnot => Some(parent)
case _ => None
object ReachCapability extends AnnotatedCapability(defn.ReachCapabilityAnnot)

/** An extractor for `ref @maybeCapability`, which is used to express
* the maybe capability `ref?` as a type.
*/
object MaybeCapability extends AnnotatedCapability(defn.MaybeCapabilityAnnot)


40 changes: 34 additions & 6 deletions compiler/src/dotty/tools/dotc/cc/CaptureSet.scala
Original file line number Diff line number Diff line change
Expand Up @@ -145,15 +145,16 @@ sealed abstract class CaptureSet extends Showable:

/** x subsumes x
* this subsumes this.f
* x subsumes y ==> x* subsumes y
* x subsumes y ==> x* subsumes y*
* x subsumes y ==> x* subsumes y, x subsumes y?
* x subsumes y ==> x* subsumes y*, x? subsumes y?
*/
extension (x: CaptureRef)
private def subsumes(y: CaptureRef)(using Context): Boolean =
(x eq y)
|| x.isRootCapability
|| y.match
case y: TermRef => !y.isReach && (y.prefix eq x)
case y: TermRef => y.prefix eq x
case MaybeCapability(y1) => x.stripMaybe.subsumes(y1)
case _ => false
|| x.match
case ReachCapability(x1) => x1.subsumes(y.stripReach)
Expand Down Expand Up @@ -312,6 +313,8 @@ sealed abstract class CaptureSet extends Showable:
def substParams(tl: BindingType, to: List[Type])(using Context) =
map(Substituters.SubstParamsMap(tl, to))

def maybe(using Context): CaptureSet = map(MaybeMap())

/** Invoke handler if this set has (or later aquires) the root capability `cap` */
def disallowRootCapability(handler: () => Context ?=> Unit)(using Context): this.type =
if isUniversal then handler()
Expand Down Expand Up @@ -445,6 +448,8 @@ object CaptureSet:
def isConst = isSolved
def isAlwaysEmpty = false

def isMaybeSet = false // overridden in BiMapped

/** A handler to be invoked if the root reference `cap` is added to this set */
var rootAddedHandler: () => Context ?=> Unit = () => ()

Expand Down Expand Up @@ -490,9 +495,10 @@ object CaptureSet:
if elem.isRootCapability then
rootAddedHandler()
newElemAddedHandler(elem)
val normElem = if isMaybeSet then elem else elem.stripMaybe
// assert(id != 5 || elems.size != 3, this)
val res = (CompareResult.OK /: deps): (r, dep) =>
r.andAlso(dep.tryInclude(elem, this))
r.andAlso(dep.tryInclude(normElem, this))
res.orElse:
elems -= elem
res.addToTrace(this)
Expand All @@ -508,6 +514,8 @@ object CaptureSet:
levelLimit.isContainedIn(elem.cls.levelOwner)
case ReachCapability(elem1) =>
levelOK(elem1)
case MaybeCapability(elem1) =>
levelOK(elem1)
case _ =>
true

Expand Down Expand Up @@ -760,6 +768,7 @@ object CaptureSet:
if source eq origin then supApprox.map(bimap.inverse)
else source.upperApprox(this).map(bimap) ** supApprox

override def isMaybeSet: Boolean = bimap.isInstanceOf[MaybeMap]
override def toString = s"BiMapped$id($source, elems = $elems)"
end BiMapped

Expand Down Expand Up @@ -840,8 +849,7 @@ object CaptureSet:
upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1)
if variance > 0 || isExact then upper
else if variance < 0 then CaptureSet.empty
else if ctx.mode.is(Mode.Printing) then upper
else assert(false, i"trying to add $upper from $r via ${tm.getClass} in a non-variant setting")
else upper.maybe

/** Apply `f` to each element in `xs`, and join result sets with `++` */
def mapRefs(xs: Refs, f: CaptureRef => CaptureSet)(using Context): CaptureSet =
Expand Down Expand Up @@ -980,6 +988,26 @@ object CaptureSet:
/** The current VarState, as passed by the implicit context */
def varState(using state: VarState): VarState = state

/** Maps `x` to `x?` */
private class MaybeMap(using Context) extends BiTypeMap:

def apply(t: Type) = t match
case t: CaptureRef if t.isTrackableRef => t.maybe
case _ => mapOver(t)

override def toString = "Maybe"

lazy val inverse = new BiTypeMap:

def apply(t: Type) = t match
case t: CaptureRef if t.isMaybe => t.stripMaybe
case t => mapOver(t)

def inverse = MaybeMap.this

override def toString = "Maybe.inverse"
end MaybeMap

/* Not needed:
def ofClass(cinfo: ClassInfo, argTypes: List[Type])(using Context): CaptureSet =
CaptureSet.empty
Expand Down
6 changes: 0 additions & 6 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ object CheckCaptures:
*/
final class SubstParamsMap(from: BindingType, to: List[Type])(using Context)
extends ApproximatingTypeMap, IdempotentCaptRefMap:
/** This SubstParamsMap is exact if `to` only contains `CaptureRef`s. */
private val isExactSubstitution: Boolean = to.forall(_.isTrackableRef)

/** As long as this substitution is exact, there is no need to create `Range`s when mapping invariant positions. */
override protected def needsRangeIfInvariant(refs: CaptureSet): Boolean = !isExactSubstitution

def apply(tp: Type): Type =
tp match
case tp: ParamRef =>
Expand Down
22 changes: 13 additions & 9 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,10 @@ class Definitions {
List(AnyType), EmptyScope)
@tu lazy val SingletonType: TypeRef = SingletonClass.typeRef

@tu lazy val MaybeCapabilityAnnot: ClassSymbol =
completeClass(enterCompleteClassSymbol(
ScalaPackageClass, tpnme.maybeCapability, Final, List(StaticAnnotationClass.typeRef)))

@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
Expand Down Expand Up @@ -993,7 +997,7 @@ class Definitions {

// Annotation base classes
@tu lazy val AnnotationClass: ClassSymbol = requiredClass("scala.annotation.Annotation")
// @tu lazy val StaticAnnotationClass: ClassSymbol = requiredClass("scala.annotation.StaticAnnotation")
@tu lazy val StaticAnnotationClass: ClassSymbol = requiredClass("scala.annotation.StaticAnnotation")
@tu lazy val RefiningAnnotationClass: ClassSymbol = requiredClass("scala.annotation.RefiningAnnotation")
@tu lazy val JavaAnnotationClass: ClassSymbol = requiredClass("java.lang.annotation.Annotation")

Expand Down Expand Up @@ -1171,19 +1175,18 @@ class Definitions {
}
}

object RefinedFunctionOf {
object RefinedFunctionOf:

/** Matches a refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
* Extracts the method type type and apply info.
*/
def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = {
def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] =
tpe.refinedInfo match
case mt: MethodOrPoly
if tpe.refinedName == nme.apply
&& (tpe.parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(tpe.parent)) =>
Some(mt)
if tpe.refinedName == nme.apply && isFunctionType(tpe.parent) => Some(mt)
case _ => None
}
}

end RefinedFunctionOf

object PolyFunctionOf {

Expand Down Expand Up @@ -2137,7 +2140,8 @@ class Definitions {
AnyValClass,
NullClass,
NothingClass,
SingletonClass)
SingletonClass,
MaybeCapabilityAnnot)

@tu lazy val syntheticCoreClasses: List[Symbol] = syntheticScalaClasses ++ List(
EmptyPackageVal,
Expand Down
1 change: 1 addition & 0 deletions compiler/src/dotty/tools/dotc/core/StdNames.scala
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,7 @@ object StdNames {
val ManifestFactory: N = "ManifestFactory"
val manifestToTypeTag: N = "manifestToTypeTag"
val map: N = "map"
val maybeCapability: N = "maybeCapability"
val materializeClassTag: N = "materializeClassTag"
val materializeWeakTypeTag: N = "materializeWeakTypeTag"
val materializeTypeTag: N = "materializeTypeTag"
Expand Down
12 changes: 0 additions & 12 deletions compiler/src/dotty/tools/dotc/core/TypeOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -540,18 +540,6 @@ object TypeOps:
val sym = tp.symbol
forbidden.contains(sym)

/** We need to split the set into upper and lower approximations
* only if it contains a local element. The idea here is that at the
* time we perform an `avoid` all local elements are already accounted for
* and no further elements will be added afterwards. So we can just keep
* the set as it is. See comment by @linyxus on #16261.
*/
override def needsRangeIfInvariant(refs: CaptureSet): Boolean =
refs.elems.exists {
case ref: TermRef => toAvoid(ref)
case _ => false
}

override def apply(tp: Type): Type = tp match
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
val lo = TypeComparer.instanceType(
Expand Down
30 changes: 17 additions & 13 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,9 @@ object Types extends TypeUtils {
safeIntersection = ctx.base.pendingMemberSearches.contains(name))
joint match
case joint: SingleDenotation
if isRefinedMethod && rinfo <:< joint.info =>
if isRefinedMethod
&& (rinfo <:< joint.info
|| name == nme.apply && defn.isFunctionType(tp.parent)) =>
// use `rinfo` to keep the right parameter names for named args. See i8516.scala.
joint.derivedSingleDenotation(joint.symbol, rinfo, pre, isRefinedMethod)
case _ =>
Expand Down Expand Up @@ -2198,7 +2200,11 @@ object Types extends TypeUtils {
/** Is this a reach reference of the form `x*`? */
def isReach(using Context): Boolean = false // overridden in AnnotatedType

/** Is this a maybe reference of the form `x?`? */
def isMaybe(using Context): Boolean = false // overridden in AnnotatedType

def stripReach(using Context): CaptureRef = this // overridden in AnnotatedType
def stripMaybe(using Context): CaptureRef = this // overridden in AnnotatedType

/** Is this reference the generic root capability `cap` ? */
def isRootCapability(using Context): Boolean = false
Expand Down Expand Up @@ -5618,14 +5624,21 @@ object Types extends TypeUtils {
}

override def isTrackableRef(using Context) =
isReach && parent.isTrackableRef
(isReach || isMaybe) && parent.isTrackableRef

/** Is this a reach reference of the form `x*`? */
override def isReach(using Context): Boolean =
annot.symbol == defn.ReachCapabilityAnnot

override def stripReach(using Context): SingletonCaptureRef =
(if isReach then parent else this).asInstanceOf[SingletonCaptureRef]
/** Is this a reach reference of the form `x*`? */
override def isMaybe(using Context): Boolean =
annot.symbol == defn.MaybeCapabilityAnnot

override def stripReach(using Context): CaptureRef =
if isReach then parent.asInstanceOf[CaptureRef] else this

override def stripMaybe(using Context): CaptureRef =
if isMaybe then parent.asInstanceOf[CaptureRef] else this

override def normalizedRef(using Context): CaptureRef =
if isReach then AnnotatedType(stripReach.normalizedRef, annot) else this
Expand Down Expand Up @@ -6475,15 +6488,6 @@ object Types extends TypeUtils {
tp.derivedLambdaType(tp.paramNames, formals, restpe)
}

/** Overridden in TypeOps.avoid and in CheckCaptures.substParamsMap */
protected def needsRangeIfInvariant(refs: CaptureSet): Boolean = true

override def mapCapturingType(tp: Type, parent: Type, refs: CaptureSet, v: Int): Type =
if v == 0 && needsRangeIfInvariant(refs) then
range(mapCapturingType(tp, parent, refs, -1), mapCapturingType(tp, parent, refs, 1))
else
super.mapCapturingType(tp, parent, refs, v)

protected def reapply(tp: Type): Type = apply(tp)
}

Expand Down
3 changes: 2 additions & 1 deletion compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import util.SourcePosition
import scala.util.control.NonFatal
import scala.annotation.switch
import config.{Config, Feature}
import cc.{CapturingType, RetainingType, CaptureSet, ReachCapability, isBoxed, levelOwner, retainedElems}
import cc.{CapturingType, RetainingType, CaptureSet, ReachCapability, MaybeCapability, isBoxed, levelOwner, retainedElems}

class PlainPrinter(_ctx: Context) extends Printer {

Expand Down Expand Up @@ -404,6 +404,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
case tp: TermRef if tp.symbol == defn.captureRoot => Str("cap")
case tp: SingletonType => toTextRef(tp)
case ReachCapability(tp1) => toTextRef(tp1) ~ "*"
case MaybeCapability(tp1) => toTextRef(tp1) ~ "?"
case _ => toText(tp)

protected def isOmittablePrefix(sym: Symbol): Boolean =
Expand Down
Loading