Skip to content

Commit 92f6b8c

Browse files
authored
Several fixes to box inference (#16141)
Inspired by the formalism, this PR changes the box adaptation logic to fix a bunch of problems. The main fixes: - *Correctly charging the capture set of the type being adapted.* For example, the following code snippet should fail: ```scala type Id[X] = [T] -> (op: X -> T) -> T val x: Id[{io} Cap] = ??? x(cap => cap.use()) // error ``` Denote `cap => cap.use()` as `f`. The expected type is `(□ {io} Cap) -> Int`, while the actual type is `({io} Cap) -> Int`. The function should be adapted as `(x: □ {io} Cap) => f(unbox {io} x)`, whose capture set is `{io}` because of the unbox. - *Stop the newly-captured variables from escaping a box*. Specifically, assuming that `T` is a function type, if the actual type is `T` while the expected type is `□ U`, the variables captured when adapting `T` to `U` will not get charged to the outer environment. For example, the following code should compile: ```scala type Box[X] = X type Id[X] = Box[X] -> Unit type Op[X] = Unit -> Box[X] val f: Unit -> ({io} Cap) -> Unit = ??? val g: {} Op[{io} Id[{io} Cap]] = f ``` The expected type of `f` in the last statement is `{} Unit -> □ {io} (□ {io} Cap) -> Unit`. The function would be adapted as `(x: Unit) => let g = (y: □ {io} Cap) => f(unbox {io} y) in box g`. Note that although an `unbox` is inferenced, outermost capture set is still empty due to the `box` we insert, which conforms to the expectation. Besides, it also adds the support to adapting polymorphic function types.
2 parents 4c46cee + 65e42d3 commit 92f6b8c

16 files changed

+317
-76
lines changed

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,11 @@ extension (tp: Type)
121121
case _ =>
122122
tp
123123

124+
def isCapturingType(using Context): Boolean =
125+
tp match
126+
case CapturingType(_, _) => true
127+
case _ => false
128+
124129
extension (sym: Symbol)
125130

126131
/** Does this symbol allow results carrying the universal capability?

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 140 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,20 @@ object CheckCaptures:
4242
end Pre
4343

4444
/** A class describing environments.
45-
* @param owner the current owner
46-
* @param captured the caputure set containing all references to tracked free variables outside of boxes
47-
* @param isBoxed true if the environment is inside a box (in which case references are not counted)
48-
* @param outer0 the next enclosing environment
45+
* @param owner the current owner
46+
* @param nestedInOwner true if the environment is a temporary one nested in the owner's environment,
47+
* and does not have a different actual owner symbol (this happens when doing box adaptation).
48+
* @param captured the caputure set containing all references to tracked free variables outside of boxes
49+
* @param isBoxed true if the environment is inside a box (in which case references are not counted)
50+
* @param outer0 the next enclosing environment
4951
*/
50-
case class Env(owner: Symbol, captured: CaptureSet, isBoxed: Boolean, outer0: Env | Null):
52+
case class Env(
53+
owner: Symbol,
54+
nestedInOwner: Boolean,
55+
captured: CaptureSet,
56+
isBoxed: Boolean,
57+
outer0: Env | Null
58+
):
5159
def outer = outer0.nn
5260

5361
def isOutermost = outer0 == null
@@ -204,7 +212,7 @@ class CheckCaptures extends Recheck, SymTransformer:
204212
report.error(i"$header included in allowed capture set ${res.blocking}", pos)
205213

206214
/** The current environment */
207-
private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, isBoxed = false, null)
215+
private var curEnv: Env = Env(NoSymbol, nestedInOwner = false, CaptureSet.empty, isBoxed = false, null)
208216

209217
private val myCapturedVars: util.EqHashMap[Symbol, CaptureSet] = EqHashMap()
210218

@@ -249,8 +257,12 @@ class CheckCaptures extends Recheck, SymTransformer:
249257
if !cs.isAlwaysEmpty then
250258
forallOuterEnvsUpTo(ctx.owner.topLevelClass) { env =>
251259
val included = cs.filter {
252-
case ref: TermRef => env.owner.isProperlyContainedIn(ref.symbol.owner)
253-
case ref: ThisType => env.owner.isProperlyContainedIn(ref.cls)
260+
case ref: TermRef =>
261+
(env.nestedInOwner || env.owner != ref.symbol.owner)
262+
&& env.owner.isContainedIn(ref.symbol.owner)
263+
case ref: ThisType =>
264+
(env.nestedInOwner || env.owner != ref.cls)
265+
&& env.owner.isContainedIn(ref.cls)
254266
case _ => false
255267
}
256268
capt.println(i"Include call capture $included in ${env.owner}")
@@ -439,7 +451,7 @@ class CheckCaptures extends Recheck, SymTransformer:
439451
if !Synthetics.isExcluded(sym) then
440452
val saved = curEnv
441453
val localSet = capturedVars(sym)
442-
if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, isBoxed = false, curEnv)
454+
if !localSet.isAlwaysEmpty then curEnv = Env(sym, nestedInOwner = false, localSet, isBoxed = false, curEnv)
443455
try super.recheckDefDef(tree, sym)
444456
finally
445457
interpolateVarsIn(tree.tpt)
@@ -455,7 +467,7 @@ class CheckCaptures extends Recheck, SymTransformer:
455467
val localSet = capturedVars(cls)
456468
for parent <- impl.parents do // (1)
457469
checkSubset(capturedVars(parent.tpe.classSymbol), localSet, parent.srcPos)
458-
if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, isBoxed = false, curEnv)
470+
if !localSet.isAlwaysEmpty then curEnv = Env(cls, nestedInOwner = false, localSet, isBoxed = false, curEnv)
459471
try
460472
val thisSet = cls.classInfo.selfType.captureSet.withDescription(i"of the self type of $cls")
461473
checkSubset(localSet, thisSet, tree.srcPos) // (2)
@@ -502,7 +514,7 @@ class CheckCaptures extends Recheck, SymTransformer:
502514
override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
503515
if tree.isTerm && pt.isBoxedCapturing then
504516
val saved = curEnv
505-
curEnv = Env(curEnv.owner, CaptureSet.Var(), isBoxed = true, curEnv)
517+
curEnv = Env(curEnv.owner, nestedInOwner = false, CaptureSet.Var(), isBoxed = true, curEnv)
506518
try super.recheck(tree, pt)
507519
finally curEnv = saved
508520
else
@@ -593,25 +605,121 @@ class CheckCaptures extends Recheck, SymTransformer:
593605

594606
/** Adapt function type `actual`, which is `aargs -> ares` (possibly with dependencies)
595607
* to `expected` type.
608+
* It returns the adapted type along with the additionally captured variable
609+
* during adaptation.
596610
* @param reconstruct how to rebuild the adapted function type
597611
*/
598612
def adaptFun(actual: Type, aargs: List[Type], ares: Type, expected: Type,
599-
covariant: Boolean,
600-
reconstruct: (List[Type], Type) => Type): Type =
601-
val (eargs, eres) = expected.dealias match
602-
case defn.FunctionOf(eargs, eres, _, _) => (eargs, eres)
603-
case _ => (aargs.map(_ => WildcardType), WildcardType)
604-
val aargs1 = aargs.zipWithConserve(eargs)(adapt(_, _, !covariant))
605-
val ares1 = adapt(ares, eres, covariant)
606-
if (ares1 eq ares) && (aargs1 eq aargs) then actual
607-
else reconstruct(aargs1, ares1)
608-
609-
def adapt(actual: Type, expected: Type, covariant: Boolean): Type = actual.dealias match
610-
case actual @ CapturingType(parent, refs) =>
611-
val parent1 = adapt(parent, expected, covariant)
612-
if actual.isBoxed != expected.isBoxedCapturing then
613+
covariant: Boolean, boxed: Boolean,
614+
reconstruct: (List[Type], Type) => Type): (Type, CaptureSet) =
615+
val saved = curEnv
616+
curEnv = Env(curEnv.owner, nestedInOwner = true, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)
617+
618+
try
619+
val (eargs, eres) = expected.dealias.stripCapturing match
620+
case defn.FunctionOf(eargs, eres, _, _) => (eargs, eres)
621+
case expected: MethodType => (expected.paramInfos, expected.resType)
622+
case expected @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(expected) => (rinfo.paramInfos, rinfo.resType)
623+
case _ => (aargs.map(_ => WildcardType), WildcardType)
624+
val aargs1 = aargs.zipWithConserve(eargs) { (aarg, earg) => adapt(aarg, earg, !covariant) }
625+
val ares1 = adapt(ares, eres, covariant)
626+
627+
val resTp =
628+
if (ares1 eq ares) && (aargs1 eq aargs) then actual
629+
else reconstruct(aargs1, ares1)
630+
631+
(resTp, curEnv.captured)
632+
finally
633+
curEnv = saved
634+
635+
/** Adapt type function type `actual` to the expected type.
636+
* @see [[adaptFun]]
637+
*/
638+
def adaptTypeFun(
639+
actual: Type, ares: Type, expected: Type,
640+
covariant: Boolean, boxed: Boolean,
641+
reconstruct: Type => Type): (Type, CaptureSet) =
642+
val saved = curEnv
643+
curEnv = Env(curEnv.owner, nestedInOwner = true, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)
644+
645+
try
646+
val eres = expected.dealias.stripCapturing match
647+
case RefinedType(_, _, rinfo: PolyType) => rinfo.resType
648+
case expected: PolyType => expected.resType
649+
case _ => WildcardType
650+
651+
val ares1 = adapt(ares, eres, covariant)
652+
653+
val resTp =
654+
if ares1 eq ares then actual
655+
else reconstruct(ares1)
656+
657+
(resTp, curEnv.captured)
658+
finally
659+
curEnv = saved
660+
end adaptTypeFun
661+
662+
def adaptInfo(actual: Type, expected: Type, covariant: Boolean): String =
663+
val arrow = if covariant then "~~>" else "<~~"
664+
i"adapting $actual $arrow $expected"
665+
666+
/** Destruct a capturing type `tp` to a tuple (cs, tp0, boxed),
667+
* where `tp0` is not a capturing type.
668+
*
669+
* If `tp` is a nested capturing type, the return tuple always represents
670+
* the innermost capturing type. The outer capture annotations can be
671+
* reconstructed with the returned function.
672+
*/
673+
def destructCapturingType(tp: Type, reconstruct: Type => Type = x => x): ((Type, CaptureSet, Boolean), Type => Type) =
674+
tp.dealias match
675+
case tp @ CapturingType(parent, cs) =>
676+
if parent.dealias.isCapturingType then
677+
destructCapturingType(parent, res => reconstruct(tp.derivedCapturingType(res, cs)))
678+
else
679+
((parent, cs, tp.isBoxed), reconstruct)
680+
case actual =>
681+
((actual, CaptureSet(), false), reconstruct)
682+
683+
def adapt(actual: Type, expected: Type, covariant: Boolean): Type = trace(adaptInfo(actual, expected, covariant), recheckr, show = true) {
684+
if expected.isInstanceOf[WildcardType] then actual
685+
else
686+
val ((parent, cs, actualIsBoxed), recon) = destructCapturingType(actual)
687+
688+
val needsAdaptation = actualIsBoxed != expected.isBoxedCapturing
689+
val insertBox = needsAdaptation && covariant != actualIsBoxed
690+
691+
val (parent1, cs1) = parent match {
692+
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
693+
val (parent1, leaked) = adaptFun(parent, args.init, args.last, expected, covariant, insertBox,
694+
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
695+
(parent1, leaked ++ cs)
696+
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
697+
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
698+
val (parent1, leaked) = adaptFun(parent, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
699+
(aargs1, ares1) =>
700+
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
701+
.toFunctionType(isJava = false, alwaysDependent = true))
702+
(parent1, leaked ++ cs)
703+
case actual: MethodType =>
704+
val (parent1, leaked) = adaptFun(parent, actual.paramInfos, actual.resType, expected, covariant, insertBox,
705+
(aargs1, ares1) =>
706+
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
707+
(parent1, leaked ++ cs)
708+
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionOrPolyType(actual) =>
709+
val (parent1, leaked) = adaptTypeFun(parent, rinfo.resType, expected, covariant, insertBox,
710+
ares1 =>
711+
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
712+
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
713+
actual1
714+
)
715+
(parent1, leaked ++ cs)
716+
case _ =>
717+
(parent, cs)
718+
}
719+
720+
if needsAdaptation then
613721
val criticalSet = // the set which is not allowed to have `*`
614-
if covariant then refs // can't box with `*`
722+
if covariant then cs1 // can't box with `*`
615723
else expected.captureSet // can't unbox with `*`
616724
if criticalSet.isUniversal then
617725
// We can't box/unbox the universal capability. Leave `actual` as it is
@@ -627,20 +735,13 @@ class CheckCaptures extends Recheck, SymTransformer:
627735
|since one of their capture sets contains the root capability `*`""",
628736
pos)
629737
}
630-
if covariant == actual.isBoxed then markFree(refs, pos)
631-
CapturingType(parent1, refs, boxed = !actual.isBoxed)
738+
if !insertBox then // unboxing
739+
markFree(criticalSet, pos)
740+
recon(CapturingType(parent1, cs1, !actualIsBoxed))
632741
else
633-
actual.derivedCapturingType(parent1, refs)
634-
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
635-
adaptFun(actual, args.init, args.last, expected, covariant,
636-
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
637-
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
638-
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
639-
adaptFun(actual, rinfo.paramInfos, rinfo.resType, expected, covariant,
640-
(aargs1, ares1) =>
641-
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
642-
.toFunctionType(isJava = false, alwaysDependent = true))
643-
case _ => actual
742+
recon(CapturingType(parent1, cs1, actualIsBoxed))
743+
}
744+
644745

645746
var actualw = actual.widenDealias
646747
actual match

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,13 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
633633
case (info1: MethodType, info2: MethodType) =>
634634
matchingMethodParams(info1, info2, precise = false)
635635
&& isSubInfo(info1.resultType, info2.resultType.subst(info2, info1))
636+
case (info1 @ CapturingType(parent1, refs1), info2: Type) =>
637+
subCaptures(refs1, info2.captureSet, frozenConstraint).isOK && sameBoxed(info1, info2, refs1)
638+
&& isSubInfo(parent1, info2)
639+
case (info1: Type, CapturingType(parent2, refs2)) =>
640+
val refs1 = info1.captureSet
641+
(refs1.isAlwaysEmpty || subCaptures(refs1, refs2, frozenConstraint).isOK) && sameBoxed(info1, info2, refs1)
642+
&& isSubInfo(info1, parent2)
636643
case _ =>
637644
isSubType(info1, info2)
638645

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
trait Cap
2+
3+
def main(io: {*} Cap, fs: {*} Cap): Unit = {
4+
val test1: {} Unit -> Unit = _ => { // error
5+
type Op = [T] -> ({io} T -> Unit) -> Unit
6+
val f: ({io} Cap) -> Unit = ???
7+
val op: Op = ???
8+
op[{io} Cap](f)
9+
// expected type of f: {io} (box {io} Cap) -> Unit
10+
// actual type: ({io} Cap) -> Unit
11+
// adapting f to the expected type will also
12+
// charge the environment with {io}
13+
}
14+
15+
val test2: {} Unit -> Unit = _ => {
16+
type Box[X] = X
17+
type Op0[X] = Box[X] -> Unit
18+
type Op1[X] = Unit -> Box[X]
19+
val f: Unit -> ({io} Cap) -> Unit = ???
20+
val test: {} Op1[{io} Op0[{io} Cap]] = f
21+
// expected: {} Unit -> box {io} (box {io} Cap) -> Unit
22+
// actual: Unit -> ({io} Cap) -> Unit
23+
//
24+
// although adapting `({io} Cap) -> Unit` to
25+
// `box {io} (box {io} Cap) -> Unit` will leak the
26+
// captured variables {io}, but since it is inside a box,
27+
// we will charge neither the outer type nor the environment
28+
}
29+
30+
val test3 = {
31+
type Box[X] = X
32+
type Id[X] = Box[X] -> Unit
33+
type Op[X] = Unit -> Box[X]
34+
val f: Unit -> ({io} Cap) -> Unit = ???
35+
val g: Op[{fs} Id[{io} Cap]] = f // error
36+
val h: {} Op[{io} Id[{io} Cap]] = f
37+
}
38+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
trait Cap { def use(): Int }
2+
3+
def test1(): Unit = {
4+
type Id[X] = [T] -> (op: X => T) -> T
5+
6+
val x: Id[{*} Cap] = ???
7+
x(cap => cap.use()) // error
8+
}
9+
10+
def test2(io: {*} Cap): Unit = {
11+
type Id[X] = [T] -> (op: X -> T) -> T
12+
13+
val x: Id[{io} Cap] = ???
14+
x(cap => cap.use()) // error
15+
}
16+
17+
def test3(io: {*} Cap): Unit = {
18+
type Id[X] = [T] -> (op: {io} X -> T) -> T
19+
20+
val x: Id[{io} Cap] = ???
21+
x(cap => cap.use()) // ok
22+
}
23+
24+
def test4(io: {*} Cap, fs: {*} Cap): Unit = {
25+
type Id[X] = [T] -> (op: {io} X -> T) -> T
26+
27+
val x: Id[{io, fs} Cap] = ???
28+
x(cap => cap.use()) // error
29+
}
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
trait Cap
2+
3+
def test1(io: {*} Cap) = {
4+
type Op[X] = [T] -> Unit -> X
5+
val f: Op[{io} Cap] = ???
6+
val x: [T] -> Unit -> ({io} Cap) = f // error
7+
}
8+
9+
def test2(io: {*} Cap) = {
10+
type Op[X] = [T] -> Unit -> {io} X
11+
val f: Op[{io} Cap] = ???
12+
val x: Unit -> ({io} Cap) = f[Unit] // error
13+
val x1: {io} Unit -> ({io} Cap) = f[Unit] // ok
14+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
trait Cap { def use(): Int }
2+
3+
def test1(io: {*} Cap): Unit = {
4+
type Id[X] = [T] -> (op: {io} X -> T) -> T
5+
6+
val x: Id[{io} Cap] = ???
7+
val f: ({*} Cap) -> Unit = ???
8+
x(f) // ok
9+
// actual: {*} Cap -> Unit
10+
// expected: {io} box {io} Cap -> Unit
11+
}
12+
13+
def test2(io: {*} Cap): Unit = {
14+
type Id[X] = [T] -> (op: {*} X -> T) -> T
15+
16+
val x: Id[{*} Cap] = ???
17+
val f: ({io} Cap) -> Unit = ???
18+
x(f) // error
19+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
trait Cap { def use(): Int }
2+
3+
def test1(io: {*} Cap): Unit = {
4+
type Id[X] = [T] -> (op: {io} X -> T) -> T
5+
6+
val x: Id[{io} Cap] = ???
7+
x(cap => cap.use()) // ok
8+
}
9+
10+
def test2(io: {*} Cap): Unit = {
11+
type Id[X] = [T] -> (op: {io} (x: X) -> T) -> T
12+
13+
val x: Id[{io} Cap] = ???
14+
x(cap => cap.use())
15+
// should work when the expected type is a dependent function
16+
}
17+
18+
def test3(io: {*} Cap): Unit = {
19+
type Id[X] = [T] -> (op: {} (x: X) -> T) -> T
20+
21+
val x: Id[{io} Cap] = ???
22+
x(cap => cap.use()) // error
23+
}

0 commit comments

Comments
 (0)