diff --git a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala index 8967558c3ab3..b6b5d569677c 100644 --- a/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala +++ b/compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala @@ -38,7 +38,7 @@ object CheckCaptures: if sym.isAllOf(PrivateParamAccessor) && !sym.hasAnnotation(defn.ConstructorOnlyAnnot) then sym.copySymDenotation(initFlags = sym.flags &~ Private | Recheck.ResetPrivate) else if Synthetics.needsTransform(sym) then - Synthetics.transformToCC(sym) + Synthetics.transform(sym, toCC = true) else sym end Pre @@ -173,7 +173,7 @@ class CheckCaptures extends Recheck, SymTransformer: super.run override def transformSym(sym: SymDenotation)(using Context): SymDenotation = - if Synthetics.needsTransform(sym) then Synthetics.transformFromCC(sym) + if Synthetics.needsTransform(sym) then Synthetics.transform(sym, toCC = false) else super.transformSym(sym) class CaptureChecker(ictx: Context) extends Rechecker(ictx): diff --git a/compiler/src/dotty/tools/dotc/cc/Synthetics.scala b/compiler/src/dotty/tools/dotc/cc/Synthetics.scala index e1c0a4272276..1e7c8d641238 100644 --- a/compiler/src/dotty/tools/dotc/cc/Synthetics.scala +++ b/compiler/src/dotty/tools/dotc/cc/Synthetics.scala @@ -56,177 +56,155 @@ object Synthetics: || isSyntheticCompanionMethod(sym, nme.fromProduct) || needsTransform(sym)) - /** Add capture dependencies to the type of the `apply` or `copy` method of a case class. - * An apply method in a case class like this: - * case class CC(a: {d} A, b: B, {cap} c: C) - * would get type - * def apply(a': {d} A, b: B, {cap} c': C): {a', c'} CC { val a = {a'} A, val c = {c'} C } - * where `'` is used to indicate the difference between parameter symbol and refinement name. - * Analogous for the copy method. + /** Transform the type of a method either to its type under capture checking + * or back to its previous type. + * @param sym The method to transform @pre needsTransform(sym) must hold. + * @param toCC Whether to transform the type to capture checking or back. */ - private def addCaptureDeps(info: Type)(using Context): Type = info match - case info: MethodType => - val trackedParams = info.paramRefs.filter(atPhase(checkCapturesPhase)(_.isTracked)) - def augmentResult(tp: Type): Type = tp match - case tp: MethodOrPoly => - tp.derivedLambdaType(resType = augmentResult(tp.resType)) - case _ => - val refined = trackedParams.foldLeft(tp) { (parent, pref) => - RefinedType(parent, pref.paramName, - CapturingType( - atPhase(ctx.phase.next)(pref.underlying.stripCapturing), - CaptureSet(pref))) - } - CapturingType(refined, CaptureSet(trackedParams*)) - if trackedParams.isEmpty then info - else augmentResult(info).showing(i"augment apply/copy type $info to $result", capt) - case info: PolyType => - info.derivedLambdaType(resType = addCaptureDeps(info.resType)) - case _ => - info - - /** Drop capture dependencies from the type of `apply` or `copy` method of a case class */ - private def dropCaptureDeps(tp: Type)(using Context): Type = tp match - case tp: MethodOrPoly => - tp.derivedLambdaType(resType = dropCaptureDeps(tp.resType)) - case CapturingType(parent, _) => - dropCaptureDeps(parent) - case RefinedType(parent, _, _) => - dropCaptureDeps(parent) - case _ => - tp - - /** Add capture information to the type of the default getter of a case class copy method */ - private def addDefaultGetterCapture(info: Type, owner: Symbol, idx: Int)(using Context): Type = info match - case info: MethodOrPoly => - info.derivedLambdaType(resType = addDefaultGetterCapture(info.resType, owner, idx)) - case info: ExprType => - info.derivedExprType(addDefaultGetterCapture(info.resType, owner, idx)) - case EventuallyCapturingType(parent, _) => - addDefaultGetterCapture(parent, owner, idx) - case info @ AnnotatedType(parent, annot) => - info.derivedAnnotatedType(addDefaultGetterCapture(parent, owner, idx), annot) - case _ if idx < owner.asClass.paramGetters.length => - val param = owner.asClass.paramGetters(idx) - val pinfo = param.info - atPhase(ctx.phase.next) { - if pinfo.captureSet.isAlwaysEmpty then info - else CapturingType(pinfo.stripCapturing, CaptureSet(param.termRef)) - } - case _ => - info - - /** Drop capture information from the type of the default getter of a case class copy method */ - private def dropDefaultGetterCapture(info: Type)(using Context): Type = info match - case info: MethodOrPoly => - info.derivedLambdaType(resType = dropDefaultGetterCapture(info.resType)) - case CapturingType(parent, _) => - parent - case info @ AnnotatedType(parent, annot) => - info.derivedAnnotatedType(dropDefaultGetterCapture(parent), annot) - case _ => - info - - /** Augment an unapply of type `(x: C): D` to `(x: {cap} C): {x} D` */ - private def addUnapplyCaptures(info: Type)(using Context): Type = info match - case info: MethodType => - val paramInfo :: Nil = info.paramInfos: @unchecked - val newParamInfo = - CapturingType(paramInfo, CaptureSet.universal) - val trackedParam = info.paramRefs.head - def newResult(tp: Type): Type = tp match - case tp: MethodOrPoly => - tp.derivedLambdaType(resType = newResult(tp.resType)) - case _ => - CapturingType(tp, CaptureSet(trackedParam)) - info.derivedLambdaType(paramInfos = newParamInfo :: Nil, resType = newResult(info.resType)) - .showing(i"augment unapply type $info to $result", capt) - case info: PolyType => - info.derivedLambdaType(resType = addUnapplyCaptures(info.resType)) - - /** Drop added capture information from the type of an `unapply` */ - private def dropUnapplyCaptures(info: Type)(using Context): Type = info match - case info: MethodType => - info.paramInfos match - case CapturingType(oldParamInfo, _) :: Nil => - def oldResult(tp: Type): Type = tp match + def transform(sym: SymDenotation, toCC: Boolean)(using Context): SymDenotation = + + /** Add capture dependencies to the type of the `apply` or `copy` method of a case class. + * An apply method in a case class like this: + * case class CC(a: A^{d}, b: B, c: C^{cap}) + * would get type + * def apply(a': A^{d}, b: B, c': C^{cap}): CC^{a', c'} { val a = A^{a'}, val c = C^{c'} } + * where `'` is used to indicate the difference between parameter symbol and refinement name. + * Analogous for the copy method. + */ + def addCaptureDeps(info: Type): Type = info match + case info: MethodType => + val trackedParams = info.paramRefs.filter(atPhase(checkCapturesPhase)(_.isTracked)) + def augmentResult(tp: Type): Type = tp match + case tp: MethodOrPoly => + tp.derivedLambdaType(resType = augmentResult(tp.resType)) + case _ => + val refined = trackedParams.foldLeft(tp) { (parent, pref) => + RefinedType(parent, pref.paramName, + CapturingType( + atPhase(ctx.phase.next)(pref.underlying.stripCapturing), + CaptureSet(pref))) + } + CapturingType(refined, CaptureSet(trackedParams*)) + if trackedParams.isEmpty then info + else augmentResult(info).showing(i"augment apply/copy type $info to $result", capt) + case info: PolyType => + info.derivedLambdaType(resType = addCaptureDeps(info.resType)) + case _ => + info + + /** Drop capture dependencies from the type of `apply` or `copy` method of a case class */ + def dropCaptureDeps(tp: Type): Type = tp match + case tp: MethodOrPoly => + tp.derivedLambdaType(resType = dropCaptureDeps(tp.resType)) + case CapturingType(parent, _) => + dropCaptureDeps(parent) + case RefinedType(parent, _, _) => + dropCaptureDeps(parent) + case _ => + tp + + /** Add capture information to the type of the default getter of a case class copy method + * if toCC = true, or remove the added info again if toCC = false. + */ + def transformDefaultGetterCaptures(info: Type, owner: Symbol, idx: Int)(using Context): Type = info match + case info: MethodOrPoly => + info.derivedLambdaType(resType = transformDefaultGetterCaptures(info.resType, owner, idx)) + case info: ExprType => + info.derivedExprType(transformDefaultGetterCaptures(info.resType, owner, idx)) + case EventuallyCapturingType(parent, _) => + if toCC then transformDefaultGetterCaptures(parent, owner, idx) + else parent + case info @ AnnotatedType(parent, annot) => + info.derivedAnnotatedType(transformDefaultGetterCaptures(parent, owner, idx), annot) + case _ if toCC && idx < owner.asClass.paramGetters.length => + val param = owner.asClass.paramGetters(idx) + val pinfo = param.info + atPhase(ctx.phase.next) { + if pinfo.captureSet.isAlwaysEmpty then info + else CapturingType(pinfo.stripCapturing, CaptureSet(param.termRef)) + } + case _ => + info + + /** Augment an unapply of type `(x: C): D` to `(x: C^{cap}): D^{x}` if toCC is true, + * or remove the added capture sets again if toCC = false. + */ + def transformUnapplyCaptures(info: Type)(using Context): Type = info match + case info: MethodType => + if toCC then + val paramInfo :: Nil = info.paramInfos: @unchecked + val newParamInfo = CapturingType(paramInfo, CaptureSet.universal) + val trackedParam = info.paramRefs.head + def newResult(tp: Type): Type = tp match case tp: MethodOrPoly => - tp.derivedLambdaType(resType = oldResult(tp.resType)) - case CapturingType(tp, _) => - tp - info.derivedLambdaType(paramInfos = oldParamInfo :: Nil, resType = oldResult(info.resType)) - case _ => - info - case info: PolyType => - info.derivedLambdaType(resType = dropUnapplyCaptures(info.resType)) - - private def transformComposeCaptures(symd: SymDenotation, toCC: Boolean)(using Context): Type = - val (pt: PolyType) = symd.info: @unchecked - val (mt: MethodType) = pt.resType: @unchecked - val (enclThis: ThisType) = symd.owner.thisType: @unchecked - val mt1 = + tp.derivedLambdaType(resType = newResult(tp.resType)) + case _ => + CapturingType(tp, CaptureSet(trackedParam)) + info.derivedLambdaType(paramInfos = newParamInfo :: Nil, resType = newResult(info.resType)) + .showing(i"augment unapply type $info to $result", capt) + else info.paramInfos match + case CapturingType(oldParamInfo, _) :: Nil => + def oldResult(tp: Type): Type = tp match + case tp: MethodOrPoly => + tp.derivedLambdaType(resType = oldResult(tp.resType)) + case CapturingType(tp, _) => + tp + info.derivedLambdaType(paramInfos = oldParamInfo :: Nil, resType = oldResult(info.resType)) + case _ => + info + case info: PolyType => + info.derivedLambdaType(resType = transformUnapplyCaptures(info.resType)) + + def transformComposeCaptures(symd: SymDenotation) = + val (pt: PolyType) = symd.info: @unchecked + val (mt: MethodType) = pt.resType: @unchecked + val (enclThis: ThisType) = symd.owner.thisType: @unchecked + val mt1 = + if toCC then + MethodType(mt.paramNames)( + mt1 => mt.paramInfos.map(_.capturing(CaptureSet.universal)), + mt1 => CapturingType(mt.resType, CaptureSet(enclThis, mt1.paramRefs.head))) + else + MethodType(mt.paramNames)( + mt1 => mt.paramInfos.map(_.stripCapturing), + mt1 => mt.resType.stripCapturing) + pt.derivedLambdaType(resType = mt1) + + def transformCurriedTupledCaptures(symd: SymDenotation) = + val (et: ExprType) = symd.info: @unchecked + val (enclThis: ThisType) = symd.owner.thisType: @unchecked + def mapFinalResult(tp: Type, f: Type => Type): Type = + val defn.FunctionOf(args, res, isContextual) = tp: @unchecked + if defn.isFunctionNType(res) then + defn.FunctionOf(args, mapFinalResult(res, f), isContextual) + else + f(tp) + val resType1 = + if toCC then + mapFinalResult(et.resType, CapturingType(_, CaptureSet(enclThis))) + else + et.resType.stripCapturing + ExprType(resType1) + + def transformCompareCaptures = if toCC then - MethodType(mt.paramNames)( - mt1 => mt.paramInfos.map(_.capturing(CaptureSet.universal)), - mt1 => CapturingType(mt.resType, CaptureSet(enclThis, mt1.paramRefs.head))) + MethodType(defn.ObjectType.capturing(CaptureSet.universal) :: Nil, defn.BooleanType) else - MethodType(mt.paramNames)( - mt1 => mt.paramInfos.map(_.stripCapturing), - mt1 => mt.resType.stripCapturing) - pt.derivedLambdaType(resType = mt1) - - def transformCurriedTupledCaptures(symd: SymDenotation, toCC: Boolean)(using Context): Type = - val (et: ExprType) = symd.info: @unchecked - val (enclThis: ThisType) = symd.owner.thisType: @unchecked - def mapFinalResult(tp: Type, f: Type => Type): Type = - val defn.FunctionOf(args, res, isContextual) = tp: @unchecked - if defn.isFunctionNType(res) then - defn.FunctionOf(args, mapFinalResult(res, f), isContextual) - else - f(tp) - val resType1 = - if toCC then - mapFinalResult(et.resType, CapturingType(_, CaptureSet(enclThis))) - else - et.resType.stripCapturing - ExprType(resType1) - - /** If `sym` refers to a synthetic apply, unapply, copy, or copy default getter method - * of a case class, transform it to account for capture information. - * The method is run in phase CheckCaptures.Pre - * @pre needsTransform(sym) - */ - def transformToCC(sym: SymDenotation)(using Context): SymDenotation = sym.name match - case DefaultGetterName(nme.copy, n) => - sym.copySymDenotation(info = addDefaultGetterCapture(sym.info, sym.owner, n)) - case nme.unapply => - sym.copySymDenotation(info = addUnapplyCaptures(sym.info)) - case nme.apply | nme.copy => - sym.copySymDenotation(info = addCaptureDeps(sym.info)) - case nme.andThen | nme.compose => - sym.copySymDenotation(info = transformComposeCaptures(sym, toCC = true)) - case nme.curried | nme.tupled => - sym.copySymDenotation(info = transformCurriedTupledCaptures(sym, toCC = true)) - case n if n == nme.eq || n == nme.ne => - sym.copySymDenotation(info = - MethodType(defn.ObjectType.capturing(CaptureSet.universal) :: Nil, defn.BooleanType)) - - /** If `sym` refers to a synthetic apply, unapply, copy, or copy default getter method - * of a case class, transform it back to what it was before the CC phase. - * @pre needsTransform(sym) - */ - def transformFromCC(sym: SymDenotation)(using Context): SymDenotation = sym.name match - case DefaultGetterName(nme.copy, n) => - sym.copySymDenotation(info = dropDefaultGetterCapture(sym.info)) - case nme.unapply => - sym.copySymDenotation(info = dropUnapplyCaptures(sym.info)) - case nme.apply | nme.copy => - sym.copySymDenotation(info = dropCaptureDeps(sym.info)) - case nme.andThen | nme.compose => - sym.copySymDenotation(info = transformComposeCaptures(sym, toCC = false)) - case nme.curried | nme.tupled => - sym.copySymDenotation(info = transformCurriedTupledCaptures(sym, toCC = false)) - case n if n == nme.eq || n == nme.ne => - sym.copySymDenotation(info = defn.methOfAnyRef(defn.BooleanType)) + defn.methOfAnyRef(defn.BooleanType) + + sym.copySymDenotation(info = sym.name match + case DefaultGetterName(nme.copy, n) => + transformDefaultGetterCaptures(sym.info, sym.owner, n) + case nme.unapply => + transformUnapplyCaptures(sym.info) + case nme.apply | nme.copy => + if toCC then addCaptureDeps(sym.info) else dropCaptureDeps(sym.info) + case nme.andThen | nme.compose => + transformComposeCaptures(sym) + case nme.curried | nme.tupled => + transformCurriedTupledCaptures(sym) + case n if n == nme.eq || n == nme.ne => + transformCompareCaptures) + end transform end Synthetics \ No newline at end of file