Skip to content

Commit 7e9ab8d

Browse files
committed
Better handling of vars with universal capture sets
- add unsafeBox operation so that a variable with universal capture set can be initialized and assigned to - infer unsafebox
1 parent d76dbe4 commit 7e9ab8d

File tree

5 files changed

+79
-39
lines changed

5 files changed

+79
-39
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,16 +96,16 @@ extension (tp: Type)
9696
/** Is the boxedCaptureSet of this type nonempty? */
9797
def isBoxedCapturing(using Context) = !tp.boxedCaptureSet.isAlwaysEmpty
9898

99-
/** If this type is a boxed capturing type, its unboxed version
100-
* If it is a TermRef of boxed capturing type, an unboxed capturing
101-
* type capturing the TermRef.
99+
/** If this type is a capturing type, the version with boxed statues as given by `boxed`.
100+
* If it is a TermRef of a capturing type, and the box status flips, widen to a capturing
101+
* type that captures the TermRef.
102102
*/
103-
def unbox(using Context): Type = tp.widenDealias match
104-
case tp @ CapturingType(parent, refs) if tp.isBoxed =>
103+
def forceBoxStatus(boxed: Boolean)(using Context): Type = tp.widenDealias match
104+
case tp @ CapturingType(parent, refs) if tp.isBoxed != boxed =>
105105
val refs1 = tp match
106106
case ref: CaptureRef if ref.isTracked => ref.singletonCaptureSet
107107
case _ => refs
108-
CapturingType(parent, refs1, boxed = false)
108+
CapturingType(parent, refs1, boxed)
109109
case _ =>
110110
tp
111111

@@ -168,6 +168,7 @@ extension (sym: Symbol)
168168
case _ => false
169169
containsEnclTypeParam(sym.info.finalResultType)
170170
&& !sym.allowsRootCapture
171+
&& sym != defn.Caps_unsafeBox
171172
&& sym != defn.Caps_unsafeUnbox
172173

173174
extension (tp: AnnotatedType)

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -332,29 +332,31 @@ class CheckCaptures extends Recheck, SymTransformer:
332332
* and Cr otherwise.
333333
*/
334334
override def recheckApply(tree: Apply, pt: Type)(using Context): Type =
335-
includeCallCaptures(tree.symbol, tree.srcPos)
336-
tree match
337-
case Apply(fn, arg :: Nil) if fn.symbol == defn.Caps_unsafeUnbox =>
338-
val argType0 = recheckStart(arg, pt).unbox
339-
val argType = super.recheckFinish(argType0, arg, pt)
340-
super.recheckFinish(argType, tree, pt)
341-
case _ =>
342-
super.recheckApply(tree, pt) match
343-
case appType @ CapturingType(appType1, refs) =>
344-
tree.fun match
345-
case Select(qual, _)
346-
if !tree.fun.symbol.isConstructor
347-
&& !qual.tpe.isBoxedCapturing
348-
&& !tree.args.exists(_.tpe.isBoxedCapturing)
349-
&& qual.tpe.captureSet.mightSubcapture(refs)
350-
&& tree.args.forall(_.tpe.captureSet.mightSubcapture(refs))
351-
=>
352-
val callCaptures = tree.args.foldLeft(qual.tpe.captureSet)((cs, arg) =>
353-
cs ++ arg.tpe.captureSet)
354-
appType.derivedCapturingType(appType1, callCaptures)
355-
.showing(i"narrow $tree: $appType, refs = $refs, qual = ${qual.tpe.captureSet} --> $result", capt)
356-
case _ => appType
357-
case appType => appType
335+
val meth = tree.fun.symbol
336+
includeCallCaptures(meth, tree.srcPos)
337+
if meth == defn.Caps_unsafeBox || meth == defn.Caps_unsafeUnbox then
338+
val arg :: Nil = tree.args: @unchecked
339+
val argType0 = recheckStart(arg, pt)
340+
.forceBoxStatus(boxed = meth == defn.Caps_unsafeBox)
341+
val argType = super.recheckFinish(argType0, arg, pt)
342+
super.recheckFinish(argType, tree, pt)
343+
else
344+
super.recheckApply(tree, pt) match
345+
case appType @ CapturingType(appType1, refs) =>
346+
tree.fun match
347+
case Select(qual, _)
348+
if !tree.fun.symbol.isConstructor
349+
&& !qual.tpe.isBoxedCapturing
350+
&& !tree.args.exists(_.tpe.isBoxedCapturing)
351+
&& qual.tpe.captureSet.mightSubcapture(refs)
352+
&& tree.args.forall(_.tpe.captureSet.mightSubcapture(refs))
353+
=>
354+
val callCaptures = tree.args.foldLeft(qual.tpe.captureSet)((cs, arg) =>
355+
cs ++ arg.tpe.captureSet)
356+
appType.derivedCapturingType(appType1, callCaptures)
357+
.showing(i"narrow $tree: $appType, refs = $refs, qual = ${qual.tpe.captureSet} --> $result", capt)
358+
case _ => appType
359+
case appType => appType
358360
end recheckApply
359361

360362
/** Handle an application of method `sym` with type `mt` to arguments of types `argTypes`.
@@ -460,10 +462,25 @@ class CheckCaptures extends Recheck, SymTransformer:
460462
case _ =>
461463
super.recheckBlock(block, pt)
462464

465+
/** If `rhsProto` has `*` as its capture set, wrap `rhs` in a `unsafeBox`.
466+
* Used to infer `unsafeBox` for expressions that get assigned to variables
467+
* that have universal capture set.
468+
*/
469+
def maybeBox(rhs: Tree, rhsProto: Type)(using Context): Tree =
470+
if rhsProto.captureSet.isUniversal then
471+
ref(defn.Caps_unsafeBox).appliedToType(rhsProto).appliedTo(rhs)
472+
else rhs
473+
474+
override def recheckAssign(tree: Assign)(using Context): Type =
475+
val rhsProto = recheck(tree.lhs).widen
476+
recheck(maybeBox(tree.rhs, rhsProto), rhsProto)
477+
defn.UnitType
478+
463479
override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit =
464480
try
465481
if !sym.is(Module) then // Modules are checked by checking the module class
466-
super.recheckValDef(tree, sym)
482+
if sym.is(Mutable) then recheck(maybeBox(tree.rhs, sym.info), sym.info)
483+
else super.recheckValDef(tree, sym)
467484
finally
468485
if !sym.is(Param) then
469486
// Parameters with inferred types belong to anonymous methods. We need to wait
@@ -765,7 +782,6 @@ class CheckCaptures extends Recheck, SymTransformer:
765782
recon(CapturingType(parent1, cs1, actualIsBoxed))
766783
}
767784

768-
769785
var actualw = actual.widenDealias
770786
actual match
771787
case ref: CaptureRef if ref.isTracked =>

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,7 @@ class Definitions {
961961
def RuntimeTupleFunctionsModule(using Context): Symbol = requiredModule("scala.runtime.TupledFunctions")
962962

963963
@tu lazy val CapsModule: Symbol = requiredModule("scala.caps")
964+
@tu lazy val Caps_unsafeBox: Symbol = CapsModule.requiredMethod("unsafeBox")
964965
@tu lazy val Caps_unsafeUnbox: Symbol = CapsModule.requiredMethod("unsafeUnbox")
965966

966967
// Annotation base classes

library/src/scala/caps.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,11 @@ import annotation.experimental
55
// @experimental , suppress @experimental so we can use in compiler itself
66
object caps:
77

8+
/** If argument is of type `cs T`, converts to type `box cs T`. This
9+
* avoids the error that would be raised when boxing `*`.
10+
*/
11+
extension [T](x: T) def unsafeBox: T = x
12+
813
/** If argument is of type `box cs T`, converts to type `cs T`. This
914
* avoids the error that would be raised when unboxing `*`.
1015
*/
Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
1-
import caps.unsafeUnbox
2-
type ErrorHandler = (Int, String) => Unit
3-
4-
var defaultIncompleteHandler: ErrorHandler = ???
5-
var incompleteHandler: ErrorHandler = defaultIncompleteHandler
6-
val x = incompleteHandler.unsafeUnbox
7-
val _ : ErrorHandler = x
8-
val _ = x(1, "a")
1+
import caps.*
2+
3+
object Test:
4+
type ErrorHandler = (Int, String) => Unit
5+
6+
var defaultIncompleteHandler: ErrorHandler = ???
7+
var incompleteHandler: ErrorHandler = defaultIncompleteHandler
8+
val x = incompleteHandler.unsafeUnbox
9+
val _ : ErrorHandler = x
10+
val _ = x(1, "a")
11+
12+
def defaultIncompleteHandler1(): ErrorHandler = ???
13+
val defaultIncompleteHandler2: ErrorHandler = ???
14+
var incompleteHandler1: ErrorHandler = defaultIncompleteHandler1()
15+
var incompleteHandler2: ErrorHandler = defaultIncompleteHandler2
16+
var incompleteHandler3: ErrorHandler = defaultIncompleteHandler1().unsafeBox
17+
var incompleteHandler4: ErrorHandler = defaultIncompleteHandler2.unsafeBox
18+
private var incompleteHandler5 = defaultIncompleteHandler1()
19+
private var incompleteHandler6 = defaultIncompleteHandler2
20+
private var incompleteHandler7 = defaultIncompleteHandler1().unsafeBox
21+
private var incompleteHandler8 = defaultIncompleteHandler2.unsafeBox
22+
23+
incompleteHandler1 = defaultIncompleteHandler2
24+
incompleteHandler1 = defaultIncompleteHandler2.unsafeBox
25+
val saved = incompleteHandler1.unsafeUnbox

0 commit comments

Comments
 (0)