Skip to content

Commit 3fa8ee6

Browse files
committed
Properly nest ThreadContextElement
* Restore the context in the reverse order of update, so they are properly nested into each other * Also, do a minor cleanup Fixes #2195
1 parent 727c38f commit 3fa8ee6

File tree

3 files changed

+84
-21
lines changed

3 files changed

+84
-21
lines changed

kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
2+
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

55
package kotlinx.coroutines
@@ -134,7 +134,7 @@ internal abstract class DispatchedTask<in T>(
134134
* Fatal exception handling can be intercepted with [CoroutineExceptionHandler] element in the context of
135135
* a failed coroutine, but such exceptions should be reported anyway.
136136
*/
137-
internal fun handleFatalException(exception: Throwable?, finallyException: Throwable?) {
137+
public fun handleFatalException(exception: Throwable?, finallyException: Throwable?) {
138138
if (exception === null && finallyException === null) return
139139
if (exception !== null && finallyException !== null) {
140140
exception.addSuppressedThrowable(finallyException)

kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt

+17-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
2+
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

55
package kotlinx.coroutines.internal
@@ -11,13 +11,22 @@ import kotlin.coroutines.*
1111
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")
1212

1313
// Used when there are >= 2 active elements in the context
14-
private class ThreadState(val context: CoroutineContext, n: Int) {
15-
private var a = arrayOfNulls<Any>(n)
14+
@Suppress("UNCHECKED_CAST")
15+
private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
16+
private val values = arrayOfNulls<Any>(n)
17+
private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
1618
private var i = 0
1719

18-
fun append(value: Any?) { a[i++] = value }
19-
fun take() = a[i++]
20-
fun start() { i = 0 }
20+
fun append(element: ThreadContextElement<*>, value: Any?) {
21+
values[i] = value
22+
elements[i++] = element as ThreadContextElement<Any?>
23+
}
24+
25+
fun restore(context: CoroutineContext) {
26+
for (i in elements.indices.reversed()) {
27+
elements[i]?.restoreThreadContext(context, values[i])
28+
}
29+
}
2130
}
2231

2332
// Counts ThreadContextElements in the context
@@ -42,17 +51,7 @@ private val findOne =
4251
private val updateState =
4352
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
4453
if (element is ThreadContextElement<*>) {
45-
state.append(element.updateThreadContext(state.context))
46-
}
47-
return state
48-
}
49-
50-
// Restores state for all ThreadContextElements in the context from the given ThreadState
51-
private val restoreState =
52-
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
53-
@Suppress("UNCHECKED_CAST")
54-
if (element is ThreadContextElement<*>) {
55-
(element as ThreadContextElement<Any?>).restoreThreadContext(state.context, state.take())
54+
state.append(element, element.updateThreadContext(state.context))
5655
}
5756
return state
5857
}
@@ -86,8 +85,7 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
8685
oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
8786
oldState is ThreadState -> {
8887
// slow path with multiple stored ThreadContextElements
89-
oldState.start()
90-
context.fold(oldState, restoreState)
88+
oldState.restore(context)
9189
}
9290
else -> {
9391
// fast path for one ThreadContextElement, but need to find it
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import kotlinx.coroutines.internal.*
8+
import org.junit.Test
9+
import kotlin.coroutines.*
10+
import kotlin.test.*
11+
12+
class ThreadContextOrderTest : TestBase() {
13+
/*
14+
* The test verifies that two thread context elements are correctly nested:
15+
* The restoration order is the reverse of update order.
16+
*/
17+
private val transactionalContext = ThreadLocal<String>()
18+
private val loggingContext = ThreadLocal<String>()
19+
20+
private val transactionalElement = object : ThreadContextElement<String> {
21+
override val key = ThreadLocalKey(transactionalContext)
22+
23+
override fun updateThreadContext(context: CoroutineContext): String {
24+
assertEquals("test", loggingContext.get())
25+
val previous = transactionalContext.get()
26+
transactionalContext.set("tr coroutine")
27+
return previous
28+
}
29+
30+
override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
31+
assertEquals("test", loggingContext.get())
32+
assertEquals("tr coroutine", transactionalContext.get())
33+
transactionalContext.set(oldState)
34+
}
35+
}
36+
37+
private val loggingElement = object : ThreadContextElement<String> {
38+
override val key = ThreadLocalKey(loggingContext)
39+
40+
override fun updateThreadContext(context: CoroutineContext): String {
41+
val previous = loggingContext.get()
42+
loggingContext.set("log coroutine")
43+
return previous
44+
}
45+
46+
override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
47+
assertEquals("log coroutine", loggingContext.get())
48+
assertEquals("tr coroutine", transactionalContext.get())
49+
loggingContext.set(oldState)
50+
}
51+
}
52+
53+
@Test
54+
fun testCorrectOrder() = runTest {
55+
transactionalContext.set("test")
56+
loggingContext.set("test")
57+
launch(transactionalElement + loggingElement) {
58+
assertEquals("log coroutine", loggingContext.get())
59+
assertEquals("tr coroutine", transactionalContext.get())
60+
}
61+
assertEquals("test", loggingContext.get())
62+
assertEquals("test", transactionalContext.get())
63+
64+
}
65+
}

0 commit comments

Comments
 (0)