Skip to content

Properly preserve thread local values for coroutines that are not intercepted with DispatchedContinuation #3252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,37 @@ internal actual class UndispatchedCoroutine<in T>actual constructor (
*/
private var threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()

init {
/*
* This is a hack for a very specific case in #2930 unless #3253 is implemented.
* 'ThreadLocalStressTest' covers this change properly.
*
* The scenario this change covers is the following:
* 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function,
* e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking
* `withContext(tlElement)` which creates `UndispatchedCoroutine`.
* 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()`
* and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both
* do thread context element tracking.
* 3) So thread locals never got chance to get properly set up via `saveThreadContext`,
* but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`.
*
* Here we detect precisely this situation and properly setup context to recover later.
*
*/
if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) {
/*
* We cannot just "read" the elements as there is no such API,
* so we update-restore it immediately and use the intermediate value
* as the initial state, leveraging the fact that thread context element
* is idempotent and such situations are increasingly rare.
*/
val values = updateThreadContext(context, null)
restoreThreadContext(context, values)
threadStateToRecover.set(context to values)
}
}

fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
threadStateToRecover.set(context to oldValue)
}
Expand Down
95 changes: 94 additions & 1 deletion kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

package kotlinx.coroutines

import kotlinx.coroutines.sync.*
import java.util.concurrent.*
import kotlin.coroutines.*
import kotlin.coroutines.intrinsics.*
import kotlin.test.*


Expand Down Expand Up @@ -63,10 +67,99 @@ class ThreadLocalStressTest : TestBase() {
withContext(threadLocal.asContextElement("foo")) {
yield()
cancel()
suspendCancellableCoroutineReusable<Unit> { }
suspendCancellableCoroutineReusable<Unit> { }
}
} finally {
assertEquals(expectedValue, threadLocal.get())
}
}

/*
* Another set of tests for undispatcheable continuations that do not require stress test multiplier.
* Also note that `uncaughtExceptionHandler` is used as the only available mechanism to propagate error from
* `resumeWith`
*/

@Test
fun testNonDispatcheableLeak() {
repeat(100) {
doTestWithPreparation(
::doTest,
{ threadLocal.set(null) }) { threadLocal.get() != null }
assertNull(threadLocal.get())
}
}

@Test
fun testNonDispatcheableLeakWithInitial() {
repeat(100) {
doTestWithPreparation(::doTest, { threadLocal.set("initial") }) { threadLocal.get() != "initial" }
assertEquals("initial", threadLocal.get())
}
}

@Test
fun testNonDispatcheableLeakWithContextSwitch() {
repeat(100) {
doTestWithPreparation(
::doTestWithContextSwitch,
{ threadLocal.set(null) }) { threadLocal.get() != null }
assertNull(threadLocal.get())
}
}

@Test
fun testNonDispatcheableLeakWithInitialWithContextSwitch() {
repeat(100) {
doTestWithPreparation(
::doTestWithContextSwitch,
{ threadLocal.set("initial") }) { false /* can randomly wake up on the non-main thread */ }
// Here we are always on the main thread
assertEquals("initial", threadLocal.get())
}
}

private fun doTestWithPreparation(testBody: suspend () -> Unit, setup: () -> Unit, isInvalid: () -> Boolean) {
setup()
val latch = CountDownLatch(1)
testBody.startCoroutineUninterceptedOrReturn(Continuation(EmptyCoroutineContext) {
if (isInvalid()) {
Thread.currentThread().uncaughtExceptionHandler.uncaughtException(
Thread.currentThread(),
IllegalStateException("Unexpected error: thread local was not cleaned")
)
}
latch.countDown()
})
latch.await()
}

private suspend fun doTest() {
withContext(threadLocal.asContextElement("foo")) {
try {
coroutineScope {
val semaphore = Semaphore(1, 1)
cancel()
semaphore.acquire()
}
} catch (e: CancellationException) {
// Ignore cancellation
}
}
}

private suspend fun doTestWithContextSwitch() {
withContext(threadLocal.asContextElement("foo")) {
try {
coroutineScope {
val semaphore = Semaphore(1, 1)
GlobalScope.launch { }.join()
cancel()
semaphore.acquire()
}
} catch (e: CancellationException) {
// Ignore cancellation
}
}
}
}