Skip to content

Commit 649d03e

Browse files
authored
Confine context-specific state to the thread in UndispatchedCoroutine… (#3155)
* Confine context-specific state to the thread in UndispatchedCoroutine in order to avoid state interference when the coroutine is updated concurrently. Concurrency is inevitable in this scenario: when the coroutine that has UndispatchedCoroutine as its completion suspends, we have to clear the thread context, but while we are doing so, concurrent resume of the coroutine could've happened that also ends up in save/clear/update context Fixes #2930
1 parent b5b254b commit 649d03e

File tree

2 files changed

+92
-12
lines changed

2 files changed

+92
-12
lines changed

kotlinx-coroutines-core/jvm/src/CoroutineContext.kt

+20-12
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedC
107107

108108
/**
109109
* Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
110-
* Used as a performance optimization to avoid stack walking where it is not nesessary.
110+
* Used as a performance optimization to avoid stack walking where it is not necessary.
111111
*/
112112
private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
113113
override val key: CoroutineContext.Key<*>
@@ -120,26 +120,34 @@ internal actual class UndispatchedCoroutine<in T>actual constructor (
120120
uCont: Continuation<T>
121121
) : ScopeCoroutine<T>(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) {
122122

123-
private var savedContext: CoroutineContext? = null
124-
private var savedOldValue: Any? = null
123+
/*
124+
* The state is thread-local because this coroutine can be used concurrently.
125+
* Scenario of usage (withContinuationContext):
126+
* val state = saveThreadContext(ctx)
127+
* try {
128+
* invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called
129+
* // COROUTINE_SUSPENDED is returned
130+
* } finally {
131+
* thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread
132+
* // and it also calls saveThreadContext and clearThreadContext
133+
* }
134+
*/
135+
private var threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()
125136

126137
fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
127-
savedContext = context
128-
savedOldValue = oldValue
138+
threadStateToRecover.set(context to oldValue)
129139
}
130140

131141
fun clearThreadContext(): Boolean {
132-
if (savedContext == null) return false
133-
savedContext = null
134-
savedOldValue = null
142+
if (threadStateToRecover.get() == null) return false
143+
threadStateToRecover.set(null)
135144
return true
136145
}
137146

138147
override fun afterResume(state: Any?) {
139-
savedContext?.let { context ->
140-
restoreThreadContext(context, savedOldValue)
141-
savedContext = null
142-
savedOldValue = null
148+
threadStateToRecover.get()?.let { (ctx, value) ->
149+
restoreThreadContext(ctx, value)
150+
threadStateToRecover.set(null)
143151
}
144152
// resume undispatched -- update context but stay on the same dispatcher
145153
val result = recoverResult(state, uCont)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import kotlin.test.*
8+
9+
10+
class ThreadLocalStressTest : TestBase() {
11+
12+
private val threadLocal = ThreadLocal<String>()
13+
14+
// See the comment in doStress for the machinery
15+
@Test
16+
fun testStress() = runTest {
17+
repeat (100 * stressTestMultiplierSqrt) {
18+
withContext(Dispatchers.Default) {
19+
repeat(100) {
20+
launch {
21+
doStress(null)
22+
}
23+
}
24+
}
25+
}
26+
}
27+
28+
@Test
29+
fun testStressWithOuterValue() = runTest {
30+
repeat (100 * stressTestMultiplierSqrt) {
31+
withContext(Dispatchers.Default + threadLocal.asContextElement("bar")) {
32+
repeat(100) {
33+
launch {
34+
doStress("bar")
35+
}
36+
}
37+
}
38+
}
39+
}
40+
41+
private suspend fun doStress(expectedValue: String?) {
42+
assertEquals(expectedValue, threadLocal.get())
43+
try {
44+
/*
45+
* Here we are using very specific code-path to trigger the execution we want to.
46+
* The bug, in general, has a larger impact, but this particular code pinpoints it:
47+
*
48+
* 1) We use _undispatched_ withContext with thread element
49+
* 2) We cancel the coroutine
50+
* 3) We use 'suspendCancellableCoroutineReusable' that does _postponed_ cancellation check
51+
* which makes the reproduction of this race pretty reliable.
52+
*
53+
* Now the following code path is likely to be triggered:
54+
*
55+
* T1 from within 'withContinuationContext' method:
56+
* Finds 'oldValue', finds undispatched completion, invokes its 'block' argument.
57+
* 'block' is this coroutine, it goes to 'trySuspend', checks for postponed cancellation and *dispatches* it.
58+
* The execution stops _right_ before 'undispatchedCompletion.clearThreadContext()'.
59+
*
60+
* T2 now executes the dispatched cancellation and concurrently mutates the state of the undispatched completion.
61+
* All bets are off, now both threads can leave the thread locals state inconsistent.
62+
*/
63+
withContext(threadLocal.asContextElement("foo")) {
64+
yield()
65+
cancel()
66+
suspendCancellableCoroutineReusable<Unit> { }
67+
}
68+
} finally {
69+
assertEquals(expectedValue, threadLocal.get())
70+
}
71+
}
72+
}

0 commit comments

Comments
 (0)