Skip to content

Apply box adaptation when checking overrides #16479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jan 13, 2023
5 changes: 5 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CaptureOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,11 @@ extension (tp: Type)
case CapturingType(_, _) => true
case _ => false

def isEventuallyCapturingType(using Context): Boolean =
tp match
case EventuallyCapturingType(_, _) => true
case _ => false

/** Is type known to be always pure by its class structure,
* so that adding a capture set to it would not make sense?
*/
Expand Down
10 changes: 10 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/CapturingType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ object CapturingType:
EventuallyCapturingType.unapply(tp)
else None

/** Check whether a type is uncachable when computing `baseType`.
* - Avoid caching all the types during the setup phase, since at that point
* the capture set variables are not fully installed yet.
* - Avoid caching capturing types when IgnoreCaptures mode is set, since the
* capture sets may be thrown away in the computed base type.
*/
def isUncachable(tp: Type)(using Context): Boolean =
ctx.phase == Phases.checkCapturesPhase &&
(Setup.isDuringSetup || ctx.mode.is(Mode.IgnoreCaptures) && tp.isEventuallyCapturingType)

end CapturingType

/** An extractor for types that will be capturing types at phase CheckCaptures. Also
Expand Down
70 changes: 49 additions & 21 deletions compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import config.Printers.{capt, recheckr}
import config.{Config, Feature}
import ast.{tpd, untpd, Trees}
import Trees.*
import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents}
import typer.RefChecks.{checkAllOverrides, checkSelfAgainstParents, OverridingPairsChecker}
import typer.Checking.{checkBounds, checkAppliedTypesIn}
import util.{SimpleIdentitySet, EqHashMap, SrcPos}
import transform.SymUtils.*
Expand Down Expand Up @@ -141,25 +141,12 @@ class CheckCaptures extends Recheck, SymTransformer:

override def run(using Context): Unit =
if Feature.ccEnabled then
checkOverrides.traverse(ctx.compilationUnit.tpdTree)
super.run

override def transformSym(sym: SymDenotation)(using Context): SymDenotation =
if Synthetics.needsTransform(sym) then Synthetics.transformFromCC(sym)
else super.transformSym(sym)

/** Check overrides again, taking capture sets into account.
* TODO: Can we avoid doing overrides checks twice?
* We need to do them here since only at this phase CaptureTypes are relevant
* But maybe we can then elide the check during the RefChecks phase under captureChecking?
*/
def checkOverrides = new TreeTraverser:
def traverse(t: Tree)(using Context) =
t match
case t: Template => checkAllOverrides(ctx.owner.asClass)
case _ =>
traverseChildren(t)

class CaptureChecker(ictx: Context) extends Rechecker(ictx):
import ast.tpd.*

Expand Down Expand Up @@ -668,8 +655,11 @@ class CheckCaptures extends Recheck, SymTransformer:
case _ =>
expected

/** Adapt `actual` type to `expected` type by inserting boxing and unboxing conversions */
def adaptBoxed(actual: Type, expected: Type, pos: SrcPos)(using Context): Type =
/** Adapt `actual` type to `expected` type by inserting boxing and unboxing conversions
*
* @param alwaysConst always make capture set variables constant after adaptation
*/
def adaptBoxed(actual: Type, expected: Type, pos: SrcPos, alwaysConst: Boolean = false)(using Context): Type =

/** Adapt function type `actual`, which is `aargs -> ares` (possibly with dependencies)
* to `expected` type.
Expand Down Expand Up @@ -746,7 +736,8 @@ class CheckCaptures extends Recheck, SymTransformer:
else
((parent, cs, tp.isBoxed), reconstruct)
case actual =>
((actual, CaptureSet(), false), reconstruct)
val res = if tp.isFromJavaObject then tp else actual
((res, CaptureSet(), false), reconstruct)

def adapt(actual: Type, expected: Type, covariant: Boolean): Type = trace(adaptInfo(actual, expected, covariant), recheckr, show = true) {
if expected.isInstanceOf[WildcardType] then actual
Expand Down Expand Up @@ -806,9 +797,9 @@ class CheckCaptures extends Recheck, SymTransformer:
}
if !insertBox then // unboxing
markFree(criticalSet, pos)
recon(CapturingType(parent1, cs1, !actualIsBoxed))
recon(CapturingType(parent1, if alwaysConst then CaptureSet(cs1.elems) else cs1, !actualIsBoxed))
else
recon(CapturingType(parent1, cs1, actualIsBoxed))
recon(CapturingType(parent1, if alwaysConst then CaptureSet(cs1.elems) else cs1, actualIsBoxed))
}

var actualw = actual.widenDealias
Expand All @@ -827,12 +818,49 @@ class CheckCaptures extends Recheck, SymTransformer:
else actual
end adaptBoxed

/** Check overrides again, taking capture sets into account.
* TODO: Can we avoid doing overrides checks twice?
* We need to do them here since only at this phase CaptureTypes are relevant
* But maybe we can then elide the check during the RefChecks phase under captureChecking?
*/
def checkOverrides = new TreeTraverser:
class OverridingPairsCheckerCC(clazz: ClassSymbol, self: Type, srcPos: SrcPos)(using Context) extends OverridingPairsChecker(clazz, self) {
/** Check subtype with box adaptation.
* This function is passed to RefChecks to check the compatibility of overriding pairs.
* @param sym symbol of the field definition that is being checked
*/
override def checkSubType(actual: Type, expected: Type)(using Context): Boolean =
val expected1 = alignDependentFunction(addOuterRefs(expected, actual), actual.stripCapturing)
val actual1 =
val saved = curEnv
try
curEnv = Env(clazz, nestedInOwner = true, capturedVars(clazz), isBoxed = false, outer0 = curEnv)
val adapted = adaptBoxed(actual, expected1, srcPos, alwaysConst = true)
actual match
case _: MethodType =>
// We remove the capture set resulted from box adaptation for method types,
// since class methods are always treated as pure, and their captured variables
// are charged to the capture set of the class (which is already done during
// box adaptation).
adapted.stripCapturing
case _ => adapted
finally curEnv = saved
actual1 frozen_<:< expected1
}

def traverse(t: Tree)(using Context) =
t match
case t: Template =>
checkAllOverrides(ctx.owner.asClass, OverridingPairsCheckerCC(_, _, t))
case _ =>
traverseChildren(t)

override def checkUnit(unit: CompilationUnit)(using Context): Unit =
Setup(preRecheckPhase, thisPhase, recheckDef)
.traverse(ctx.compilationUnit.tpdTree)
Setup(preRecheckPhase, thisPhase, recheckDef)(ctx.compilationUnit.tpdTree)
//println(i"SETUP:\n${Recheck.addRecheckedTypes.transform(ctx.compilationUnit.tpdTree)}")
withCaptureSetsExplained {
super.checkUnit(unit)
checkOverrides.traverse(unit.tpdTree)
checkSelfTypes(unit.tpdTree)
postCheck(unit.tpdTree)
if ctx.settings.YccDebug.value then
Expand Down
11 changes: 11 additions & 0 deletions compiler/src/dotty/tools/dotc/cc/Setup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import ast.tpd
import transform.Recheck.*
import CaptureSet.IdentityCaptRefMap
import Synthetics.isExcluded
import util.Property

/** A tree traverser that prepares a compilation unit to be capture checked.
* It does the following:
Expand Down Expand Up @@ -484,4 +485,14 @@ extends tpd.TreeTraverser:
capt.println(i"update info of ${tree.symbol} from $info to $newInfo")
case _ =>
end traverse

def apply(tree: Tree)(using Context): Unit =
traverse(tree)(using ctx.withProperty(Setup.IsDuringSetupKey, Some(())))
end Setup

object Setup:
val IsDuringSetupKey = new Property.Key[Unit]

def isDuringSetup(using Context): Boolean =
ctx.property(IsDuringSetupKey).isDefined

9 changes: 5 additions & 4 deletions compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import config.Config
import reporting._
import collection.mutable
import transform.TypeUtils._
import cc.{CapturingType, derivedCapturingType}
import cc.{CapturingType, derivedCapturingType, Setup, EventuallyCapturingType, isEventuallyCapturingType}

import scala.annotation.internal.sharable

Expand Down Expand Up @@ -2147,7 +2147,7 @@ object SymDenotations {
Stats.record("basetype cache entries")
if (!baseTp.exists) Stats.record("basetype cache NoTypes")
}
if (!tp.isProvisional)
if (!tp.isProvisional && !CapturingType.isUncachable(tp))
btrCache(tp) = baseTp
else
btrCache.remove(tp) // Remove any potential sentinel value
Expand All @@ -2161,8 +2161,9 @@ object SymDenotations {
def recur(tp: Type): Type = try {
tp match {
case tp: CachedType =>
val baseTp = btrCache.lookup(tp)
if (baseTp != null) return ensureAcyclic(baseTp)
val baseTp: Type | Null = btrCache.lookup(tp)
if (baseTp != null)
return ensureAcyclic(baseTp)
case _ =>
}
if (Stats.monitored) {
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/dotty/tools/dotc/core/Types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1071,12 +1071,15 @@ object Types {
* @param relaxedCheck if true type `Null` becomes a subtype of non-primitive value types in TypeComparer.
* @param matchLoosely if true the types `=> T` and `()T` are seen as overriding each other.
* @param checkClassInfo if true we check that ClassInfos are within bounds of abstract types
*
* @param isSubType a function used for checking subtype relationships.
*/
final def overrides(that: Type, relaxedCheck: Boolean, matchLoosely: => Boolean, checkClassInfo: Boolean = true)(using Context): Boolean = {
final def overrides(that: Type, relaxedCheck: Boolean, matchLoosely: => Boolean, checkClassInfo: Boolean = true,
isSubType: (Type, Type) => Context ?=> Boolean = (tp1, tp2) => tp1 frozen_<:< tp2)(using Context): Boolean = {
val overrideCtx = if relaxedCheck then ctx.relaxedOverrideContext else ctx
inContext(overrideCtx) {
!checkClassInfo && this.isInstanceOf[ClassInfo]
|| (this.widenExpr frozen_<:< that.widenExpr)
|| isSubType(this.widenExpr, that.widenExpr)
|| matchLoosely && {
val this1 = this.widenNullaryMethod
val that1 = that.widenNullaryMethod
Expand Down
11 changes: 7 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/OverridingPairs.scala
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,13 @@ object OverridingPairs:
/** Let `member` and `other` be members of some common class C with types
* `memberTp` and `otherTp` in C. Are the two symbols considered an overriding
* pair in C? We assume that names already match so we test only the types here.
* @param fallBack A function called if the initial test is false and
* `member` and `other` are term symbols.
* @param fallBack A function called if the initial test is false and
* `member` and `other` are term symbols.
* @param isSubType A function to be used for checking subtype relationships
* between term fields.
*/
def isOverridingPair(member: Symbol, memberTp: Type, other: Symbol, otherTp: Type, fallBack: => Boolean = false)(using Context): Boolean =
def isOverridingPair(member: Symbol, memberTp: Type, other: Symbol, otherTp: Type, fallBack: => Boolean = false,
isSubType: (Type, Type) => Context ?=> Boolean = (tp1, tp2) => tp1 frozen_<:< tp2)(using Context): Boolean =
if member.isType then // intersection of bounds to refined types must be nonempty
memberTp.bounds.hi.hasSameKindAs(otherTp.bounds.hi)
&& (
Expand All @@ -222,6 +225,6 @@ object OverridingPairs:
val relaxedOverriding = ctx.explicitNulls && (member.is(JavaDefined) || other.is(JavaDefined))
member.name.is(DefaultGetterName) // default getters are not checked for compatibility
|| memberTp.overrides(otherTp, relaxedOverriding,
member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack)
member.matchNullaryLoosely || other.matchNullaryLoosely || fallBack, isSubType = isSubType)

end OverridingPairs
61 changes: 41 additions & 20 deletions compiler/src/dotty/tools/dotc/typer/RefChecks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,16 @@ object RefChecks {
&& inLinearizationOrder(sym1, sym2, parent)
&& !sym2.is(AbsOverride)

def checkAll(checkOverride: (Symbol, Symbol) => Unit) =
// Checks the subtype relationship tp1 <:< tp2.
// It is passed to the `checkOverride` operation in `checkAll`, to be used for
// compatibility checking.
def checkSubType(tp1: Type, tp2: Type)(using Context): Boolean = tp1 frozen_<:< tp2

private val subtypeChecker: (Type, Type) => Context ?=> Boolean = this.checkSubType

def checkAll(checkOverride: ((Type, Type) => Context ?=> Boolean, Symbol, Symbol) => Unit) =
while hasNext do
checkOverride(overriding, overridden)
checkOverride(subtypeChecker, overriding, overridden)
next()

// The OverridingPairs cursor does assume that concrete overrides abstract
Expand All @@ -253,7 +260,7 @@ object RefChecks {
if dcl.is(Deferred) then
for other <- dcl.allOverriddenSymbols do
if !other.is(Deferred) then
checkOverride(dcl, other)
checkOverride(checkSubType, dcl, other)
end checkAll
end OverridingPairsChecker

Expand Down Expand Up @@ -290,8 +297,11 @@ object RefChecks {
* TODO check that classes are not overridden
* TODO This still needs to be cleaned up; the current version is a straight port of what was there
* before, but it looks too complicated and method bodies are far too large.
*
* @param makeOverridePairsChecker A function for creating a OverridePairsChecker instance
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we might be able to streamline this, maybe by moving checkOverrides into OverridingPairsChecker. But we can do that later after this PR is merged.

* from the class symbol and the self type
*/
def checkAllOverrides(clazz: ClassSymbol)(using Context): Unit = {
def checkAllOverrides(clazz: ClassSymbol, makeOverridingPairsChecker: ((ClassSymbol, Type) => Context ?=> OverridingPairsChecker) | Null = null)(using Context): Unit = {
val self = clazz.thisType
val upwardsSelf = upwardsThisType(clazz)
var hasErrors = false
Expand Down Expand Up @@ -322,10 +332,17 @@ object RefChecks {
def infoStringWithLocation(sym: Symbol) =
err.infoString(sym, self, showLocation = true)

def isInheritedAccessor(mbr: Symbol, other: Symbol): Boolean =
mbr.is(ParamAccessor)
&& {
val next = ParamForwarding.inheritedAccessor(mbr)
next == other || isInheritedAccessor(next, other)
}

/* Check that all conditions for overriding `other` by `member`
* of class `clazz` are met.
*/
def checkOverride(member: Symbol, other: Symbol): Unit =
* of class `clazz` are met.
*/
def checkOverride(checkSubType: (Type, Type) => Context ?=> Boolean, member: Symbol, other: Symbol): Unit =
def memberTp(self: Type) =
if (member.isClass) TypeAlias(member.typeRef.EtaExpand(member.typeParams))
else self.memberInfo(member)
Expand All @@ -344,7 +361,8 @@ object RefChecks {
isOverridingPair(member, memberTp, other, otherTp,
fallBack = warnOnMigration(
overrideErrorMsg("no longer has compatible type"),
(if (member.owner == clazz) member else clazz).srcPos, version = `3.0`))
(if (member.owner == clazz) member else clazz).srcPos, version = `3.0`),
isSubType = checkSubType)
catch case ex: MissingType =>
// can happen when called with upwardsSelf as qualifier of memberTp and otherTp,
// because in that case we might access types that are not members of the qualifier.
Expand All @@ -356,7 +374,16 @@ object RefChecks {
* Type members are always assumed to match.
*/
def trueMatch: Boolean =
member.isType || memberTp(self).matches(otherTp(self))
member.isType || withMode(Mode.IgnoreCaptures) {
// `matches` does not perform box adaptation so the result here would be
// spurious during capture checking.
//
// Instead of parameterizing `matches` with the function for subtype checking
// with box adaptation, we simply ignore capture annotations here.
// This should be safe since the compatibility under box adaptation is already
// checked.
memberTp(self).matches(otherTp(self))
}

def emitOverrideError(fullmsg: Message) =
if (!(hasErrors && member.is(Synthetic) && member.is(Module))) {
Expand Down Expand Up @@ -491,7 +518,7 @@ object RefChecks {
else if (member.is(ModuleVal) && !other.isRealMethod && !other.isOneOf(DeferredOrLazy))
overrideError("may not override a concrete non-lazy value")
else if (member.is(Lazy, butNot = Module) && !other.isRealMethod && !other.is(Lazy) &&
!warnOnMigration(overrideErrorMsg("may not override a non-lazy value"), member.srcPos, version = `3.0`))
!warnOnMigration(overrideErrorMsg("may not override a non-lazy value"), member.srcPos, version = `3.0`))
overrideError("may not override a non-lazy value")
else if (other.is(Lazy) && !other.isRealMethod && !member.is(Lazy))
overrideError("must be declared lazy to override a lazy value")
Expand Down Expand Up @@ -524,14 +551,8 @@ object RefChecks {
overrideDeprecation("", member, other, "removed or renamed")
end checkOverride

def isInheritedAccessor(mbr: Symbol, other: Symbol): Boolean =
mbr.is(ParamAccessor)
&& {
val next = ParamForwarding.inheritedAccessor(mbr)
next == other || isInheritedAccessor(next, other)
}

OverridingPairsChecker(clazz, self).checkAll(checkOverride)
val checker = if makeOverridingPairsChecker == null then OverridingPairsChecker(clazz, self) else makeOverridingPairsChecker(clazz, self)
checker.checkAll(checkOverride)
printMixinOverrideErrors()

// Verifying a concrete class has nothing unimplemented.
Expand Down Expand Up @@ -575,7 +596,7 @@ object RefChecks {
clazz.nonPrivateMembersNamed(mbr.name)
.filterWithPredicate(
impl => isConcrete(impl.symbol)
&& mbrDenot.matchesLoosely(impl, alwaysCompareTypes = true))
&& withMode(Mode.IgnoreCaptures)(mbrDenot.matchesLoosely(impl, alwaysCompareTypes = true)))
.exists

/** The term symbols in this class and its baseclasses that are
Expand Down Expand Up @@ -722,7 +743,7 @@ object RefChecks {
def checkNoAbstractDecls(bc: Symbol): Unit = {
for (decl <- bc.info.decls)
if (decl.is(Deferred)) {
val impl = decl.matchingMember(clazz.thisType)
val impl = withMode(Mode.IgnoreCaptures)(decl.matchingMember(clazz.thisType))
if (impl == NoSymbol || decl.owner.isSubClass(impl.owner))
&& !ignoreDeferred(decl)
then
Expand Down
Loading