Skip to content

Commit 9a877cb

Browse files
committed
Make retains an annotation
Translate @retains types to CapturingTypes only at phase cc, not before.
1 parent 632efd8 commit 9a877cb

File tree

12 files changed

+171
-70
lines changed

12 files changed

+171
-70
lines changed

compiler/src/dotty/tools/dotc/core/CaptureSet.scala

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -302,23 +302,26 @@ object CaptureSet:
302302
def varState(using state: VarState): VarState = state
303303

304304
def ofClass(cinfo: ClassInfo, argTypes: List[Type])(using Context): CaptureSet =
305-
def captureSetOf(tp: Type): CaptureSet = tp match
306-
case tp: TypeRef if tp.symbol.is(ParamAccessor) =>
307-
def mapArg(accs: List[Symbol], tps: List[Type]): CaptureSet = accs match
308-
case acc :: accs1 if tps.nonEmpty =>
309-
if acc == tp.symbol then tps.head.captureSet
310-
else mapArg(accs1, tps.tail)
311-
case _ =>
312-
empty
313-
mapArg(cinfo.cls.paramAccessors, argTypes)
314-
case _ =>
315-
tp.captureSet
316-
val css =
317-
for
318-
parent <- cinfo.parents if parent.classSymbol == defn.RetainsClass
319-
arg <- parent.argInfos
320-
yield captureSetOf(arg)
321-
css.foldLeft(empty)(_ ++ _)
305+
CaptureSet.empty
306+
/*
307+
def captureSetOf(tp: Type): CaptureSet = tp match
308+
case tp: TypeRef if tp.symbol.is(ParamAccessor) =>
309+
def mapArg(accs: List[Symbol], tps: List[Type]): CaptureSet = accs match
310+
case acc :: accs1 if tps.nonEmpty =>
311+
if acc == tp.symbol then tps.head.captureSet
312+
else mapArg(accs1, tps.tail)
313+
case _ =>
314+
empty
315+
mapArg(cinfo.cls.paramAccessors, argTypes)
316+
case _ =>
317+
tp.captureSet
318+
val css =
319+
for
320+
parent <- cinfo.parents if parent.classSymbol == defn.RetainingClass
321+
arg <- parent.argInfos
322+
yield captureSetOf(arg)
323+
css.foldLeft(empty)(_ ++ _)
324+
*/
322325

323326
def ofType(tp: Type)(using Context): CaptureSet =
324327
def recur(tp: Type): CaptureSet = tp.dealias match
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package dotty.tools
2+
package dotc
3+
package core
4+
5+
import Types.*, Symbols.*, Contexts.*
6+
7+
object CaptureTypeOps:
8+
9+
extension (tp: Type)
10+
11+
/** If this is type variable instantiated or upper bounded with a capturing type,
12+
* the capture set associated with that type. Extended to and-or types and
13+
* type proxies in the obvious way. If a term has a type with a boxed captureset,
14+
* that captureset counts towards the capture variables of the envirionment.
15+
*/
16+
def boxedCaptured(using Context): CaptureSet =
17+
def getBoxed(tp: Type, enabled: Boolean): CaptureSet = tp match
18+
case tp: CapturingType if enabled => tp.refs
19+
case tp: TypeVar => getBoxed(tp.underlying, enabled = true)
20+
case tp: TypeRef if tp.symbol == defn.AnyClass && enabled => CaptureSet.universal
21+
case tp: TypeProxy => getBoxed(tp.superType, enabled)
22+
case tp: AndType => getBoxed(tp.tp1, enabled) ++ getBoxed(tp.tp2, enabled)
23+
case tp: OrType => getBoxed(tp.tp1, enabled) ** getBoxed(tp.tp2, enabled)
24+
case _ => CaptureSet.empty
25+
getBoxed(tp, enabled = false)
26+
27+
/** If this type appears as an expected type of a term, does it imply
28+
* that the term should be boxed? ^^^ Special treat Any?
29+
*/
30+
def needsBox(using Context): Boolean = tp match
31+
case _: TypeVar => true
32+
case tp: TypeRef =>
33+
tp.info match
34+
case TypeBounds(lo, _) => lo.needsBox
35+
case _ => false
36+
case tp: RefinedOrRecType => tp.parent.needsBox
37+
case tp: AnnotatedType => tp.parent.needsBox
38+
case tp: LazyRef => tp.ref.needsBox
39+
case tp: AndType => tp.tp1.needsBox || tp.tp2.needsBox
40+
case tp: OrType => tp.tp1.needsBox && tp.tp2.needsBox
41+
case _ => false
42+
43+
def canHaveInferredCapture(using Context): Boolean = tp match
44+
case tp: CapturingType =>
45+
false
46+
case tp: TypeRef =>
47+
if tp.symbol.isClass then
48+
!tp.symbol.isValueClass && tp.symbol != defn.AnyClass
49+
else
50+
tp.underlying.canHaveInferredCapture
51+
case tp: TypeProxy =>
52+
tp.underlying.canHaveInferredCapture
53+
case tp: AndType =>
54+
tp.tp1.canHaveInferredCapture && tp.tp2.canHaveInferredCapture
55+
case tp: OrType =>
56+
tp.tp1.canHaveInferredCapture || tp.tp2.canHaveInferredCapture
57+
case _ =>
58+
false
59+
60+
def addCaptureVars(using Context): Type =
61+
if ctx.settings.Ycc.value && canHaveInferredCapture then
62+
CapturingType(tp, CaptureSet.Var()) // ^^^ go deep
63+
else
64+
tp
65+
end CaptureTypeOps

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,8 +888,6 @@ class Definitions {
888888
lazy val RuntimeTuples_isInstanceOfEmptyTuple: Symbol = RuntimeTuplesModule.requiredMethod("isInstanceOfEmptyTuple")
889889
lazy val RuntimeTuples_isInstanceOfNonEmptyTuple: Symbol = RuntimeTuplesModule.requiredMethod("isInstanceOfNonEmptyTuple")
890890

891-
@tu lazy val RetainsClass: ClassSymbol = requiredClass("scala.Retains")
892-
893891
// Annotation base classes
894892
@tu lazy val AnnotationClass: ClassSymbol = requiredClass("scala.annotation.Annotation")
895893
@tu lazy val ClassfileAnnotationClass: ClassSymbol = requiredClass("scala.annotation.ClassfileAnnotation")
@@ -943,6 +941,7 @@ class Definitions {
943941
@tu lazy val FunctionalInterfaceAnnot: ClassSymbol = requiredClass("java.lang.FunctionalInterface")
944942
@tu lazy val TargetNameAnnot: ClassSymbol = requiredClass("scala.annotation.targetName")
945943
@tu lazy val VarargsAnnot: ClassSymbol = requiredClass("scala.annotation.varargs")
944+
@tu lazy val RetainsAnnot: ClassSymbol = requiredClass("scala.retains")
946945
@tu lazy val AbilityAnnot: ClassSymbol = requiredClass("scala.annotation.ability")
947946

948947
@tu lazy val JavaRepeatableAnnot: ClassSymbol = requiredClass("java.lang.annotation.Repeatable")

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
5353
else sym.flags
5454
).installAfter(preRecheckPhase)
5555

56-
/** Hook to be overridden */
56+
/** Hooks to be overridden */
5757
protected def reinfer(tp: Type)(using Context): Type = tp
58+
protected def transformType(tp: Type)(using Context): Type = tp
5859

5960
def reinferResult(info: Type)(using Context): Type = info match
6061
case info: MethodOrPoly =>
@@ -66,12 +67,14 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
6667

6768
def enterDef(stat: Tree)(using Context): Unit =
6869
val sym = stat.symbol
70+
var newInfo = transformType(sym.info)
6971
stat match
7072
case stat: ValOrDefDef if stat.tpt.isInstanceOf[InferredTypeTree] =>
71-
sym.updateInfo(reinferResult(sym.info))
73+
newInfo = reinferResult(sym.info)
7274
case stat: Bind =>
73-
sym.updateInfo(reinferResult(sym.info))
75+
newInfo = reinferResult(sym.info)
7476
case _ =>
77+
sym.updateInfo(newInfo)
7578

7679
def constFold(tree: Tree, tp: Type)(using Context): Type =
7780
val tree1 = tree.withType(tp)
@@ -107,14 +110,14 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
107110
bindType
108111

109112
def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type =
110-
if !tree.rhs.isEmpty then recheck(tree.rhs, tree.symbol.info)
113+
if !tree.rhs.isEmpty then recheck(tree.rhs, sym.info)
111114
sym.termRef
112115

113116
def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Type =
114117
tree.paramss.foreach(_.foreach(enterDef))
115118
val rhsCtx = linkConstructorParams(sym)
116119
if !tree.rhs.isEmpty && !sym.isInlineMethod && !sym.isEffectivelyErased then
117-
recheck(tree.rhs, tree.symbol.localReturnType)(using rhsCtx)
120+
recheck(tree.rhs, sym.localReturnType)(using rhsCtx)
118121
sym.termRef
119122

120123
def recheckTypeDef(tree: TypeDef, sym: Symbol)(using Context): Type =
@@ -231,7 +234,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
231234

232235
def recheckTypeTree(tree: TypeTree)(using Context): Type = tree match
233236
case tree: InferredTypeTree => reinfer(tree.tpe)
234-
case _ => tree.tpe
237+
case _ => transformType(tree.tpe)
235238

236239
def recheckAnnotated(tree: Annotated)(using Context): Type =
237240
tree.tpe match

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,17 @@ class CheckCaptures extends Recheck:
8888
override def reinfer(tp: Type)(using Context): Type =
8989
CapturingType(tp, CaptureSet.Var()) // ^^^ go deep
9090

91+
override def transformType(tp: Type)(using Context): Type =
92+
val mapType = new TypeMap:
93+
def apply(t: Type) = mapOver(t) match
94+
case AnnotatedType(parent, annot) if annot.symbol == defn.RetainsAnnot =>
95+
annot.tree match
96+
case Apply(_, Typed(SeqLiteral(elems, _), _) :: Nil) =>
97+
CapturingType(parent, CaptureSet(elems.tpes.asInstanceOf[List[CaptureRef]]*))
98+
case t1 =>
99+
t1
100+
mapType(tp)
101+
91102
private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, false, null)
92103

93104
private val myCapturedVars: util.EqHashMap[Symbol, CaptureSet] = EqHashMap()
@@ -130,6 +141,18 @@ class CheckCaptures extends Recheck:
130141
if res != CompareResult.OK then
131142
report.error(i"references $cs1 are not all included in allowed capture set ${res.blocking}", pos)
132143

144+
private var keepTypesCache: Boolean = compiletime.uninitialized
145+
private var keepTypesDefined = false
146+
147+
private def keepTypes(using Context) =
148+
if !keepTypesDefined then
149+
keepTypesCache = ctx.settings.Xprint.value.containsPhase(thisPhase)
150+
keepTypesDefined = true
151+
keepTypesCache
152+
153+
def remember(tree: Tree, tpe: Type)(using Context): Unit =
154+
if keepTypes && (tpe ne tree.tpe) then tree.putAttachment(RecheckedType, tpe)
155+
133156
override def recheckClosure(tree: Closure, pt: Type)(using Context): Type =
134157
val cs = capturedVars(tree.meth.symbol)
135158
recheckr.println(i"typing closure $tree with cvs $cs")
@@ -141,7 +164,12 @@ class CheckCaptures extends Recheck:
141164
if tree.symbol.is(Method) then includeCallCaptures(tree.symbol, tree.srcPos)
142165
super.recheckIdent(tree)
143166

167+
override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Type =
168+
remember(tree.tpt, sym.info)
169+
super.recheckValDef(tree, sym)
170+
144171
override def recheckDefDef(tree: DefDef, sym: Symbol)(using Context): Type =
172+
remember(tree, sym.localReturnType)
145173
val saved = curEnv
146174
val localSet = capturedVars(sym)
147175
if !localSet.isEmpty then curEnv = Env(sym, localSet, false, curEnv)
@@ -161,22 +189,13 @@ class CheckCaptures extends Recheck:
161189
val cs = if sym.isConstructor then capturedVars(sym.owner) else CaptureSet.empty
162190
super.recheckApply(tree, pt).capturing(cs)
163191

164-
private var keepTypesCache: Boolean = compiletime.uninitialized
165-
private var keepTypesDefined = false
166-
167-
private def keepTypes(using Context) =
168-
if !keepTypesDefined then
169-
keepTypesCache = ctx.settings.Xprint.value.containsPhase(thisPhase)
170-
keepTypesDefined = true
171-
keepTypesCache
172-
173192
override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
174193
val saved = curEnv
175194
if pt.needsBox && !curEnv.isBoxed then // ^^^ refine?
176195
curEnv = Env(NoSymbol, CaptureSet.Var(), true, curEnv)
177196
try
178197
val res = super.recheck(tree, pt)
179-
if keepTypes && (res ne tree.tpe) then tree.putAttachment(RecheckedType, res)
198+
remember(tree, res)
180199
if curEnv.isOpen then assertSub(res.boxedCaptured, curEnv.captured)
181200
res
182201
finally curEnv = saved
Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package scala
22

3-
/** Parent trait that indicates capturing. Example usage:
4-
*
5-
* class Foo(using ctx: Context) extends Holds[ctx | CanThrow[Exception]]
3+
/** An annotation that indicates capture
64
*/
7-
trait Retains[T]
5+
class retains(xs: Any*) extends annotation.StaticAnnotation
6+

tests/neg-custom-args/captures/capt1.check

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,40 @@
1-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:3:2 ------------------------------------------
2-
3 | () => if x == null then y else y // error
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:4:2 ------------------------------------------
2+
4 | () => if x == null then y else y // error
33
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
44
| Found: {x} () => C
55
| Required: () => C
66

77
longer explanation available when compiling with `-explain`
8-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:6:2 ------------------------------------------
9-
6 | () => if x == null then y else y // error
8+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:7:2 ------------------------------------------
9+
7 | () => if x == null then y else y // error
1010
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
1111
| Found: {x} () => C
1212
| Required: Matchable
1313

1414
longer explanation available when compiling with `-explain`
15-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:14:2 -----------------------------------------
16-
14 | f // error
15+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:15:2 -----------------------------------------
16+
15 | f // error
1717
| ^
1818
| Found: {x} Int => Int
1919
| Required: Matchable
2020

2121
longer explanation available when compiling with `-explain`
22-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:22:2 -----------------------------------------
23-
22 | F(22) // error
22+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:23:2 -----------------------------------------
23+
23 | F(22) // error
2424
| ^^^^^
2525
| Found: {x} A
2626
| Required: A
2727

2828
longer explanation available when compiling with `-explain`
29-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:26:40 ----------------------------------------
30-
26 | def m() = if x == null then y else y // error
29+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:27:40 ----------------------------------------
30+
27 | def m() = if x == null then y else y // error
3131
| ^
3232
| Found: {x} A
3333
| Required: A
3434

3535
longer explanation available when compiling with `-explain`
36-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:31:24 ----------------------------------------
37-
31 | val z2 = h[() => Cap](() => x)(() => C()) // error
36+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:32:24 ----------------------------------------
37+
32 | val z2 = h[() => Cap](() => x)(() => C()) // error
3838
| ^^^^^^^
3939
| Found: {x} () => Cap
4040
| Required: () => Cap
Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
1+
import scala.retains
12
class C
2-
def f(x: C retains *, y: C): () => C =
3+
def f(x: C @retains(*), y: C): () => C =
34
() => if x == null then y else y // error
45

5-
def g(x: C retains *, y: C): Matchable =
6+
def g(x: C @retains(*), y: C): Matchable =
67
() => if x == null then y else y // error
78

8-
def h1(x: C retains *, y: C): Any =
9+
def h1(x: C @retains(*), y: C): Any =
910
def f() = if x == null then y else y
1011
() => f() // ok
1112

12-
def h2(x: C retains *): Matchable =
13+
def h2(x: C @retains(*)): Matchable =
1314
def f(y: Int) = if x == null then y else y
1415
f // error
1516

1617
class A
17-
type Cap = C retains *
18+
type Cap = C @retains(*)
1819

1920
def h3(x: Cap): A =
2021
class F(y: Int) extends A:
@@ -26,9 +27,9 @@ def h4(x: Cap, y: Int): A =
2627
def m() = if x == null then y else y // error
2728

2829
def foo() =
29-
val x: C retains * = ???
30+
val x: C @retains(*) = ???
3031
def h[X](a: X)(b: X) = a
3132
val z2 = h[() => Cap](() => x)(() => C()) // error
32-
val z3 = h[(() => Cap) retains x.type](() => x)(() => C()) // ok
33-
val z4 = h[(() => Cap) retains x.type](() => x)(() => C()) // what was inferred for z3
33+
val z3 = h[(() => Cap) @retains(x)](() => x)(() => C()) // ok
34+
val z4 = h[(() => Cap) @retains(x)](() => x)(() => C()) // what was inferred for z3
3435

tests/new/test.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
object Test:
22

3+
val x = ""
4+
5+
class C extends AnyRef @scala.retains(x)
6+
37
def test = ???
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
object Test:
2+
3+
def test() =
4+
val x = "abc"
5+
val y: Object @scala.retains(x) = ???
6+
val z: Object @scala.retains(x, *) = y: Object @scala.retains(x)
7+
Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
import scala.retains
12
class C
2-
type Cap = C retains *
3+
type Cap = C @retains(*)
34

45
def test1() =
56
val y = ""
6-
def x: {y} Object = y
7+
def x: Object @retains(y) = y
78

89
def test2() =
910
val x: Cap = C()
10-
val y: ({x} () => Unit) = () => { x; () } // TODO: drop type ascription
11-
def z: ({y} () => Unit) = y
12-
z: ({y} () => Unit)
11+
val y = () => { x; () }
12+
def z: (() => Unit) @retains(x) = y
13+
z: (() => Unit) @retains(x) // TODO: replace x with y

0 commit comments

Comments
 (0)