Skip to content

Commit c16ac69

Browse files
qwwdfsadpablobaxter
authored andcommitted
Properly preserve thread local values for coroutines that are not intercepted with DispatchedContinuation (Kotlin#3252)
* Properly preserve thread local values for coroutines that are not intercepted with DispatchedContinuation Fixes Kotlin#2930
1 parent 20e4bde commit c16ac69

File tree

2 files changed

+125
-1
lines changed

2 files changed

+125
-1
lines changed

kotlinx-coroutines-core/jvm/src/CoroutineContext.kt

+31
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,37 @@ internal actual class UndispatchedCoroutine<in T>actual constructor (
181181
*/
182182
private var threadStateToRecover = ThreadLocal<Pair<CoroutineContext, Any?>>()
183183

184+
init {
185+
/*
186+
* This is a hack for a very specific case in #2930 unless #3253 is implemented.
187+
* 'ThreadLocalStressTest' covers this change properly.
188+
*
189+
* The scenario this change covers is the following:
190+
* 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function,
191+
* e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking
192+
* `withContext(tlElement)` which creates `UndispatchedCoroutine`.
193+
* 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()`
194+
* and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both
195+
* do thread context element tracking.
196+
* 3) So thread locals never got chance to get properly set up via `saveThreadContext`,
197+
* but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`.
198+
*
199+
* Here we detect precisely this situation and properly setup context to recover later.
200+
*
201+
*/
202+
if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) {
203+
/*
204+
* We cannot just "read" the elements as there is no such API,
205+
* so we update-restore it immediately and use the intermediate value
206+
* as the initial state, leveraging the fact that thread context element
207+
* is idempotent and such situations are increasingly rare.
208+
*/
209+
val values = updateThreadContext(context, null)
210+
restoreThreadContext(context, values)
211+
saveThreadContext(context, values)
212+
}
213+
}
214+
184215
fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
185216
threadStateToRecover.set(context to oldValue)
186217
}

kotlinx-coroutines-core/jvm/test/ThreadLocalStressTest.kt

+94-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
package kotlinx.coroutines
66

7+
import kotlinx.coroutines.sync.*
8+
import java.util.concurrent.*
9+
import kotlin.coroutines.*
10+
import kotlin.coroutines.intrinsics.*
711
import kotlin.test.*
812

913

@@ -63,10 +67,99 @@ class ThreadLocalStressTest : TestBase() {
6367
withContext(threadLocal.asContextElement("foo")) {
6468
yield()
6569
cancel()
66-
suspendCancellableCoroutineReusable<Unit> { }
70+
suspendCancellableCoroutineReusable<Unit> { }
6771
}
6872
} finally {
6973
assertEquals(expectedValue, threadLocal.get())
7074
}
7175
}
76+
77+
/*
78+
* Another set of tests for undispatcheable continuations that do not require stress test multiplier.
79+
* Also note that `uncaughtExceptionHandler` is used as the only available mechanism to propagate error from
80+
* `resumeWith`
81+
*/
82+
83+
@Test
84+
fun testNonDispatcheableLeak() {
85+
repeat(100) {
86+
doTestWithPreparation(
87+
::doTest,
88+
{ threadLocal.set(null) }) { threadLocal.get() == null }
89+
assertNull(threadLocal.get())
90+
}
91+
}
92+
93+
@Test
94+
fun testNonDispatcheableLeakWithInitial() {
95+
repeat(100) {
96+
doTestWithPreparation(::doTest, { threadLocal.set("initial") }) { threadLocal.get() == "initial" }
97+
assertEquals("initial", threadLocal.get())
98+
}
99+
}
100+
101+
@Test
102+
fun testNonDispatcheableLeakWithContextSwitch() {
103+
repeat(100) {
104+
doTestWithPreparation(
105+
::doTestWithContextSwitch,
106+
{ threadLocal.set(null) }) { threadLocal.get() == null }
107+
assertNull(threadLocal.get())
108+
}
109+
}
110+
111+
@Test
112+
fun testNonDispatcheableLeakWithInitialWithContextSwitch() {
113+
repeat(100) {
114+
doTestWithPreparation(
115+
::doTestWithContextSwitch,
116+
{ threadLocal.set("initial") }) { true /* can randomly wake up on the non-main thread */ }
117+
// Here we are always on the main thread
118+
assertEquals("initial", threadLocal.get())
119+
}
120+
}
121+
122+
private fun doTestWithPreparation(testBody: suspend () -> Unit, setup: () -> Unit, isValid: () -> Boolean) {
123+
setup()
124+
val latch = CountDownLatch(1)
125+
testBody.startCoroutineUninterceptedOrReturn(Continuation(EmptyCoroutineContext) {
126+
if (!isValid()) {
127+
Thread.currentThread().uncaughtExceptionHandler.uncaughtException(
128+
Thread.currentThread(),
129+
IllegalStateException("Unexpected error: thread local was not cleaned")
130+
)
131+
}
132+
latch.countDown()
133+
})
134+
latch.await()
135+
}
136+
137+
private suspend fun doTest() {
138+
withContext(threadLocal.asContextElement("foo")) {
139+
try {
140+
coroutineScope {
141+
val semaphore = Semaphore(1, 1)
142+
cancel()
143+
semaphore.acquire()
144+
}
145+
} catch (e: CancellationException) {
146+
// Ignore cancellation
147+
}
148+
}
149+
}
150+
151+
private suspend fun doTestWithContextSwitch() {
152+
withContext(threadLocal.asContextElement("foo")) {
153+
try {
154+
coroutineScope {
155+
val semaphore = Semaphore(1, 1)
156+
GlobalScope.launch { }.join()
157+
cancel()
158+
semaphore.acquire()
159+
}
160+
} catch (e: CancellationException) {
161+
// Ignore cancellation
162+
}
163+
}
164+
}
72165
}

0 commit comments

Comments
 (0)