Skip to content

Commit 0011ea6

Browse files
authored
Introduce Maybe Capabilities (scala#19500)
Introduce maybe capabilities x? to handle the case where a capture set appears invariantly in its surrounding type. Maybe capabilities are similar to TypeBounds types, but restricted to capture sets. For instance, Array[C^{x?}] should be morally equivaelent to Array[_ >: C^{} <: C^{x}] but it has fewer issues with type inference.
2 parents f15bbda + 3ef2514 commit 0011ea6

File tree

21 files changed

+173
-95
lines changed

21 files changed

+173
-95
lines changed

compiler/src/dotty/tools/dotc/ast/tpd.scala

+4-4
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,13 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
118118
* otherwise specified).
119119
*/
120120
def Closure(meth: TermSymbol, rhsFn: List[List[Tree]] => Tree, targs: List[Tree] = Nil, targetType: Type = NoType)(using Context): Block = {
121-
val targetTpt = if (targetType.exists) TypeTree(targetType) else EmptyTree
121+
val targetTpt = if (targetType.exists) TypeTree(targetType, inferred = true) else EmptyTree
122122
val call =
123123
if (targs.isEmpty) Ident(TermRef(NoPrefix, meth))
124124
else TypeApply(Ident(TermRef(NoPrefix, meth)), targs)
125-
Block(
126-
DefDef(meth, rhsFn) :: Nil,
127-
Closure(Nil, call, targetTpt))
125+
var mdef0 = DefDef(meth, rhsFn)
126+
val mdef = cpy.DefDef(mdef0)(tpt = TypeTree(mdef0.tpt.tpe, inferred = true))
127+
Block(mdef :: Nil, Closure(Nil, call, targetTpt))
128128
}
129129

130130
/** A closure whose anonymous function has the given method type */

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

+39-8
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,31 @@ extension (tp: Type)
220220
* type of `x`. If `x` and `y` are different variables then `{x*}` and `{y*}`
221221
* are unrelated.
222222
*/
223-
def reach(using Context): CaptureRef =
224-
assert(tp.isTrackableRef)
225-
AnnotatedType(tp, Annotation(defn.ReachCapabilityAnnot, util.Spans.NoSpan))
223+
def reach(using Context): CaptureRef = tp match
224+
case tp: CaptureRef if tp.isTrackableRef =>
225+
if tp.isReach then tp else ReachCapability(tp)
226+
227+
/** If `x` is a capture ref, its maybe capability `x?`, represented internally
228+
* as `x @maybeCapability`. `x?` stands for a capability `x` that might or might
229+
* not be part of a capture set. We have `{} <: {x?} <: {x}`. Maybe capabilities
230+
* cannot be propagated between sets. If `a <: b` and `a` acquires `x?` then
231+
* `x` is propagated to `b` as a conservative approximation.
232+
*
233+
* Maybe capabilities should only arise for capture sets that appear in invariant
234+
* position in their surrounding type. They are similar to TypeBunds types, but
235+
* restricted to capture sets. For instance,
236+
*
237+
* Array[C^{x?}]
238+
*
239+
* should be morally equivalent to
240+
*
241+
* Array[_ >: C^{} <: C^{x}]
242+
*
243+
* but it has fewer issues with type inference.
244+
*/
245+
def maybe(using Context): CaptureRef = tp match
246+
case tp: CaptureRef if tp.isTrackableRef =>
247+
if tp.isMaybe then tp else MaybeCapability(tp)
226248

227249
/** If `ref` is a trackable capture ref, and `tp` has only covariant occurrences of a
228250
* universal capture set, replace all these occurrences by `{ref*}`. This implements
@@ -419,12 +441,21 @@ object ReachCapabilityApply:
419441
case Apply(reach, arg :: Nil) if reach.symbol == defn.Caps_reachCapability => Some(arg)
420442
case _ => None
421443

444+
class AnnotatedCapability(annot: Context ?=> ClassSymbol):
445+
def apply(tp: Type)(using Context) =
446+
AnnotatedType(tp, Annotation(annot, util.Spans.NoSpan))
447+
def unapply(tree: AnnotatedType)(using Context): Option[SingletonCaptureRef] = tree match
448+
case AnnotatedType(parent: SingletonCaptureRef, ann) if ann.symbol == annot => Some(parent)
449+
case _ => None
450+
422451
/** An extractor for `ref @annotation.internal.reachCapability`, which is used to express
423452
* the reach capability `ref*` as a type.
424453
*/
425-
object ReachCapability:
426-
def unapply(tree: AnnotatedType)(using Context): Option[SingletonCaptureRef] = tree match
427-
case AnnotatedType(parent: SingletonCaptureRef, ann)
428-
if ann.symbol == defn.ReachCapabilityAnnot => Some(parent)
429-
case _ => None
454+
object ReachCapability extends AnnotatedCapability(defn.ReachCapabilityAnnot)
455+
456+
/** An extractor for `ref @maybeCapability`, which is used to express
457+
* the maybe capability `ref?` as a type.
458+
*/
459+
object MaybeCapability extends AnnotatedCapability(defn.MaybeCapabilityAnnot)
460+
430461

compiler/src/dotty/tools/dotc/cc/CaptureSet.scala

+34-6
Original file line numberDiff line numberDiff line change
@@ -145,15 +145,16 @@ sealed abstract class CaptureSet extends Showable:
145145

146146
/** x subsumes x
147147
* this subsumes this.f
148-
* x subsumes y ==> x* subsumes y
149-
* x subsumes y ==> x* subsumes y*
148+
* x subsumes y ==> x* subsumes y, x subsumes y?
149+
* x subsumes y ==> x* subsumes y*, x? subsumes y?
150150
*/
151151
extension (x: CaptureRef)
152152
private def subsumes(y: CaptureRef)(using Context): Boolean =
153153
(x eq y)
154154
|| x.isRootCapability
155155
|| y.match
156-
case y: TermRef => !y.isReach && (y.prefix eq x)
156+
case y: TermRef => y.prefix eq x
157+
case MaybeCapability(y1) => x.stripMaybe.subsumes(y1)
157158
case _ => false
158159
|| x.match
159160
case ReachCapability(x1) => x1.subsumes(y.stripReach)
@@ -312,6 +313,8 @@ sealed abstract class CaptureSet extends Showable:
312313
def substParams(tl: BindingType, to: List[Type])(using Context) =
313314
map(Substituters.SubstParamsMap(tl, to))
314315

316+
def maybe(using Context): CaptureSet = map(MaybeMap())
317+
315318
/** Invoke handler if this set has (or later aquires) the root capability `cap` */
316319
def disallowRootCapability(handler: () => Context ?=> Unit)(using Context): this.type =
317320
if isUniversal then handler()
@@ -445,6 +448,8 @@ object CaptureSet:
445448
def isConst = isSolved
446449
def isAlwaysEmpty = false
447450

451+
def isMaybeSet = false // overridden in BiMapped
452+
448453
/** A handler to be invoked if the root reference `cap` is added to this set */
449454
var rootAddedHandler: () => Context ?=> Unit = () => ()
450455

@@ -490,9 +495,10 @@ object CaptureSet:
490495
if elem.isRootCapability then
491496
rootAddedHandler()
492497
newElemAddedHandler(elem)
498+
val normElem = if isMaybeSet then elem else elem.stripMaybe
493499
// assert(id != 5 || elems.size != 3, this)
494500
val res = (CompareResult.OK /: deps): (r, dep) =>
495-
r.andAlso(dep.tryInclude(elem, this))
501+
r.andAlso(dep.tryInclude(normElem, this))
496502
res.orElse:
497503
elems -= elem
498504
res.addToTrace(this)
@@ -508,6 +514,8 @@ object CaptureSet:
508514
levelLimit.isContainedIn(elem.cls.levelOwner)
509515
case ReachCapability(elem1) =>
510516
levelOK(elem1)
517+
case MaybeCapability(elem1) =>
518+
levelOK(elem1)
511519
case _ =>
512520
true
513521

@@ -760,6 +768,7 @@ object CaptureSet:
760768
if source eq origin then supApprox.map(bimap.inverse)
761769
else source.upperApprox(this).map(bimap) ** supApprox
762770

771+
override def isMaybeSet: Boolean = bimap.isInstanceOf[MaybeMap]
763772
override def toString = s"BiMapped$id($source, elems = $elems)"
764773
end BiMapped
765774

@@ -840,8 +849,7 @@ object CaptureSet:
840849
upper.isAlwaysEmpty || upper.isConst && upper.elems.size == 1 && upper.elems.contains(r1)
841850
if variance > 0 || isExact then upper
842851
else if variance < 0 then CaptureSet.empty
843-
else if ctx.mode.is(Mode.Printing) then upper
844-
else assert(false, i"trying to add $upper from $r via ${tm.getClass} in a non-variant setting")
852+
else upper.maybe
845853

846854
/** Apply `f` to each element in `xs`, and join result sets with `++` */
847855
def mapRefs(xs: Refs, f: CaptureRef => CaptureSet)(using Context): CaptureSet =
@@ -980,6 +988,26 @@ object CaptureSet:
980988
/** The current VarState, as passed by the implicit context */
981989
def varState(using state: VarState): VarState = state
982990

991+
/** Maps `x` to `x?` */
992+
private class MaybeMap(using Context) extends BiTypeMap:
993+
994+
def apply(t: Type) = t match
995+
case t: CaptureRef if t.isTrackableRef => t.maybe
996+
case _ => mapOver(t)
997+
998+
override def toString = "Maybe"
999+
1000+
lazy val inverse = new BiTypeMap:
1001+
1002+
def apply(t: Type) = t match
1003+
case t: CaptureRef if t.isMaybe => t.stripMaybe
1004+
case t => mapOver(t)
1005+
1006+
def inverse = MaybeMap.this
1007+
1008+
override def toString = "Maybe.inverse"
1009+
end MaybeMap
1010+
9831011
/* Not needed:
9841012
def ofClass(cinfo: ClassInfo, argTypes: List[Type])(using Context): CaptureSet =
9851013
CaptureSet.empty

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

-6
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,6 @@ object CheckCaptures:
6868
*/
6969
final class SubstParamsMap(from: BindingType, to: List[Type])(using Context)
7070
extends ApproximatingTypeMap, IdempotentCaptRefMap:
71-
/** This SubstParamsMap is exact if `to` only contains `CaptureRef`s. */
72-
private val isExactSubstitution: Boolean = to.forall(_.isTrackableRef)
73-
74-
/** As long as this substitution is exact, there is no need to create `Range`s when mapping invariant positions. */
75-
override protected def needsRangeIfInvariant(refs: CaptureSet): Boolean = !isExactSubstitution
76-
7771
def apply(tp: Type): Type =
7872
tp match
7973
case tp: ParamRef =>

compiler/src/dotty/tools/dotc/core/Definitions.scala

+13-9
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,10 @@ class Definitions {
535535
List(AnyType), EmptyScope)
536536
@tu lazy val SingletonType: TypeRef = SingletonClass.typeRef
537537

538+
@tu lazy val MaybeCapabilityAnnot: ClassSymbol =
539+
completeClass(enterCompleteClassSymbol(
540+
ScalaPackageClass, tpnme.maybeCapability, Final, List(StaticAnnotationClass.typeRef)))
541+
538542
@tu lazy val CollectionSeqType: TypeRef = requiredClassRef("scala.collection.Seq")
539543
@tu lazy val SeqType: TypeRef = requiredClassRef("scala.collection.immutable.Seq")
540544
@tu lazy val SeqModule: Symbol = requiredModule("scala.collection.immutable.Seq")
@@ -993,7 +997,7 @@ class Definitions {
993997

994998
// Annotation base classes
995999
@tu lazy val AnnotationClass: ClassSymbol = requiredClass("scala.annotation.Annotation")
996-
// @tu lazy val StaticAnnotationClass: ClassSymbol = requiredClass("scala.annotation.StaticAnnotation")
1000+
@tu lazy val StaticAnnotationClass: ClassSymbol = requiredClass("scala.annotation.StaticAnnotation")
9971001
@tu lazy val RefiningAnnotationClass: ClassSymbol = requiredClass("scala.annotation.RefiningAnnotation")
9981002
@tu lazy val JavaAnnotationClass: ClassSymbol = requiredClass("java.lang.annotation.Annotation")
9991003

@@ -1171,19 +1175,18 @@ class Definitions {
11711175
}
11721176
}
11731177

1174-
object RefinedFunctionOf {
1178+
object RefinedFunctionOf:
1179+
11751180
/** Matches a refined `PolyFunction`/`FunctionN[...]`/`ContextFunctionN[...]`.
11761181
* Extracts the method type type and apply info.
11771182
*/
1178-
def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] = {
1183+
def unapply(tpe: RefinedType)(using Context): Option[MethodOrPoly] =
11791184
tpe.refinedInfo match
11801185
case mt: MethodOrPoly
1181-
if tpe.refinedName == nme.apply
1182-
&& (tpe.parent.derivesFrom(defn.PolyFunctionClass) || isFunctionNType(tpe.parent)) =>
1183-
Some(mt)
1186+
if tpe.refinedName == nme.apply && isFunctionType(tpe.parent) => Some(mt)
11841187
case _ => None
1185-
}
1186-
}
1188+
1189+
end RefinedFunctionOf
11871190

11881191
object PolyFunctionOf {
11891192

@@ -2137,7 +2140,8 @@ class Definitions {
21372140
AnyValClass,
21382141
NullClass,
21392142
NothingClass,
2140-
SingletonClass)
2143+
SingletonClass,
2144+
MaybeCapabilityAnnot)
21412145

21422146
@tu lazy val syntheticCoreClasses: List[Symbol] = syntheticScalaClasses ++ List(
21432147
EmptyPackageVal,

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ object StdNames {
537537
val ManifestFactory: N = "ManifestFactory"
538538
val manifestToTypeTag: N = "manifestToTypeTag"
539539
val map: N = "map"
540+
val maybeCapability: N = "maybeCapability"
540541
val materializeClassTag: N = "materializeClassTag"
541542
val materializeWeakTypeTag: N = "materializeWeakTypeTag"
542543
val materializeTypeTag: N = "materializeTypeTag"

compiler/src/dotty/tools/dotc/core/TypeOps.scala

-12
Original file line numberDiff line numberDiff line change
@@ -540,18 +540,6 @@ object TypeOps:
540540
val sym = tp.symbol
541541
forbidden.contains(sym)
542542

543-
/** We need to split the set into upper and lower approximations
544-
* only if it contains a local element. The idea here is that at the
545-
* time we perform an `avoid` all local elements are already accounted for
546-
* and no further elements will be added afterwards. So we can just keep
547-
* the set as it is. See comment by @linyxus on #16261.
548-
*/
549-
override def needsRangeIfInvariant(refs: CaptureSet): Boolean =
550-
refs.elems.exists {
551-
case ref: TermRef => toAvoid(ref)
552-
case _ => false
553-
}
554-
555543
override def apply(tp: Type): Type = tp match
556544
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
557545
val lo = TypeComparer.instanceType(

compiler/src/dotty/tools/dotc/core/Types.scala

+17-13
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,9 @@ object Types extends TypeUtils {
845845
safeIntersection = ctx.base.pendingMemberSearches.contains(name))
846846
joint match
847847
case joint: SingleDenotation
848-
if isRefinedMethod && rinfo <:< joint.info =>
848+
if isRefinedMethod
849+
&& (rinfo <:< joint.info
850+
|| name == nme.apply && defn.isFunctionType(tp.parent)) =>
849851
// use `rinfo` to keep the right parameter names for named args. See i8516.scala.
850852
joint.derivedSingleDenotation(joint.symbol, rinfo, pre, isRefinedMethod)
851853
case _ =>
@@ -2198,7 +2200,11 @@ object Types extends TypeUtils {
21982200
/** Is this a reach reference of the form `x*`? */
21992201
def isReach(using Context): Boolean = false // overridden in AnnotatedType
22002202

2203+
/** Is this a maybe reference of the form `x?`? */
2204+
def isMaybe(using Context): Boolean = false // overridden in AnnotatedType
2205+
22012206
def stripReach(using Context): CaptureRef = this // overridden in AnnotatedType
2207+
def stripMaybe(using Context): CaptureRef = this // overridden in AnnotatedType
22022208

22032209
/** Is this reference the generic root capability `cap` ? */
22042210
def isRootCapability(using Context): Boolean = false
@@ -5618,14 +5624,21 @@ object Types extends TypeUtils {
56185624
}
56195625

56205626
override def isTrackableRef(using Context) =
5621-
isReach && parent.isTrackableRef
5627+
(isReach || isMaybe) && parent.isTrackableRef
56225628

56235629
/** Is this a reach reference of the form `x*`? */
56245630
override def isReach(using Context): Boolean =
56255631
annot.symbol == defn.ReachCapabilityAnnot
56265632

5627-
override def stripReach(using Context): SingletonCaptureRef =
5628-
(if isReach then parent else this).asInstanceOf[SingletonCaptureRef]
5633+
/** Is this a reach reference of the form `x*`? */
5634+
override def isMaybe(using Context): Boolean =
5635+
annot.symbol == defn.MaybeCapabilityAnnot
5636+
5637+
override def stripReach(using Context): CaptureRef =
5638+
if isReach then parent.asInstanceOf[CaptureRef] else this
5639+
5640+
override def stripMaybe(using Context): CaptureRef =
5641+
if isMaybe then parent.asInstanceOf[CaptureRef] else this
56295642

56305643
override def normalizedRef(using Context): CaptureRef =
56315644
if isReach then AnnotatedType(stripReach.normalizedRef, annot) else this
@@ -6475,15 +6488,6 @@ object Types extends TypeUtils {
64756488
tp.derivedLambdaType(tp.paramNames, formals, restpe)
64766489
}
64776490

6478-
/** Overridden in TypeOps.avoid and in CheckCaptures.substParamsMap */
6479-
protected def needsRangeIfInvariant(refs: CaptureSet): Boolean = true
6480-
6481-
override def mapCapturingType(tp: Type, parent: Type, refs: CaptureSet, v: Int): Type =
6482-
if v == 0 && needsRangeIfInvariant(refs) then
6483-
range(mapCapturingType(tp, parent, refs, -1), mapCapturingType(tp, parent, refs, 1))
6484-
else
6485-
super.mapCapturingType(tp, parent, refs, v)
6486-
64876491
protected def reapply(tp: Type): Type = apply(tp)
64886492
}
64896493

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import util.SourcePosition
1515
import scala.util.control.NonFatal
1616
import scala.annotation.switch
1717
import config.{Config, Feature}
18-
import cc.{CapturingType, RetainingType, CaptureSet, ReachCapability, isBoxed, levelOwner, retainedElems}
18+
import cc.{CapturingType, RetainingType, CaptureSet, ReachCapability, MaybeCapability, isBoxed, levelOwner, retainedElems}
1919

2020
class PlainPrinter(_ctx: Context) extends Printer {
2121

@@ -404,6 +404,7 @@ class PlainPrinter(_ctx: Context) extends Printer {
404404
case tp: TermRef if tp.symbol == defn.captureRoot => Str("cap")
405405
case tp: SingletonType => toTextRef(tp)
406406
case ReachCapability(tp1) => toTextRef(tp1) ~ "*"
407+
case MaybeCapability(tp1) => toTextRef(tp1) ~ "?"
407408
case _ => toText(tp)
408409

409410
protected def isOmittablePrefix(sym: Symbol): Boolean =

0 commit comments

Comments
 (0)