Skip to content

Commit 98f6c68

Browse files
committed
First stab at handling classes
1 parent ca04038 commit 98f6c68

File tree

3 files changed

+75
-13
lines changed

3 files changed

+75
-13
lines changed

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
119119
excluded = if tree.symbol.is(Private) then EmptyFlags else Private
120120
).suchThat(tree.symbol ==)
121121
constFold(tree, qualType.select(name, mbr))
122+
//.showing(i"recheck select $qualType . $name : ${mbr.symbol.info} = $result")
122123

123124
def recheckBind(tree: Bind, pt: Type)(using Context): Type = tree match
124125
case Bind(name, body) =>
@@ -161,8 +162,10 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
161162
case _ => mapOver(t)
162163
formals.mapConserve(tm)
163164

164-
/** Hook for method type instantiation */
165-
def instantiate(mt: MethodType, argTypes: => List[Type])(using Context): Type =
165+
/** Hook for method type instantiation
166+
*
167+
*/
168+
def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type =
166169
mt.instantiate(argTypes)
167170

168171
def recheckApply(tree: Apply, pt: Type)(using Context): Type =
@@ -184,7 +187,7 @@ abstract class Recheck extends Phase, IdentityDenotTransformer:
184187
assert(formals.isEmpty)
185188
Nil
186189
val argTypes = recheckArgs(tree.args, formals, fntpe.paramRefs)
187-
constFold(tree, instantiate(fntpe, argTypes))
190+
constFold(tree, instantiate(fntpe, argTypes, tree.fun.symbol))
188191

189192
def recheckTypeApply(tree: TypeApply, pt: Type)(using Context): Type =
190193
recheck(tree.fun).widen match

compiler/src/dotty/tools/dotc/typer/CheckCaptures.scala

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -125,11 +125,30 @@ class CheckCaptures extends Recheck:
125125
case _ =>
126126
tp
127127

128+
/** Refine a possibly applied class type C where the class has tracked parameters
129+
* x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n }
130+
* where CV_1, ..., CV_n are fresh capture sets.
131+
*/
132+
def addCaptureRefinements(tp: Type): Type = tp.stripped match
133+
case _: TypeRef | _: AppliedType if tp.typeSymbol.isClass =>
134+
val cls = tp.typeSymbol.asClass
135+
cls.paramGetters.foldLeft(tp) { (core, getter) =>
136+
if getter.termRef.isTracked then
137+
val getterType = tp.memberInfo(getter).strippedDealias
138+
RefinedType(core, getter.name, CapturingType(getterType, CaptureSet.Var()))
139+
.showing(i"add capture refinement $tp --> $result", capt)
140+
else
141+
core
142+
}
143+
case _ =>
144+
tp
145+
128146
def addVars(tp: Type): Type =
129-
val tp1 = addInnerVars(tp)
147+
var tp1 = addInnerVars(tp)
148+
val tp2 = addCaptureRefinements(tp1)
130149
if tp1.canHaveInferredCapture
131-
then CapturingType(tp1, CaptureSet.Var())
132-
else tp1
150+
then CapturingType(tp2, CaptureSet.Var())
151+
else tp2
133152

134153
addVars(cleanType(tp))
135154
end reinfer
@@ -204,15 +223,38 @@ class CheckCaptures extends Recheck:
204223
try super.recheckClassDef(tree, impl, sym)
205224
finally curEnv = saved
206225

207-
override def instantiate(mt: MethodType, argTypes: => List[Type])(using Context): Type =
208-
if mt.isResultDependent then SubstParamsMap(mt, argTypes)(mt.resType)
209-
else mt.resType
226+
/** Refine the type of a constructor call `new C(t_1, ..., t_n)`
227+
* to C{val x_1: T_1, ..., x_m: T_m} where x_1, ..., x_m are the tracked
228+
* parameters of C and T_1, ..., T_m are the types of the corresponding arguments.
229+
*/
230+
private def addParamArgRefinements(core: Type, argTypes: List[Type], cls: ClassSymbol)(using Context): Type =
231+
cls.paramGetters.lazyZip(argTypes).foldLeft(core) { (core, refine) =>
232+
val (getter, argType) = refine
233+
if getter.termRef.isTracked then RefinedType(core, getter.name, argType)
234+
else core
235+
}
236+
237+
/** Handle an application of method `sym` with type `mt` to arguments of types `argTypes`.
238+
* This means:
239+
* - Instantiate result type with actual arguments
240+
* - If call is to a constructor:
241+
* - remember types of arguments corresponding to tracked
242+
* parameters in refinements.
243+
* - add capture set of instantiated class to capture set of result type.
244+
*/
245+
override def instantiate(mt: MethodType, argTypes: List[Type], sym: Symbol)(using Context): Type =
246+
val ownType =
247+
if mt.isResultDependent then SubstParamsMap(mt, argTypes)(mt.resType)
248+
else mt.resType
249+
if sym.isConstructor then
250+
val cls = sym.owner.asClass
251+
addParamArgRefinements(ownType, argTypes, cls).capturing(capturedVars(cls))
252+
.showing(i"constr type $mt with $argTypes%, % in $cls = $result", capt)
253+
else ownType
210254

211255
override def recheckApply(tree: Apply, pt: Type)(using Context): Type =
212-
val sym = tree.symbol
213-
includeCallCaptures(sym, tree.srcPos)
214-
val cs = if sym.isConstructor then capturedVars(sym.owner) else CaptureSet.empty
215-
super.recheckApply(tree, pt).capturing(cs)
256+
includeCallCaptures(tree.symbol, tree.srcPos)
257+
super.recheckApply(tree, pt)
216258

217259
override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
218260
val saved = curEnv
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
class B
2+
type Cap = {*} B
3+
class C(val n: Cap):
4+
this: ({n} C) =>
5+
def foo(): {n} B = n
6+
7+
8+
def test(x: Cap, y: Cap) =
9+
val c0 = C(x)
10+
val c1: {x} C {val n: {x} B} = c0
11+
val z = c1.foo()
12+
z: ({x} B)
13+
14+
val c2 = if ??? then C(x) else identity(C(y))
15+
val c3: {x, y} C { val n: {x, y} B } = c2
16+
val z1 = c3.foo()
17+
z1: B @retains(x, y)

0 commit comments

Comments
 (0)