diff --git a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala index 7c75ed833945..deeb474f018a 100644 --- a/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala +++ b/compiler/src/dotty/tools/dotc/cc/CaptureOps.scala @@ -445,6 +445,14 @@ extension (tp: AnnotatedType) case ann: CaptureAnnotation => ann.boxed case _ => false +/** Drop retains annotations in the type. */ +class CleanupRetains(using Context) extends TypeMap: + def apply(tp: Type): Type = + tp match + case AnnotatedType(tp, annot) if annot.symbol == defn.RetainsAnnot || annot.symbol == defn.RetainsByNameAnnot => + RetainingType(tp, Nil, byName = annot.symbol == defn.RetainsByNameAnnot) + case _ => mapOver(tp) + /** An extractor for `caps.reachCapability(ref)`, which is used to express a reach * capability as a tree in a @retains annotation. */ diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 3bcec80b5b10..a977694ded27 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -19,6 +19,7 @@ import config.Feature import util.SrcPos import reporting.* import NameKinds.WildcardParamName +import cc.* object PostTyper { val name: String = "posttyper" @@ -279,6 +280,21 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => if !tree.symbol.is(Package) then tree else errorTree(tree, em"${tree.symbol} cannot be used as a type") + // Cleans up retains annotations in inferred type trees. This is needed because + // during the typer, it is infeasible to correctly infer the capture sets in most + // cases, resulting ill-formed capture sets that could crash the pickler later on. + // See #20035. + private def cleanupRetainsAnnot(symbol: Symbol, tpt: Tree)(using Context): Tree = + tpt match + case tpt: InferredTypeTree + if !symbol.allOverriddenSymbols.hasNext => + // if there are overridden symbols, the annotation comes from an explicit type of the overridden symbol + // and should be retained. + val tm = new CleanupRetains + val tpe1 = tm(tpt.tpe) + tpt.withType(tpe1) + case _ => tpt + override def transform(tree: Tree)(using Context): Tree = try tree match { // TODO move CaseDef case lower: keep most probable trees first for performance @@ -388,7 +404,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => registerIfHasMacroAnnotations(tree) checkErasedDef(tree) Checking.checkPolyFunctionType(tree.tpt) - val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol)) + val tree1 = cpy.ValDef(tree)(tpt = cleanupRetainsAnnot(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol)) if tree1.removeAttachment(desugar.UntupledParam).isDefined then checkStableSelection(tree.rhs) processValOrDefDef(super.transform(tree1)) @@ -398,7 +414,7 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase => checkErasedDef(tree) Checking.checkPolyFunctionType(tree.tpt) annotateContextResults(tree) - val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol)) + val tree1 = cpy.DefDef(tree)(tpt = cleanupRetainsAnnot(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol)) processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef])) case tree: TypeDef => registerIfHasMacroAnnotations(tree) diff --git a/tests/neg-custom-args/captures/byname.check b/tests/neg-custom-args/captures/byname.check index 226bee2cd0e5..e06a3a1f8268 100644 --- a/tests/neg-custom-args/captures/byname.check +++ b/tests/neg-custom-args/captures/byname.check @@ -9,7 +9,7 @@ | Found: (x$0: Int) ->{cap2} Int | Required: (x$0: Int) -> Int | - | Note that the expected type Int => Int + | Note that the expected type Int ->{} Int | is the previously inferred result type of method test | which is also the type seen in separately compiled sources. | The new inferred type (x$0: Int) ->{cap2} Int diff --git a/tests/pos-custom-args/captures/tablediff.scala b/tests/pos-custom-args/captures/tablediff.scala new file mode 100644 index 000000000000..244ee1a46a23 --- /dev/null +++ b/tests/pos-custom-args/captures/tablediff.scala @@ -0,0 +1,11 @@ +import language.experimental.captureChecking + +trait Seq[+A]: + def zipAll[A1 >: A, B](that: Seq[B]^, thisElem: A1, thatElem: B): Seq[(A1, B)]^{this, that} + def map[B](f: A => B): Seq[B]^{this, f} + +def zipAllOption[X](left: Seq[X], right: Seq[X]) = + left.map(Option(_)).zipAll(right.map(Option(_)), None, None) + +def fillRow[T](headRow: Seq[T], tailRow: Seq[T]) = + val paddedZip = zipAllOption(headRow, tailRow)