From 3fa8ee681989227144551ea5ffa513c8319aa35d Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Tue, 2 Feb 2021 17:17:07 +0300 Subject: [PATCH 1/2] Properly nest ThreadContextElement * Restore the context in the reverse order of update, so they are properly nested into each other * Also, do a minor cleanup Fixes #2195 --- .../common/src/internal/DispatchedTask.kt | 4 +- .../jvm/src/internal/ThreadContext.kt | 36 +++++----- .../jvm/test/ThreadContextOrderTest.kt | 65 +++++++++++++++++++ 3 files changed, 84 insertions(+), 21 deletions(-) create mode 100644 kotlinx-coroutines-core/jvm/test/ThreadContextOrderTest.kt diff --git a/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt b/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt index ce05979db6..d982f95bdf 100644 --- a/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt +++ b/kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt @@ -1,5 +1,5 @@ /* - * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. */ package kotlinx.coroutines @@ -134,7 +134,7 @@ internal abstract class DispatchedTask( * Fatal exception handling can be intercepted with [CoroutineExceptionHandler] element in the context of * a failed coroutine, but such exceptions should be reported anyway. */ - internal fun handleFatalException(exception: Throwable?, finallyException: Throwable?) { + public fun handleFatalException(exception: Throwable?, finallyException: Throwable?) { if (exception === null && finallyException === null) return if (exception !== null && finallyException !== null) { exception.addSuppressedThrowable(finallyException) diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 18c2ce0459..5e71923632 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -1,5 +1,5 @@ /* - * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. */ package kotlinx.coroutines.internal @@ -11,13 +11,22 @@ import kotlin.coroutines.* internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") // Used when there are >= 2 active elements in the context -private class ThreadState(val context: CoroutineContext, n: Int) { - private var a = arrayOfNulls(n) +@Suppress("UNCHECKED_CAST") +private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { + private val values = arrayOfNulls(n) + private val elements = arrayOfNulls>(n) private var i = 0 - fun append(value: Any?) { a[i++] = value } - fun take() = a[i++] - fun start() { i = 0 } + fun append(element: ThreadContextElement<*>, value: Any?) { + values[i] = value + elements[i++] = element as ThreadContextElement + } + + fun restore(context: CoroutineContext) { + for (i in elements.indices.reversed()) { + elements[i]?.restoreThreadContext(context, values[i]) + } + } } // Counts ThreadContextElements in the context @@ -42,17 +51,7 @@ private val findOne = private val updateState = fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { if (element is ThreadContextElement<*>) { - state.append(element.updateThreadContext(state.context)) - } - return state - } - -// Restores state for all ThreadContextElements in the context from the given ThreadState -private val restoreState = - fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { - @Suppress("UNCHECKED_CAST") - if (element is ThreadContextElement<*>) { - (element as ThreadContextElement).restoreThreadContext(state.context, state.take()) + state.append(element, element.updateThreadContext(state.context)) } return state } @@ -86,8 +85,7 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements oldState is ThreadState -> { // slow path with multiple stored ThreadContextElements - oldState.start() - context.fold(oldState, restoreState) + oldState.restore(context) } else -> { // fast path for one ThreadContextElement, but need to find it diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextOrderTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextOrderTest.kt new file mode 100644 index 0000000000..49f4a12ee0 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextOrderTest.kt @@ -0,0 +1,65 @@ +/* + * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlinx.coroutines.internal.* +import org.junit.Test +import kotlin.coroutines.* +import kotlin.test.* + +class ThreadContextOrderTest : TestBase() { + /* + * The test verifies that two thread context elements are correctly nested: + * The restoration order is the reverse of update order. + */ + private val transactionalContext = ThreadLocal() + private val loggingContext = ThreadLocal() + + private val transactionalElement = object : ThreadContextElement { + override val key = ThreadLocalKey(transactionalContext) + + override fun updateThreadContext(context: CoroutineContext): String { + assertEquals("test", loggingContext.get()) + val previous = transactionalContext.get() + transactionalContext.set("tr coroutine") + return previous + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: String) { + assertEquals("test", loggingContext.get()) + assertEquals("tr coroutine", transactionalContext.get()) + transactionalContext.set(oldState) + } + } + + private val loggingElement = object : ThreadContextElement { + override val key = ThreadLocalKey(loggingContext) + + override fun updateThreadContext(context: CoroutineContext): String { + val previous = loggingContext.get() + loggingContext.set("log coroutine") + return previous + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: String) { + assertEquals("log coroutine", loggingContext.get()) + assertEquals("tr coroutine", transactionalContext.get()) + loggingContext.set(oldState) + } + } + + @Test + fun testCorrectOrder() = runTest { + transactionalContext.set("test") + loggingContext.set("test") + launch(transactionalElement + loggingElement) { + assertEquals("log coroutine", loggingContext.get()) + assertEquals("tr coroutine", transactionalContext.get()) + } + assertEquals("test", loggingContext.get()) + assertEquals("test", transactionalContext.get()) + + } +} From c71ef3ee3e26ff3ea29d246beb6b7ee2c3c1f635 Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Mon, 8 Feb 2021 12:04:26 +0300 Subject: [PATCH 2/2] ~replace .? with !! where elements are never null --- kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 5e71923632..8536cef65d 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -24,7 +24,7 @@ private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { fun restore(context: CoroutineContext) { for (i in elements.indices.reversed()) { - elements[i]?.restoreThreadContext(context, values[i]) + elements[i]!!.restoreThreadContext(context, values[i]) } } }