Skip to content

Commit 363f142

Browse files
committed
refactor box adaptation
- special handling for the env created during box adaptation - rewrite `adapt` to make it cleaner and easier to understand
1 parent daf6766 commit 363f142

File tree

4 files changed

+106
-109
lines changed

4 files changed

+106
-109
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 89 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,20 @@ object CheckCaptures:
4242
end Pre
4343

4444
/** A class describing environments.
45-
* @param owner the current owner
46-
* @param captured the caputure set containing all references to tracked free variables outside of boxes
47-
* @param isBoxed true if the environment is inside a box (in which case references are not counted)
48-
* @param outer0 the next enclosing environment
45+
* @param owner the current owner
46+
* @param nestedInOwner true if the environment is a temporary one nested in the owner's environment,
47+
* and does not have an actual owner symbol (this happens when doing box adaptation).
48+
* @param captured the caputure set containing all references to tracked free variables outside of boxes
49+
* @param isBoxed true if the environment is inside a box (in which case references are not counted)
50+
* @param outer0 the next enclosing environment
4951
*/
50-
case class Env(owner: Symbol, captured: CaptureSet, isBoxed: Boolean, outer0: Env | Null):
52+
case class Env(
53+
owner: Symbol,
54+
nestedInOwner: Boolean,
55+
captured: CaptureSet,
56+
isBoxed: Boolean,
57+
outer0: Env | Null
58+
):
5159
def outer = outer0.nn
5260

5361
def isOutermost = outer0 == null
@@ -204,7 +212,7 @@ class CheckCaptures extends Recheck, SymTransformer:
204212
report.error(i"$header included in allowed capture set ${res.blocking}", pos)
205213

206214
/** The current environment */
207-
private var curEnv: Env = Env(NoSymbol, CaptureSet.empty, isBoxed = false, null)
215+
private var curEnv: Env = Env(NoSymbol, false, CaptureSet.empty, isBoxed = false, null)
208216

209217
private val myCapturedVars: util.EqHashMap[Symbol, CaptureSet] = EqHashMap()
210218

@@ -249,8 +257,12 @@ class CheckCaptures extends Recheck, SymTransformer:
249257
if !cs.isAlwaysEmpty then
250258
forallOuterEnvsUpTo(ctx.owner.topLevelClass) { env =>
251259
val included = cs.filter {
252-
case ref: TermRef => env.owner.isProperlyContainedIn(ref.symbol.owner)
253-
case ref: ThisType => env.owner.isProperlyContainedIn(ref.cls)
260+
case ref: TermRef =>
261+
(env.nestedInOwner || env.owner != ref.symbol.owner)
262+
&& env.owner.isContainedIn(ref.symbol.owner)
263+
case ref: ThisType =>
264+
(env.nestedInOwner || env.owner != ref.cls)
265+
&& env.owner.isContainedIn(ref.cls)
254266
case _ => false
255267
}
256268
capt.println(i"Include call capture $included in ${env.owner}")
@@ -439,7 +451,7 @@ class CheckCaptures extends Recheck, SymTransformer:
439451
if !Synthetics.isExcluded(sym) then
440452
val saved = curEnv
441453
val localSet = capturedVars(sym)
442-
if !localSet.isAlwaysEmpty then curEnv = Env(sym, localSet, isBoxed = false, curEnv)
454+
if !localSet.isAlwaysEmpty then curEnv = Env(sym, false, localSet, isBoxed = false, curEnv)
443455
try super.recheckDefDef(tree, sym)
444456
finally
445457
interpolateVarsIn(tree.tpt)
@@ -455,7 +467,7 @@ class CheckCaptures extends Recheck, SymTransformer:
455467
val localSet = capturedVars(cls)
456468
for parent <- impl.parents do // (1)
457469
checkSubset(capturedVars(parent.tpe.classSymbol), localSet, parent.srcPos)
458-
if !localSet.isAlwaysEmpty then curEnv = Env(cls, localSet, isBoxed = false, curEnv)
470+
if !localSet.isAlwaysEmpty then curEnv = Env(cls, false, localSet, isBoxed = false, curEnv)
459471
try
460472
val thisSet = cls.classInfo.selfType.captureSet.withDescription(i"of the self type of $cls")
461473
checkSubset(localSet, thisSet, tree.srcPos) // (2)
@@ -502,7 +514,7 @@ class CheckCaptures extends Recheck, SymTransformer:
502514
override def recheck(tree: Tree, pt: Type = WildcardType)(using Context): Type =
503515
if tree.isTerm && pt.isBoxedCapturing then
504516
val saved = curEnv
505-
curEnv = Env(curEnv.owner, CaptureSet.Var(), isBoxed = true, curEnv)
517+
curEnv = Env(curEnv.owner, false, CaptureSet.Var(), isBoxed = true, curEnv)
506518
try super.recheck(tree, pt)
507519
finally curEnv = saved
508520
else
@@ -595,12 +607,11 @@ class CheckCaptures extends Recheck, SymTransformer:
595607
* to `expected` type.
596608
* @param reconstruct how to rebuild the adapted function type
597609
*/
598-
def adaptFun(actualTp: (Type, CaptureSet), aargs: List[Type], ares: Type, expected: Type,
610+
def adaptFun(actual: Type, aargs: List[Type], ares: Type, expected: Type,
599611
covariant: Boolean, boxed: Boolean,
600612
reconstruct: (List[Type], Type) => Type): (Type, CaptureSet) =
601-
val (actual, cs0) = actualTp
602613
val saved = curEnv
603-
curEnv = Env(curEnv.owner, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)
614+
curEnv = Env(curEnv.owner, true, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)
604615

605616
try
606617
val (eargs, eres) = expected.dealias match
@@ -618,17 +629,16 @@ class CheckCaptures extends Recheck, SymTransformer:
618629
else reconstruct(aargs1, ares1)
619630

620631
curEnv.captured.asVar.markSolved()
621-
(resTp, curEnv.captured ++ cs0)
632+
(resTp, curEnv.captured)
622633
finally
623634
curEnv = saved
624635

625636
def adaptTypeFun(
626-
actualTp: (Type, CaptureSet), ares: Type, expected: Type,
637+
actual: Type, ares: Type, expected: Type,
627638
covariant: Boolean, boxed: Boolean,
628639
reconstruct: Type => Type): (Type, CaptureSet) =
629-
val (actual, cs0) = actualTp
630640
val saved = curEnv
631-
curEnv = Env(curEnv.owner, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)
641+
curEnv = Env(curEnv.owner, true, CaptureSet.Var(), isBoxed = false, if boxed then null else curEnv)
632642

633643
try
634644
val eres = expected.dealias.stripCapturing match
@@ -642,7 +652,7 @@ class CheckCaptures extends Recheck, SymTransformer:
642652
else reconstruct(ares1)
643653

644654
curEnv.captured.asVar.markSolved()
645-
(resTp, curEnv.captured ++ cs0)
655+
(resTp, curEnv.captured)
646656
finally
647657
curEnv = saved
648658
end adaptTypeFun
@@ -651,8 +661,8 @@ class CheckCaptures extends Recheck, SymTransformer:
651661
val arrow = if covariant then "~~>" else "<~~"
652662
i"adapting $actual $arrow $expected"
653663

654-
def adapt(actual: Type, expected: Type, covariant: Boolean): Type = trace(adaptInfo(actual, expected, covariant), recheckr, show = true) {
655-
def destructCapturingType(tp: Type, reconstruct: Type => Type): ((Type, CaptureSet, Boolean), Type => Type) = tp.dealias match
664+
def destructCapturingType(tp: Type, reconstruct: Type => Type = x => x): ((Type, CaptureSet, Boolean), Type => Type) =
665+
tp.dealias match
656666
case tp @ CapturingType(parent, cs) =>
657667
if parent.dealias.isCapturingType then
658668
destructCapturingType(parent, res => reconstruct(tp.derivedCapturingType(res, cs)))
@@ -661,72 +671,67 @@ class CheckCaptures extends Recheck, SymTransformer:
661671
case actual =>
662672
((actual, CaptureSet(), false), reconstruct)
663673

664-
if expected.isInstanceOf[WildcardType] then
665-
actual
674+
def adapt(actual: Type, expected: Type, covariant: Boolean): Type = trace(adaptInfo(actual, expected, covariant), recheckr, show = true) {
675+
if expected.isInstanceOf[WildcardType] then actual
666676
else
667-
val (actualTp, recon) = destructCapturingType(actual, x => x)
668-
val (parent1, cs1, isBoxed1) = adaptCapturingType(actualTp, expected, covariant)
669-
recon(CapturingType(parent1, cs1, isBoxed1))
670-
}
671-
672-
def adaptCapturingType(
673-
actual: (Type, CaptureSet, Boolean),
674-
expected: Type,
675-
covariant: Boolean
676-
): (Type, CaptureSet, Boolean) =
677-
val (parent, cs, actualIsBoxed) = actual
678-
679-
val needsAdaptation = actualIsBoxed != expected.isBoxedCapturing
680-
val insertBox = needsAdaptation && covariant != actualIsBoxed
681-
682-
val (parent1, cs1) = parent match {
683-
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
684-
adaptFun((parent, cs), args.init, args.last, expected, covariant, insertBox,
685-
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
686-
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
687-
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
688-
adaptFun((parent, cs), rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
689-
(aargs1, ares1) =>
690-
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
691-
.toFunctionType(isJava = false, alwaysDependent = true))
692-
case actual: MethodType =>
693-
adaptFun((parent, cs), actual.paramInfos, actual.resType, expected, covariant, insertBox,
694-
(aargs1, ares1) =>
695-
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
696-
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionOrPolyType(actual) =>
697-
adaptTypeFun((parent, cs), rinfo.resType, expected, covariant, insertBox,
698-
ares1 =>
699-
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
700-
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
701-
actual1
702-
)
703-
case _ =>
704-
(parent, cs)
705-
}
677+
val ((parent, cs, actualIsBoxed), recon) = destructCapturingType(actual)
678+
679+
val needsAdaptation = actualIsBoxed != expected.isBoxedCapturing
680+
val insertBox = needsAdaptation && covariant != actualIsBoxed
681+
682+
val (parent1, cs1) = parent match {
683+
case actual @ AppliedType(tycon, args) if defn.isNonRefinedFunction(actual) =>
684+
val (parent1, cs1) = adaptFun(parent, args.init, args.last, expected, covariant, insertBox,
685+
(aargs1, ares1) => actual.derivedAppliedType(tycon, aargs1 :+ ares1))
686+
(parent1, cs1 ++ cs)
687+
case actual @ RefinedType(_, _, rinfo: MethodType) if defn.isFunctionType(actual) =>
688+
// TODO Find a way to combine handling of generic and dependent function types (here and elsewhere)
689+
val (parent1, cs1) = adaptFun(parent, rinfo.paramInfos, rinfo.resType, expected, covariant, insertBox,
690+
(aargs1, ares1) =>
691+
rinfo.derivedLambdaType(paramInfos = aargs1, resType = ares1)
692+
.toFunctionType(isJava = false, alwaysDependent = true))
693+
(parent1, cs1 ++ cs)
694+
case actual: MethodType =>
695+
val (parent1, cs1) = adaptFun(parent, actual.paramInfos, actual.resType, expected, covariant, insertBox,
696+
(aargs1, ares1) =>
697+
actual.derivedLambdaType(paramInfos = aargs1, resType = ares1))
698+
(parent1, cs1 ++ cs)
699+
case actual @ RefinedType(p, nme, rinfo: PolyType) if defn.isFunctionOrPolyType(actual) =>
700+
val (parent1, cs1) = adaptTypeFun(parent, rinfo.resType, expected, covariant, insertBox,
701+
ares1 =>
702+
val rinfo1 = rinfo.derivedLambdaType(rinfo.paramNames, rinfo.paramInfos, ares1)
703+
val actual1 = actual.derivedRefinedType(p, nme, rinfo1)
704+
actual1
705+
)
706+
(parent1, cs1 ++ cs)
707+
case _ =>
708+
(parent, cs)
709+
}
706710

707-
if needsAdaptation then
708-
val criticalSet = // the set which is not allowed to have `*`
709-
if covariant then cs1 // can't box with `*`
710-
else expected.captureSet // can't unbox with `*`
711-
if criticalSet.isUniversal then
712-
// We can't box/unbox the universal capability. Leave `actual` as it is
713-
// so we get an error in checkConforms. This tends to give better error
714-
// messages than disallowing the root capability in `criticalSet`.
715-
capt.println(i"cannot box/unbox $cs $parent vs $expected")
716-
actual
711+
if needsAdaptation then
712+
val criticalSet = // the set which is not allowed to have `*`
713+
if covariant then cs1 // can't box with `*`
714+
else expected.captureSet // can't unbox with `*`
715+
if criticalSet.isUniversal then
716+
// We can't box/unbox the universal capability. Leave `actual` as it is
717+
// so we get an error in checkConforms. This tends to give better error
718+
// messages than disallowing the root capability in `criticalSet`.
719+
capt.println(i"cannot box/unbox $actual vs $expected")
720+
actual
721+
else
722+
// Disallow future addition of `*` to `criticalSet`.
723+
criticalSet.disallowRootCapability { () =>
724+
report.error(
725+
em"""$actual cannot be box-converted to $expected
726+
|since one of their capture sets contains the root capability `*`""",
727+
pos)
728+
}
729+
if !insertBox then // unboxing
730+
markFree(criticalSet, pos)
731+
recon(CapturingType(parent1, cs1, !actualIsBoxed))
717732
else
718-
// Disallow future addition of `*` to `criticalSet`.
719-
criticalSet.disallowRootCapability { () =>
720-
report.error(
721-
em"""$actualIsBoxed $cs $parent cannot be box-converted to $expected
722-
|since one of their capture sets contains the root capability `*`""",
723-
pos)
724-
}
725-
if !insertBox then // unboxing
726-
markFree(cs1, pos)
727-
(parent1, cs1, !actualIsBoxed)
728-
else
729-
(parent1, cs1, actualIsBoxed)
733+
recon(CapturingType(parent1, cs1, actualIsBoxed))
734+
}
730735

731736

732737
var actualw = actual.widenDealias

tests/neg-custom-args/captures/capt1.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,14 +40,14 @@
4040
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:32:24 ----------------------------------------
4141
32 | val z2 = h[() -> Cap](() => x) // error
4242
| ^^^^^^^
43-
| Found: {x} () -> {*} C
43+
| Found: {x} () -> Cap
4444
| Required: () -> box {*} C
4545
|
4646
| longer explanation available when compiling with `-explain`
4747
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:33:5 -----------------------------------------
4848
33 | (() => C()) // error
4949
| ^^^^^^^^^
50-
| Found: ? () -> {*} C
50+
| Found: ? () -> Cap
5151
| Required: () -> box {*} C
5252
|
5353
| longer explanation available when compiling with `-explain`

tests/neg-custom-args/captures/i15772.check

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,15 @@
1-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:18:2 ----------------------------------------
2-
18 | () => // error
3-
| ^
4-
| Found: {x} () -> Int
5-
| Required: () -> Int
6-
19 | val c : {x} C = new C(x)
7-
20 | val boxed1 : (({*} C) => Unit) -> Unit = box1(c)
8-
21 | boxed1((cap: {*} C) => unsafe(c))
9-
22 | 0
1+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:20:49 ---------------------------------------
2+
20 | val boxed1 : (({*} C) => Unit) -> Unit = box1(c) // error
3+
| ^^^^^^^
4+
| Found: {c} ({*} ({c} C{arg: {*} C}) -> Unit) -> Unit
5+
| Required: (({*} C) => Unit) -> Unit
106
|
117
| longer explanation available when compiling with `-explain`
12-
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:25:2 ----------------------------------------
13-
25 | () => // error
14-
| ^
15-
| Found: {x} () -> Int
16-
| Required: () -> Int
17-
26 | val c : {x} C = new C(x)
18-
27 | val boxed2 : Observe[{*} C] = box2(c)
19-
28 | boxed2((cap: {*} C) => unsafe(c))
20-
29 | 0
8+
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:27:38 ---------------------------------------
9+
27 | val boxed2 : Observe[{*} C] = box2(c) // error
10+
| ^^^^^^^
11+
| Found: {c} ({*} ({c} C{arg: {*} C}) -> Unit) -> Unit
12+
| Required: Observe[{*} C]
2113
|
2214
| longer explanation available when compiling with `-explain`
2315
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/i15772.scala:33:37 ---------------------------------------

tests/neg-custom-args/captures/i15772.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,16 @@ class C(val arg: {*} C) {
1515
}
1616

1717
def main1(x: {*} C) : () -> Int =
18-
() => // error
18+
() =>
1919
val c : {x} C = new C(x)
20-
val boxed1 : (({*} C) => Unit) -> Unit = box1(c)
20+
val boxed1 : (({*} C) => Unit) -> Unit = box1(c) // error
2121
boxed1((cap: {*} C) => unsafe(c))
2222
0
2323

2424
def main2(x: {*} C) : () -> Int =
25-
() => // error
25+
() =>
2626
val c : {x} C = new C(x)
27-
val boxed2 : Observe[{*} C] = box2(c)
27+
val boxed2 : Observe[{*} C] = box2(c) // error
2828
boxed2((cap: {*} C) => unsafe(c))
2929
0
3030

@@ -41,4 +41,4 @@ def main(io: {*} Any) =
4141
val sayHello: (({io} File) => Unit) = (file: {io} File) => file.write("Hello World!\r\n")
4242
val filesList : List[{io} File] = ???
4343
val x = () => filesList.foreach(sayHello)
44-
x: (() -> Unit) // error
44+
x: (() -> Unit) // error

0 commit comments

Comments
 (0)