@@ -125,11 +125,30 @@ class CheckCaptures extends Recheck:
125
125
case _ =>
126
126
tp
127
127
128
+ /** Refine a possibly applied class type C where the class has tracked parameters
129
+ * x_1: T_1, ..., x_n: T_n to C { val x_1: CV_1 T_1, ..., val x_n: CV_n T_n }
130
+ * where CV_1, ..., CV_n are fresh capture sets.
131
+ */
132
+ def addCaptureRefinements (tp : Type ): Type = tp.stripped match
133
+ case _ : TypeRef | _ : AppliedType if tp.typeSymbol.isClass =>
134
+ val cls = tp.typeSymbol.asClass
135
+ cls.paramGetters.foldLeft(tp) { (core, getter) =>
136
+ if getter.termRef.isTracked then
137
+ val getterType = tp.memberInfo(getter).strippedDealias
138
+ RefinedType (core, getter.name, CapturingType (getterType, CaptureSet .Var ()))
139
+ .showing(i " add capture refinement $tp --> $result" , capt)
140
+ else
141
+ core
142
+ }
143
+ case _ =>
144
+ tp
145
+
128
146
def addVars (tp : Type ): Type =
129
- val tp1 = addInnerVars(tp)
147
+ var tp1 = addInnerVars(tp)
148
+ val tp2 = addCaptureRefinements(tp1)
130
149
if tp1.canHaveInferredCapture
131
- then CapturingType (tp1 , CaptureSet .Var ())
132
- else tp1
150
+ then CapturingType (tp2 , CaptureSet .Var ())
151
+ else tp2
133
152
134
153
addVars(cleanType(tp))
135
154
end reinfer
@@ -204,15 +223,38 @@ class CheckCaptures extends Recheck:
204
223
try super .recheckClassDef(tree, impl, sym)
205
224
finally curEnv = saved
206
225
207
- override def instantiate (mt : MethodType , argTypes : => List [Type ])(using Context ): Type =
208
- if mt.isResultDependent then SubstParamsMap (mt, argTypes)(mt.resType)
209
- else mt.resType
226
+ /** Refine the type of a constructor call `new C(t_1, ..., t_n)`
227
+ * to C{val x_1: T_1, ..., x_m: T_m} where x_1, ..., x_m are the tracked
228
+ * parameters of C and T_1, ..., T_m are the types of the corresponding arguments.
229
+ */
230
+ private def addParamArgRefinements (core : Type , argTypes : List [Type ], cls : ClassSymbol )(using Context ): Type =
231
+ cls.paramGetters.lazyZip(argTypes).foldLeft(core) { (core, refine) =>
232
+ val (getter, argType) = refine
233
+ if getter.termRef.isTracked then RefinedType (core, getter.name, argType)
234
+ else core
235
+ }
236
+
237
+ /** Handle an application of method `sym` with type `mt` to arguments of types `argTypes`.
238
+ * This means:
239
+ * - Instantiate result type with actual arguments
240
+ * - If call is to a constructor:
241
+ * - remember types of arguments corresponding to tracked
242
+ * parameters in refinements.
243
+ * - add capture set of instantiated class to capture set of result type.
244
+ */
245
+ override def instantiate (mt : MethodType , argTypes : List [Type ], sym : Symbol )(using Context ): Type =
246
+ val ownType =
247
+ if mt.isResultDependent then SubstParamsMap (mt, argTypes)(mt.resType)
248
+ else mt.resType
249
+ if sym.isConstructor then
250
+ val cls = sym.owner.asClass
251
+ addParamArgRefinements(ownType, argTypes, cls).capturing(capturedVars(cls))
252
+ .showing(i " constr type $mt with $argTypes%, % in $cls = $result" , capt)
253
+ else ownType
210
254
211
255
override def recheckApply (tree : Apply , pt : Type )(using Context ): Type =
212
- val sym = tree.symbol
213
- includeCallCaptures(sym, tree.srcPos)
214
- val cs = if sym.isConstructor then capturedVars(sym.owner) else CaptureSet .empty
215
- super .recheckApply(tree, pt).capturing(cs)
256
+ includeCallCaptures(tree.symbol, tree.srcPos)
257
+ super .recheckApply(tree, pt)
216
258
217
259
override def recheck (tree : Tree , pt : Type = WildcardType )(using Context ): Type =
218
260
val saved = curEnv
0 commit comments