Skip to content

Commit 4c69d1c

Browse files
committed
Properly cleanup thread locals for non-CoroutineDispatcher-intercepted continuations
There was a one codepath not covered by undispatched thread local cleanup procedure: when a custom ContinuationInterceptor is used and the scoped coroutine (i.e. withContext) is completed in-place without suspensions. Fixed with the introduction of the corresponding machinery for ScopeCoroutine Fixes #4296
1 parent 2cafea4 commit 4c69d1c

File tree

5 files changed

+126
-5
lines changed

5 files changed

+126
-5
lines changed

kotlinx-coroutines-core/common/src/AbstractCoroutine.kt

+10
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ package kotlinx.coroutines
55
import kotlinx.coroutines.CoroutineStart.*
66
import kotlinx.coroutines.intrinsics.*
77
import kotlin.coroutines.*
8+
import kotlinx.coroutines.internal.ScopeCoroutine
89

910
/**
1011
* Abstract base class for implementation of coroutines in coroutine builders.
@@ -100,6 +101,15 @@ public abstract class AbstractCoroutine<in T>(
100101
afterResume(state)
101102
}
102103

104+
/**
105+
* Invoked when the corresponding `AbstractCoroutine` was **conceptually** resumed, but not mechanically.
106+
* Currently, this function only invokes `resume` on the underlying continuation for [ScopeCoroutine]
107+
* or does nothing otherwise.
108+
*
109+
* Examples of resumes:
110+
* - `afterCompletion` calls when the corresponding `Job` changed its state (i.e. got cancelled)
111+
* - [AbstractCoroutine.resumeWith] was invoked
112+
*/
103113
protected open fun afterResume(state: Any?): Unit = afterCompletion(state)
104114

105115
internal final override fun handleOnCompletionException(exception: Throwable) {

kotlinx-coroutines-core/common/src/internal/Scopes.kt

+7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ internal open class ScopeCoroutine<in T>(
2323
uCont.intercepted().resumeCancellableWith(recoverResult(state, uCont))
2424
}
2525

26+
/**
27+
* Invoked when a scoped coorutine was completed in an undispatched manner directly
28+
* at the place of its start because it never suspended.
29+
*/
30+
open fun afterCompletionUndispatched() {
31+
}
32+
2633
override fun afterResume(state: Any?) {
2734
// Resume direct because scope is already in the correct context
2835
uCont.resumeWith(recoverResult(state, uCont))

kotlinx-coroutines-core/common/src/intrinsics/Undispatched.kt

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ private inline fun <T> ScopeCoroutine<T>.undispatchedResult(
7979
if (result === COROUTINE_SUSPENDED) return COROUTINE_SUSPENDED // (1)
8080
val state = makeCompletingOnce(result)
8181
if (state === COMPLETING_WAITING_CHILDREN) return COROUTINE_SUSPENDED // (2)
82+
afterCompletionUndispatched()
8283
return if (state is CompletedExceptionally) { // (3)
8384
when {
8485
shouldThrow(state.cause) -> throw recoverStackTrace(state.cause, uCont)

kotlinx-coroutines-core/jvm/src/CoroutineContext.kt

+13-5
Original file line numberDiff line numberDiff line change
@@ -253,18 +253,26 @@ internal actual class UndispatchedCoroutine<in T>actual constructor (
253253
}
254254
}
255255

256+
override fun afterCompletionUndispatched() {
257+
clearThreadLocal()
258+
}
259+
256260
override fun afterResume(state: Any?) {
261+
clearThreadLocal()
262+
// resume undispatched -- update context but stay on the same dispatcher
263+
val result = recoverResult(state, uCont)
264+
withContinuationContext(uCont, null) {
265+
uCont.resumeWith(result)
266+
}
267+
}
268+
269+
private fun clearThreadLocal() {
257270
if (threadLocalIsSet) {
258271
threadStateToRecover.get()?.let { (ctx, value) ->
259272
restoreThreadContext(ctx, value)
260273
}
261274
threadStateToRecover.remove()
262275
}
263-
// resume undispatched -- update context but stay on the same dispatcher
264-
val result = recoverResult(state, uCont)
265-
withContinuationContext(uCont, null) {
266-
uCont.resumeWith(result)
267-
}
268276
}
269277
}
270278

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package kotlinx.coroutines
2+
3+
import kotlinx.coroutines.testing.TestBase
4+
import java.lang.ref.WeakReference
5+
import kotlin.coroutines.AbstractCoroutineContextElement
6+
import kotlin.coroutines.Continuation
7+
import kotlin.coroutines.ContinuationInterceptor
8+
import kotlin.coroutines.CoroutineContext
9+
import kotlin.test.Test
10+
import kotlin.test.assertNull
11+
12+
/*
13+
* This is an adapted verion of test from #4296.
14+
*
15+
* qwwdfsad: the test relies on System.gc() actually collecting the garbage.
16+
* If these tests flake on CI, first check that JDK/GC setup in not an issue.
17+
*/
18+
class ThreadLocalCustomContinuationInterceptorTest : TestBase() {
19+
20+
private class CustomContinuationInterceptor(private val delegate: ContinuationInterceptor) :
21+
AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
22+
23+
override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> {
24+
return delegate.interceptContinuation(continuation)
25+
}
26+
}
27+
28+
private class CustomNeverEqualContinuationInterceptor(private val delegate: ContinuationInterceptor) :
29+
AbstractCoroutineContextElement(ContinuationInterceptor), ContinuationInterceptor {
30+
31+
override fun <T> interceptContinuation(continuation: Continuation<T>): Continuation<T> {
32+
return delegate.interceptContinuation(continuation)
33+
}
34+
35+
override fun equals(other: Any?) = false
36+
}
37+
38+
@Test
39+
fun testDefaultDispatcherNoSuspension() = ensureCoroutineContextGCed(Dispatchers.Default, suspend = false)
40+
41+
@Test
42+
fun testDefaultDispatcher() = ensureCoroutineContextGCed(Dispatchers.Default, suspend = true)
43+
44+
45+
@Test
46+
fun testNonCoroutineDispatcher() = ensureCoroutineContextGCed(
47+
CustomContinuationInterceptor(Dispatchers.Default),
48+
suspend = true
49+
)
50+
51+
@Test
52+
fun testNonCoroutineDispatcherSuspension() = ensureCoroutineContextGCed(
53+
CustomContinuationInterceptor(Dispatchers.Default),
54+
suspend = false
55+
)
56+
57+
// Note asymmetric equals codepath never goes through the undispatched withContext, thus the separate test case
58+
59+
@Test
60+
fun testNonCoroutineDispatcherAsymmetricEquals() =
61+
ensureCoroutineContextGCed(
62+
CustomNeverEqualContinuationInterceptor(Dispatchers.Default),
63+
suspend = true
64+
)
65+
66+
@Test
67+
fun testNonCoroutineDispatcherAsymmetricEqualsSuspension() =
68+
ensureCoroutineContextGCed(
69+
CustomNeverEqualContinuationInterceptor(Dispatchers.Default),
70+
suspend = false
71+
)
72+
73+
74+
private fun ensureCoroutineContextGCed(coroutineContext: CoroutineContext, suspend: Boolean) {
75+
runTest {
76+
lateinit var ref: WeakReference<CoroutineName>
77+
val job = GlobalScope.launch(coroutineContext) {
78+
val coroutineName = CoroutineName("Yo")
79+
ref = WeakReference(coroutineName)
80+
withContext(coroutineName) {
81+
if (suspend) {
82+
delay(1)
83+
}
84+
}
85+
}
86+
job.join()
87+
88+
// Twice is enough to ensure
89+
System.gc()
90+
System.gc()
91+
assertNull(ref.get())
92+
}
93+
}
94+
95+
}

0 commit comments

Comments
 (0)