Skip to content

Commit 6bf993c

Browse files
committed
Refine handling of val-bound closures
Don't treat them as level roots if they are implicit eta expansions that don't mention `cap` explicitly.
1 parent b6132d9 commit 6bf993c

File tree

4 files changed

+41
-9
lines changed

4 files changed

+41
-9
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,23 @@ trait TypedTreeInfo extends TreeInfo[Type] { self: Trees.Instance[Type] =>
813813
case _ => tree
814814
}
815815

816+
/** An extractor for eta expanded `mdef` an eta-expansion of a method reference? To recognize this, we use
817+
* the following criterion: A method definition is an eta expansion, if
818+
* it contains at least one term paramter, the parameter has a zero extent span,
819+
* and the right hand side is either an application or a closure with'
820+
* an anonymous method that's itself characterized as an eta expansion.
821+
*/
822+
def isEtaExpansion(mdef: DefDef)(using Context): Boolean =
823+
!rhsOfEtaExpansion(mdef).isEmpty
824+
825+
def rhsOfEtaExpansion(mdef: DefDef)(using Context): Tree = mdef.paramss match
826+
case (param :: _) :: _ if param.asInstanceOf[Tree].span.isZeroExtent =>
827+
mdef.rhs match
828+
case rhs: Apply => rhs
829+
case closureDef(mdef1) => rhsOfEtaExpansion(mdef1)
830+
case _ => EmptyTree
831+
case _ => EmptyTree
832+
816833
/** The variables defined by a pattern, in reverse order of their appearance. */
817834
def patVars(tree: Tree)(using Context): List[Symbol] = {
818835
val acc = new TreeAccumulator[List[Symbol]] { outer =>

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -606,13 +606,6 @@ class CheckCaptures extends Recheck, SymTransformer:
606606
// rechecking the body.
607607
openClosures = (mdef.symbol, pt) :: openClosures
608608
try
609-
def isEtaExpansion(mdef: DefDef): Boolean = mdef.paramss match
610-
case (param :: _) :: _ if param.asInstanceOf[Tree].span.isZeroExtent =>
611-
mdef.rhs match
612-
case _: Apply => true
613-
case closureDef(mdef1) => isEtaExpansion(mdef1)
614-
case _ => false
615-
case _ => false
616609
val res = recheckClosure(expr, pt, forceDependent = true)
617610
if !isEtaExpansion(mdef) then
618611
// If closure is an eta expanded method reference it's better to not constrain

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,10 +314,23 @@ extends tpd.TreeTraverser:
314314
case _ =>
315315
traverseChildren(tree)
316316
case tree @ ValDef(_, tpt: TypeTree, rhs) =>
317+
def containsCap(tp: Type) = tp.existsPart:
318+
case CapturingType(_, refs) => refs.isUniversal
319+
case _ => false
320+
def mentionsCap(tree: Tree): Boolean = tree match
321+
case Apply(fn, _) => mentionsCap(fn)
322+
case TypeApply(fn, args) => args.exists(mentionsCap)
323+
case _: InferredTypeTree => false
324+
case _: TypeTree => containsCap(expandAliases(tree.tpe))
325+
case _ => false
317326
val mapRoots = rhs match
318-
case possiblyTypedClosureDef(ddef) =>
327+
case possiblyTypedClosureDef(ddef) if !mentionsCap(rhsOfEtaExpansion(ddef)) =>
319328
ddef.symbol.setNestingLevel(ctx.owner.nestingLevel + 1)
320-
// toplevel closures bound to vals count as level owners
329+
// Toplevel closures bound to vals count as level owners
330+
// unless the closure is an implicit eta expansion over a type application
331+
// that mentions `cap`. In that case we prefer not to silently rebind
332+
// the `cap` to a local root of an invisible closure. See
333+
// pos-custom-args/captures/eta-expansions.scala for examples of both cases.
321334
!tpt.isInstanceOf[InferredTypeTree]
322335
// in this case roots in inferred val type count as polymorphic
323336
case _ =>
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
@annotation.capability class Cap
2+
3+
def test(d: Cap) =
4+
def map2(xs: List[Int])(f: Int => Int): List[Int] = xs.map(f)
5+
val f1 = map2 // capture polymorphic implicit eta expansion
6+
def f2c: List[Int] => (Int => Int) => List[Int] = f1
7+
val a0 = identity[Cap ->{d} Unit] // capture monomorphic implicit eta expansion
8+
val a0c: (Cap ->{d} Unit) ->{d} Cap ->{d} Unit = a0
9+
val b0 = (x: Cap ->{d} Unit) => identity[Cap ->{d} Unit](x) // not an implicit eta expansion, hence capture polymorphic

0 commit comments

Comments
 (0)