From 33c85c777a0cca428685a7297d01fb1459a5f47a Mon Sep 17 00:00:00 2001 From: Roman Elizarov Date: Wed, 25 Jul 2018 12:22:46 +0300 Subject: [PATCH 1/2] Introduce ThreadContextElement API to integrate with thread-local sensitive code * Debug thread name is redesigned using ThreadContextElement API where the name of thread reflects the name of currently coroutine. * Intrinsics for startCoroutineUndispatched that correspond to CoroutineStart.UNDISPATCHED properly update coroutine context. * New intrinsics named startCoroutineUnintercepted are introduced. They do not update thread context. * withContext logic is fixed properly update context is various situations. * DebugThreadNameTest is introduced. * Reporting of unhandled errors in TestBase is improved. Its CoroutineExceptionHandler records but does not rethrow exception. This makes sure that failed tests actually fail and do not hang in recursive attempt to handle unhandled coroutine exception. Fixes #119 --- .../kotlinx-coroutines-core.txt | 16 +- build.gradle | 2 + .../src/Builders.common.kt | 26 ++- .../src/JobSupport.kt | 4 +- .../src/ResumeMode.kt | 4 +- .../src/channels/AbstractChannel.kt | 8 +- .../src/channels/ConflatedBroadcastChannel.kt | 2 +- .../src/intrinsics/Undispatched.kt | 61 +++++-- .../src/selects/Select.kt | 2 +- .../src/sync/Mutex.kt | 2 +- .../test/selects/SelectArrayChannelTest.kt | 2 +- .../selects/SelectRendezvousChannelTest.kt | 2 +- .../src/CoroutineContext.kt | 51 +++--- .../src/ThreadContextElement.kt | 153 ++++++++++++++++++ .../test/DebugThreadNameTest.kt | 74 +++++++++ core/kotlinx-coroutines-core/test/TestBase.kt | 16 +- .../test/ThreadContextElementTest.kt | 82 ++++++++++ .../test/selects/SelectChannelStressTest.kt | 2 +- .../kotlinx-coroutines-quasar/src/Quasar.kt | 3 +- .../test/TestBase.kt | 16 +- .../test/TestBase.kt | 15 +- 21 files changed, 469 insertions(+), 74 deletions(-) create mode 100644 core/kotlinx-coroutines-core/src/ThreadContextElement.kt create mode 100644 core/kotlinx-coroutines-core/test/DebugThreadNameTest.kt create mode 100644 core/kotlinx-coroutines-core/test/ThreadContextElementTest.kt diff --git a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt index fd84630b59..6307e06dbf 100644 --- a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt +++ b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt @@ -136,8 +136,6 @@ public final class kotlinx/coroutines/experimental/CoroutineContextKt { public static final fun newCoroutineContext (Lkotlin/coroutines/experimental/CoroutineContext;)Lkotlin/coroutines/experimental/CoroutineContext; public static final fun newCoroutineContext (Lkotlin/coroutines/experimental/CoroutineContext;Lkotlinx/coroutines/experimental/Job;)Lkotlin/coroutines/experimental/CoroutineContext; public static synthetic fun newCoroutineContext$default (Lkotlin/coroutines/experimental/CoroutineContext;Lkotlinx/coroutines/experimental/Job;ILjava/lang/Object;)Lkotlin/coroutines/experimental/CoroutineContext; - public static final fun restoreThreadContext (Ljava/lang/String;)V - public static final fun updateThreadContext (Lkotlin/coroutines/experimental/CoroutineContext;)Ljava/lang/String; } public abstract class kotlinx/coroutines/experimental/CoroutineDispatcher : kotlin/coroutines/experimental/AbstractCoroutineContextElement, kotlin/coroutines/experimental/ContinuationInterceptor { @@ -436,6 +434,18 @@ public final class kotlinx/coroutines/experimental/ScheduledKt { public static synthetic fun withTimeoutOrNull$default (JLjava/util/concurrent/TimeUnit;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/experimental/Continuation;ILjava/lang/Object;)Ljava/lang/Object; } +public abstract interface class kotlinx/coroutines/experimental/ThreadContextElement : kotlin/coroutines/experimental/CoroutineContext$Element { + public abstract fun restoreThreadContext (Lkotlin/coroutines/experimental/CoroutineContext;Ljava/lang/Object;)V + public abstract fun updateThreadContext (Lkotlin/coroutines/experimental/CoroutineContext;)Ljava/lang/Object; +} + +public final class kotlinx/coroutines/experimental/ThreadContextElement$DefaultImpls { + public static fun fold (Lkotlinx/coroutines/experimental/ThreadContextElement;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object; + public static fun get (Lkotlinx/coroutines/experimental/ThreadContextElement;Lkotlin/coroutines/experimental/CoroutineContext$Key;)Lkotlin/coroutines/experimental/CoroutineContext$Element; + public static fun minusKey (Lkotlinx/coroutines/experimental/ThreadContextElement;Lkotlin/coroutines/experimental/CoroutineContext$Key;)Lkotlin/coroutines/experimental/CoroutineContext; + public static fun plus (Lkotlinx/coroutines/experimental/ThreadContextElement;Lkotlin/coroutines/experimental/CoroutineContext;)Lkotlin/coroutines/experimental/CoroutineContext; +} + public final class kotlinx/coroutines/experimental/ThreadPoolDispatcher : kotlinx/coroutines/experimental/ExecutorCoroutineDispatcherBase { public fun close ()V public fun getExecutor ()Ljava/util/concurrent/Executor; @@ -939,6 +949,8 @@ public final class kotlinx/coroutines/experimental/intrinsics/CancellableKt { public final class kotlinx/coroutines/experimental/intrinsics/UndispatchedKt { public static final fun startCoroutineUndispatched (Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/experimental/Continuation;)V public static final fun startCoroutineUndispatched (Lkotlin/jvm/functions/Function2;Ljava/lang/Object;Lkotlin/coroutines/experimental/Continuation;)V + public static final fun startCoroutineUnintercepted (Lkotlin/jvm/functions/Function1;Lkotlin/coroutines/experimental/Continuation;)V + public static final fun startCoroutineUnintercepted (Lkotlin/jvm/functions/Function2;Ljava/lang/Object;Lkotlin/coroutines/experimental/Continuation;)V public static final fun startUndispatchedOrReturn (Lkotlinx/coroutines/experimental/AbstractCoroutine;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object; public static final fun startUndispatchedOrReturn (Lkotlinx/coroutines/experimental/AbstractCoroutine;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object; } diff --git a/build.gradle b/build.gradle index e3bd895a74..187ce662c8 100644 --- a/build.gradle +++ b/build.gradle @@ -110,8 +110,10 @@ configure(subprojects.findAll { !it.name.contains(sourceless) && it.name != "ben sourceSets { main.kotlin.srcDirs = ['src'] test.kotlin.srcDirs = ['test'] + // todo: do we still need this workaround? if (!projectName.endsWith("-native")) { main.resources.srcDirs = ['resources'] + test.resources.srcDirs = ['test-resources'] } } } diff --git a/common/kotlinx-coroutines-core-common/src/Builders.common.kt b/common/kotlinx-coroutines-core-common/src/Builders.common.kt index 82e8fea711..e5745b9506 100644 --- a/common/kotlinx-coroutines-core-common/src/Builders.common.kt +++ b/common/kotlinx-coroutines-core-common/src/Builders.common.kt @@ -119,8 +119,11 @@ public suspend fun withContext( // fast path #3 if the new dispatcher is the same as the old one. // `equals` is used by design (see equals implementation is wrapper context like ExecutorCoroutineDispatcher) if (newContext[ContinuationInterceptor] == oldContext[ContinuationInterceptor]) { - val newContinuation = RunContinuationDirect(newContext, uCont) - return@sc block.startCoroutineUninterceptedOrReturn(newContinuation) + val newContinuation = RunContinuationUnintercepted(newContext, uCont) + // There are some other changes in the context, so this thread needs to be updated + withCoroutineContext(newContext) { + return@sc block.startCoroutineUninterceptedOrReturn(newContinuation) + } } // slowest path otherwise -- use new interceptor, sync to its result via a full-blown instance of RunCompletion require(!start.isLazy) { "$start start is not supported" } @@ -130,7 +133,6 @@ public suspend fun withContext( resumeMode = if (start == CoroutineStart.ATOMIC) MODE_ATOMIC_DEFAULT else MODE_CANCELLABLE ) completion.initParentJobInternal(newContext[Job]) // attach to job - @Suppress("DEPRECATION") start(block, completion) completion.getResult() } @@ -178,10 +180,22 @@ private class LazyStandaloneCoroutine( } } -private class RunContinuationDirect( +private class RunContinuationUnintercepted( override val context: CoroutineContext, - continuation: Continuation -) : Continuation by continuation + private val continuation: Continuation +): Continuation { + override fun resume(value: T) { + withCoroutineContext(continuation.context) { + continuation.resume(value) + } + } + + override fun resumeWithException(exception: Throwable) { + withCoroutineContext(continuation.context) { + continuation.resumeWithException(exception) + } + } +} @Suppress("UNCHECKED_CAST") private class RunCompletion( diff --git a/common/kotlinx-coroutines-core-common/src/JobSupport.kt b/common/kotlinx-coroutines-core-common/src/JobSupport.kt index 0f6818e634..4ab879af6f 100644 --- a/common/kotlinx-coroutines-core-common/src/JobSupport.kt +++ b/common/kotlinx-coroutines-core-common/src/JobSupport.kt @@ -531,7 +531,7 @@ internal open class JobSupport constructor(active: Boolean) : Job, SelectClause0 // already complete -- select result if (select.trySelect(null)) { select.completion.context.checkCompletion() // always check for our completion - block.startCoroutineUndispatched(select.completion) + block.startCoroutineUnintercepted(select.completion) } return } @@ -992,7 +992,7 @@ internal open class JobSupport constructor(active: Boolean) : Job, SelectClause0 if (state is CompletedExceptionally) select.resumeSelectCancellableWithException(state.cause) else - block.startCoroutineUndispatched(state as T, select.completion) + block.startCoroutineUnintercepted(state as T, select.completion) } return } diff --git a/common/kotlinx-coroutines-core-common/src/ResumeMode.kt b/common/kotlinx-coroutines-core-common/src/ResumeMode.kt index 6faf515911..a44434747c 100644 --- a/common/kotlinx-coroutines-core-common/src/ResumeMode.kt +++ b/common/kotlinx-coroutines-core-common/src/ResumeMode.kt @@ -43,7 +43,7 @@ internal fun Continuation.resumeUninterceptedMode(value: T, mode: Int) { MODE_ATOMIC_DEFAULT -> intercepted().resume(value) MODE_CANCELLABLE -> intercepted().resumeCancellable(value) MODE_DIRECT -> resume(value) - MODE_UNDISPATCHED -> resume(value) + MODE_UNDISPATCHED -> withCoroutineContext(context) { resume(value) } MODE_IGNORE -> {} else -> error("Invalid mode $mode") } @@ -54,7 +54,7 @@ internal fun Continuation.resumeUninterceptedWithExceptionMode(exception: MODE_ATOMIC_DEFAULT -> intercepted().resumeWithException(exception) MODE_CANCELLABLE -> intercepted().resumeCancellableWithException(exception) MODE_DIRECT -> resumeWithException(exception) - MODE_UNDISPATCHED -> resumeWithException(exception) + MODE_UNDISPATCHED -> withCoroutineContext(context) { resumeWithException(exception) } MODE_IGNORE -> {} else -> error("Invalid mode $mode") } diff --git a/common/kotlinx-coroutines-core-common/src/channels/AbstractChannel.kt b/common/kotlinx-coroutines-core-common/src/channels/AbstractChannel.kt index 42fbb20bbc..2a1aa61653 100644 --- a/common/kotlinx-coroutines-core-common/src/channels/AbstractChannel.kt +++ b/common/kotlinx-coroutines-core-common/src/channels/AbstractChannel.kt @@ -414,7 +414,7 @@ public abstract class AbstractSendChannel : SendChannel { offerResult === ALREADY_SELECTED -> return offerResult === OFFER_FAILED -> {} // retry offerResult === OFFER_SUCCESS -> { - block.startCoroutineUndispatched(receiver = this, completion = select.completion) + block.startCoroutineUnintercepted(receiver = this, completion = select.completion) return } offerResult is Closed<*> -> throw offerResult.sendException @@ -753,7 +753,7 @@ public abstract class AbstractChannel : AbstractSendChannel(), Channel pollResult === POLL_FAILED -> {} // retry pollResult is Closed<*> -> throw pollResult.receiveException else -> { - block.startCoroutineUndispatched(pollResult as E, select.completion) + block.startCoroutineUnintercepted(pollResult as E, select.completion) return } } @@ -788,14 +788,14 @@ public abstract class AbstractChannel : AbstractSendChannel(), Channel pollResult is Closed<*> -> { if (pollResult.closeCause == null) { if (select.trySelect(null)) - block.startCoroutineUndispatched(null, select.completion) + block.startCoroutineUnintercepted(null, select.completion) return } else throw pollResult.closeCause } else -> { // selected successfully - block.startCoroutineUndispatched(pollResult as E, select.completion) + block.startCoroutineUnintercepted(pollResult as E, select.completion) return } } diff --git a/common/kotlinx-coroutines-core-common/src/channels/ConflatedBroadcastChannel.kt b/common/kotlinx-coroutines-core-common/src/channels/ConflatedBroadcastChannel.kt index 34ed4de68b..3bdd4c4a47 100644 --- a/common/kotlinx-coroutines-core-common/src/channels/ConflatedBroadcastChannel.kt +++ b/common/kotlinx-coroutines-core-common/src/channels/ConflatedBroadcastChannel.kt @@ -262,7 +262,7 @@ public class ConflatedBroadcastChannel() : BroadcastChannel { select.resumeSelectCancellableWithException(it.sendException) return } - block.startCoroutineUndispatched(receiver = this, completion = select.completion) + block.startCoroutineUnintercepted(receiver = this, completion = select.completion) } @Suppress("DEPRECATION") diff --git a/common/kotlinx-coroutines-core-common/src/intrinsics/Undispatched.kt b/common/kotlinx-coroutines-core-common/src/intrinsics/Undispatched.kt index e7ab254e4f..80048a8497 100644 --- a/common/kotlinx-coroutines-core-common/src/intrinsics/Undispatched.kt +++ b/common/kotlinx-coroutines-core-common/src/intrinsics/Undispatched.kt @@ -9,38 +9,69 @@ import kotlin.coroutines.experimental.* import kotlin.coroutines.experimental.intrinsics.* /** - * Use this function to restart coroutine directly from inside of [suspendCoroutine] in the same context. + * Use this function to restart coroutine directly from inside of [suspendCoroutine], + * when the code is already in the context of this coroutine. + * It does not use [ContinuationInterceptor] and does not update context of the current thread. */ -@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "UNCHECKED_CAST") -public fun (suspend () -> T).startCoroutineUndispatched(completion: Continuation) { - val value = try { +public fun (suspend () -> T).startCoroutineUnintercepted(completion: Continuation) { + startDirect(completion) { startCoroutineUninterceptedOrReturn(completion) - } catch (e: Throwable) { - completion.resumeWithException(e) - return } - if (value !== COROUTINE_SUSPENDED) - completion.resume(value as T) } /** - * Use this function to restart coroutine directly from inside of [suspendCoroutine] in the same context. + * Use this function to restart coroutine directly from inside of [suspendCoroutine], + * when the code is already in the context of this coroutine. + * It does not use [ContinuationInterceptor] and does not update context of the current thread. + */ +public fun (suspend (R) -> T).startCoroutineUnintercepted(receiver: R, completion: Continuation) { + startDirect(completion) { + startCoroutineUninterceptedOrReturn(receiver, completion) + } +} + +/** + * Use this function to start new coroutine in [CoroutineStart.UNDISPATCHED] mode — + * immediately execute coroutine in the current thread until next suspension. + * It does not use [ContinuationInterceptor], but updates the context of the current thread for the new coroutine. + */ +public fun (suspend () -> T).startCoroutineUndispatched(completion: Continuation) { + startDirect(completion) { + withCoroutineContext(completion.context) { + startCoroutineUninterceptedOrReturn(completion) + } + } +} + +/** + * Use this function to start new coroutine in [CoroutineStart.UNDISPATCHED] mode — + * immediately execute coroutine in the current thread until next suspension. + * It does not use [ContinuationInterceptor], but updates the context of the current thread for the new coroutine. */ -@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN", "UNCHECKED_CAST") public fun (suspend (R) -> T).startCoroutineUndispatched(receiver: R, completion: Continuation) { + startDirect(completion) { + withCoroutineContext(completion.context) { + startCoroutineUninterceptedOrReturn(receiver, completion) + } + } +} + +private inline fun startDirect(completion: Continuation, block: () -> Any?) { val value = try { - startCoroutineUninterceptedOrReturn(receiver, completion) + block() } catch (e: Throwable) { completion.resumeWithException(e) return } - if (value !== COROUTINE_SUSPENDED) + if (value !== COROUTINE_SUSPENDED) { + @Suppress("UNCHECKED_CAST") completion.resume(value as T) + } } /** * Starts this coroutine with the given code [block] in the same context and returns result when it - * completes without suspnesion. + * completes without suspension. * This function shall be invoked at most once on this coroutine. * * First, this function initializes parent job from the `parentContext` of this coroutine that was passed to it @@ -53,7 +84,7 @@ public fun AbstractCoroutine.startUndispatchedOrReturn(block: suspend () /** * Starts this coroutine with the given code [block] in the same context and returns result when it - * completes without suspnesion. + * completes without suspension. * This function shall be invoked at most once on this coroutine. * * First, this function initializes parent job from the `parentContext` of this coroutine that was passed to it diff --git a/common/kotlinx-coroutines-core-common/src/selects/Select.kt b/common/kotlinx-coroutines-core-common/src/selects/Select.kt index eae68adfa2..23b6752cd0 100644 --- a/common/kotlinx-coroutines-core-common/src/selects/Select.kt +++ b/common/kotlinx-coroutines-core-common/src/selects/Select.kt @@ -407,7 +407,7 @@ internal class SelectBuilderImpl( override fun onTimeout(time: Long, unit: TimeUnit, block: suspend () -> R) { if (time <= 0L) { if (trySelect(null)) - block.startCoroutineUndispatched(completion) + block.startCoroutineUnintercepted(completion) return } val action = Runnable { diff --git a/common/kotlinx-coroutines-core-common/src/sync/Mutex.kt b/common/kotlinx-coroutines-core-common/src/sync/Mutex.kt index 7657e77560..997be13861 100644 --- a/common/kotlinx-coroutines-core-common/src/sync/Mutex.kt +++ b/common/kotlinx-coroutines-core-common/src/sync/Mutex.kt @@ -252,7 +252,7 @@ internal class MutexImpl(locked: Boolean) : Mutex, SelectClause2 { val failure = select.performAtomicTrySelect(TryLockDesc(this, owner)) when { failure == null -> { // success - block.startCoroutineUndispatched(receiver = this, completion = select.completion) + block.startCoroutineUnintercepted(receiver = this, completion = select.completion) return } failure === ALREADY_SELECTED -> return // already selected -- bail out diff --git a/common/kotlinx-coroutines-core-common/test/selects/SelectArrayChannelTest.kt b/common/kotlinx-coroutines-core-common/test/selects/SelectArrayChannelTest.kt index d4a9be4fdb..a80577042c 100644 --- a/common/kotlinx-coroutines-core-common/test/selects/SelectArrayChannelTest.kt +++ b/common/kotlinx-coroutines-core-common/test/selects/SelectArrayChannelTest.kt @@ -289,6 +289,6 @@ class SelectArrayChannelTest : TestBase() { internal fun SelectBuilder.default(block: suspend () -> R) { this as SelectBuilderImpl // type assertion if (!trySelect(null)) return - block.startCoroutineUndispatched(this) + block.startCoroutineUnintercepted(this) } } diff --git a/common/kotlinx-coroutines-core-common/test/selects/SelectRendezvousChannelTest.kt b/common/kotlinx-coroutines-core-common/test/selects/SelectRendezvousChannelTest.kt index 631230628d..2f7f63b500 100644 --- a/common/kotlinx-coroutines-core-common/test/selects/SelectRendezvousChannelTest.kt +++ b/common/kotlinx-coroutines-core-common/test/selects/SelectRendezvousChannelTest.kt @@ -310,6 +310,6 @@ class SelectRendezvousChannelTest : TestBase() { internal fun SelectBuilder.default(block: suspend () -> R) { this as SelectBuilderImpl // type assertion if (!trySelect(null)) return - block.startCoroutineUndispatched(this) + block.startCoroutineUnintercepted(this) } } diff --git a/core/kotlinx-coroutines-core/src/CoroutineContext.kt b/core/kotlinx-coroutines-core/src/CoroutineContext.kt index ea8d034366..247ba5042a 100644 --- a/core/kotlinx-coroutines-core/src/CoroutineContext.kt +++ b/core/kotlinx-coroutines-core/src/CoroutineContext.kt @@ -4,6 +4,7 @@ package kotlinx.coroutines.experimental +import java.util.* import kotlinx.coroutines.experimental.internal.* import kotlinx.coroutines.experimental.scheduling.* import java.util.concurrent.atomic.* @@ -98,31 +99,14 @@ public actual fun newCoroutineContext(context: CoroutineContext, parent: Job? = * Executes a block using a given coroutine context. */ internal actual inline fun withCoroutineContext(context: CoroutineContext, block: () -> T): T { - val oldName = context.updateThreadContext() + val oldValue = updateThreadContext(context) try { return block() } finally { - restoreThreadContext(oldName) + restoreThreadContext(context, oldValue) } } -@PublishedApi -internal fun CoroutineContext.updateThreadContext(): String? { - if (!DEBUG) return null - val coroutineId = this[CoroutineId] ?: return null - val coroutineName = this[CoroutineName]?.name ?: "coroutine" - val currentThread = Thread.currentThread() - val oldName = currentThread.name - currentThread.name = buildString(oldName.length + coroutineName.length + 10) { - append(oldName) - append(" @") - append(coroutineName) - append('#') - append(coroutineId.id) - } - return oldName -} - internal actual val CoroutineContext.coroutineName: String? get() { if (!DEBUG) return null val coroutineId = this[CoroutineId] ?: return null @@ -130,12 +114,31 @@ internal actual val CoroutineContext.coroutineName: String? get() { return "$coroutineName#${coroutineId.id}" } -@PublishedApi -internal fun restoreThreadContext(oldName: String?) { - if (oldName != null) Thread.currentThread().name = oldName -} +private const val DEBUG_THREAD_NAME_SEPARATOR = " @" -private class CoroutineId(val id: Long) : AbstractCoroutineContextElement(CoroutineId) { +internal data class CoroutineId( + val id: Long +) : ThreadContextElement, AbstractCoroutineContextElement(CoroutineId) { companion object Key : CoroutineContext.Key override fun toString(): String = "CoroutineId($id)" + + override fun updateThreadContext(context: CoroutineContext): String { + val coroutineName = context[CoroutineName]?.name ?: "coroutine" + val currentThread = Thread.currentThread() + val oldName = currentThread.name + var lastIndex = oldName.lastIndexOf(DEBUG_THREAD_NAME_SEPARATOR) + if (lastIndex < 0) lastIndex = oldName.length + currentThread.name = buildString(lastIndex + coroutineName.length + 10) { + append(oldName.substring(0, lastIndex)) + append(DEBUG_THREAD_NAME_SEPARATOR) + append(coroutineName) + append('#') + append(id) + } + return oldName + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: String) { + Thread.currentThread().name = oldState + } } diff --git a/core/kotlinx-coroutines-core/src/ThreadContextElement.kt b/core/kotlinx-coroutines-core/src/ThreadContextElement.kt new file mode 100644 index 0000000000..140c9c3770 --- /dev/null +++ b/core/kotlinx-coroutines-core/src/ThreadContextElement.kt @@ -0,0 +1,153 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.experimental + +import kotlinx.coroutines.experimental.internal.* +import kotlin.coroutines.experimental.* + +/** + * Defines elements in [CoroutineContext] that are installed into thread context + * every time the coroutine with this element in the context is resumed on a thread. + * + * Implementations of this interface define a type [S] of the thread-local state that they need to store on + * resume of a coroutine and restore later on suspend and the infrastructure provides the corresponding storage. + * + * Example usage looks like this: + * + * ``` + * // declare thread local variable holding MyData + * private val myThreadLocal = ThreadLocal() + * + * // declare context element holding MyData + * class MyElement(val data: MyData) : ThreadContextElement { + * // declare companion object for a key of this element in coroutine context + * companion object Key : CoroutineContext.Key + * + * // provide the key of the corresponding context element + * override val key: CoroutineContext.Key + * get() = Key + * + * // this is invoked before coroutine is resumed on current thread + * override fun updateThreadContext(context: CoroutineContext): MyData? { + * val oldState = myThreadLocal.get() + * myThreadLocal.set(data) + * return oldState + * } + * + * // this is invoked after coroutine has suspended on current thread + * override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { + * myThreadLocal.set(oldState) + * } + * } + * ``` + */ +public interface ThreadContextElement : CoroutineContext.Element { + /** + * Updates context of the current thread. + * This function is invoked before the coroutine in the specified [context] is resumed in the current thread + * when the context of the coroutine this element. + * The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext]. + * + * @param context the coroutine context. + */ + public fun updateThreadContext(context: CoroutineContext): S + + /** + * Restores context of the current thread. + * This function is invoked after the coroutine in the specified [context] is suspended in the current thread + * if [updateThreadContext] was previously invoked on resume of this coroutine. + * The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should + * be restored in the thread-local state by this function. + * + * @param context the coroutine context. + * @param oldState the value returned by the previous invocation of [updateThreadContext]. + */ + public fun restoreThreadContext(context: CoroutineContext, oldState: S) +} + +private val ZERO = Symbol("ZERO") + +// Used when there are >= 2 active elements in the context +private class ThreadState(val context: CoroutineContext, n: Int) { + private var a = arrayOfNulls(n) + private var i = 0 + + fun append(value: Any?) { a[i++] = value } + fun take() = a[i++] + fun start() { i = 0 } +} + +// Counts ThreadContextElements in the context +// Any? here is Int | ThreadContextElement (when count is one) +private val countAll = + fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { + if (element is ThreadContextElement<*>) { + val inCount = countOrElement as? Int ?: 1 + return if (inCount == 0) element else inCount + 1 + } + return countOrElement + } + +// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one +private val findOne = + fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { + if (found != null) return found + return element as? ThreadContextElement<*> + } + +// Updates state for ThreadContextElements in the context using the given ThreadState +private val updateState = + fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { + if (element is ThreadContextElement<*>) { + state.append(element.updateThreadContext(state.context)) + } + return state + } + +// Restores state for all ThreadContextElements in the context from the given ThreadState +private val restoreState = + fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { + @Suppress("UNCHECKED_CAST") + if (element is ThreadContextElement<*>) { + (element as ThreadContextElement).restoreThreadContext(state.context, state.take()) + } + return state + } + +internal fun updateThreadContext(context: CoroutineContext): Any? { + val count = context.fold(0, countAll) + @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") + return when { + count === 0 -> ZERO // very fast path when there are no active ThreadContextElements + // ^^^ identity comparison for speed, we know zero always has the same identity + count is Int -> { + // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values + context.fold(ThreadState(context, count), updateState) + } + else -> { + // fast path for one ThreadContextElement (no allocations, no additional context scan) + @Suppress("UNCHECKED_CAST") + val element = count as ThreadContextElement + element.updateThreadContext(context) + } + } +} + +internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { + when { + oldState === ZERO -> return // very fast path when there are no ThreadContextElements + oldState is ThreadState -> { + // slow path with multiple stored ThreadContextElements + oldState.start() + context.fold(oldState, restoreState) + } + else -> { + // fast path for one ThreadContextElement, but need to find it + @Suppress("UNCHECKED_CAST") + val element = context.fold(null, findOne) as ThreadContextElement + element.restoreThreadContext(context, oldState) + } + } +} diff --git a/core/kotlinx-coroutines-core/test/DebugThreadNameTest.kt b/core/kotlinx-coroutines-core/test/DebugThreadNameTest.kt new file mode 100644 index 0000000000..ff91555c88 --- /dev/null +++ b/core/kotlinx-coroutines-core/test/DebugThreadNameTest.kt @@ -0,0 +1,74 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.experimental + +import kotlin.coroutines.experimental.* +import kotlin.test.* + +class DebugThreadNameTest : TestBase() { + @BeforeTest + fun resetName() { + resetCoroutineId() + } + + @Test + fun testLaunchId() = runTest { + assertName("coroutine#1") + launch(coroutineContext) { + assertName("coroutine#2") + yield() + assertName("coroutine#2") + } + assertName("coroutine#1") + } + + @Test + fun testLaunchIdUndispatched() = runTest { + assertName("coroutine#1") + launch(coroutineContext, start = CoroutineStart.UNDISPATCHED) { + assertName("coroutine#2") + yield() + assertName("coroutine#2") + } + assertName("coroutine#1") + } + + @Test + fun testLaunchName() = runTest { + assertName("coroutine#1") + launch(coroutineContext + CoroutineName("TEST")) { + assertName("TEST#2") + yield() + assertName("TEST#2") + } + assertName("coroutine#1") + } + + @Test + fun testWithContext() = runTest { + assertName("coroutine#1") + withContext(DefaultDispatcher) { + assertName("coroutine#1") + yield() + assertName("coroutine#1") + withContext(CoroutineName("TEST")) { + assertName("TEST#1") + yield() + assertName("TEST#1") + } + assertName("coroutine#1") + yield() + assertName("coroutine#1") + } + assertName("coroutine#1") + } + + private fun assertName(expected: String) { + val name = Thread.currentThread().name + val split = name.split(Regex(" @")) + assertEquals(2, split.size, "Thread name '$name' is expected to contain one coroutine name") + assertEquals(expected, split[1], "Thread name '$name' is expected to end with coroutine name '$expected'") + } +} \ No newline at end of file diff --git a/core/kotlinx-coroutines-core/test/TestBase.kt b/core/kotlinx-coroutines-core/test/TestBase.kt index 2ef6cdd6b7..aa5a1e6a35 100644 --- a/core/kotlinx-coroutines-core/test/TestBase.kt +++ b/core/kotlinx-coroutines-core/test/TestBase.kt @@ -55,6 +55,12 @@ public actual open class TestBase actual constructor() { throw exception } + private fun printError(message: String, cause: Throwable) { + error.compareAndSet(null, cause) + println("$message: $cause") + cause.printStackTrace(System.out) + } + /** * Throws [IllegalStateException] when `value` is false like `check` in stdlib, but also ensures that the * test will not complete successfully even if this exception is consumed somewhere in the test. @@ -132,10 +138,12 @@ public actual open class TestBase actual constructor() { runBlocking(block = block, context = CoroutineExceptionHandler { context, e -> if (e is CancellationException) return@CoroutineExceptionHandler // are ignored exCount++ - if (exCount > unhandled.size) - error("Too many unhandled exceptions $exCount, expected ${unhandled.size}, got: $e", e) - if (!unhandled[exCount - 1](e)) - error("Unhandled exception was unexpected: $e", e) + when { + exCount > unhandled.size -> + printError("Too many unhandled exceptions $exCount, expected ${unhandled.size}, got: $e", e) + !unhandled[exCount - 1](e) -> + printError("Unhandled exception was unexpected: $e", e) + } context[Job]?.cancel(e) }) } catch (e: Throwable) { diff --git a/core/kotlinx-coroutines-core/test/ThreadContextElementTest.kt b/core/kotlinx-coroutines-core/test/ThreadContextElementTest.kt new file mode 100644 index 0000000000..0c670f219a --- /dev/null +++ b/core/kotlinx-coroutines-core/test/ThreadContextElementTest.kt @@ -0,0 +1,82 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.experimental + +import org.junit.Test +import kotlin.coroutines.experimental.* +import kotlin.test.* + +class ThreadContextElementTest : TestBase() { + @Test + fun testExample() = runTest { + val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! + val mainDispatcher = coroutineContext[ContinuationInterceptor]!! + val mainThread = Thread.currentThread() + val data = MyData() + val element = MyElement(data) + assertNull(myThreadLocal.get()) + val job = launch(element + exceptionHandler) { + assertTrue(mainThread != Thread.currentThread()) + assertSame(element, coroutineContext[MyElement]) + assertSame(data, myThreadLocal.get()) + withContext(mainDispatcher) { + assertSame(mainThread, Thread.currentThread()) + assertSame(element, coroutineContext[MyElement]) + assertSame(data, myThreadLocal.get()) + } + assertTrue(mainThread != Thread.currentThread()) + assertSame(element, coroutineContext[MyElement]) + assertSame(data, myThreadLocal.get()) + } + assertNull(myThreadLocal.get()) + job.join() + assertNull(myThreadLocal.get()) + } + + @Test + fun testUndispatched()= runTest { + val exceptionHandler = coroutineContext[CoroutineExceptionHandler]!! + val data = MyData() + val element = MyElement(data) + val job = launch( + context = DefaultDispatcher + exceptionHandler + element, + start = CoroutineStart.UNDISPATCHED + ) { + assertSame(data, myThreadLocal.get()) + yield() + assertSame(data, myThreadLocal.get()) + } + assertNull(myThreadLocal.get()) + job.join() + assertNull(myThreadLocal.get()) + } +} + +class MyData + +// declare thread local variable holding MyData +private val myThreadLocal = ThreadLocal() + +// declare context element holding MyData +class MyElement(val data: MyData) : ThreadContextElement { + // declare companion object for a key of this element in coroutine context + companion object Key : CoroutineContext.Key + + // provide the key of the corresponding context element + override val key: CoroutineContext.Key + get() = Key + + // this is invoked before coroutine is resumed on current thread + override fun updateThreadContext(context: CoroutineContext): MyData? { + val oldState = myThreadLocal.get() + myThreadLocal.set(data) + return oldState + } + + // this is invoked after coroutine has suspended on current thread + override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { + myThreadLocal.set(oldState) + } +} diff --git a/core/kotlinx-coroutines-core/test/selects/SelectChannelStressTest.kt b/core/kotlinx-coroutines-core/test/selects/SelectChannelStressTest.kt index 95d661c3af..a83b1c4b59 100644 --- a/core/kotlinx-coroutines-core/test/selects/SelectChannelStressTest.kt +++ b/core/kotlinx-coroutines-core/test/selects/SelectChannelStressTest.kt @@ -71,6 +71,6 @@ class SelectChannelStressTest: TestBase() { internal fun SelectBuilder.default(block: suspend () -> R) { this as SelectBuilderImpl // type assertion if (!trySelect(null)) return - block.startCoroutineUndispatched(this) + block.startCoroutineUnintercepted(this) } } diff --git a/integration/kotlinx-coroutines-quasar/src/Quasar.kt b/integration/kotlinx-coroutines-quasar/src/Quasar.kt index 73a7f71494..7dd693e01f 100644 --- a/integration/kotlinx-coroutines-quasar/src/Quasar.kt +++ b/integration/kotlinx-coroutines-quasar/src/Quasar.kt @@ -43,7 +43,8 @@ fun runFiberBlocking(block: suspend () -> T): T = private class CoroutineAsync( private val block: suspend () -> T ) : FiberAsync(), Continuation { - override val context: CoroutineContext = Fiber.currentFiber().scheduler.executor.asCoroutineDispatcher() + override val context: CoroutineContext = + newCoroutineContext(Fiber.currentFiber().scheduler.executor.asCoroutineDispatcher()) override fun resume(value: T) { asyncCompleted(value) } override fun resumeWithException(exception: Throwable) { asyncFailed(exception) } diff --git a/js/kotlinx-coroutines-core-js/test/TestBase.kt b/js/kotlinx-coroutines-core-js/test/TestBase.kt index 061b1d7c8a..8db40dbd8e 100644 --- a/js/kotlinx-coroutines-core-js/test/TestBase.kt +++ b/js/kotlinx-coroutines-core-js/test/TestBase.kt @@ -27,6 +27,12 @@ public actual open class TestBase actual constructor() { throw exception } + private fun printError(message: String, cause: Throwable) { + if (error == null) error = cause + println("$message: $cause") + console.log(cause) + } + /** * Asserts that this invocation is `index`-th in the execution sequence (counting from one). */ @@ -69,10 +75,12 @@ public actual open class TestBase actual constructor() { return promise(block = block, context = CoroutineExceptionHandler { context, e -> if (e is CancellationException) return@CoroutineExceptionHandler // are ignored exCount++ - if (exCount > unhandled.size) - error("Too many unhandled exceptions $exCount, expected ${unhandled.size}", e) - if (!unhandled[exCount - 1](e)) - error("Unhandled exception was unexpected", e) + when { + exCount > unhandled.size -> + printError("Too many unhandled exceptions $exCount, expected ${unhandled.size}, got: $e", e) + !unhandled[exCount - 1](e) -> + printError("Unhandled exception was unexpected: $e", e) + } context[Job]?.cancel(e) }).catch { e -> ex = e diff --git a/native/kotlinx-coroutines-core-native/test/TestBase.kt b/native/kotlinx-coroutines-core-native/test/TestBase.kt index f2873524f8..7f75e4439e 100644 --- a/native/kotlinx-coroutines-core-native/test/TestBase.kt +++ b/native/kotlinx-coroutines-core-native/test/TestBase.kt @@ -23,6 +23,11 @@ public actual open class TestBase actual constructor() { throw exception } + private fun printError(message: String, cause: Throwable) { + if (error == null) error = cause + println("$message: $cause") + } + /** * Asserts that this invocation is `index`-th in the execution sequence (counting from one). */ @@ -65,10 +70,12 @@ public actual open class TestBase actual constructor() { runBlocking(block = block, context = CoroutineExceptionHandler { context, e -> if (e is CancellationException) return@CoroutineExceptionHandler // are ignored exCount++ - if (exCount > unhandled.size) - error("Too many unhandled exceptions $exCount, expected ${unhandled.size}, got: $e", e) - if (!unhandled[exCount - 1](e)) - error("Unhandled exception was unexpected: $e", e) + when { + exCount > unhandled.size -> + printError("Too many unhandled exceptions $exCount, expected ${unhandled.size}, got: $e", e) + !unhandled[exCount - 1](e) -> + printError("Unhandled exception was unexpected: $e", e) + } context[Job]?.cancel(e) }) } catch (e: Throwable) { From 0dc7bd8fad549c1b8e3fa6f7ae9bf85dc34e6b45 Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Wed, 22 Aug 2018 19:41:57 +0300 Subject: [PATCH 2/2] Introduce ThreadLocal.asContextElement() * Move implementation to internal package * Add guide section --- .../kotlinx-coroutines-core.txt | 5 + .../src/CoroutineContext.kt | 1 - .../src/ThreadContextElement.kt | 159 ++++++-------- .../src/internal/ThreadContext.kt | 122 +++++++++++ .../test/ThreadContextElementTest.kt | 33 +++ .../test/ThreadLocalTest.kt | 199 ++++++++++++++++++ .../test/guide/example-context-11.kt | 23 ++ .../test/guide/test/GuideTest.kt | 10 + coroutines-guide.md | 61 ++++++ 9 files changed, 514 insertions(+), 99 deletions(-) create mode 100644 core/kotlinx-coroutines-core/src/internal/ThreadContext.kt create mode 100644 core/kotlinx-coroutines-core/test/ThreadLocalTest.kt create mode 100644 core/kotlinx-coroutines-core/test/guide/example-context-11.kt diff --git a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt index 6307e06dbf..500d3debb3 100644 --- a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt +++ b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt @@ -446,6 +446,11 @@ public final class kotlinx/coroutines/experimental/ThreadContextElement$DefaultI public static fun plus (Lkotlinx/coroutines/experimental/ThreadContextElement;Lkotlin/coroutines/experimental/CoroutineContext;)Lkotlin/coroutines/experimental/CoroutineContext; } +public final class kotlinx/coroutines/experimental/ThreadContextElementKt { + public static final fun asContextElement (Ljava/lang/ThreadLocal;Ljava/lang/Object;)Lkotlinx/coroutines/experimental/ThreadContextElement; + public static synthetic fun asContextElement$default (Ljava/lang/ThreadLocal;Ljava/lang/Object;ILjava/lang/Object;)Lkotlinx/coroutines/experimental/ThreadContextElement; +} + public final class kotlinx/coroutines/experimental/ThreadPoolDispatcher : kotlinx/coroutines/experimental/ExecutorCoroutineDispatcherBase { public fun close ()V public fun getExecutor ()Ljava/util/concurrent/Executor; diff --git a/core/kotlinx-coroutines-core/src/CoroutineContext.kt b/core/kotlinx-coroutines-core/src/CoroutineContext.kt index 247ba5042a..2fcd014e15 100644 --- a/core/kotlinx-coroutines-core/src/CoroutineContext.kt +++ b/core/kotlinx-coroutines-core/src/CoroutineContext.kt @@ -4,7 +4,6 @@ package kotlinx.coroutines.experimental -import java.util.* import kotlinx.coroutines.experimental.internal.* import kotlinx.coroutines.experimental.scheduling.* import java.util.concurrent.atomic.* diff --git a/core/kotlinx-coroutines-core/src/ThreadContextElement.kt b/core/kotlinx-coroutines-core/src/ThreadContextElement.kt index 140c9c3770..b43497d38d 100644 --- a/core/kotlinx-coroutines-core/src/ThreadContextElement.kt +++ b/core/kotlinx-coroutines-core/src/ThreadContextElement.kt @@ -12,36 +12,42 @@ import kotlin.coroutines.experimental.* * every time the coroutine with this element in the context is resumed on a thread. * * Implementations of this interface define a type [S] of the thread-local state that they need to store on - * resume of a coroutine and restore later on suspend and the infrastructure provides the corresponding storage. + * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage. * * Example usage looks like this: * * ``` - * // declare thread local variable holding MyData - * private val myThreadLocal = ThreadLocal() - * - * // declare context element holding MyData - * class MyElement(val data: MyData) : ThreadContextElement { + * // Appends "name" of a coroutine to a current thread name when coroutine is executed + * class CoroutineName(val name: String) : ThreadContextElement { * // declare companion object for a key of this element in coroutine context - * companion object Key : CoroutineContext.Key + * companion object Key : CoroutineContext.Key * * // provide the key of the corresponding context element - * override val key: CoroutineContext.Key + * override val key: CoroutineContext.Key * get() = Key * * // this is invoked before coroutine is resumed on current thread - * override fun updateThreadContext(context: CoroutineContext): MyData? { - * val oldState = myThreadLocal.get() - * myThreadLocal.set(data) - * return oldState + * override fun updateThreadContext(context: CoroutineContext): String { + * val previousName = Thread.currentThread().name + * Thread.currentThread().name = "$previousName # $name" + * return previousName * } * * // this is invoked after coroutine has suspended on current thread - * override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) { - * myThreadLocal.set(oldState) + * override fun restoreThreadContext(context: CoroutineContext, oldState: String) { + * Thread.currentThread().name = oldState * } * } + * + * // Usage + * launch(UI + CoroutineName("Progress bar coroutine")) { ... } * ``` + * + * Every time this coroutine is resumed on a thread, UI thread name is updated to + * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when + * this coroutine suspends. + * + * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function. */ public interface ThreadContextElement : CoroutineContext.Element { /** @@ -67,87 +73,44 @@ public interface ThreadContextElement : CoroutineContext.Element { public fun restoreThreadContext(context: CoroutineContext, oldState: S) } -private val ZERO = Symbol("ZERO") - -// Used when there are >= 2 active elements in the context -private class ThreadState(val context: CoroutineContext, n: Int) { - private var a = arrayOfNulls(n) - private var i = 0 - - fun append(value: Any?) { a[i++] = value } - fun take() = a[i++] - fun start() { i = 0 } -} - -// Counts ThreadContextElements in the context -// Any? here is Int | ThreadContextElement (when count is one) -private val countAll = - fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { - if (element is ThreadContextElement<*>) { - val inCount = countOrElement as? Int ?: 1 - return if (inCount == 0) element else inCount + 1 - } - return countOrElement - } - -// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one -private val findOne = - fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { - if (found != null) return found - return element as? ThreadContextElement<*> - } - -// Updates state for ThreadContextElements in the context using the given ThreadState -private val updateState = - fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { - if (element is ThreadContextElement<*>) { - state.append(element.updateThreadContext(state.context)) - } - return state - } - -// Restores state for all ThreadContextElements in the context from the given ThreadState -private val restoreState = - fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { - @Suppress("UNCHECKED_CAST") - if (element is ThreadContextElement<*>) { - (element as ThreadContextElement).restoreThreadContext(state.context, state.take()) - } - return state - } - -internal fun updateThreadContext(context: CoroutineContext): Any? { - val count = context.fold(0, countAll) - @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") - return when { - count === 0 -> ZERO // very fast path when there are no active ThreadContextElements - // ^^^ identity comparison for speed, we know zero always has the same identity - count is Int -> { - // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values - context.fold(ThreadState(context, count), updateState) - } - else -> { - // fast path for one ThreadContextElement (no allocations, no additional context scan) - @Suppress("UNCHECKED_CAST") - val element = count as ThreadContextElement - element.updateThreadContext(context) - } - } -} - -internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { - when { - oldState === ZERO -> return // very fast path when there are no ThreadContextElements - oldState is ThreadState -> { - // slow path with multiple stored ThreadContextElements - oldState.start() - context.fold(oldState, restoreState) - } - else -> { - // fast path for one ThreadContextElement, but need to find it - @Suppress("UNCHECKED_CAST") - val element = context.fold(null, findOne) as ThreadContextElement - element.restoreThreadContext(context, oldState) - } - } -} +/** + * Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement] + * maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on. + * By default [ThreadLocal.get] is used as a value for the thread-local variable, but it can be overridden with [value] parameter. + * + * Example usage looks like this: + * + * ``` + * val myThreadLocal = ThreadLocal() + * ... + * println(myThreadLocal.get()) // Prints "null" + * launch(CommonPool + myThreadLocal.asContextElement(initialValue = "foo")) { + * println(myThreadLocal.get()) // Prints "foo" + * withContext(UI) { + * println(myThreadLocal.get()) // Prints "foo", but it's on UI thread + * } + * } + * println(myThreadLocal.get()) // Prints "null" + * ``` + * + * Note that the context element does not track modifications of the thread-local variable, for example: + * + * ``` + * myThreadLocal.set("main") + * withContext(UI) { + * println(myThreadLocal.get()) // Prints "main" + * myThreadLocal.set("UI") + * } + * println(myThreadLocal.get()) // Prints "main", not "UI" + * ``` + * + * Use `withContext` to update the corresponding thread-local variable to a different value, for example: + * + * ``` + * withContext(myThreadLocal.asContextElement("foo")) { + * println(myThreadLocal.get()) // Prints "foo" + * } + * ``` + */ +public fun ThreadLocal.asContextElement(value: T = get()): ThreadContextElement = + ThreadLocalElement(value, this) diff --git a/core/kotlinx-coroutines-core/src/internal/ThreadContext.kt b/core/kotlinx-coroutines-core/src/internal/ThreadContext.kt new file mode 100644 index 0000000000..abee55b749 --- /dev/null +++ b/core/kotlinx-coroutines-core/src/internal/ThreadContext.kt @@ -0,0 +1,122 @@ +package kotlinx.coroutines.experimental.internal + +import kotlinx.coroutines.experimental.* +import kotlin.coroutines.experimental.* + + +private val ZERO = Symbol("ZERO") + +// Used when there are >= 2 active elements in the context +private class ThreadState(val context: CoroutineContext, n: Int) { + private var a = arrayOfNulls(n) + private var i = 0 + + fun append(value: Any?) { a[i++] = value } + fun take() = a[i++] + fun start() { i = 0 } +} + +// Counts ThreadContextElements in the context +// Any? here is Int | ThreadContextElement (when count is one) +private val countAll = + fun (countOrElement: Any?, element: CoroutineContext.Element): Any? { + if (element is ThreadContextElement<*>) { + val inCount = countOrElement as? Int ?: 1 + return if (inCount == 0) element else inCount + 1 + } + return countOrElement + } + +// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one +private val findOne = + fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? { + if (found != null) return found + return element as? ThreadContextElement<*> + } + +// Updates state for ThreadContextElements in the context using the given ThreadState +private val updateState = + fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { + if (element is ThreadContextElement<*>) { + state.append(element.updateThreadContext(state.context)) + } + return state + } + +// Restores state for all ThreadContextElements in the context from the given ThreadState +private val restoreState = + fun (state: ThreadState, element: CoroutineContext.Element): ThreadState { + @Suppress("UNCHECKED_CAST") + if (element is ThreadContextElement<*>) { + (element as ThreadContextElement).restoreThreadContext(state.context, state.take()) + } + return state + } + +internal fun updateThreadContext(context: CoroutineContext): Any? { + val count = context.fold(0, countAll) + @Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS") + return when { + count === 0 -> ZERO // very fast path when there are no active ThreadContextElements + // ^^^ identity comparison for speed, we know zero always has the same identity + count is Int -> { + // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values + context.fold(ThreadState(context, count), updateState) + } + else -> { + // fast path for one ThreadContextElement (no allocations, no additional context scan) + @Suppress("UNCHECKED_CAST") + val element = count as ThreadContextElement + element.updateThreadContext(context) + } + } +} + +internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { + when { + oldState === ZERO -> return // very fast path when there are no ThreadContextElements + oldState is ThreadState -> { + // slow path with multiple stored ThreadContextElements + oldState.start() + context.fold(oldState, restoreState) + } + else -> { + // fast path for one ThreadContextElement, but need to find it + @Suppress("UNCHECKED_CAST") + val element = context.fold(null, findOne) as ThreadContextElement + element.restoreThreadContext(context, oldState) + } + } +} + +// top-level data class for a nicer out-of-the-box toString representation and class name +private data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key> + +internal class ThreadLocalElement( + private val value: T, + private val threadLocal: ThreadLocal +) : ThreadContextElement { + override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal) + + override fun updateThreadContext(context: CoroutineContext): T { + val oldState = threadLocal.get() + threadLocal.set(value) + return oldState + } + + override fun restoreThreadContext(context: CoroutineContext, oldState: T) { + threadLocal.set(oldState) + } + + // this method is overridden to perform value comparison (==) on key + override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext { + return if (this.key == key) EmptyCoroutineContext else this + } + + // this method is overridden to perform value comparison (==) on key + public override operator fun get(key: CoroutineContext.Key): E? = + @Suppress("UNCHECKED_CAST") + if (this.key == key) this as E else null + + override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)" +} diff --git a/core/kotlinx-coroutines-core/test/ThreadContextElementTest.kt b/core/kotlinx-coroutines-core/test/ThreadContextElementTest.kt index 0c670f219a..a8b17d20b8 100644 --- a/core/kotlinx-coroutines-core/test/ThreadContextElementTest.kt +++ b/core/kotlinx-coroutines-core/test/ThreadContextElementTest.kt @@ -52,6 +52,39 @@ class ThreadContextElementTest : TestBase() { job.join() assertNull(myThreadLocal.get()) } + + + @Test + fun testWithContext() = runTest { + expect(1) + newSingleThreadContext("withContext").use { + val data = MyData() + async(CommonPool + MyElement(data)) { + assertSame(data, myThreadLocal.get()) + expect(2) + + val newData = MyData() + async(it + MyElement(newData)) { + assertSame(newData, myThreadLocal.get()) + expect(3) + }.await() + + withContext(it + MyElement(newData)) { + assertSame(newData, myThreadLocal.get()) + expect(4) + } + + async(it) { + assertNull(myThreadLocal.get()) + expect(5) + }.await() + + expect(6) + }.await() + } + + finish(7) + } } class MyData diff --git a/core/kotlinx-coroutines-core/test/ThreadLocalTest.kt b/core/kotlinx-coroutines-core/test/ThreadLocalTest.kt new file mode 100644 index 0000000000..b932e75456 --- /dev/null +++ b/core/kotlinx-coroutines-core/test/ThreadLocalTest.kt @@ -0,0 +1,199 @@ + +package kotlinx.coroutines.experimental + +import org.junit.* +import org.junit.Test +import kotlin.coroutines.experimental.* +import kotlin.test.* + +@Suppress("RedundantAsync") +class ThreadLocalTest : TestBase() { + private val stringThreadLocal = ThreadLocal() + private val intThreadLocal = ThreadLocal() + private val executor = newFixedThreadPoolContext(1, "threadLocalTest") + + @After + fun tearDown() { + executor.close() + } + + @Test + fun testThreadLocal() = runTest { + assertNull(stringThreadLocal.get()) + val deferred = async(CommonPool + stringThreadLocal.asContextElement("value")) { + assertEquals("value", stringThreadLocal.get()) + withContext(executor) { + assertEquals("value", stringThreadLocal.get()) + } + assertEquals("value", stringThreadLocal.get()) + } + + assertNull(stringThreadLocal.get()) + deferred.await() + assertNull(stringThreadLocal.get()) + } + + @Test + fun testThreadLocalInitialValue() = runTest { + intThreadLocal.set(42) + val deferred = async(CommonPool + intThreadLocal.asContextElement(239)) { + assertEquals(239, intThreadLocal.get()) + withContext(executor) { + assertEquals(239, intThreadLocal.get()) + } + assertEquals(239, intThreadLocal.get()) + } + + deferred.await() + assertEquals(42, intThreadLocal.get()) + } + + @Test + fun testMultipleThreadLocals() = runTest { + stringThreadLocal.set("test") + intThreadLocal.set(314) + + val deferred = async(CommonPool + + intThreadLocal.asContextElement(value = 239) + stringThreadLocal.asContextElement(value = "pew")) { + assertEquals(239, intThreadLocal.get()) + assertEquals("pew", stringThreadLocal.get()) + + withContext(executor) { + assertEquals(239, intThreadLocal.get()) + assertEquals("pew", stringThreadLocal.get()) + } + + assertEquals(239, intThreadLocal.get()) + assertEquals("pew", stringThreadLocal.get()) + } + + deferred.await() + assertEquals(314, intThreadLocal.get()) + assertEquals("test", stringThreadLocal.get()) + } + + @Test + fun testConflictingThreadLocals() = runTest { + intThreadLocal.set(42) + + val deferred = async(CommonPool + + intThreadLocal.asContextElement(1)) { + assertEquals(1, intThreadLocal.get()) + + withContext(executor + intThreadLocal.asContextElement(42)) { + assertEquals(42, intThreadLocal.get()) + } + + assertEquals(1, intThreadLocal.get()) + + val deferred = async(coroutineContext + intThreadLocal.asContextElement(53)) { + assertEquals(53, intThreadLocal.get()) + } + + deferred.await() + assertEquals(1, intThreadLocal.get()) + + val deferred2 = async(executor) { + assertNull(intThreadLocal.get()) + } + + deferred2.await() + assertEquals(1, intThreadLocal.get()) + } + + deferred.await() + assertEquals(42, intThreadLocal.get()) + } + + @Test + fun testThreadLocalModification() = runTest { + stringThreadLocal.set("main") + + val deferred = async(CommonPool + + stringThreadLocal.asContextElement("initial")) { + assertEquals("initial", stringThreadLocal.get()) + + stringThreadLocal.set("overridden") // <- this value is not reflected in the context, so it's not restored + + withContext(executor + stringThreadLocal.asContextElement("ctx")) { + assertEquals("ctx", stringThreadLocal.get()) + } + + val deferred = async(coroutineContext + stringThreadLocal.asContextElement("async")) { + assertEquals("async", stringThreadLocal.get()) + } + + deferred.await() + assertEquals("initial", stringThreadLocal.get()) // <- not restored + } + + deferred.await() + assertEquals("main", stringThreadLocal.get()) + } + + + + private data class Counter(var cnt: Int) + private val myCounterLocal = ThreadLocal() + + @Test + fun testThreadLocalModificationMutableBox() = runTest { + myCounterLocal.set(Counter(42)) + + val deferred = async(CommonPool + + myCounterLocal.asContextElement(Counter(0))) { + assertEquals(0, myCounterLocal.get().cnt) + + // Mutate + myCounterLocal.get().cnt = 71 + + withContext(executor + myCounterLocal.asContextElement(Counter(-1))) { + assertEquals(-1, myCounterLocal.get().cnt) + ++myCounterLocal.get().cnt + } + + val deferred = async(coroutineContext + myCounterLocal.asContextElement(Counter(31))) { + assertEquals(31, myCounterLocal.get().cnt) + ++myCounterLocal.get().cnt + } + + deferred.await() + assertEquals(71, myCounterLocal.get().cnt) + } + + deferred.await() + assertEquals(42, myCounterLocal.get().cnt) + } + + @Test + fun testWithContext() = runTest { + expect(1) + newSingleThreadContext("withContext").use { + val data = 42 + async(CommonPool + intThreadLocal.asContextElement(42)) { + + assertSame(data, intThreadLocal.get()) + expect(2) + + async(it + intThreadLocal.asContextElement(31)) { + assertEquals(31, intThreadLocal.get()) + expect(3) + }.await() + + withContext(it + intThreadLocal.asContextElement(2)) { + assertSame(2, intThreadLocal.get()) + expect(4) + } + + async(it) { + assertNull(intThreadLocal.get()) + expect(5) + }.await() + + expect(6) + }.await() + } + + finish(7) + } +} diff --git a/core/kotlinx-coroutines-core/test/guide/example-context-11.kt b/core/kotlinx-coroutines-core/test/guide/example-context-11.kt new file mode 100644 index 0000000000..4d43911446 --- /dev/null +++ b/core/kotlinx-coroutines-core/test/guide/example-context-11.kt @@ -0,0 +1,23 @@ +/* + * Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +// This file was automatically generated from coroutines-guide.md by Knit tool. Do not edit. +package kotlinx.coroutines.experimental.guide.context11 + +import kotlinx.coroutines.experimental.* +import kotlin.coroutines.experimental.* + +val threadLocal = ThreadLocal() // declare thread-local variable + +fun main(args: Array) = runBlocking { + threadLocal.set("main") + println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") + val job = launch(CommonPool + threadLocal.asContextElement(value = "launch"), start = CoroutineStart.UNDISPATCHED) { + println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") + yield() + println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") + } + job.join() + println("Post-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") +} diff --git a/core/kotlinx-coroutines-core/test/guide/test/GuideTest.kt b/core/kotlinx-coroutines-core/test/guide/test/GuideTest.kt index ec527a6a47..82958a2392 100644 --- a/core/kotlinx-coroutines-core/test/guide/test/GuideTest.kt +++ b/core/kotlinx-coroutines-core/test/guide/test/GuideTest.kt @@ -267,6 +267,16 @@ class GuideTest { ) } + @Test + fun testKotlinxCoroutinesExperimentalGuideContext11() { + test("KotlinxCoroutinesExperimentalGuideContext11") { kotlinx.coroutines.experimental.guide.context11.main(emptyArray()) }.verifyLinesFlexibleThread( + "Pre-main, current thread: Thread[main @coroutine#1,5,main], thread local value: 'main'", + "Launch start, current thread: Thread[main @coroutine#2,5,main], thread local value: 'launch'", + "After yield, current thread: Thread[ForkJoinPool.commonPool-worker-1 @coroutine#2,5,main], thread local value: 'launch'", + "Post-main, current thread: Thread[main @coroutine#1,5,main], thread local value: 'main'" + ) + } + @Test fun testKotlinxCoroutinesExperimentalGuideExceptions01() { test("KotlinxCoroutinesExperimentalGuideExceptions01") { kotlinx.coroutines.experimental.guide.exceptions01.main(emptyArray()) }.verifyExceptions( diff --git a/coroutines-guide.md b/coroutines-guide.md index bfdbe8c44d..c1d8679ad3 100644 --- a/coroutines-guide.md +++ b/coroutines-guide.md @@ -67,6 +67,7 @@ You need to add a dependency on `kotlinx-coroutines-core` module as explained * [Parental responsibilities](#parental-responsibilities) * [Naming coroutines for debugging](#naming-coroutines-for-debugging) * [Cancellation via explicit job](#cancellation-via-explicit-job) + * [Thread-local data](#thread-local-data) * [Exception handling](#exception-handling) * [Exception propagation](#exception-propagation) * [CoroutineExceptionHandler](#coroutineexceptionhandler) @@ -1266,6 +1267,64 @@ and cancel it when activity is destroyed. We cannot `join` them in the case of A since it is synchronous, but this joining ability is useful when building backend services to ensure bounded resource usage. +### Thread-local data + +Sometimes it is very convenient to have an ability to pass some thread-local data, but, for coroutines, which +are not bound to any particular thread, it is hard to achieve it manually without writing a lot of boilerplate. + +For [`ThreadLocal`](https://docs.oracle.com/javase/8/docs/api/java/lang/ThreadLocal.html), +[asContextElement] is here for the rescue. It creates an additional context element, +which keep the value of the given `ThreadLocal` and restores it every time the coroutine switches its context. + +It is easy to demonstrate it in action: + + + +```kotlin +val threadLocal = ThreadLocal() // declare thread-local variable + +fun main(args: Array) = runBlocking { + threadLocal.set("main") + println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") + val job = launch(CommonPool + threadLocal.asContextElement(value = "launch"), start = CoroutineStart.UNDISPATCHED) { + println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") + yield() + println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") + } + job.join() + println("Post-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") +} +``` + +> You can get full code [here](core/kotlinx-coroutines-core/test/guide/example-context-11.kt) + +The output of this example is: + +```text +Pre-main, current thread: Thread[main @coroutine#1,5,main], thread local value: 'main' +Launch start, current thread: Thread[main @coroutine#2,5,main], thread local value: 'launch' +After yield, current thread: Thread[ForkJoinPool.commonPool-worker-1 @coroutine#2,5,main], thread local value: 'launch' +Post-main, current thread: Thread[main @coroutine#1,5,main], thread local value: 'main' +``` + + + +Note how thread-local value is restored properly, no matter on what thread the coroutine is executed. +`ThreadLocal` has first-class support and can be used with any primitive `kotlinx.corotuines` provides. +It has one key limitation: when thread-local is mutated, a new value is not propagated to the coroutine caller +(as context element cannot track all `ThreadLocal` object accesses) and updated value is lost on the next suspension. +Use [withContext] to update the value of the thread-local in a coroutine, see [asContextElement] for more details. + +Alternatively, a value can be stored in a mutable box like `class Counter(var i: Int)`, which is, in turn, +is stored in a thread-local variable. However, in this case you are fully responsible to synchronize +potentially concurrent modifications to the variable in this box. + +For advanced usage, for example for integration with logging MDC, transactional contexts or any other libraries +which internally use thread-locals for passing data, see documentation for [ThreadContextElement] interface +that should be implemented. + ## Exception handling