From 32d48add142f557eb24943d198843739fd7d3e82 Mon Sep 17 00:00:00 2001 From: Dmitry Khalanskiy Date: Mon, 3 Mar 2025 11:46:21 +0100 Subject: [PATCH 1/3] Move code from the JVM to common (naively, will not compile) --- .../common/src/Builders.common.kt | 120 +++++++ .../common/src/CoroutineContext.common.kt | 74 +++++ .../src/CoroutineContext.sharedWithJvm.kt | 85 +++++ .../common/src/ThreadContextElement.common.kt | 187 +++++++++++ .../src/internal/ThreadContext.common.kt | 91 ++++++ .../common/test/ThreadContextElementTest.kt | 146 +++++++++ .../test/ThreadContextMutableCopiesTest.kt | 0 .../ThreadContextElementConcurrentTest.kt | 171 ++++++++++ .../jvm/src/CoroutineContext.kt | 273 ---------------- .../jvm/src/ThreadContextElement.kt | 184 ----------- .../jvm/src/internal/ThreadContext.kt | 88 ----- .../jvm/test/ThreadContextElementTest.kt | 301 ------------------ 12 files changed, 874 insertions(+), 846 deletions(-) create mode 100644 kotlinx-coroutines-core/common/src/CoroutineContext.sharedWithJvm.kt create mode 100644 kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt create mode 100644 kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt rename kotlinx-coroutines-core/{jvm => common}/test/ThreadContextMutableCopiesTest.kt (100%) create mode 100644 kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt diff --git a/kotlinx-coroutines-core/common/src/Builders.common.kt b/kotlinx-coroutines-core/common/src/Builders.common.kt index 23ef7665b5..165a12c705 100644 --- a/kotlinx-coroutines-core/common/src/Builders.common.kt +++ b/kotlinx-coroutines-core/common/src/Builders.common.kt @@ -211,6 +211,126 @@ internal expect class UndispatchedCoroutine( uCont: Continuation ) : ScopeCoroutine +// Used by withContext when context changes, but dispatcher stays the same +internal actual class UndispatchedCoroutineactual constructor ( + context: CoroutineContext, + uCont: Continuation +) : ScopeCoroutine(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) { + + /** + * The state of [ThreadContextElement]s associated with the current undispatched coroutine. + * It is stored in a thread local because this coroutine can be used concurrently in suspend-resume race scenario. + * See the followin, boiled down example with inlined `withContinuationContext` body: + * ``` + * val state = saveThreadContext(ctx) + * try { + * invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called + * // COROUTINE_SUSPENDED is returned + * } finally { + * thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread + * // and it also calls saveThreadContext and clearThreadContext + * } + * ``` + * + * Usage note: + * + * This part of the code is performance-sensitive. + * It is a well-established pattern to wrap various activities into system-specific undispatched + * `withContext` for the sake of logging, MDC, tracing etc., meaning that there exists thousands of + * undispatched coroutines. + * Each access to Java's [ThreadLocal] leaves a footprint in the corresponding Thread's `ThreadLocalMap` + * that is cleared automatically as soon as the associated thread-local (-> UndispatchedCoroutine) is garbage collected + * when either the corresponding thread is GC'ed or it cleans up its stale entries on other TL accesses. + * When such coroutines are promoted to old generation, `ThreadLocalMap`s become bloated and an arbitrary accesses to thread locals + * start to consume significant amount of CPU because these maps are open-addressed and cleaned up incrementally on each access. + * (You can read more about this effect as "GC nepotism"). + * + * To avoid that, we attempt to narrow down the lifetime of this thread local as much as possible: + * - It's never accessed when we are sure there are no thread context elements + * - It's cleaned up via [ThreadLocal.remove] as soon as the coroutine is suspended or finished. + */ + private val threadStateToRecover = ThreadLocal>() + + /* + * Indicates that a coroutine has at least one thread context element associated with it + * and that 'threadStateToRecover' is going to be set in case of dispatchhing in order to preserve them. + * Better than nullable thread-local for easier debugging. + * + * It is used as a performance optimization to avoid 'threadStateToRecover' initialization + * (note: tl.get() initializes thread local), + * and is prone to false-positives as it is never reset: otherwise + * it may lead to logical data races between suspensions point where + * coroutine is yet being suspended in one thread while already being resumed + * in another. + */ + @Volatile + private var threadLocalIsSet = false + + init { + /* + * This is a hack for a very specific case in #2930 unless #3253 is implemented. + * 'ThreadLocalStressTest' covers this change properly. + * + * The scenario this change covers is the following: + * 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function, + * e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking + * `withContext(tlElement)` which creates `UndispatchedCoroutine`. + * 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()` + * and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both + * do thread context element tracking. + * 3) So thread locals never got chance to get properly set up via `saveThreadContext`, + * but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`. + * + * Here we detect precisely this situation and properly setup context to recover later. + * + */ + if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) { + /* + * We cannot just "read" the elements as there is no such API, + * so we update-restore it immediately and use the intermediate value + * as the initial state, leveraging the fact that thread context element + * is idempotent and such situations are increasingly rare. + */ + val values = updateThreadContext(context, null) + restoreThreadContext(context, values) + saveThreadContext(context, values) + } + } + + fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { + threadLocalIsSet = true // Specify that thread-local is touched at all + threadStateToRecover.set(context to oldValue) + } + + fun clearThreadContext(): Boolean { + return !(threadLocalIsSet && threadStateToRecover.get() == null).also { + threadStateToRecover.remove() + } + } + + override fun afterCompletionUndispatched() { + clearThreadLocal() + } + + override fun afterResume(state: Any?) { + clearThreadLocal() + // resume undispatched -- update context but stay on the same dispatcher + val result = recoverResult(state, uCont) + withContinuationContext(uCont, null) { + uCont.resumeWith(result) + } + } + + private fun clearThreadLocal() { + if (threadLocalIsSet) { + threadStateToRecover.get()?.let { (ctx, value) -> + restoreThreadContext(ctx, value) + } + threadStateToRecover.remove() + } + } +} + private const val UNDECIDED = 0 private const val SUSPENDED = 1 private const val RESUMED = 2 diff --git a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt index 48e59fe3a9..bc932d32eb 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt @@ -1,5 +1,6 @@ package kotlinx.coroutines +import kotlinx.coroutines.internal.* import kotlin.coroutines.* /** @@ -25,3 +26,76 @@ internal expect inline fun withCoroutineContext(context: CoroutineContext, c internal expect inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T internal expect fun Continuation<*>.toDebugString(): String internal expect val CoroutineContext.coroutineName: String? + +/** + * Executes a block using a given coroutine context. + */ +internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { + val oldValue = updateThreadContext(context, countOrElement) + try { + return block() + } finally { + restoreThreadContext(context, oldValue) + } +} + +/** + * Executes a block using a context of a given continuation. + */ +internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { + val context = continuation.context + val oldValue = updateThreadContext(context, countOrElement) + val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { + // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them + continuation.updateUndispatchedCompletion(context, oldValue) + } else { + null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context + } + try { + return block() + } finally { + if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) { + restoreThreadContext(context, oldValue) + } + } +} + +internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { + if (this !is CoroutineStackFrame) return null + /* + * Fast-path to detect whether we have undispatched coroutine at all in our stack. + * + * Implementation note. + * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: + * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance + * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker` + * from the context when creating dispatched coroutine in `withContext`. + * Another option is to "unmark it" instead of removing to save an allocation. + * Both options should work, but it requires more careful studying of the performance + * and, mostly, maintainability impact. + */ + val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null + if (!potentiallyHasUndispatchedCoroutine) return null + val completion = undispatchedCompletion() + completion?.saveThreadContext(context, oldValue) + return completion +} + +internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { + // Find direct completion of this continuation + val completion: CoroutineStackFrame = when (this) { + is DispatchedCoroutine<*> -> return null + else -> callerFrame ?: return null // something else -- not supported + } + if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine! + return completion.undispatchedCompletion() // walk up the call stack with tail call +} + +/** + * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack. + * Used as a performance optimization to avoid stack walking where it is not necessary. + */ +private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { + override val key: CoroutineContext.Key<*> + get() = this +} diff --git a/kotlinx-coroutines-core/common/src/CoroutineContext.sharedWithJvm.kt b/kotlinx-coroutines-core/common/src/CoroutineContext.sharedWithJvm.kt new file mode 100644 index 0000000000..d293b223cf --- /dev/null +++ b/kotlinx-coroutines-core/common/src/CoroutineContext.sharedWithJvm.kt @@ -0,0 +1,85 @@ +package kotlinx.coroutines + +import kotlin.coroutines.CoroutineContext +import kotlin.coroutines.EmptyCoroutineContext + + +/** + * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or + * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on) + * and copyable-thread-local facilities on JVM. + * See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM. + */ +@ExperimentalCoroutinesApi +public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext { + val combined = foldCopies(coroutineContext, context, true) + val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined + return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null) + debug + Dispatchers.Default else debug +} + +/** + * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext]. + * @suppress + */ +@InternalCoroutinesApi +public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { + /* + * Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements) + * contains copyable elements. + */ + if (!addedContext.hasCopyableElements()) return this + addedContext + return foldCopies(this, addedContext, false) +} + +private fun CoroutineContext.hasCopyableElements(): Boolean = + fold(false) { result, it -> result || it is CopyableThreadContextElement<*> } + +/** + * Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary. + * The rules are the following: + * - If neither context has CTCE, the sum of two contexts is returned + * - Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context + * is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`. + * - Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild] + * - Every CTCE from the right-hand side context that hasn't been merged is copied + * - Everything else is added to the resulting context as is. + */ +private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext { + // Do we have something to copy left-hand side? + val hasElementsLeft = originalContext.hasCopyableElements() + val hasElementsRight = appendContext.hasCopyableElements() + + // Nothing to fold, so just return the sum of contexts + if (!hasElementsLeft && !hasElementsRight) { + return originalContext + appendContext + } + + var leftoverContext = appendContext + val folded = originalContext.fold(EmptyCoroutineContext) { result, element -> + if (element !is CopyableThreadContextElement<*>) return@fold result + element + // Will this element be overwritten? + val newElement = leftoverContext[element.key] + // No, just copy it + if (newElement == null) { + // For 'withContext'-like builders we do not copy as the element is not shared + return@fold result + if (isNewCoroutine) element.copyForChild() else element + } + // Yes, then first remove the element from append context + leftoverContext = leftoverContext.minusKey(element.key) + // Return the sum + @Suppress("UNCHECKED_CAST") + return@fold result + (element as CopyableThreadContextElement).mergeForChild(newElement) + } + + if (hasElementsRight) { + leftoverContext = leftoverContext.fold(EmptyCoroutineContext) { result, element -> + // We're appending new context element -- we have to copy it, otherwise it may be shared with others + if (element is CopyableThreadContextElement<*>) { + return@fold result + element.copyForChild() + } + return@fold result + element + } + } + return folded + leftoverContext +} diff --git a/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt new file mode 100644 index 0000000000..0fe5bc9351 --- /dev/null +++ b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt @@ -0,0 +1,187 @@ +package kotlinx.coroutines + +import kotlin.coroutines.* + +/** + * Defines elements in a [CoroutineContext] that are installed into the thread context + * every time the coroutine with this element in the context is resumed on a thread. + * + * Implementations of this interface define a type [S] of the thread-local state that they need to store + * upon resuming a coroutine and restore later upon suspension. + * The infrastructure provides the corresponding storage. + * + * Example usage looks like this: + * + * ``` + * // Appends "name" of a coroutine to a current thread name when coroutine is executed + * class CoroutineName(val name: String) : ThreadContextElement { + * // declare companion object for a key of this element in coroutine context + * companion object Key : CoroutineContext.Key + * + * // provide the key of the corresponding context element + * override val key: CoroutineContext.Key + * get() = Key + * + * // this is invoked before coroutine is resumed on current thread + * override fun updateThreadContext(context: CoroutineContext): String { + * val previousName = Thread.currentThread().name + * Thread.currentThread().name = "$previousName # $name" + * return previousName + * } + * + * // this is invoked after coroutine has suspended on current thread + * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { + * Thread.currentThread().name = oldState + * } + * } + * + * // Usage + * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... } + * ``` + * + * Every time this coroutine is resumed on a thread, UI thread name is updated to + * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when + * this coroutine suspends. + * + * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. + * + * ### Reentrancy and thread-safety + * + * Correct implementations of this interface must expect that calls to [restoreThreadContext] + * may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations. + * See [CopyableThreadContextElement] for advanced interleaving details. + * + * All implementations of [ThreadContextElement] should be thread-safe and guard their internal mutable state + * within an element accordingly. + */ +public interface ThreadContextElement : CoroutineContext.Element { + /** + * Updates context of the current thread. + * This function is invoked before the coroutine in the specified [context] is resumed in the current thread + * when the context of the coroutine this element. + * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. + * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which + * context is updated in an undefined state and may crash an application. + * + * @param context the coroutine context. + */ + public fun updateThreadContext(context: CoroutineContext): S + + /** + * Restores context of the current thread. + * This function is invoked after the coroutine in the specified [context] is suspended in the current thread + * if [updateThreadContext] was previously invoked on resume of this coroutine. + * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should + * be restored in the thread-local state by this function. + * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which + * context is updated in an undefined state and may crash an application. + * + * @param context the coroutine context. + * @param oldState the value returned by the previous invocation of [updateThreadContext]. + */ + public fun restoreThreadContext(context: CoroutineContext, oldState: S) +} + +/** + * A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it. + * + * When an API uses a _mutable_ [ThreadLocal] for consistency, a [CopyableThreadContextElement] + * can give coroutines "coroutine-safe" write access to that `ThreadLocal`. + * + * A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine + * will be visible to _itself_ and any child coroutine launched _after_ that write. + * + * Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen + * to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_ + * launching a child coroutine will not be visible to that child coroutine. + * + * This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and + * correctly, regardless of the coroutine's structured concurrency. + * + * This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace + * is in a coroutine: + * + * ``` + * class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement { + * companion object Key : CoroutineContext.Key + * + * override val key: CoroutineContext.Key = Key + * + * override fun updateThreadContext(context: CoroutineContext): TraceData? { + * val oldState = traceThreadLocal.get() + * traceThreadLocal.set(traceData) + * return oldState + * } + * + * override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) { + * traceThreadLocal.set(oldState) + * } + * + * override fun copyForChild(): TraceContextElement { + * // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes + * // ThreadLocal writes between resumption of the parent coroutine and the launch of the + * // child coroutine visible to the child. + * return TraceContextElement(traceThreadLocal.get()?.copy()) + * } + * + * override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext { + * // Merge operation defines how to handle situations when both + * // the parent coroutine has an element in the context and + * // an element with the same key was also + * // explicitly passed to the child coroutine. + * // If merging does not require special behavior, + * // the copy of the element can be returned. + * return TraceContextElement(traceThreadLocal.get()?.copy()) + * } + * } + * ``` + * + * A coroutine using this mechanism can safely call Java code that assumes the corresponding thread local element's + * value is installed into the target thread local. + * + * ### Reentrancy and thread-safety + * + * Correct implementations of this interface must expect that calls to [restoreThreadContext] + * may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations. + * + * Even though an element is copied for each child coroutine, an implementation should be able to handle the following + * interleaving when a coroutine with the corresponding element is launched on a multithreaded dispatcher: + * + * ``` + * coroutine.updateThreadContext() // Thread #1 + * ... coroutine body ... + * // suspension + immediate dispatch happen here + * coroutine.updateThreadContext() // Thread #2, coroutine is already resumed + * // ... coroutine body after suspension point on Thread #2 ... + * coroutine.restoreThreadContext() // Thread #1, is invoked late because Thread #1 is slow + * coroutine.restoreThreadContext() // Thread #2, may happen in parallel with the previous restore + * ``` + * + * All implementations of [CopyableThreadContextElement] should be thread-safe and guard their internal mutable state + * within an element accordingly. + */ +@DelicateCoroutinesApi +@ExperimentalCoroutinesApi +public interface CopyableThreadContextElement : ThreadContextElement { + + /** + * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child + * coroutine's context that is under construction if the added context does not contain an element with the same [key]. + * + * This function is called on the element each time a new coroutine inherits a context containing it, + * and the returned value is folded into the context given to the child. + * + * Since this method is called whenever a new coroutine is launched in a context containing this + * [CopyableThreadContextElement], implementations are performance-sensitive. + */ + public fun copyForChild(): CopyableThreadContextElement + + /** + * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child + * coroutine's context that is under construction if the added context does contain an element with the same [key]. + * + * This method is invoked on the original element, accepting as the parameter + * the element that is supposed to overwrite it. + */ + public fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext +} diff --git a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt index c52d35c128..2dc52ec7e1 100644 --- a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt +++ b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt @@ -1,5 +1,96 @@ package kotlinx.coroutines.internal +import kotlinx.coroutines.ThreadContextElement import kotlin.coroutines.* +import kotlin.jvm.JvmField internal expect fun threadContextElements(context: CoroutineContext): Any + +@JvmField +internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") + +// Used when there are >= 2 active elements in the context +@Suppress("UNCHECKED_CAST") +private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { + private val values = arrayOfNulls(n) + private val elements = arrayOfNulls>(n) + private var i = 0 + + fun append(element: ThreadContextElement<*>, value: Any?) { + values[i] = value + elements[i++] = element as ThreadContextElement + } + + fun restore(context: CoroutineContext) { + for (i in elements.indices.reversed()) { + elements[i]!!.restoreThreadContext(context, values[i]) + } + } +} + +// Counts ThreadContextElements in the context +// Any? here is Int | ThreadContextElement (when count is one) +private val countAll = + fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { + if (element is ThreadContextElement<*>) { + val inCount = countOrElement as? Int ?: 1 + return if (inCount == 0) element else inCount + 1 + } + return countOrElement + } + +// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one +private val findOne = + fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { + if (found != null) return found + return element as? ThreadContextElement<*> + } + +// Updates state for ThreadContextElements in the context using the given ThreadState +private val updateState = + fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { + if (element is ThreadContextElement<*>) { + state.append(element, element.updateThreadContext(state.context)) + } + return state + } + +internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! + +// countOrElement is pre-cached in dispatched continuation +// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements +internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { + @Suppress("NAME_SHADOWING") + val countOrElement = countOrElement ?: threadContextElements(context) + @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") + return when { + countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements + // ^^^ identity comparison for speed, we know zero always has the same identity + countOrElement is Int -> { + // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values + context.fold(ThreadState(context, countOrElement), updateState) + } + else -> { + // fast path for one ThreadContextElement (no allocations, no additional context scan) + @Suppress("UNCHECKED_CAST") + val element = countOrElement as ThreadContextElement + element.updateThreadContext(context) + } + } +} + +internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { + when { + oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements + oldState is ThreadState -> { + // slow path with multiple stored ThreadContextElements + oldState.restore(context) + } + else -> { + // fast path for one ThreadContextElement, but need to find it + @Suppress("UNCHECKED_CAST") + val element = context.fold(null, findOne) as ThreadContextElement + element.restoreThreadContext(context, oldState) + } + } +} diff --git a/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt new file mode 100644 index 0000000000..75f92e656c --- /dev/null +++ b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt @@ -0,0 +1,146 @@ +package kotlinx.coroutines + +import kotlinx.coroutines.testing.* +import kotlin.test.* +import kotlinx.coroutines.flow.* +import kotlin.coroutines.* + +class ThreadContextElementTest: TestBase() { + + @Test + fun testUndispatched() = runTest { + val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! + val data = MyData() + val element = MyElement(data) + val job = GlobalScope.launch( + context = Dispatchers.Default + exceptionHandler + element, + start = CoroutineStart.UNDISPATCHED + ) { + assertSame(data, myThreadLocal.get()) + yield() + assertSame(data, myThreadLocal.get()) + } + assertNull(myThreadLocal.get()) + job.join() + assertNull(myThreadLocal.get()) + } + + /** + * For stability of the test, it is important to make sure that + * the parent job actually suspends when calling + * `withContext(dispatcher2 + CoroutineName("dispatched"))`. + * + * Here this requirement is fulfilled by forcing execution on a single thread. + * However, dispatching is performed with two non-equal dispatchers to force dispatching. + * + * Suspend of the parent coroutine [kotlinx.coroutines.DispatchedCoroutine.trySuspend] is out of the control of the test, + * while being executed concurrently with resume of the child coroutine [kotlinx.coroutines.DispatchedCoroutine.tryResume]. + */ + @Test + fun testWithContextJobAccess() = runTest { + val executor = Executors.newSingleThreadExecutor() + // Emulate non-equal dispatchers + val executor1 = object : ExecutorService by executor {} + val executor2 = object : ExecutorService by executor {} + val dispatcher1 = executor1.asCoroutineDispatcher() + val dispatcher2 = executor2.asCoroutineDispatcher() + val captor = JobCaptor() + val manuallyCaptured = mutableListOf() + + fun registerUpdate(job: Job?) = manuallyCaptured.add("Update: $job") + fun registerRestore(job: Job?) = manuallyCaptured.add("Restore: $job") + + var rootJob: Job? = null + runBlocking(captor + dispatcher1) { + rootJob = coroutineContext.job + registerUpdate(rootJob) + var undispatchedJob: Job? = null + withContext(CoroutineName("undispatched")) { + undispatchedJob = coroutineContext.job + registerUpdate(undispatchedJob) + // These 2 restores and the corresponding next 2 updates happen only if the following `withContext` + // call actually suspends. + registerRestore(undispatchedJob) + registerRestore(rootJob) + // Without forcing of single backing thread the code inside `withContext` + // may already complete at the moment when the parent coroutine decides + // whether it needs to suspend or not. + var dispatchedJob: Job? = null + withContext(dispatcher2 + CoroutineName("dispatched")) { + dispatchedJob = coroutineContext.job + registerUpdate(dispatchedJob) + } + registerRestore(dispatchedJob) + // Context restored, captured again + registerUpdate(undispatchedJob) + } + registerRestore(undispatchedJob) + // Context restored, captured again + registerUpdate(rootJob) + } + registerRestore(rootJob) + + // Restores may be called concurrently to the update calls in other threads, so their order is not checked. + val expected = manuallyCaptured.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") + val actual = captor.capturees.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") + assertEquals(expected, actual) + executor.shutdownNow() + } + + @Test + fun testThreadLocalFlowOn() = runTest { + val myData = MyData() + myThreadLocal.set(myData) + expect(1) + flow { + assertEquals(myData, myThreadLocal.get()) + emit(1) + } + .flowOn(myThreadLocal.asContextElement() + Dispatchers.Default) + .single() + myThreadLocal.set(null) + finish(2) + } +} + +class MyData + +class JobCaptor(val capturees: MutableList = CopyOnWriteArrayList()) : ThreadContextElement { + + companion object Key : CoroutineContext.Key + + override val key: CoroutineContext.Key<*> get() = Key + + override fun updateThreadContext(context: CoroutineContext) { + capturees.add("Update: ${context.job}") + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: Unit) { + capturees.add("Restore: ${context.job}") + } +} + +// declare thread local variable holding MyData +private val myThreadLocal = ThreadLocal() + +// declare context element holding MyData +class MyElement(val data: MyData) : ThreadContextElement { + // declare companion object for a key of this element in coroutine context + companion object Key : CoroutineContext.Key + + // provide the key of the corresponding context element + override val key: CoroutineContext.Key + get() = Key + + // this is invoked before coroutine is resumed on current thread + override fun updateThreadContext(context: CoroutineContext): MyData? { + val oldState = myThreadLocal.get() + myThreadLocal.set(data) + return oldState + } + + // this is invoked after coroutine has suspended on current thread + override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { + myThreadLocal.set(oldState) + } +} diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt b/kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt similarity index 100% rename from kotlinx-coroutines-core/jvm/test/ThreadContextMutableCopiesTest.kt rename to kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt diff --git a/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt b/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt new file mode 100644 index 0000000000..4f66fef813 --- /dev/null +++ b/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt @@ -0,0 +1,171 @@ +package kotlinx.coroutines + +import kotlinx.coroutines.testing.* +import kotlin.test.* + +class ThreadContextElementConcurrentTest: TestBase() { + + @Test + fun testWithContext() = runTest { + expect(1) + newSingleThreadContext("withContext").use { + val data = MyData() + GlobalScope.async(Dispatchers.Default + MyElement(data)) { + assertSame(data, myThreadLocal.get()) + expect(2) + + val newData = MyData() + GlobalScope.async(it + MyElement(newData)) { + assertSame(newData, myThreadLocal.get()) + expect(3) + }.await() + + withContext(it + MyElement(newData)) { + assertSame(newData, myThreadLocal.get()) + expect(4) + } + + GlobalScope.async(it) { + assertNull(myThreadLocal.get()) + expect(5) + }.await() + + expect(6) + }.await() + } + + finish(7) + } + + @Test + fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest { + var parentElement: MyElement? = null + var inheritedElement: MyElement? = null + + newSingleThreadContext("withContext").use { + withContext(it + MyElement(MyData())) { + parentElement = coroutineContext[MyElement.Key] + launch { + inheritedElement = coroutineContext[MyElement.Key] + } + } + } + + assertSame(inheritedElement, parentElement, + "Inner and outer coroutines did not have the same object reference to a" + + " ThreadContextElement that did not override `copyForChildCoroutine()`") + } + + @Test + fun testCopyableElementCopiedOnLaunch() = runTest { + var parentElement: CopyForChildCoroutineElement? = null + var inheritedElement: CopyForChildCoroutineElement? = null + + newSingleThreadContext("withContext").use { + withContext(it + CopyForChildCoroutineElement(MyData())) { + parentElement = coroutineContext[CopyForChildCoroutineElement.Key] + launch { + inheritedElement = coroutineContext[CopyForChildCoroutineElement.Key] + } + } + } + + assertNotSame(inheritedElement, parentElement, + "Inner coroutine did not copy its copyable ThreadContextElement.") + } + + @Test + fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest { + newFixedThreadPoolContext(nThreads = 4, name = "withContext").use { + withContext(it + CopyForChildCoroutineElement(MyData())) { + val forBlockData = MyData() + myThreadLocal.setForBlock(forBlockData) { + assertSame(myThreadLocal.get(), forBlockData) + launch { + assertSame(myThreadLocal.get(), forBlockData) + } + launch { + assertSame(myThreadLocal.get(), forBlockData) + // Modify value in child coroutine. Writes to the ThreadLocal and + // the (copied) ThreadLocalElement's memory are not visible to peer or + // ancestor coroutines, so this write is both threadsafe and coroutinesafe. + val innerCoroutineData = MyData() + myThreadLocal.setForBlock(innerCoroutineData) { + assertSame(myThreadLocal.get(), innerCoroutineData) + } + assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored. + } + launch { + val innerCoroutineData = MyData() + myThreadLocal.setForBlock(innerCoroutineData) { + assertSame(myThreadLocal.get(), innerCoroutineData) + } + assertSame(myThreadLocal.get(), forBlockData) + } + } + assertNull(myThreadLocal.get()) // Asserts value was restored to its origin + } + } + } +} + + +/** + * A [ThreadContextElement] that implements copy semantics in [copyForChild]. + */ +class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement { + companion object Key : CoroutineContext.Key + + override val key: CoroutineContext.Key + get() = Key + + override fun updateThreadContext(context: CoroutineContext): MyData? { + val oldState = myThreadLocal.get() + myThreadLocal.set(data) + return oldState + } + + override fun mergeForChild(overwritingElement: CoroutineContext.Element): CopyForChildCoroutineElement { + TODO("Not used in tests") + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { + myThreadLocal.set(oldState) + } + + /** + * At coroutine launch time, the _current value of the ThreadLocal_ is inherited by the new + * child coroutine, and that value is copied to a new, unique, ThreadContextElement memory + * reference for the child coroutine to use uniquely. + * + * n.b. the value copied to the child must be the __current value of the ThreadLocal__ and not + * the value initially passed to the ThreadContextElement in order to reflect writes made to the + * ThreadLocal between coroutine resumption and the child coroutine launch point. Those writes + * will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the + * thread and calls [restoreThreadContext]. + */ + override fun copyForChild(): CopyForChildCoroutineElement { + return CopyForChildCoroutineElement(myThreadLocal.get()) + } +} + +/** + * Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block]. + * + * When a [CopyForChildCoroutineElement] for `this` [ThreadLocal] is used within a + * [CoroutineContext], a ThreadLocal set this way will have the "correct" value expected lexically + * at every statement reached, whether that statement is reached immediately, across suspend and + * redispatch within one coroutine, or within a child coroutine. Writes made to the `ThreadLocal` + * by child coroutines will not be visible to the parent coroutine. Writes made to the `ThreadLocal` + * by the parent coroutine _after_ launching a child coroutine will not be visible to that child + * coroutine. + */ +private inline fun ThreadLocal.setForBlock( + value: ThreadLocalT, + crossinline block: () -> OutputT +) { + val priorValue = get() + set(value) + block() + set(priorValue) +} diff --git a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt index 7628d6ac85..8fde7d8b01 100644 --- a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt @@ -4,279 +4,6 @@ import kotlinx.coroutines.internal.* import kotlin.coroutines.* import kotlin.coroutines.jvm.internal.CoroutineStackFrame -/** - * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or - * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on) - * and copyable-thread-local facilities on JVM. - * See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM. - */ -@ExperimentalCoroutinesApi -public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext { - val combined = foldCopies(coroutineContext, context, true) - val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined - return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null) - debug + Dispatchers.Default else debug -} - -/** - * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext]. - * @suppress - */ -@InternalCoroutinesApi -public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { - /* - * Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements) - * contains copyable elements. - */ - if (!addedContext.hasCopyableElements()) return this + addedContext - return foldCopies(this, addedContext, false) -} - -private fun CoroutineContext.hasCopyableElements(): Boolean = - fold(false) { result, it -> result || it is CopyableThreadContextElement<*> } - -/** - * Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary. - * The rules are the following: - * - If neither context has CTCE, the sum of two contexts is returned - * - Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context - * is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`. - * - Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild] - * - Every CTCE from the right-hand side context that hasn't been merged is copied - * - Everything else is added to the resulting context as is. - */ -private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext { - // Do we have something to copy left-hand side? - val hasElementsLeft = originalContext.hasCopyableElements() - val hasElementsRight = appendContext.hasCopyableElements() - - // Nothing to fold, so just return the sum of contexts - if (!hasElementsLeft && !hasElementsRight) { - return originalContext + appendContext - } - - var leftoverContext = appendContext - val folded = originalContext.fold(EmptyCoroutineContext) { result, element -> - if (element !is CopyableThreadContextElement<*>) return@fold result + element - // Will this element be overwritten? - val newElement = leftoverContext[element.key] - // No, just copy it - if (newElement == null) { - // For 'withContext'-like builders we do not copy as the element is not shared - return@fold result + if (isNewCoroutine) element.copyForChild() else element - } - // Yes, then first remove the element from append context - leftoverContext = leftoverContext.minusKey(element.key) - // Return the sum - @Suppress("UNCHECKED_CAST") - return@fold result + (element as CopyableThreadContextElement).mergeForChild(newElement) - } - - if (hasElementsRight) { - leftoverContext = leftoverContext.fold(EmptyCoroutineContext) { result, element -> - // We're appending new context element -- we have to copy it, otherwise it may be shared with others - if (element is CopyableThreadContextElement<*>) { - return@fold result + element.copyForChild() - } - return@fold result + element - } - } - return folded + leftoverContext -} - -/** - * Executes a block using a given coroutine context. - */ -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { - val oldValue = updateThreadContext(context, countOrElement) - try { - return block() - } finally { - restoreThreadContext(context, oldValue) - } -} - -/** - * Executes a block using a context of a given continuation. - */ -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { - val context = continuation.context - val oldValue = updateThreadContext(context, countOrElement) - val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { - // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them - continuation.updateUndispatchedCompletion(context, oldValue) - } else { - null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context - } - try { - return block() - } finally { - if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) { - restoreThreadContext(context, oldValue) - } - } -} - -internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { - if (this !is CoroutineStackFrame) return null - /* - * Fast-path to detect whether we have undispatched coroutine at all in our stack. - * - * Implementation note. - * If we ever find that stackwalking for thread-locals is way too slow, here is another idea: - * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance - * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker` - * from the context when creating dispatched coroutine in `withContext`. - * Another option is to "unmark it" instead of removing to save an allocation. - * Both options should work, but it requires more careful studying of the performance - * and, mostly, maintainability impact. - */ - val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null - if (!potentiallyHasUndispatchedCoroutine) return null - val completion = undispatchedCompletion() - completion?.saveThreadContext(context, oldValue) - return completion -} - -internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { - // Find direct completion of this continuation - val completion: CoroutineStackFrame = when (this) { - is DispatchedCoroutine<*> -> return null - else -> callerFrame ?: return null // something else -- not supported - } - if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine! - return completion.undispatchedCompletion() // walk up the call stack with tail call -} - -/** - * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack. - * Used as a performance optimization to avoid stack walking where it is not necessary. - */ -private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { - override val key: CoroutineContext.Key<*> - get() = this -} - -// Used by withContext when context changes, but dispatcher stays the same -internal actual class UndispatchedCoroutineactual constructor ( - context: CoroutineContext, - uCont: Continuation -) : ScopeCoroutine(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) { - - /** - * The state of [ThreadContextElement]s associated with the current undispatched coroutine. - * It is stored in a thread local because this coroutine can be used concurrently in suspend-resume race scenario. - * See the followin, boiled down example with inlined `withContinuationContext` body: - * ``` - * val state = saveThreadContext(ctx) - * try { - * invokeSmthWithThisCoroutineAsCompletion() // Completion implies that 'afterResume' will be called - * // COROUTINE_SUSPENDED is returned - * } finally { - * thisCoroutine().clearThreadContext() // Concurrently the "smth" could've been already resumed on a different thread - * // and it also calls saveThreadContext and clearThreadContext - * } - * ``` - * - * Usage note: - * - * This part of the code is performance-sensitive. - * It is a well-established pattern to wrap various activities into system-specific undispatched - * `withContext` for the sake of logging, MDC, tracing etc., meaning that there exists thousands of - * undispatched coroutines. - * Each access to Java's [ThreadLocal] leaves a footprint in the corresponding Thread's `ThreadLocalMap` - * that is cleared automatically as soon as the associated thread-local (-> UndispatchedCoroutine) is garbage collected - * when either the corresponding thread is GC'ed or it cleans up its stale entries on other TL accesses. - * When such coroutines are promoted to old generation, `ThreadLocalMap`s become bloated and an arbitrary accesses to thread locals - * start to consume significant amount of CPU because these maps are open-addressed and cleaned up incrementally on each access. - * (You can read more about this effect as "GC nepotism"). - * - * To avoid that, we attempt to narrow down the lifetime of this thread local as much as possible: - * - It's never accessed when we are sure there are no thread context elements - * - It's cleaned up via [ThreadLocal.remove] as soon as the coroutine is suspended or finished. - */ - private val threadStateToRecover = ThreadLocal>() - - /* - * Indicates that a coroutine has at least one thread context element associated with it - * and that 'threadStateToRecover' is going to be set in case of dispatchhing in order to preserve them. - * Better than nullable thread-local for easier debugging. - * - * It is used as a performance optimization to avoid 'threadStateToRecover' initialization - * (note: tl.get() initializes thread local), - * and is prone to false-positives as it is never reset: otherwise - * it may lead to logical data races between suspensions point where - * coroutine is yet being suspended in one thread while already being resumed - * in another. - */ - @Volatile - private var threadLocalIsSet = false - - init { - /* - * This is a hack for a very specific case in #2930 unless #3253 is implemented. - * 'ThreadLocalStressTest' covers this change properly. - * - * The scenario this change covers is the following: - * 1) The coroutine is being started as plain non kotlinx.coroutines related suspend function, - * e.g. `suspend fun main` or, more importantly, Ktor `SuspendFunGun`, that is invoking - * `withContext(tlElement)` which creates `UndispatchedCoroutine`. - * 2) It (original continuation) is then not wrapped into `DispatchedContinuation` via `intercept()` - * and goes neither through `DC.run` nor through `resumeUndispatchedWith` that both - * do thread context element tracking. - * 3) So thread locals never got chance to get properly set up via `saveThreadContext`, - * but when `withContext` finishes, it attempts to recover thread locals in its `afterResume`. - * - * Here we detect precisely this situation and properly setup context to recover later. - * - */ - if (uCont.context[ContinuationInterceptor] !is CoroutineDispatcher) { - /* - * We cannot just "read" the elements as there is no such API, - * so we update-restore it immediately and use the intermediate value - * as the initial state, leveraging the fact that thread context element - * is idempotent and such situations are increasingly rare. - */ - val values = updateThreadContext(context, null) - restoreThreadContext(context, values) - saveThreadContext(context, values) - } - } - - fun saveThreadContext(context: CoroutineContext, oldValue: Any?) { - threadLocalIsSet = true // Specify that thread-local is touched at all - threadStateToRecover.set(context to oldValue) - } - - fun clearThreadContext(): Boolean { - return !(threadLocalIsSet && threadStateToRecover.get() == null).also { - threadStateToRecover.remove() - } - } - - override fun afterCompletionUndispatched() { - clearThreadLocal() - } - - override fun afterResume(state: Any?) { - clearThreadLocal() - // resume undispatched -- update context but stay on the same dispatcher - val result = recoverResult(state, uCont) - withContinuationContext(uCont, null) { - uCont.resumeWith(result) - } - } - - private fun clearThreadLocal() { - if (threadLocalIsSet) { - threadStateToRecover.get()?.let { (ctx, value) -> - restoreThreadContext(ctx, value) - } - threadStateToRecover.remove() - } - } -} - internal actual val CoroutineContext.coroutineName: String? get() { if (!DEBUG) return null val coroutineId = this[CoroutineId] ?: return null diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt index 5015a259f0..efe879b6c6 100644 --- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt +++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt @@ -3,190 +3,6 @@ package kotlinx.coroutines import kotlinx.coroutines.internal.* import kotlin.coroutines.* -/** - * Defines elements in a [CoroutineContext] that are installed into the thread context - * every time the coroutine with this element in the context is resumed on a thread. - * - * Implementations of this interface define a type [S] of the thread-local state that they need to store - * upon resuming a coroutine and restore later upon suspension. - * The infrastructure provides the corresponding storage. - * - * Example usage looks like this: - * - * ``` - * // Appends "name" of a coroutine to a current thread name when coroutine is executed - * class CoroutineName(val name: String) : ThreadContextElement { - * // declare companion object for a key of this element in coroutine context - * companion object Key : CoroutineContext.Key - * - * // provide the key of the corresponding context element - * override val key: CoroutineContext.Key - * get() = Key - * - * // this is invoked before coroutine is resumed on current thread - * override fun updateThreadContext(context: CoroutineContext): String { - * val previousName = Thread.currentThread().name - * Thread.currentThread().name = "$previousName # $name" - * return previousName - * } - * - * // this is invoked after coroutine has suspended on current thread - * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { - * Thread.currentThread().name = oldState - * } - * } - * - * // Usage - * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... } - * ``` - * - * Every time this coroutine is resumed on a thread, UI thread name is updated to - * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when - * this coroutine suspends. - * - * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. - * - * ### Reentrancy and thread-safety - * - * Correct implementations of this interface must expect that calls to [restoreThreadContext] - * may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations. - * See [CopyableThreadContextElement] for advanced interleaving details. - * - * All implementations of [ThreadContextElement] should be thread-safe and guard their internal mutable state - * within an element accordingly. - */ -public interface ThreadContextElement : CoroutineContext.Element { - /** - * Updates context of the current thread. - * This function is invoked before the coroutine in the specified [context] is resumed in the current thread - * when the context of the coroutine this element. - * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. - * - * @param context the coroutine context. - */ - public fun updateThreadContext(context: CoroutineContext): S - - /** - * Restores context of the current thread. - * This function is invoked after the coroutine in the specified [context] is suspended in the current thread - * if [updateThreadContext] was previously invoked on resume of this coroutine. - * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should - * be restored in the thread-local state by this function. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. - * - * @param context the coroutine context. - * @param oldState the value returned by the previous invocation of [updateThreadContext]. - */ - public fun restoreThreadContext(context: CoroutineContext, oldState: S) -} - -/** - * A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it. - * - * When an API uses a _mutable_ [ThreadLocal] for consistency, a [CopyableThreadContextElement] - * can give coroutines "coroutine-safe" write access to that `ThreadLocal`. - * - * A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine - * will be visible to _itself_ and any child coroutine launched _after_ that write. - * - * Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen - * to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_ - * launching a child coroutine will not be visible to that child coroutine. - * - * This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and - * correctly, regardless of the coroutine's structured concurrency. - * - * This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace - * is in a coroutine: - * - * ``` - * class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement { - * companion object Key : CoroutineContext.Key - * - * override val key: CoroutineContext.Key = Key - * - * override fun updateThreadContext(context: CoroutineContext): TraceData? { - * val oldState = traceThreadLocal.get() - * traceThreadLocal.set(traceData) - * return oldState - * } - * - * override fun restoreThreadContext(context: CoroutineContext, oldState: TraceData?) { - * traceThreadLocal.set(oldState) - * } - * - * override fun copyForChild(): TraceContextElement { - * // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes - * // ThreadLocal writes between resumption of the parent coroutine and the launch of the - * // child coroutine visible to the child. - * return TraceContextElement(traceThreadLocal.get()?.copy()) - * } - * - * override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext { - * // Merge operation defines how to handle situations when both - * // the parent coroutine has an element in the context and - * // an element with the same key was also - * // explicitly passed to the child coroutine. - * // If merging does not require special behavior, - * // the copy of the element can be returned. - * return TraceContextElement(traceThreadLocal.get()?.copy()) - * } - * } - * ``` - * - * A coroutine using this mechanism can safely call Java code that assumes the corresponding thread local element's - * value is installed into the target thread local. - * - * ### Reentrancy and thread-safety - * - * Correct implementations of this interface must expect that calls to [restoreThreadContext] - * may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations. - * - * Even though an element is copied for each child coroutine, an implementation should be able to handle the following - * interleaving when a coroutine with the corresponding element is launched on a multithreaded dispatcher: - * - * ``` - * coroutine.updateThreadContext() // Thread #1 - * ... coroutine body ... - * // suspension + immediate dispatch happen here - * coroutine.updateThreadContext() // Thread #2, coroutine is already resumed - * // ... coroutine body after suspension point on Thread #2 ... - * coroutine.restoreThreadContext() // Thread #1, is invoked late because Thread #1 is slow - * coroutine.restoreThreadContext() // Thread #2, may happen in parallel with the previous restore - * ``` - * - * All implementations of [CopyableThreadContextElement] should be thread-safe and guard their internal mutable state - * within an element accordingly. - */ -@DelicateCoroutinesApi -@ExperimentalCoroutinesApi -public interface CopyableThreadContextElement : ThreadContextElement { - - /** - * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child - * coroutine's context that is under construction if the added context does not contain an element with the same [key]. - * - * This function is called on the element each time a new coroutine inherits a context containing it, - * and the returned value is folded into the context given to the child. - * - * Since this method is called whenever a new coroutine is launched in a context containing this - * [CopyableThreadContextElement], implementations are performance-sensitive. - */ - public fun copyForChild(): CopyableThreadContextElement - - /** - * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child - * coroutine's context that is under construction if the added context does contain an element with the same [key]. - * - * This method is invoked on the original element, accepting as the parameter - * the element that is supposed to overwrite it. - */ - public fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext -} - /** * Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement] * maintains the given [value] of the given [ThreadLocal] for a coroutine regardless of the actual thread it is resumed on. diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 8f21b13c25..205583c624 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -3,94 +3,6 @@ package kotlinx.coroutines.internal import kotlinx.coroutines.* import kotlin.coroutines.* -@JvmField -internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") - -// Used when there are >= 2 active elements in the context -@Suppress("UNCHECKED_CAST") -private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { - private val values = arrayOfNulls(n) - private val elements = arrayOfNulls>(n) - private var i = 0 - - fun append(element: ThreadContextElement<*>, value: Any?) { - values[i] = value - elements[i++] = element as ThreadContextElement - } - - fun restore(context: CoroutineContext) { - for (i in elements.indices.reversed()) { - elements[i]!!.restoreThreadContext(context, values[i]) - } - } -} - -// Counts ThreadContextElements in the context -// Any? here is Int | ThreadContextElement (when count is one) -private val countAll = - fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { - if (element is ThreadContextElement<*>) { - val inCount = countOrElement as? Int ?: 1 - return if (inCount == 0) element else inCount + 1 - } - return countOrElement - } - -// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one -private val findOne = - fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { - if (found != null) return found - return element as? ThreadContextElement<*> - } - -// Updates state for ThreadContextElements in the context using the given ThreadState -private val updateState = - fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { - if (element is ThreadContextElement<*>) { - state.append(element, element.updateThreadContext(state.context)) - } - return state - } - -internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! - -// countOrElement is pre-cached in dispatched continuation -// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements -internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { - @Suppress("NAME_SHADOWING") - val countOrElement = countOrElement ?: threadContextElements(context) - @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") - return when { - countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements - // ^^^ identity comparison for speed, we know zero always has the same identity - countOrElement is Int -> { - // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values - context.fold(ThreadState(context, countOrElement), updateState) - } - else -> { - // fast path for one ThreadContextElement (no allocations, no additional context scan) - @Suppress("UNCHECKED_CAST") - val element = countOrElement as ThreadContextElement - element.updateThreadContext(context) - } - } -} - -internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { - when { - oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements - oldState is ThreadState -> { - // slow path with multiple stored ThreadContextElements - oldState.restore(context) - } - else -> { - // fast path for one ThreadContextElement, but need to find it - @Suppress("UNCHECKED_CAST") - val element = context.fold(null, findOne) as ThreadContextElement - element.restoreThreadContext(context, oldState) - } - } -} // top-level data class for a nicer out-of-the-box toString representation and class name @PublishedApi diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt index 54e88677e1..afd27e674b 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt @@ -37,305 +37,4 @@ class ThreadContextElementTest : TestBase() { assertNull(myThreadLocal.get()) } - @Test - fun testUndispatched() = runTest { - val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! - val data = MyData() - val element = MyElement(data) - val job = GlobalScope.launch( - context = Dispatchers.Default + exceptionHandler + element, - start = CoroutineStart.UNDISPATCHED - ) { - assertSame(data, myThreadLocal.get()) - yield() - assertSame(data, myThreadLocal.get()) - } - assertNull(myThreadLocal.get()) - job.join() - assertNull(myThreadLocal.get()) - } - - @Test - fun testWithContext() = runTest { - expect(1) - newSingleThreadContext("withContext").use { - val data = MyData() - GlobalScope.async(Dispatchers.Default + MyElement(data)) { - assertSame(data, myThreadLocal.get()) - expect(2) - - val newData = MyData() - GlobalScope.async(it + MyElement(newData)) { - assertSame(newData, myThreadLocal.get()) - expect(3) - }.await() - - withContext(it + MyElement(newData)) { - assertSame(newData, myThreadLocal.get()) - expect(4) - } - - GlobalScope.async(it) { - assertNull(myThreadLocal.get()) - expect(5) - }.await() - - expect(6) - }.await() - } - - finish(7) - } - - @Test - fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest { - var parentElement: MyElement? = null - var inheritedElement: MyElement? = null - - newSingleThreadContext("withContext").use { - withContext(it + MyElement(MyData())) { - parentElement = coroutineContext[MyElement.Key] - launch { - inheritedElement = coroutineContext[MyElement.Key] - } - } - } - - assertSame(inheritedElement, parentElement, - "Inner and outer coroutines did not have the same object reference to a" + - " ThreadContextElement that did not override `copyForChildCoroutine()`") - } - - @Test - fun testCopyableElementCopiedOnLaunch() = runTest { - var parentElement: CopyForChildCoroutineElement? = null - var inheritedElement: CopyForChildCoroutineElement? = null - - newSingleThreadContext("withContext").use { - withContext(it + CopyForChildCoroutineElement(MyData())) { - parentElement = coroutineContext[CopyForChildCoroutineElement.Key] - launch { - inheritedElement = coroutineContext[CopyForChildCoroutineElement.Key] - } - } - } - - assertNotSame(inheritedElement, parentElement, - "Inner coroutine did not copy its copyable ThreadContextElement.") - } - - @Test - fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest { - newFixedThreadPoolContext(nThreads = 4, name = "withContext").use { - withContext(it + CopyForChildCoroutineElement(MyData())) { - val forBlockData = MyData() - myThreadLocal.setForBlock(forBlockData) { - assertSame(myThreadLocal.get(), forBlockData) - launch { - assertSame(myThreadLocal.get(), forBlockData) - } - launch { - assertSame(myThreadLocal.get(), forBlockData) - // Modify value in child coroutine. Writes to the ThreadLocal and - // the (copied) ThreadLocalElement's memory are not visible to peer or - // ancestor coroutines, so this write is both threadsafe and coroutinesafe. - val innerCoroutineData = MyData() - myThreadLocal.setForBlock(innerCoroutineData) { - assertSame(myThreadLocal.get(), innerCoroutineData) - } - assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored. - } - launch { - val innerCoroutineData = MyData() - myThreadLocal.setForBlock(innerCoroutineData) { - assertSame(myThreadLocal.get(), innerCoroutineData) - } - assertSame(myThreadLocal.get(), forBlockData) - } - } - assertNull(myThreadLocal.get()) // Asserts value was restored to its origin - } - } - } - - class JobCaptor(val capturees: MutableList = CopyOnWriteArrayList()) : ThreadContextElement { - - companion object Key : CoroutineContext.Key - - override val key: CoroutineContext.Key<*> get() = Key - - override fun updateThreadContext(context: CoroutineContext) { - capturees.add("Update: ${context.job}") - } - - override fun restoreThreadContext(context: CoroutineContext, oldState: Unit) { - capturees.add("Restore: ${context.job}") - } - } - - /** - * For stability of the test, it is important to make sure that - * the parent job actually suspends when calling - * `withContext(dispatcher2 + CoroutineName("dispatched"))`. - * - * Here this requirement is fulfilled by forcing execution on a single thread. - * However, dispatching is performed with two non-equal dispatchers to force dispatching. - * - * Suspend of the parent coroutine [kotlinx.coroutines.DispatchedCoroutine.trySuspend] is out of the control of the test, - * while being executed concurrently with resume of the child coroutine [kotlinx.coroutines.DispatchedCoroutine.tryResume]. - */ - @Test - fun testWithContextJobAccess() = runTest { - val executor = Executors.newSingleThreadExecutor() - // Emulate non-equal dispatchers - val executor1 = object : ExecutorService by executor {} - val executor2 = object : ExecutorService by executor {} - val dispatcher1 = executor1.asCoroutineDispatcher() - val dispatcher2 = executor2.asCoroutineDispatcher() - val captor = JobCaptor() - val manuallyCaptured = mutableListOf() - - fun registerUpdate(job: Job?) = manuallyCaptured.add("Update: $job") - fun registerRestore(job: Job?) = manuallyCaptured.add("Restore: $job") - - var rootJob: Job? = null - runBlocking(captor + dispatcher1) { - rootJob = coroutineContext.job - registerUpdate(rootJob) - var undispatchedJob: Job? = null - withContext(CoroutineName("undispatched")) { - undispatchedJob = coroutineContext.job - registerUpdate(undispatchedJob) - // These 2 restores and the corresponding next 2 updates happen only if the following `withContext` - // call actually suspends. - registerRestore(undispatchedJob) - registerRestore(rootJob) - // Without forcing of single backing thread the code inside `withContext` - // may already complete at the moment when the parent coroutine decides - // whether it needs to suspend or not. - var dispatchedJob: Job? = null - withContext(dispatcher2 + CoroutineName("dispatched")) { - dispatchedJob = coroutineContext.job - registerUpdate(dispatchedJob) - } - registerRestore(dispatchedJob) - // Context restored, captured again - registerUpdate(undispatchedJob) - } - registerRestore(undispatchedJob) - // Context restored, captured again - registerUpdate(rootJob) - } - registerRestore(rootJob) - - // Restores may be called concurrently to the update calls in other threads, so their order is not checked. - val expected = manuallyCaptured.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") - val actual = captor.capturees.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") - assertEquals(expected, actual) - executor.shutdownNow() - } - - @Test - fun testThreadLocalFlowOn() = runTest { - val myData = MyData() - myThreadLocal.set(myData) - expect(1) - flow { - assertEquals(myData, myThreadLocal.get()) - emit(1) - } - .flowOn(myThreadLocal.asContextElement() + Dispatchers.Default) - .single() - myThreadLocal.set(null) - finish(2) - } -} - -class MyData - -// declare thread local variable holding MyData -private val myThreadLocal = ThreadLocal() - -// declare context element holding MyData -class MyElement(val data: MyData) : ThreadContextElement { - // declare companion object for a key of this element in coroutine context - companion object Key : CoroutineContext.Key - - // provide the key of the corresponding context element - override val key: CoroutineContext.Key - get() = Key - - // this is invoked before coroutine is resumed on current thread - override fun updateThreadContext(context: CoroutineContext): MyData? { - val oldState = myThreadLocal.get() - myThreadLocal.set(data) - return oldState - } - - // this is invoked after coroutine has suspended on current thread - override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { - myThreadLocal.set(oldState) - } } - -/** - * A [ThreadContextElement] that implements copy semantics in [copyForChild]. - */ -class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement { - companion object Key : CoroutineContext.Key - - override val key: CoroutineContext.Key - get() = Key - - override fun updateThreadContext(context: CoroutineContext): MyData? { - val oldState = myThreadLocal.get() - myThreadLocal.set(data) - return oldState - } - - override fun mergeForChild(overwritingElement: CoroutineContext.Element): CopyForChildCoroutineElement { - TODO("Not used in tests") - } - - override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { - myThreadLocal.set(oldState) - } - - /** - * At coroutine launch time, the _current value of the ThreadLocal_ is inherited by the new - * child coroutine, and that value is copied to a new, unique, ThreadContextElement memory - * reference for the child coroutine to use uniquely. - * - * n.b. the value copied to the child must be the __current value of the ThreadLocal__ and not - * the value initially passed to the ThreadContextElement in order to reflect writes made to the - * ThreadLocal between coroutine resumption and the child coroutine launch point. Those writes - * will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the - * thread and calls [restoreThreadContext]. - */ - override fun copyForChild(): CopyForChildCoroutineElement { - return CopyForChildCoroutineElement(myThreadLocal.get()) - } -} - - -/** - * Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block]. - * - * When a [CopyForChildCoroutineElement] for `this` [ThreadLocal] is used within a - * [CoroutineContext], a ThreadLocal set this way will have the "correct" value expected lexically - * at every statement reached, whether that statement is reached immediately, across suspend and - * redispatch within one coroutine, or within a child coroutine. Writes made to the `ThreadLocal` - * by child coroutines will not be visible to the parent coroutine. Writes made to the `ThreadLocal` - * by the parent coroutine _after_ launching a child coroutine will not be visible to that child - * coroutine. - */ -private inline fun ThreadLocal.setForBlock( - value: ThreadLocalT, - crossinline block: () -> OutputT -) { - val priorValue = get() - set(value) - block() - set(priorValue) -} - From 0423482640c6a19ecbbee1bd0d79b2c528307673 Mon Sep 17 00:00:00 2001 From: Dmitry Khalanskiy Date: Mon, 3 Mar 2025 12:05:20 +0100 Subject: [PATCH 2/3] Make MPP ThreadContextElement compile and run --- .../api/kotlinx-coroutines-core.klib.api | 10 +++ .../common/src/Builders.common.kt | 11 +--- .../common/src/CoroutineContext.common.kt | 28 ++------- .../src/CoroutineContext.sharedWithJvm.kt | 17 ++--- .../src/internal/ThreadContext.common.kt | 14 ++--- .../common/src/internal/ThreadLocal.common.kt | 1 + .../common/test/ThreadContextElementTest.kt | 63 ++++++++++++++++--- .../test/ThreadContextMutableCopiesTest.kt | 16 +++-- .../ThreadContextElementConcurrentTest.kt | 6 +- .../jsAndWasmShared/src/CoroutineContext.kt | 21 +------ .../src/internal/ThreadContext.kt | 5 +- .../src/internal/ThreadLocal.kt | 1 + .../jvm/src/CoroutineContext.kt | 11 +++- .../jvm/src/internal/ThreadContext.kt | 3 + ...Test.kt => ThreadContextElementJvmTest.kt} | 6 +- .../native/src/CoroutineContext.kt | 21 +------ .../native/src/internal/ThreadContext.kt | 5 +- .../native/src/internal/ThreadLocal.kt | 1 + 18 files changed, 125 insertions(+), 115 deletions(-) rename kotlinx-coroutines-core/jvm/test/{ThreadContextElementTest.kt => ThreadContextElementJvmTest.kt} (85%) diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api index 373a1eee52..c5d44acf03 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.klib.api @@ -186,6 +186,16 @@ abstract interface <#A: kotlin/Any?> kotlinx.coroutines/CompletableDeferred : ko abstract fun completeExceptionally(kotlin/Throwable): kotlin/Boolean // kotlinx.coroutines/CompletableDeferred.completeExceptionally|completeExceptionally(kotlin.Throwable){}[0] } +abstract interface <#A: kotlin/Any?> kotlinx.coroutines/CopyableThreadContextElement : kotlinx.coroutines/ThreadContextElement<#A> { // kotlinx.coroutines/CopyableThreadContextElement|null[0] + abstract fun copyForChild(): kotlinx.coroutines/CopyableThreadContextElement<#A> // kotlinx.coroutines/CopyableThreadContextElement.copyForChild|copyForChild(){}[0] + abstract fun mergeForChild(kotlin.coroutines/CoroutineContext.Element): kotlin.coroutines/CoroutineContext // kotlinx.coroutines/CopyableThreadContextElement.mergeForChild|mergeForChild(kotlin.coroutines.CoroutineContext.Element){}[0] +} + +abstract interface <#A: kotlin/Any?> kotlinx.coroutines/ThreadContextElement : kotlin.coroutines/CoroutineContext.Element { // kotlinx.coroutines/ThreadContextElement|null[0] + abstract fun restoreThreadContext(kotlin.coroutines/CoroutineContext, #A) // kotlinx.coroutines/ThreadContextElement.restoreThreadContext|restoreThreadContext(kotlin.coroutines.CoroutineContext;1:0){}[0] + abstract fun updateThreadContext(kotlin.coroutines/CoroutineContext): #A // kotlinx.coroutines/ThreadContextElement.updateThreadContext|updateThreadContext(kotlin.coroutines.CoroutineContext){}[0] +} + abstract interface <#A: kotlin/Throwable & kotlinx.coroutines/CopyableThrowable<#A>> kotlinx.coroutines/CopyableThrowable { // kotlinx.coroutines/CopyableThrowable|null[0] abstract fun createCopy(): #A? // kotlinx.coroutines/CopyableThrowable.createCopy|createCopy(){}[0] } diff --git a/kotlinx-coroutines-core/common/src/Builders.common.kt b/kotlinx-coroutines-core/common/src/Builders.common.kt index 165a12c705..d68e0f3603 100644 --- a/kotlinx-coroutines-core/common/src/Builders.common.kt +++ b/kotlinx-coroutines-core/common/src/Builders.common.kt @@ -9,6 +9,7 @@ import kotlinx.atomicfu.* import kotlinx.coroutines.internal.* import kotlinx.coroutines.intrinsics.* import kotlinx.coroutines.selects.* +import kotlin.concurrent.Volatile import kotlin.contracts.* import kotlin.coroutines.* import kotlin.coroutines.intrinsics.* @@ -206,13 +207,7 @@ private class LazyStandaloneCoroutine( } // Used by withContext when context changes, but dispatcher stays the same -internal expect class UndispatchedCoroutine( - context: CoroutineContext, - uCont: Continuation -) : ScopeCoroutine - -// Used by withContext when context changes, but dispatcher stays the same -internal actual class UndispatchedCoroutineactual constructor ( +internal class UndispatchedCoroutine( context: CoroutineContext, uCont: Continuation ) : ScopeCoroutine(if (context[UndispatchedMarker] == null) context + UndispatchedMarker else context, uCont) { @@ -249,7 +244,7 @@ internal actual class UndispatchedCoroutineactual constructor ( * - It's never accessed when we are sure there are no thread context elements * - It's cleaned up via [ThreadLocal.remove] as soon as the coroutine is suspended or finished. */ - private val threadStateToRecover = ThreadLocal>() + private val threadStateToRecover = commonThreadLocal?>(Symbol("UndispatchedCoroutine")) /* * Indicates that a coroutine has at least one thread context element associated with it diff --git a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt index bc932d32eb..c25a37cc4b 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineContext.common.kt @@ -3,34 +3,18 @@ package kotlinx.coroutines import kotlinx.coroutines.internal.* import kotlin.coroutines.* -/** - * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or - * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on) - * and copyable-thread-local facilities on JVM. - */ -public expect fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext - -/** - * Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext]. - * @suppress - */ -@InternalCoroutinesApi -public expect fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext - @PublishedApi // to have unmangled name when using from other modules via suppress @Suppress("PropertyName") internal expect val DefaultDelay: Delay -// countOrElement -- pre-cached value for ThreadContext.kt -internal expect inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T -internal expect inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T internal expect fun Continuation<*>.toDebugString(): String internal expect val CoroutineContext.coroutineName: String? +internal expect fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext /** * Executes a block using a given coroutine context. */ -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { +internal inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T { val oldValue = updateThreadContext(context, countOrElement) try { return block() @@ -42,7 +26,7 @@ internal actual inline fun withCoroutineContext(context: CoroutineContext, c /** * Executes a block using a context of a given continuation. */ -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { +internal inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T { val context = continuation.context val oldValue = updateThreadContext(context, countOrElement) val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) { @@ -60,7 +44,7 @@ internal actual inline fun withContinuationContext(continuation: Continuatio } } -internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { +private fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? { if (this !is CoroutineStackFrame) return null /* * Fast-path to detect whether we have undispatched coroutine at all in our stack. @@ -81,7 +65,7 @@ internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineCont return completion } -internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { +private tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? { // Find direct completion of this continuation val completion: CoroutineStackFrame = when (this) { is DispatchedCoroutine<*> -> return null @@ -95,7 +79,7 @@ internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedC * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack. * Used as a performance optimization to avoid stack walking where it is not necessary. */ -private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { +internal object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key { override val key: CoroutineContext.Key<*> get() = this } diff --git a/kotlinx-coroutines-core/common/src/CoroutineContext.sharedWithJvm.kt b/kotlinx-coroutines-core/common/src/CoroutineContext.sharedWithJvm.kt index d293b223cf..77c560349d 100644 --- a/kotlinx-coroutines-core/common/src/CoroutineContext.sharedWithJvm.kt +++ b/kotlinx-coroutines-core/common/src/CoroutineContext.sharedWithJvm.kt @@ -1,19 +1,22 @@ +// This file should be a part of `CoroutineContext.common.kt`, but adding `JvmName` to that fails: KT-75248 +@file:JvmName("CoroutineContextKt") +@file:JvmMultifileClass package kotlinx.coroutines +import kotlin.coroutines.ContinuationInterceptor import kotlin.coroutines.CoroutineContext import kotlin.coroutines.EmptyCoroutineContext - +import kotlin.jvm.JvmMultifileClass +import kotlin.jvm.JvmName /** * Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or - * [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on) - * and copyable-thread-local facilities on JVM. - * See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM. + * [ContinuationInterceptor] is specified and */ @ExperimentalCoroutinesApi -public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext { +public fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext { val combined = foldCopies(coroutineContext, context, true) - val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined + val debug = wrapContextWithDebug(combined) return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null) debug + Dispatchers.Default else debug } @@ -23,7 +26,7 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): * @suppress */ @InternalCoroutinesApi -public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { +public fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { /* * Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements) * contains copyable elements. diff --git a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt index 2dc52ec7e1..184245c66e 100644 --- a/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt +++ b/kotlinx-coroutines-core/common/src/internal/ThreadContext.common.kt @@ -4,8 +4,6 @@ import kotlinx.coroutines.ThreadContextElement import kotlin.coroutines.* import kotlin.jvm.JvmField -internal expect fun threadContextElements(context: CoroutineContext): Any - @JvmField internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS") @@ -29,9 +27,9 @@ private class ThreadState(@JvmField val context: CoroutineContext, n: Int) { } // Counts ThreadContextElements in the context -// Any? here is Int | ThreadContextElement (when count is one) +// Any here is Int | ThreadContextElement (when count is one) private val countAll = - fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { + fun (countOrElement: Any, element: CoroutineContext.Element): Any { if (element is ThreadContextElement<*>) { val inCount = countOrElement as? Int ?: 1 return if (inCount == 0) element else inCount + 1 @@ -55,17 +53,15 @@ private val updateState = return state } -internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!! +internal expect inline fun isZeroCount(countOrElement: Any?): Boolean // countOrElement is pre-cached in dispatched continuation // returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? { @Suppress("NAME_SHADOWING") val countOrElement = countOrElement ?: threadContextElements(context) - @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") return when { - countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements - // ^^^ identity comparison for speed, we know zero always has the same identity + isZeroCount(countOrElement) -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements countOrElement is Int -> { // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values context.fold(ThreadState(context, countOrElement), updateState) @@ -94,3 +90,5 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { } } } + +internal fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll) diff --git a/kotlinx-coroutines-core/common/src/internal/ThreadLocal.common.kt b/kotlinx-coroutines-core/common/src/internal/ThreadLocal.common.kt index 15622597d5..93f79825cb 100644 --- a/kotlinx-coroutines-core/common/src/internal/ThreadLocal.common.kt +++ b/kotlinx-coroutines-core/common/src/internal/ThreadLocal.common.kt @@ -3,6 +3,7 @@ package kotlinx.coroutines.internal internal expect class CommonThreadLocal { fun get(): T fun set(value: T) + fun remove() } /** diff --git a/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt index 75f92e656c..af4edf9f64 100644 --- a/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt @@ -1,8 +1,13 @@ package kotlinx.coroutines +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.loop import kotlinx.coroutines.testing.* import kotlin.test.* import kotlinx.coroutines.flow.* +import kotlinx.coroutines.internal.CommonThreadLocal +import kotlinx.coroutines.internal.Symbol +import kotlinx.coroutines.internal.commonThreadLocal import kotlin.coroutines.* class ThreadContextElementTest: TestBase() { @@ -38,12 +43,10 @@ class ThreadContextElementTest: TestBase() { */ @Test fun testWithContextJobAccess() = runTest { - val executor = Executors.newSingleThreadExecutor() // Emulate non-equal dispatchers - val executor1 = object : ExecutorService by executor {} - val executor2 = object : ExecutorService by executor {} - val dispatcher1 = executor1.asCoroutineDispatcher() - val dispatcher2 = executor2.asCoroutineDispatcher() + val dispatcher = Dispatchers.Default.limitedParallelism(1) + val dispatcher1 = dispatcher.limitedParallelism(1, "dispatcher1") + val dispatcher2 = dispatcher.limitedParallelism(1, "dispatcher2") val captor = JobCaptor() val manuallyCaptured = mutableListOf() @@ -51,7 +54,7 @@ class ThreadContextElementTest: TestBase() { fun registerRestore(job: Job?) = manuallyCaptured.add("Restore: $job") var rootJob: Job? = null - runBlocking(captor + dispatcher1) { + withContext(captor + dispatcher1) { rootJob = coroutineContext.job registerUpdate(rootJob) var undispatchedJob: Job? = null @@ -84,7 +87,6 @@ class ThreadContextElementTest: TestBase() { val expected = manuallyCaptured.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") val actual = captor.capturees.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") assertEquals(expected, actual) - executor.shutdownNow() } @Test @@ -96,7 +98,7 @@ class ThreadContextElementTest: TestBase() { assertEquals(myData, myThreadLocal.get()) emit(1) } - .flowOn(myThreadLocal.asContextElement() + Dispatchers.Default) + .flowOn(myThreadLocal.asCtxElement() + Dispatchers.Default) .single() myThreadLocal.set(null) finish(2) @@ -105,7 +107,7 @@ class ThreadContextElementTest: TestBase() { class MyData -class JobCaptor(val capturees: MutableList = CopyOnWriteArrayList()) : ThreadContextElement { +private class JobCaptor(val capturees: CopyOnWriteList = CopyOnWriteList()) : ThreadContextElement { companion object Key : CoroutineContext.Key @@ -121,7 +123,7 @@ class JobCaptor(val capturees: MutableList = CopyOnWriteArrayList()) : T } // declare thread local variable holding MyData -private val myThreadLocal = ThreadLocal() +internal val myThreadLocal = commonThreadLocal(Symbol("myElement")) // declare context element holding MyData class MyElement(val data: MyData) : ThreadContextElement { @@ -144,3 +146,44 @@ class MyElement(val data: MyData) : ThreadContextElement { myThreadLocal.set(oldState) } } + + +private class CommonThreadLocalContextElement( + private val threadLocal: CommonThreadLocal, + private val value: T = threadLocal.get() +): ThreadContextElement, CoroutineContext.Key> { + // provide the key of the corresponding context element + override val key: CoroutineContext.Key> + get() = this + + // this is invoked before coroutine is resumed on current thread + override fun updateThreadContext(context: CoroutineContext): T { + val oldState = threadLocal.get() + threadLocal.set(value) + return oldState + } + + // this is invoked after coroutine has suspended on current thread + override fun restoreThreadContext(context: CoroutineContext, oldState: T) { + threadLocal.set(oldState) + } +} + +// overload resolution issues if this is called `asContextElement` +internal fun CommonThreadLocal.asCtxElement(value: T = get()): ThreadContextElement = + CommonThreadLocalContextElement(this, value) + +private class CopyOnWriteList private constructor(list: List) { + private val field = atomic(list) + + constructor() : this(emptyList()) + + fun add(value: T) { + field.loop { current -> + val new = current + value + if (field.compareAndSet(current, new)) return + } + } + + fun filter(predicate: (T) -> Boolean): List = field.value.filter(predicate) +} diff --git a/kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt b/kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt index 0b174ec539..cd362e81f5 100644 --- a/kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt +++ b/kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt @@ -2,12 +2,16 @@ package kotlinx.coroutines import kotlinx.coroutines.testing.* import kotlinx.coroutines.flow.* +import kotlinx.coroutines.internal.Symbol +import kotlinx.coroutines.internal.commonThreadLocal import kotlin.coroutines.* import kotlin.test.* class ThreadContextMutableCopiesTest : TestBase() { companion object { - val threadLocalData: ThreadLocal> = ThreadLocal.withInitial { ArrayList() } + internal val threadLocalData = commonThreadLocal>(Symbol("ThreadLocalData")).also { + it.set(mutableListOf()) + } } class MyMutableElement( @@ -42,7 +46,7 @@ class ThreadContextMutableCopiesTest : TestBase() { @Test fun testDataIsCopied() = runTest { val root = MyMutableElement(ArrayList()) - runBlocking(root) { + launch(root) { val data = threadLocalData.get() expect(1) launch(root) { @@ -56,7 +60,7 @@ class ThreadContextMutableCopiesTest : TestBase() { @Test fun testDataIsNotOverwritten() = runTest { val root = MyMutableElement(ArrayList()) - runBlocking(root) { + withContext(root) { expect(1) val originalData = threadLocalData.get() threadLocalData.get().add("X") @@ -75,7 +79,7 @@ class ThreadContextMutableCopiesTest : TestBase() { @Test fun testDataIsMerged() = runTest { val root = MyMutableElement(ArrayList()) - runBlocking(root) { + withContext(root) { expect(1) val originalData = threadLocalData.get() threadLocalData.get().add("X") @@ -94,7 +98,7 @@ class ThreadContextMutableCopiesTest : TestBase() { @Test fun testDataIsNotOverwrittenWithContext() = runTest { val root = MyMutableElement(ArrayList()) - runBlocking(root) { + withContext(root) { val originalData = threadLocalData.get() threadLocalData.get().add("X") expect(1) @@ -114,7 +118,7 @@ class ThreadContextMutableCopiesTest : TestBase() { fun testDataIsCopiedForRunBlocking() = runTest { val root = MyMutableElement(ArrayList()) val originalData = root.mutableData - runBlocking(root) { + withContext(root) { assertNotSame(originalData, threadLocalData.get()) } } diff --git a/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt b/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt index 4f66fef813..14931459f5 100644 --- a/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt +++ b/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt @@ -1,6 +1,8 @@ package kotlinx.coroutines +import kotlinx.coroutines.internal.CommonThreadLocal import kotlinx.coroutines.testing.* +import kotlin.coroutines.CoroutineContext import kotlin.test.* class ThreadContextElementConcurrentTest: TestBase() { @@ -113,7 +115,7 @@ class ThreadContextElementConcurrentTest: TestBase() { /** * A [ThreadContextElement] that implements copy semantics in [copyForChild]. */ -class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement { +private class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement { companion object Key : CoroutineContext.Key override val key: CoroutineContext.Key @@ -160,7 +162,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle * by the parent coroutine _after_ launching a child coroutine will not be visible to that child * coroutine. */ -private inline fun ThreadLocal.setForBlock( +private inline fun CommonThreadLocal.setForBlock( value: ThreadLocalT, crossinline block: () -> OutputT ) { diff --git a/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt b/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt index 82862ac8aa..70ffea4649 100644 --- a/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jsAndWasmShared/src/CoroutineContext.kt @@ -1,31 +1,12 @@ package kotlinx.coroutines -import kotlinx.coroutines.internal.ScopeCoroutine import kotlin.coroutines.* @PublishedApi // Used from kotlinx-coroutines-test via suppress, not part of ABI internal actual val DefaultDelay: Delay get() = Dispatchers.Default as Delay -public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext { - val combined = coroutineContext + context - return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null) - combined + Dispatchers.Default else combined -} - -public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { - return this + addedContext -} - // No debugging facilities on Wasm and JS -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on Wasm and JS - -internal actual class UndispatchedCoroutine actual constructor( - context: CoroutineContext, - uCont: Continuation -) : ScopeCoroutine(context, uCont) { - override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) -} +internal actual fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext = context diff --git a/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt index 3f56f99d6c..7fa08d66e0 100644 --- a/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadContext.kt @@ -1,5 +1,4 @@ package kotlinx.coroutines.internal -import kotlin.coroutines.* - -internal actual fun threadContextElements(context: CoroutineContext): Any = 0 +@Suppress("NOTHING_TO_INLINE") +internal actual inline fun isZeroCount(countOrElement: Any?): Boolean = countOrElement is Int && countOrElement == 0 diff --git a/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadLocal.kt b/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadLocal.kt index 94eecfa0ee..fedddf60a3 100644 --- a/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadLocal.kt +++ b/kotlinx-coroutines-core/jsAndWasmShared/src/internal/ThreadLocal.kt @@ -5,6 +5,7 @@ internal actual class CommonThreadLocal { @Suppress("UNCHECKED_CAST") actual fun get(): T = value as T actual fun set(value: T) { this.value = value } + actual fun remove() { value = null } } internal actual fun commonThreadLocal(name: Symbol): CommonThreadLocal = CommonThreadLocal() \ No newline at end of file diff --git a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt index 8fde7d8b01..389ca1f9cc 100644 --- a/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/jvm/src/CoroutineContext.kt @@ -1,8 +1,17 @@ +@file:JvmName("CoroutineContextKt") +@file:JvmMultifileClass package kotlinx.coroutines import kotlinx.coroutines.internal.* import kotlin.coroutines.* -import kotlin.coroutines.jvm.internal.CoroutineStackFrame + +/** + * Adds optional support for debugging facilities (when turned on) + * and copyable-thread-local facilities on JVM. + * See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM. + */ +internal actual fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext = + if (DEBUG) context + CoroutineId(COROUTINE_ID.incrementAndGet()) else context internal actual val CoroutineContext.coroutineName: String? get() { if (!DEBUG) return null diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 205583c624..7b1672caa1 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -3,6 +3,9 @@ package kotlinx.coroutines.internal import kotlinx.coroutines.* import kotlin.coroutines.* +// identity comparison for speed, we know zero always has the same identity +@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS", "KotlinConstantConditions", "NOTHING_TO_INLINE") +internal actual inline fun isZeroCount(countOrElement: Any?): Boolean = countOrElement === 0 // top-level data class for a nicer out-of-the-box toString representation and class name @PublishedApi diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt similarity index 85% rename from kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt rename to kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt index afd27e674b..68e2042520 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt @@ -1,15 +1,11 @@ package kotlinx.coroutines -import kotlinx.coroutines.flow.* import kotlinx.coroutines.testing.* import org.junit.Test -import java.util.concurrent.CopyOnWriteArrayList -import java.util.concurrent.ExecutorService -import java.util.concurrent.Executors import kotlin.coroutines.* import kotlin.test.* -class ThreadContextElementTest : TestBase() { +class ThreadContextElementJvmTest : TestBase() { @Test fun testExample() = runTest { diff --git a/kotlinx-coroutines-core/native/src/CoroutineContext.kt b/kotlinx-coroutines-core/native/src/CoroutineContext.kt index 3f4c8d9a01..063833a208 100644 --- a/kotlinx-coroutines-core/native/src/CoroutineContext.kt +++ b/kotlinx-coroutines-core/native/src/CoroutineContext.kt @@ -1,6 +1,5 @@ package kotlinx.coroutines -import kotlinx.coroutines.internal.* import kotlin.coroutines.* internal actual object DefaultExecutor : CoroutineDispatcher(), Delay { @@ -29,25 +28,7 @@ internal expect fun createDefaultDispatcher(): CoroutineDispatcher @PublishedApi internal actual val DefaultDelay: Delay = DefaultExecutor -public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext { - val combined = coroutineContext + context - return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null) - combined + Dispatchers.Default else combined -} - -public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext { - return this + addedContext -} - // No debugging facilities on native -internal actual inline fun withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block() -internal actual inline fun withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block() internal actual fun Continuation<*>.toDebugString(): String = toString() internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on native - -internal actual class UndispatchedCoroutine actual constructor( - context: CoroutineContext, - uCont: Continuation -) : ScopeCoroutine(context, uCont) { - override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont)) -} +internal actual fun wrapContextWithDebug(context: CoroutineContext): CoroutineContext = context diff --git a/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt index 3f56f99d6c..7fa08d66e0 100644 --- a/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/native/src/internal/ThreadContext.kt @@ -1,5 +1,4 @@ package kotlinx.coroutines.internal -import kotlin.coroutines.* - -internal actual fun threadContextElements(context: CoroutineContext): Any = 0 +@Suppress("NOTHING_TO_INLINE") +internal actual inline fun isZeroCount(countOrElement: Any?): Boolean = countOrElement is Int && countOrElement == 0 diff --git a/kotlinx-coroutines-core/native/src/internal/ThreadLocal.kt b/kotlinx-coroutines-core/native/src/internal/ThreadLocal.kt index 0c803a7e36..ca6d62de58 100644 --- a/kotlinx-coroutines-core/native/src/internal/ThreadLocal.kt +++ b/kotlinx-coroutines-core/native/src/internal/ThreadLocal.kt @@ -6,6 +6,7 @@ internal actual class CommonThreadLocal(private val name: Symbol) { @Suppress("UNCHECKED_CAST") actual fun get(): T = Storage[name] as T actual fun set(value: T) { Storage[name] = value } + actual fun remove() { Storage.remove(name) } } internal actual fun commonThreadLocal(name: Symbol): CommonThreadLocal = CommonThreadLocal(name) From 0776bb61dfee033970bbe7303f5ff6a42f09e153 Mon Sep 17 00:00:00 2001 From: Dmitry Khalanskiy Date: Mon, 3 Mar 2025 12:38:37 +0100 Subject: [PATCH 3/3] Improve the documentation and tests --- .../common/src/ThreadContextElement.common.kt | 115 +++++++++------ .../common/test/ThreadContextElementTest.kt | 138 +++++++++++------- .../test/ThreadContextMutableCopiesTest.kt | 51 +++++-- .../ThreadContextElementConcurrentTest.kt | 69 ++++----- ...hreadContextMutableCopiesConcurrentTest.kt | 15 ++ .../test/FailingCoroutinesMachineryTest.kt | 4 +- .../jvm/test/ThreadContextElementJvmTest.kt | 36 ----- .../jvm/test/ThreadLocalTest.kt | 10 ++ 8 files changed, 255 insertions(+), 183 deletions(-) create mode 100644 kotlinx-coroutines-core/concurrent/test/ThreadContextMutableCopiesConcurrentTest.kt delete mode 100644 kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt diff --git a/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt index 0fe5bc9351..076ce48848 100644 --- a/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt +++ b/kotlinx-coroutines-core/common/src/ThreadContextElement.common.kt @@ -6,14 +6,20 @@ import kotlin.coroutines.* * Defines elements in a [CoroutineContext] that are installed into the thread context * every time the coroutine with this element in the context is resumed on a thread. * + * In this context, by a "thread" we mean an environment where coroutines are executed in parallel to coroutines + * other threads. + * On JVM and Native, this is the same as an operating system thread. + * On JS, Wasm/JS, and Wasm/WASI, because coroutines can not actually execute in parallel, + * we say that there is a single thread running all coroutines. + * * Implementations of this interface define a type [S] of the thread-local state that they need to store - * upon resuming a coroutine and restore later upon suspension. - * The infrastructure provides the corresponding storage. + * when the coroutine is resumed and restore later on when it suspends. + * The coroutines infrastructure provides the corresponding storage. * * Example usage looks like this: * * ``` - * // Appends "name" of a coroutine to a current thread name when coroutine is executed + * // Appends "name" of a coroutine to the current thread name when a coroutine is executed * class CoroutineName(val name: String) : ThreadContextElement { * // declare companion object for a key of this element in coroutine context * companion object Key : CoroutineContext.Key @@ -22,14 +28,14 @@ import kotlin.coroutines.* * override val key: CoroutineContext.Key * get() = Key * - * // this is invoked before coroutine is resumed on current thread + * // this is invoked before a coroutine is resumed on the current thread * override fun updateThreadContext(context: CoroutineContext): String { * val previousName = Thread.currentThread().name * Thread.currentThread().name = "$previousName # $name" * return previousName * } * - * // this is invoked after coroutine has suspended on current thread + * // this is invoked after a coroutine has suspended on the current thread * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { * Thread.currentThread().name = oldState * } @@ -39,13 +45,13 @@ import kotlin.coroutines.* * launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... } * ``` * - * Every time this coroutine is resumed on a thread, UI thread name is updated to - * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when + * Every time this coroutine is resumed on a thread, the name of the thread backing [Dispatchers.Main] is updated to + * "UI thread original name # Progress bar coroutine", and the thread name is restored to the original one when * this coroutine suspends. * - * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. + * On JVM, to use a `ThreadLocal` variable within the coroutine, use the `ThreadLocal.asContextElement` function. * - * ### Reentrancy and thread-safety + * ### Reentrancy and thread safety * * Correct implementations of this interface must expect that calls to [restoreThreadContext] * may happen in parallel to the subsequent [updateThreadContext] and [restoreThreadContext] operations. @@ -56,50 +62,65 @@ import kotlin.coroutines.* */ public interface ThreadContextElement : CoroutineContext.Element { /** - * Updates context of the current thread. - * This function is invoked before the coroutine in the specified [context] is resumed in the current thread - * when the context of the coroutine this element. - * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. + * Updates the context of the current thread. + * + * This function is invoked before the coroutine in the specified [context] is started or resumed + * in the current thread when this element is present in the context of the coroutine. + * The result of this function is the old value of the thread-local state + * that will be passed to [restoreThreadContext] when the coroutine eventually suspends or completes. + * This method should handle its own exceptions and not rethrow them. + * Thrown exceptions will leave the coroutine whose context is updated in an undefined state + * and may crash the application. * - * @param context the coroutine context. + * @param context the context of the coroutine that's being started or resumed. */ public fun updateThreadContext(context: CoroutineContext): S /** - * Restores context of the current thread. - * This function is invoked after the coroutine in the specified [context] is suspended in the current thread - * if [updateThreadContext] was previously invoked on resume of this coroutine. - * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should - * be restored in the thread-local state by this function. - * This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which - * context is updated in an undefined state and may crash an application. + * Restores the context of the current thread. * - * @param context the coroutine context. - * @param oldState the value returned by the previous invocation of [updateThreadContext]. + * This function is invoked after the coroutine in the specified [context] has suspended or finished + * in the current thread if [updateThreadContext] was previously invoked when this coroutine was started or resumed. + * [oldState] is the result of the preceding invocation of [updateThreadContext], + * and this value should be restored in the thread-local state by this function. + * This method should handle its own exceptions and not rethrow them. + * Thrown exceptions will leave the coroutine whose context is updated in an undefined state + * and may crash the application. + * + * @param context the context of the coroutine that has suspended or finished. + * @param oldState the value returned by the preceding invocation of [updateThreadContext]. */ public fun restoreThreadContext(context: CoroutineContext, oldState: S) } /** - * A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it. - * - * When an API uses a _mutable_ [ThreadLocal] for consistency, a [CopyableThreadContextElement] - * can give coroutines "coroutine-safe" write access to that `ThreadLocal`. - * - * A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine - * will be visible to _itself_ and any child coroutine launched _after_ that write. - * - * Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen - * to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_ - * launching a child coroutine will not be visible to that child coroutine. - * - * This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and + * A [ThreadContextElement] that is copied whenever a child coroutine inherits a context containing it. + * + * [ThreadContextElement] can be used to ensure that when several coroutines share the same thread, + * they can each have their personal (though immutable) thread-local state without affecting the other coroutines. + * Often, however, it is desirable to propagate the thread-local state across coroutine suspensions + * and to child coroutines. + * A [CopyableThreadContextElement] is an instrument for implementing exactly this kind of + * hierarchical mutable thread-local state. + * + * A change made to a thread-local value with a matching [CopyableThreadContextElement] by a coroutine + * will be visible to _itself_ (even after the coroutine suspends and subsequently resumes) + * and any child coroutine launched _after_ that write. + * Changes introduced to the thread-local value by the parent coroutine _after_ launching a child coroutine + * will not be visible to that child coroutine. + * Changes will not be visible to the parent coroutine, peer coroutines, or coroutines that also have + * this [CopyableThreadContextElement] in their context and simply happen to use the same thread. + * + * This can be used to allow a coroutine to use a mutable-thread-local-value-based API transparently and * correctly, regardless of the coroutine's structured concurrency. * - * This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace - * is in a coroutine: + * The changes *may* be visible to unrelated coroutines that are launched on the same thread if those coroutines + * do not have a [CopyableThreadContextElement] with the same key in their context. + * Because of this, it is an error to access a thread-local value from a coroutine without the corresponding + * [CopyableThreadContextElement] when other coroutines may have modified it. + * + * This example adapts thread-local-value-based method tracing to follow coroutine switches and child coroutine creation. + * when the tracing happens inside a coroutine: * * ``` * class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement { @@ -118,14 +139,14 @@ public interface ThreadContextElement : CoroutineContext.Element { * } * * override fun copyForChild(): TraceContextElement { - * // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes - * // ThreadLocal writes between resumption of the parent coroutine and the launch of the + * // Copy from the ThreadLocal source of truth at the child coroutine launch time. This makes + * // ThreadLocal writes between the resumption of the parent coroutine and the launch of the * // child coroutine visible to the child. * return TraceContextElement(traceThreadLocal.get()?.copy()) * } * * override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext { - * // Merge operation defines how to handle situations when both + * // The merge operation defines how to handle situations when both * // the parent coroutine has an element in the context and * // an element with the same key was also * // explicitly passed to the child coroutine. @@ -136,8 +157,8 @@ public interface ThreadContextElement : CoroutineContext.Element { * } * ``` * - * A coroutine using this mechanism can safely call Java code that assumes the corresponding thread local element's - * value is installed into the target thread local. + * A coroutine using this mechanism can safely call coroutine-oblivious code that assumes + * a specific thread local element's value is installed into the target thread local. * * ### Reentrancy and thread-safety * @@ -165,7 +186,7 @@ public interface ThreadContextElement : CoroutineContext.Element { public interface CopyableThreadContextElement : ThreadContextElement { /** - * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child + * Returns the [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child * coroutine's context that is under construction if the added context does not contain an element with the same [key]. * * This function is called on the element each time a new coroutine inherits a context containing it, @@ -177,7 +198,7 @@ public interface CopyableThreadContextElement : ThreadContextElement { public fun copyForChild(): CopyableThreadContextElement /** - * Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child + * Returns the [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child * coroutine's context that is under construction if the added context does contain an element with the same [key]. * * This method is invoked on the original element, accepting as the parameter diff --git a/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt index af4edf9f64..329b9cb8a8 100644 --- a/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt +++ b/kotlinx-coroutines-core/common/test/ThreadContextElementTest.kt @@ -2,34 +2,85 @@ package kotlinx.coroutines import kotlinx.atomicfu.atomic import kotlinx.atomicfu.loop +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.flow.single import kotlinx.coroutines.testing.* -import kotlin.test.* -import kotlinx.coroutines.flow.* -import kotlinx.coroutines.internal.CommonThreadLocal -import kotlinx.coroutines.internal.Symbol -import kotlinx.coroutines.internal.commonThreadLocal import kotlin.coroutines.* +import kotlin.test.* +import kotlinx.coroutines.internal.* + +class ThreadContextElementTest : TestBase() { + interface TestThreadContextElement : ThreadContextElement { + companion object Key : CoroutineContext.Key + } + + @Test + fun testUpdatesAndRestores() = runTest { + var updateCount = 0 + var restoreCount = 0 + val threadContextElement = object : TestThreadContextElement { + override fun updateThreadContext(context: CoroutineContext): Int { + updateCount++ + return 0 + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: Int) { + restoreCount++ + } -class ThreadContextElementTest: TestBase() { + override val key: CoroutineContext.Key<*> + get() = TestThreadContextElement.Key + } + launch(Dispatchers.Unconfined + threadContextElement) { + assertEquals(1, updateCount) + assertEquals(0, restoreCount) + } + assertEquals(1, updateCount) + assertEquals(1, restoreCount) + } + + @Test + fun testDispatched() = runTest { + val mainDispatcher = coroutineContext[ContinuationInterceptor]!! + val data = MyData() + val element = threadContextElementThreadLocal.asCtxElement(data) + assertNull(threadContextElementThreadLocal.get()) + val job = launch(Dispatchers.Default + element) { + assertSame(element, coroutineContext[element.key]) + assertSame(data, threadContextElementThreadLocal.get()) + withContext(mainDispatcher) { + assertSame(element, coroutineContext[element.key]) + assertSame(data, threadContextElementThreadLocal.get()) + } + assertSame(element, coroutineContext[element.key]) + assertSame(data, threadContextElementThreadLocal.get()) + } + assertNull(threadContextElementThreadLocal.get()) + job.join() + assertNull(threadContextElementThreadLocal.get()) + } @Test fun testUndispatched() = runTest { val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! val data = MyData() - val element = MyElement(data) - val job = GlobalScope.launch( + val element = threadContextElementThreadLocal.asCtxElement(data) + val job = launch( context = Dispatchers.Default + exceptionHandler + element, start = CoroutineStart.UNDISPATCHED ) { - assertSame(data, myThreadLocal.get()) + assertSame(element, coroutineContext[element.key]) + assertSame(data, threadContextElementThreadLocal.get()) yield() - assertSame(data, myThreadLocal.get()) + assertSame(element, coroutineContext[element.key]) + assertSame(data, threadContextElementThreadLocal.get()) } - assertNull(myThreadLocal.get()) + assertNull(threadContextElementThreadLocal.get()) job.join() - assertNull(myThreadLocal.get()) + assertNull(threadContextElementThreadLocal.get()) } - + /** * For stability of the test, it is important to make sure that * the parent job actually suspends when calling @@ -48,10 +99,10 @@ class ThreadContextElementTest: TestBase() { val dispatcher1 = dispatcher.limitedParallelism(1, "dispatcher1") val dispatcher2 = dispatcher.limitedParallelism(1, "dispatcher2") val captor = JobCaptor() - val manuallyCaptured = mutableListOf() + val manuallyCaptured = mutableListOf>() - fun registerUpdate(job: Job?) = manuallyCaptured.add("Update: $job") - fun registerRestore(job: Job?) = manuallyCaptured.add("Restore: $job") + fun registerUpdate(job: Job?) = manuallyCaptured.add(Event.Update(job)) + fun registerRestore(job: Job?) = manuallyCaptured.add(Event.Restore(job)) var rootJob: Job? = null withContext(captor + dispatcher1) { @@ -84,69 +135,47 @@ class ThreadContextElementTest: TestBase() { registerRestore(rootJob) // Restores may be called concurrently to the update calls in other threads, so their order is not checked. - val expected = manuallyCaptured.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") - val actual = captor.capturees.filter { it.startsWith("Update: ") }.joinToString(separator = "\n") + val expected = manuallyCaptured.mapNotNull { (it as? Event.Update)?.value }.joinToString(separator = "\n") + val actual = captor.capturees.mapNotNull { (it as? Event.Update)?.value }.joinToString(separator = "\n") assertEquals(expected, actual) } + // #3787 @Test fun testThreadLocalFlowOn() = runTest { val myData = MyData() - myThreadLocal.set(myData) + threadContextElementThreadLocal.set(myData) expect(1) flow { - assertEquals(myData, myThreadLocal.get()) + assertEquals(myData, threadContextElementThreadLocal.get()) emit(1) } - .flowOn(myThreadLocal.asCtxElement() + Dispatchers.Default) + .flowOn(threadContextElementThreadLocal.asCtxElement(threadContextElementThreadLocal.get()!!) + Dispatchers.Default) .single() - myThreadLocal.set(null) + threadContextElementThreadLocal.set(null) finish(2) } } -class MyData +internal class MyData -private class JobCaptor(val capturees: CopyOnWriteList = CopyOnWriteList()) : ThreadContextElement { +private class JobCaptor(val capturees: CopyOnWriteList> = CopyOnWriteList()) : ThreadContextElement { - companion object Key : CoroutineContext.Key + companion object Key : CoroutineContext.Key> override val key: CoroutineContext.Key<*> get() = Key override fun updateThreadContext(context: CoroutineContext) { - capturees.add("Update: ${context.job}") + capturees.add(Event.Update(context.job)) } override fun restoreThreadContext(context: CoroutineContext, oldState: Unit) { - capturees.add("Restore: ${context.job}") + capturees.add(Event.Restore(context.job)) } } // declare thread local variable holding MyData -internal val myThreadLocal = commonThreadLocal(Symbol("myElement")) - -// declare context element holding MyData -class MyElement(val data: MyData) : ThreadContextElement { - // declare companion object for a key of this element in coroutine context - companion object Key : CoroutineContext.Key - - // provide the key of the corresponding context element - override val key: CoroutineContext.Key - get() = Key - - // this is invoked before coroutine is resumed on current thread - override fun updateThreadContext(context: CoroutineContext): MyData? { - val oldState = myThreadLocal.get() - myThreadLocal.set(data) - return oldState - } - - // this is invoked after coroutine has suspended on current thread - override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { - myThreadLocal.set(oldState) - } -} - +internal val threadContextElementThreadLocal = commonThreadLocal(Symbol("ThreadContextElementTest")) private class CommonThreadLocalContextElement( private val threadLocal: CommonThreadLocal, @@ -173,6 +202,11 @@ private class CommonThreadLocalContextElement( internal fun CommonThreadLocal.asCtxElement(value: T = get()): ThreadContextElement = CommonThreadLocalContextElement(this, value) +private sealed class Event { + class Update(val value: T): Event() + class Restore(val value: T): Event() +} + private class CopyOnWriteList private constructor(list: List) { private val field = atomic(list) @@ -185,5 +219,5 @@ private class CopyOnWriteList private constructor(list: List) { } } - fun filter(predicate: (T) -> Boolean): List = field.value.filter(predicate) + fun mapNotNull(transform: (T) -> R): List = field.value.mapNotNull(transform) } diff --git a/kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt b/kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt index cd362e81f5..263ed537ad 100644 --- a/kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt +++ b/kotlinx-coroutines-core/common/test/ThreadContextMutableCopiesTest.kt @@ -1,11 +1,15 @@ package kotlinx.coroutines -import kotlinx.coroutines.testing.* -import kotlinx.coroutines.flow.* +import kotlinx.coroutines.flow.flow +import kotlinx.coroutines.flow.flowOn +import kotlinx.coroutines.flow.single import kotlinx.coroutines.internal.Symbol import kotlinx.coroutines.internal.commonThreadLocal -import kotlin.coroutines.* -import kotlin.test.* +import kotlinx.coroutines.testing.TestBase +import kotlin.coroutines.CoroutineContext +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertNotSame class ThreadContextMutableCopiesTest : TestBase() { companion object { @@ -115,22 +119,45 @@ class ThreadContextMutableCopiesTest : TestBase() { } @Test - fun testDataIsCopiedForRunBlocking() = runTest { + fun testDataIsCopiedForCoroutine() = runTest { val root = MyMutableElement(ArrayList()) val originalData = root.mutableData - withContext(root) { + expect(1) + launch(root) { assertNotSame(originalData, threadLocalData.get()) + finish(2) } } @Test - fun testDataIsCopiedForCoroutine() = runTest { + fun testDataIsNotResetOnSuspensions() = runTest { val root = MyMutableElement(ArrayList()) - val originalData = root.mutableData - expect(1) - launch(root) { - assertNotSame(originalData, threadLocalData.get()) - finish(2) + withContext(root) { + threadLocalData.get().add("X") + assertEquals(listOf("X"), threadLocalData.get()) + yield() + assertEquals(listOf("X"), threadLocalData.get()) + threadLocalData.get().add("Y") + launch { + assertEquals(listOf("X", "Y"), threadLocalData.get()) + threadLocalData.get().add("Z") + yield() + assertEquals(listOf("X", "Y", "Z"), threadLocalData.get()) + } + } + } + + @Test + fun testDataIsNotVisibleToUndispatchedCoroutines() = runTest { + threadLocalData.set(mutableListOf()) + val root = MyMutableElement(ArrayList()) + val anotherRootScope = CoroutineScope(Dispatchers.Unconfined + root) + withContext(root) { + threadLocalData.get().add("X") + assertEquals(listOf("X"), threadLocalData.get()) + anotherRootScope.launch { + assertEquals(listOf(), threadLocalData.get()) + } } } diff --git a/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt b/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt index 14931459f5..c05be10b69 100644 --- a/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt +++ b/kotlinx-coroutines-core/concurrent/test/ThreadContextElementConcurrentTest.kt @@ -5,30 +5,29 @@ import kotlinx.coroutines.testing.* import kotlin.coroutines.CoroutineContext import kotlin.test.* -class ThreadContextElementConcurrentTest: TestBase() { - +class ThreadContextElementConcurrentTest : TestBase() { @Test fun testWithContext() = runTest { expect(1) newSingleThreadContext("withContext").use { val data = MyData() - GlobalScope.async(Dispatchers.Default + MyElement(data)) { - assertSame(data, myThreadLocal.get()) + GlobalScope.async(Dispatchers.Default + threadContextElementThreadLocal.asCtxElement(data)) { + assertSame(data, threadContextElementThreadLocal.get()) expect(2) val newData = MyData() - GlobalScope.async(it + MyElement(newData)) { - assertSame(newData, myThreadLocal.get()) + GlobalScope.async(it + threadContextElementThreadLocal.asCtxElement(newData)) { + assertSame(newData, threadContextElementThreadLocal.get()) expect(3) }.await() - withContext(it + MyElement(newData)) { - assertSame(newData, myThreadLocal.get()) + withContext(it + threadContextElementThreadLocal.asCtxElement(newData)) { + assertSame(newData, threadContextElementThreadLocal.get()) expect(4) } GlobalScope.async(it) { - assertNull(myThreadLocal.get()) + assertNull(threadContextElementThreadLocal.get()) expect(5) }.await() @@ -41,21 +40,24 @@ class ThreadContextElementConcurrentTest: TestBase() { @Test fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest { - var parentElement: MyElement? = null - var inheritedElement: MyElement? = null + var parentElement: Any? = null + var inheritedElement: Any? = null newSingleThreadContext("withContext").use { - withContext(it + MyElement(MyData())) { - parentElement = coroutineContext[MyElement.Key] + val myElement = threadContextElementThreadLocal.asCtxElement(MyData()) + withContext(it + myElement) { + parentElement = coroutineContext[myElement.key] launch { - inheritedElement = coroutineContext[MyElement.Key] + inheritedElement = coroutineContext[myElement.key] } } } - assertSame(inheritedElement, parentElement, + assertSame( + inheritedElement, parentElement, "Inner and outer coroutines did not have the same object reference to a" + - " ThreadContextElement that did not override `copyForChildCoroutine()`") + " ThreadContextElement that did not override `copyForChildCoroutine()`" + ) } @Test @@ -81,37 +83,36 @@ class ThreadContextElementConcurrentTest: TestBase() { newFixedThreadPoolContext(nThreads = 4, name = "withContext").use { withContext(it + CopyForChildCoroutineElement(MyData())) { val forBlockData = MyData() - myThreadLocal.setForBlock(forBlockData) { - assertSame(myThreadLocal.get(), forBlockData) + threadContextElementThreadLocal.setForBlock(forBlockData) { + assertSame(threadContextElementThreadLocal.get(), forBlockData) launch { - assertSame(myThreadLocal.get(), forBlockData) + assertSame(threadContextElementThreadLocal.get(), forBlockData) } launch { - assertSame(myThreadLocal.get(), forBlockData) + assertSame(threadContextElementThreadLocal.get(), forBlockData) // Modify value in child coroutine. Writes to the ThreadLocal and // the (copied) ThreadLocalElement's memory are not visible to peer or // ancestor coroutines, so this write is both threadsafe and coroutinesafe. val innerCoroutineData = MyData() - myThreadLocal.setForBlock(innerCoroutineData) { - assertSame(myThreadLocal.get(), innerCoroutineData) + threadContextElementThreadLocal.setForBlock(innerCoroutineData) { + assertSame(threadContextElementThreadLocal.get(), innerCoroutineData) } - assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored. + assertSame(threadContextElementThreadLocal.get(), forBlockData) // Asserts value was restored. } launch { val innerCoroutineData = MyData() - myThreadLocal.setForBlock(innerCoroutineData) { - assertSame(myThreadLocal.get(), innerCoroutineData) + threadContextElementThreadLocal.setForBlock(innerCoroutineData) { + assertSame(threadContextElementThreadLocal.get(), innerCoroutineData) } - assertSame(myThreadLocal.get(), forBlockData) + assertSame(threadContextElementThreadLocal.get(), forBlockData) } } - assertNull(myThreadLocal.get()) // Asserts value was restored to its origin + assertNull(threadContextElementThreadLocal.get()) // Asserts value was restored to its origin } } } } - /** * A [ThreadContextElement] that implements copy semantics in [copyForChild]. */ @@ -122,8 +123,8 @@ private class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadCo get() = Key override fun updateThreadContext(context: CoroutineContext): MyData? { - val oldState = myThreadLocal.get() - myThreadLocal.set(data) + val oldState = threadContextElementThreadLocal.get() + threadContextElementThreadLocal.set(data) return oldState } @@ -132,7 +133,7 @@ private class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadCo } override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { - myThreadLocal.set(oldState) + threadContextElementThreadLocal.set(oldState) } /** @@ -147,14 +148,14 @@ private class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadCo * thread and calls [restoreThreadContext]. */ override fun copyForChild(): CopyForChildCoroutineElement { - return CopyForChildCoroutineElement(myThreadLocal.get()) + return CopyForChildCoroutineElement(threadContextElementThreadLocal.get()) } } /** - * Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block]. + * Calls [block], setting the value of [this] [CommonThreadLocal] for the duration of [block]. * - * When a [CopyForChildCoroutineElement] for `this` [ThreadLocal] is used within a + * When a [CopyForChildCoroutineElement] for `this` [CommonThreadLocal] is used within a * [CoroutineContext], a ThreadLocal set this way will have the "correct" value expected lexically * at every statement reached, whether that statement is reached immediately, across suspend and * redispatch within one coroutine, or within a child coroutine. Writes made to the `ThreadLocal` diff --git a/kotlinx-coroutines-core/concurrent/test/ThreadContextMutableCopiesConcurrentTest.kt b/kotlinx-coroutines-core/concurrent/test/ThreadContextMutableCopiesConcurrentTest.kt new file mode 100644 index 0000000000..f35d052d4e --- /dev/null +++ b/kotlinx-coroutines-core/concurrent/test/ThreadContextMutableCopiesConcurrentTest.kt @@ -0,0 +1,15 @@ +package kotlinx.coroutines + +import kotlinx.coroutines.testing.* +import kotlin.test.* + +class ThreadContextMutableCopiesConcurrentTest : TestBase() { + @Test + fun testDataIsCopiedForRunBlocking() = runTest { + val root = ThreadContextMutableCopiesTest.MyMutableElement(ArrayList()) + val originalData = root.mutableData + runBlocking(root) { + assertNotSame(originalData, ThreadContextMutableCopiesTest.Companion.threadLocalData.get()) + } + } +} diff --git a/kotlinx-coroutines-core/jvm/test/FailingCoroutinesMachineryTest.kt b/kotlinx-coroutines-core/jvm/test/FailingCoroutinesMachineryTest.kt index 144e4e9dc4..7b73e7c37e 100644 --- a/kotlinx-coroutines-core/jvm/test/FailingCoroutinesMachineryTest.kt +++ b/kotlinx-coroutines-core/jvm/test/FailingCoroutinesMachineryTest.kt @@ -34,7 +34,7 @@ class FailingCoroutinesMachineryTest( private val lazyOuterDispatcher = lazy { newFixedThreadPoolContext(1, "") } private object FailingUpdate : ThreadContextElement { - private object Key : CoroutineContext.Key + private object Key : CoroutineContext.Key> override val key: CoroutineContext.Key<*> get() = Key @@ -49,7 +49,7 @@ class FailingCoroutinesMachineryTest( } private object FailingRestore : ThreadContextElement { - private object Key : CoroutineContext.Key + private object Key : CoroutineContext.Key> override val key: CoroutineContext.Key<*> get() = Key diff --git a/kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt deleted file mode 100644 index 68e2042520..0000000000 --- a/kotlinx-coroutines-core/jvm/test/ThreadContextElementJvmTest.kt +++ /dev/null @@ -1,36 +0,0 @@ -package kotlinx.coroutines - -import kotlinx.coroutines.testing.* -import org.junit.Test -import kotlin.coroutines.* -import kotlin.test.* - -class ThreadContextElementJvmTest : TestBase() { - - @Test - fun testExample() = runTest { - val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! - val mainDispatcher = coroutineContext[ContinuationInterceptor]!! - val mainThread = Thread.currentThread() - val data = MyData() - val element = MyElement(data) - assertNull(myThreadLocal.get()) - val job = GlobalScope.launch(element + exceptionHandler) { - assertTrue(mainThread != Thread.currentThread()) - assertSame(element, coroutineContext[MyElement]) - assertSame(data, myThreadLocal.get()) - withContext(mainDispatcher) { - assertSame(mainThread, Thread.currentThread()) - assertSame(element, coroutineContext[MyElement]) - assertSame(data, myThreadLocal.get()) - } - assertTrue(mainThread != Thread.currentThread()) - assertSame(element, coroutineContext[MyElement]) - assertSame(data, myThreadLocal.get()) - } - assertNull(myThreadLocal.get()) - job.join() - assertNull(myThreadLocal.get()) - } - -} diff --git a/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt index 79a2490fc5..e48b7cb727 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt @@ -114,6 +114,16 @@ class ThreadLocalTest : TestBase() { assertEquals(42, intThreadLocal.get()) } + @Test + fun testWritesLostOnSuspensions() = runTest { + withContext(intThreadLocal.asContextElement(1)) { + assertEquals(1, intThreadLocal.get()) + intThreadLocal.set(5) + yield() + assertEquals(1, intThreadLocal.get()) + } + } + @Test fun testThreadLocalModification() = runTest { stringThreadLocal.set("main")