Skip to content

Simplify cc/Synthetics.scala #18273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ object CheckCaptures:
if sym.isAllOf(PrivateParamAccessor) && !sym.hasAnnotation(defn.ConstructorOnlyAnnot) then
sym.copySymDenotation(initFlags = sym.flags &~ Private | Recheck.ResetPrivate)
else if Synthetics.needsTransform(sym) then
Synthetics.transformToCC(sym)
Synthetics.transform(sym, toCC = true)
else
sym
end Pre
Expand Down Expand Up @@ -173,7 +173,7 @@ class CheckCaptures extends Recheck, SymTransformer:
super.run

override def transformSym(sym: SymDenotation)(using Context): SymDenotation =
if Synthetics.needsTransform(sym) then Synthetics.transformFromCC(sym)
if Synthetics.needsTransform(sym) then Synthetics.transform(sym, toCC = false)
else super.transformSym(sym)

class CaptureChecker(ictx: Context) extends Rechecker(ictx):
Expand Down
314 changes: 146 additions & 168 deletions compiler/src/dotty/tools/dotc/cc/Synthetics.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,177 +56,155 @@ object Synthetics:
|| isSyntheticCompanionMethod(sym, nme.fromProduct)
|| needsTransform(sym))

/** Add capture dependencies to the type of the `apply` or `copy` method of a case class.
* An apply method in a case class like this:
* case class CC(a: {d} A, b: B, {cap} c: C)
* would get type
* def apply(a': {d} A, b: B, {cap} c': C): {a', c'} CC { val a = {a'} A, val c = {c'} C }
* where `'` is used to indicate the difference between parameter symbol and refinement name.
* Analogous for the copy method.
/** Transform the type of a method either to its type under capture checking
* or back to its previous type.
* @param sym The method to transform @pre needsTransform(sym) must hold.
* @param toCC Whether to transform the type to capture checking or back.
*/
private def addCaptureDeps(info: Type)(using Context): Type = info match
case info: MethodType =>
val trackedParams = info.paramRefs.filter(atPhase(checkCapturesPhase)(_.isTracked))
def augmentResult(tp: Type): Type = tp match
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = augmentResult(tp.resType))
case _ =>
val refined = trackedParams.foldLeft(tp) { (parent, pref) =>
RefinedType(parent, pref.paramName,
CapturingType(
atPhase(ctx.phase.next)(pref.underlying.stripCapturing),
CaptureSet(pref)))
}
CapturingType(refined, CaptureSet(trackedParams*))
if trackedParams.isEmpty then info
else augmentResult(info).showing(i"augment apply/copy type $info to $result", capt)
case info: PolyType =>
info.derivedLambdaType(resType = addCaptureDeps(info.resType))
case _ =>
info

/** Drop capture dependencies from the type of `apply` or `copy` method of a case class */
private def dropCaptureDeps(tp: Type)(using Context): Type = tp match
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = dropCaptureDeps(tp.resType))
case CapturingType(parent, _) =>
dropCaptureDeps(parent)
case RefinedType(parent, _, _) =>
dropCaptureDeps(parent)
case _ =>
tp

/** Add capture information to the type of the default getter of a case class copy method */
private def addDefaultGetterCapture(info: Type, owner: Symbol, idx: Int)(using Context): Type = info match
case info: MethodOrPoly =>
info.derivedLambdaType(resType = addDefaultGetterCapture(info.resType, owner, idx))
case info: ExprType =>
info.derivedExprType(addDefaultGetterCapture(info.resType, owner, idx))
case EventuallyCapturingType(parent, _) =>
addDefaultGetterCapture(parent, owner, idx)
case info @ AnnotatedType(parent, annot) =>
info.derivedAnnotatedType(addDefaultGetterCapture(parent, owner, idx), annot)
case _ if idx < owner.asClass.paramGetters.length =>
val param = owner.asClass.paramGetters(idx)
val pinfo = param.info
atPhase(ctx.phase.next) {
if pinfo.captureSet.isAlwaysEmpty then info
else CapturingType(pinfo.stripCapturing, CaptureSet(param.termRef))
}
case _ =>
info

/** Drop capture information from the type of the default getter of a case class copy method */
private def dropDefaultGetterCapture(info: Type)(using Context): Type = info match
case info: MethodOrPoly =>
info.derivedLambdaType(resType = dropDefaultGetterCapture(info.resType))
case CapturingType(parent, _) =>
parent
case info @ AnnotatedType(parent, annot) =>
info.derivedAnnotatedType(dropDefaultGetterCapture(parent), annot)
case _ =>
info

/** Augment an unapply of type `(x: C): D` to `(x: {cap} C): {x} D` */
private def addUnapplyCaptures(info: Type)(using Context): Type = info match
case info: MethodType =>
val paramInfo :: Nil = info.paramInfos: @unchecked
val newParamInfo =
CapturingType(paramInfo, CaptureSet.universal)
val trackedParam = info.paramRefs.head
def newResult(tp: Type): Type = tp match
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = newResult(tp.resType))
case _ =>
CapturingType(tp, CaptureSet(trackedParam))
info.derivedLambdaType(paramInfos = newParamInfo :: Nil, resType = newResult(info.resType))
.showing(i"augment unapply type $info to $result", capt)
case info: PolyType =>
info.derivedLambdaType(resType = addUnapplyCaptures(info.resType))

/** Drop added capture information from the type of an `unapply` */
private def dropUnapplyCaptures(info: Type)(using Context): Type = info match
case info: MethodType =>
info.paramInfos match
case CapturingType(oldParamInfo, _) :: Nil =>
def oldResult(tp: Type): Type = tp match
def transform(sym: SymDenotation, toCC: Boolean)(using Context): SymDenotation =

/** Add capture dependencies to the type of the `apply` or `copy` method of a case class.
* An apply method in a case class like this:
* case class CC(a: A^{d}, b: B, c: C^{cap})
* would get type
* def apply(a': A^{d}, b: B, c': C^{cap}): CC^{a', c'} { val a = A^{a'}, val c = C^{c'} }
* where `'` is used to indicate the difference between parameter symbol and refinement name.
* Analogous for the copy method.
*/
def addCaptureDeps(info: Type): Type = info match
case info: MethodType =>
val trackedParams = info.paramRefs.filter(atPhase(checkCapturesPhase)(_.isTracked))
def augmentResult(tp: Type): Type = tp match
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = augmentResult(tp.resType))
case _ =>
val refined = trackedParams.foldLeft(tp) { (parent, pref) =>
RefinedType(parent, pref.paramName,
CapturingType(
atPhase(ctx.phase.next)(pref.underlying.stripCapturing),
CaptureSet(pref)))
}
CapturingType(refined, CaptureSet(trackedParams*))
if trackedParams.isEmpty then info
else augmentResult(info).showing(i"augment apply/copy type $info to $result", capt)
case info: PolyType =>
info.derivedLambdaType(resType = addCaptureDeps(info.resType))
case _ =>
info

/** Drop capture dependencies from the type of `apply` or `copy` method of a case class */
def dropCaptureDeps(tp: Type): Type = tp match
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = dropCaptureDeps(tp.resType))
case CapturingType(parent, _) =>
dropCaptureDeps(parent)
case RefinedType(parent, _, _) =>
dropCaptureDeps(parent)
case _ =>
tp

/** Add capture information to the type of the default getter of a case class copy method
* if toCC = true, or remove the added info again if toCC = false.
*/
def transformDefaultGetterCaptures(info: Type, owner: Symbol, idx: Int)(using Context): Type = info match
case info: MethodOrPoly =>
info.derivedLambdaType(resType = transformDefaultGetterCaptures(info.resType, owner, idx))
case info: ExprType =>
info.derivedExprType(transformDefaultGetterCaptures(info.resType, owner, idx))
case EventuallyCapturingType(parent, _) =>
if toCC then transformDefaultGetterCaptures(parent, owner, idx)
else parent
case info @ AnnotatedType(parent, annot) =>
info.derivedAnnotatedType(transformDefaultGetterCaptures(parent, owner, idx), annot)
case _ if toCC && idx < owner.asClass.paramGetters.length =>
val param = owner.asClass.paramGetters(idx)
val pinfo = param.info
atPhase(ctx.phase.next) {
if pinfo.captureSet.isAlwaysEmpty then info
else CapturingType(pinfo.stripCapturing, CaptureSet(param.termRef))
}
case _ =>
info

/** Augment an unapply of type `(x: C): D` to `(x: C^{cap}): D^{x}` if toCC is true,
* or remove the added capture sets again if toCC = false.
*/
def transformUnapplyCaptures(info: Type)(using Context): Type = info match
case info: MethodType =>
if toCC then
val paramInfo :: Nil = info.paramInfos: @unchecked
val newParamInfo = CapturingType(paramInfo, CaptureSet.universal)
val trackedParam = info.paramRefs.head
def newResult(tp: Type): Type = tp match
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = oldResult(tp.resType))
case CapturingType(tp, _) =>
tp
info.derivedLambdaType(paramInfos = oldParamInfo :: Nil, resType = oldResult(info.resType))
case _ =>
info
case info: PolyType =>
info.derivedLambdaType(resType = dropUnapplyCaptures(info.resType))

private def transformComposeCaptures(symd: SymDenotation, toCC: Boolean)(using Context): Type =
val (pt: PolyType) = symd.info: @unchecked
val (mt: MethodType) = pt.resType: @unchecked
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
val mt1 =
tp.derivedLambdaType(resType = newResult(tp.resType))
case _ =>
CapturingType(tp, CaptureSet(trackedParam))
info.derivedLambdaType(paramInfos = newParamInfo :: Nil, resType = newResult(info.resType))
.showing(i"augment unapply type $info to $result", capt)
else info.paramInfos match
case CapturingType(oldParamInfo, _) :: Nil =>
def oldResult(tp: Type): Type = tp match
case tp: MethodOrPoly =>
tp.derivedLambdaType(resType = oldResult(tp.resType))
case CapturingType(tp, _) =>
tp
info.derivedLambdaType(paramInfos = oldParamInfo :: Nil, resType = oldResult(info.resType))
case _ =>
info
case info: PolyType =>
info.derivedLambdaType(resType = transformUnapplyCaptures(info.resType))

def transformComposeCaptures(symd: SymDenotation) =
val (pt: PolyType) = symd.info: @unchecked
val (mt: MethodType) = pt.resType: @unchecked
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
val mt1 =
if toCC then
MethodType(mt.paramNames)(
mt1 => mt.paramInfos.map(_.capturing(CaptureSet.universal)),
mt1 => CapturingType(mt.resType, CaptureSet(enclThis, mt1.paramRefs.head)))
else
MethodType(mt.paramNames)(
mt1 => mt.paramInfos.map(_.stripCapturing),
mt1 => mt.resType.stripCapturing)
pt.derivedLambdaType(resType = mt1)

def transformCurriedTupledCaptures(symd: SymDenotation) =
val (et: ExprType) = symd.info: @unchecked
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
def mapFinalResult(tp: Type, f: Type => Type): Type =
val defn.FunctionOf(args, res, isContextual) = tp: @unchecked
if defn.isFunctionNType(res) then
defn.FunctionOf(args, mapFinalResult(res, f), isContextual)
else
f(tp)
val resType1 =
if toCC then
mapFinalResult(et.resType, CapturingType(_, CaptureSet(enclThis)))
else
et.resType.stripCapturing
ExprType(resType1)

def transformCompareCaptures =
if toCC then
MethodType(mt.paramNames)(
mt1 => mt.paramInfos.map(_.capturing(CaptureSet.universal)),
mt1 => CapturingType(mt.resType, CaptureSet(enclThis, mt1.paramRefs.head)))
MethodType(defn.ObjectType.capturing(CaptureSet.universal) :: Nil, defn.BooleanType)
else
MethodType(mt.paramNames)(
mt1 => mt.paramInfos.map(_.stripCapturing),
mt1 => mt.resType.stripCapturing)
pt.derivedLambdaType(resType = mt1)

def transformCurriedTupledCaptures(symd: SymDenotation, toCC: Boolean)(using Context): Type =
val (et: ExprType) = symd.info: @unchecked
val (enclThis: ThisType) = symd.owner.thisType: @unchecked
def mapFinalResult(tp: Type, f: Type => Type): Type =
val defn.FunctionOf(args, res, isContextual) = tp: @unchecked
if defn.isFunctionNType(res) then
defn.FunctionOf(args, mapFinalResult(res, f), isContextual)
else
f(tp)
val resType1 =
if toCC then
mapFinalResult(et.resType, CapturingType(_, CaptureSet(enclThis)))
else
et.resType.stripCapturing
ExprType(resType1)

/** If `sym` refers to a synthetic apply, unapply, copy, or copy default getter method
* of a case class, transform it to account for capture information.
* The method is run in phase CheckCaptures.Pre
* @pre needsTransform(sym)
*/
def transformToCC(sym: SymDenotation)(using Context): SymDenotation = sym.name match
case DefaultGetterName(nme.copy, n) =>
sym.copySymDenotation(info = addDefaultGetterCapture(sym.info, sym.owner, n))
case nme.unapply =>
sym.copySymDenotation(info = addUnapplyCaptures(sym.info))
case nme.apply | nme.copy =>
sym.copySymDenotation(info = addCaptureDeps(sym.info))
case nme.andThen | nme.compose =>
sym.copySymDenotation(info = transformComposeCaptures(sym, toCC = true))
case nme.curried | nme.tupled =>
sym.copySymDenotation(info = transformCurriedTupledCaptures(sym, toCC = true))
case n if n == nme.eq || n == nme.ne =>
sym.copySymDenotation(info =
MethodType(defn.ObjectType.capturing(CaptureSet.universal) :: Nil, defn.BooleanType))

/** If `sym` refers to a synthetic apply, unapply, copy, or copy default getter method
* of a case class, transform it back to what it was before the CC phase.
* @pre needsTransform(sym)
*/
def transformFromCC(sym: SymDenotation)(using Context): SymDenotation = sym.name match
case DefaultGetterName(nme.copy, n) =>
sym.copySymDenotation(info = dropDefaultGetterCapture(sym.info))
case nme.unapply =>
sym.copySymDenotation(info = dropUnapplyCaptures(sym.info))
case nme.apply | nme.copy =>
sym.copySymDenotation(info = dropCaptureDeps(sym.info))
case nme.andThen | nme.compose =>
sym.copySymDenotation(info = transformComposeCaptures(sym, toCC = false))
case nme.curried | nme.tupled =>
sym.copySymDenotation(info = transformCurriedTupledCaptures(sym, toCC = false))
case n if n == nme.eq || n == nme.ne =>
sym.copySymDenotation(info = defn.methOfAnyRef(defn.BooleanType))
defn.methOfAnyRef(defn.BooleanType)

sym.copySymDenotation(info = sym.name match
case DefaultGetterName(nme.copy, n) =>
transformDefaultGetterCaptures(sym.info, sym.owner, n)
case nme.unapply =>
transformUnapplyCaptures(sym.info)
case nme.apply | nme.copy =>
if toCC then addCaptureDeps(sym.info) else dropCaptureDeps(sym.info)
case nme.andThen | nme.compose =>
transformComposeCaptures(sym)
case nme.curried | nme.tupled =>
transformCurriedTupledCaptures(sym)
case n if n == nme.eq || n == nme.ne =>
transformCompareCaptures)
end transform

end Synthetics