Skip to content

Properly nest ThreadContextElement #2517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 8, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines
Expand Down Expand Up @@ -134,7 +134,7 @@ internal abstract class DispatchedTask<in T>(
* Fatal exception handling can be intercepted with [CoroutineExceptionHandler] element in the context of
* a failed coroutine, but such exceptions should be reported anyway.
*/
internal fun handleFatalException(exception: Throwable?, finallyException: Throwable?) {
public fun handleFatalException(exception: Throwable?, finallyException: Throwable?) {
if (exception === null && finallyException === null) return
if (exception !== null && finallyException !== null) {
exception.addSuppressedThrowable(finallyException)
Expand Down
36 changes: 17 additions & 19 deletions kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.internal
Expand All @@ -11,13 +11,22 @@ import kotlin.coroutines.*
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")

// Used when there are >= 2 active elements in the context
private class ThreadState(val context: CoroutineContext, n: Int) {
private var a = arrayOfNulls<Any>(n)
@Suppress("UNCHECKED_CAST")
private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
private val values = arrayOfNulls<Any>(n)
private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
private var i = 0

fun append(value: Any?) { a[i++] = value }
fun take() = a[i++]
fun start() { i = 0 }
fun append(element: ThreadContextElement<*>, value: Any?) {
values[i] = value
elements[i++] = element as ThreadContextElement<Any?>
}

fun restore(context: CoroutineContext) {
for (i in elements.indices.reversed()) {
elements[i]?.restoreThreadContext(context, values[i])
}
}
}

// Counts ThreadContextElements in the context
Expand All @@ -42,17 +51,7 @@ private val findOne =
private val updateState =
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
if (element is ThreadContextElement<*>) {
state.append(element.updateThreadContext(state.context))
}
return state
}

// Restores state for all ThreadContextElements in the context from the given ThreadState
private val restoreState =
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
@Suppress("UNCHECKED_CAST")
if (element is ThreadContextElement<*>) {
(element as ThreadContextElement<Any?>).restoreThreadContext(state.context, state.take())
state.append(element, element.updateThreadContext(state.context))
}
return state
}
Expand Down Expand Up @@ -86,8 +85,7 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
oldState is ThreadState -> {
// slow path with multiple stored ThreadContextElements
oldState.start()
context.fold(oldState, restoreState)
oldState.restore(context)
}
else -> {
// fast path for one ThreadContextElement, but need to find it
Expand Down
65 changes: 65 additions & 0 deletions kotlinx-coroutines-core/jvm/test/ThreadContextOrderTest.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import kotlinx.coroutines.internal.*
import org.junit.Test
import kotlin.coroutines.*
import kotlin.test.*

class ThreadContextOrderTest : TestBase() {
/*
* The test verifies that two thread context elements are correctly nested:
* The restoration order is the reverse of update order.
*/
private val transactionalContext = ThreadLocal<String>()
private val loggingContext = ThreadLocal<String>()

private val transactionalElement = object : ThreadContextElement<String> {
override val key = ThreadLocalKey(transactionalContext)

override fun updateThreadContext(context: CoroutineContext): String {
assertEquals("test", loggingContext.get())
val previous = transactionalContext.get()
transactionalContext.set("tr coroutine")
return previous
}

override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
assertEquals("test", loggingContext.get())
assertEquals("tr coroutine", transactionalContext.get())
transactionalContext.set(oldState)
}
}

private val loggingElement = object : ThreadContextElement<String> {
override val key = ThreadLocalKey(loggingContext)

override fun updateThreadContext(context: CoroutineContext): String {
val previous = loggingContext.get()
loggingContext.set("log coroutine")
return previous
}

override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
assertEquals("log coroutine", loggingContext.get())
assertEquals("tr coroutine", transactionalContext.get())
loggingContext.set(oldState)
}
}

@Test
fun testCorrectOrder() = runTest {
transactionalContext.set("test")
loggingContext.set("test")
launch(transactionalElement + loggingElement) {
assertEquals("log coroutine", loggingContext.get())
assertEquals("tr coroutine", transactionalContext.get())
}
assertEquals("test", loggingContext.get())
assertEquals("test", transactionalContext.get())

}
}