From 198515594af9a249b0d7a7aefff0b9d5e117e449 Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Thu, 14 Mar 2019 12:57:49 +0300 Subject: [PATCH 1/2] Add ThreadLocal.isPresent and ThreadLocal.ensurePresent methods Fixes #1028 --- .../kotlinx-coroutines-core.txt | 2 ++ docs/coroutine-context-and-dispatchers.md | 7 +++- .../test/ListenableFutureTest.kt | 2 -- .../test/future/FutureTest.kt | 1 - .../jvm/src/ThreadContextElement.kt | 36 +++++++++++++++++++ .../jvm/src/internal/ThreadContext.kt | 2 +- kotlinx-coroutines-core/jvm/test/TestBase.kt | 6 ++++ .../jvm/test/ThreadLocalTest.kt | 18 ++++++++++ .../jvm/test/guide/example-context-11.kt | 2 +- 9 files changed, 70 insertions(+), 6 deletions(-) diff --git a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt index 2705acc0e4..21c473accf 100644 --- a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt +++ b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt @@ -481,6 +481,8 @@ public final class kotlinx/coroutines/ThreadContextElement$DefaultImpls { public final class kotlinx/coroutines/ThreadContextElementKt { public static final fun asContextElement (Ljava/lang/ThreadLocal;Ljava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement; public static synthetic fun asContextElement$default (Ljava/lang/ThreadLocal;Ljava/lang/Object;ILjava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement; + public static final fun ensurePresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun isPresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; } public final class kotlinx/coroutines/ThreadPoolDispatcherKt { diff --git a/docs/coroutine-context-and-dispatchers.md b/docs/coroutine-context-and-dispatchers.md index 00b0db9549..29da4b4067 100644 --- a/docs/coroutine-context-and-dispatchers.md +++ b/docs/coroutine-context-and-dispatchers.md @@ -635,7 +635,7 @@ fun main() = runBlocking { threadLocal.set("main") println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) { - println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") + println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") yield() println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") } @@ -664,6 +664,10 @@ Post-main, current thread: Thread[main @coroutine#1,5,main], thread local value: +Note how easily one may forget the corresponding context element and then still safely access thread local. +To avoid such situations, it is recommended to use [ensurePresent] method +and fail-fast on improper usages. + `ThreadLocal` has first-class support and can be used with any primitive `kotlinx.coroutines` provides. It has one key limitation: when thread-local is mutated, a new value is not propagated to the coroutine caller (as context element cannot track all `ThreadLocal` object accesses) and updated value is lost on the next suspension. @@ -701,5 +705,6 @@ that should be implemented. [MainScope()]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-main-scope.html [Dispatchers.Main]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-dispatchers/-main.html [asContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/as-context-element.html +[ensurePresent]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/ensure-present.html [ThreadContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-thread-context-element/index.html diff --git a/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt b/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt index bfb5cfd452..cf82318a47 100644 --- a/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt +++ b/integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt @@ -11,9 +11,7 @@ import org.hamcrest.core.* import org.junit.* import org.junit.Assert.* import org.junit.Test -import java.io.* import java.util.concurrent.* -import kotlin.test.assertFailsWith class ListenableFutureTest : TestBase() { @Before diff --git a/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt b/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt index 7038363cb5..7d128c6e18 100644 --- a/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt +++ b/integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt @@ -16,7 +16,6 @@ import java.util.function.* import kotlin.concurrent.* import kotlin.coroutines.* import kotlin.reflect.* -import kotlin.test.assertFailsWith class FutureTest : TestBase() { @Before diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt index c68ee45cd5..0d173a5386 100644 --- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt +++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt @@ -135,3 +135,39 @@ public interface ThreadContextElement : CoroutineContext.Element { */ public fun ThreadLocal.asContextElement(value: T = get()): ThreadContextElement = ThreadLocalElement(value, this) + +/** + * Return `true` when current thread local is present in the coroutine context, `false` otherwise. + * Thread local can be present in the context only if it was added via [asContextElement] to the context. + * + * Example of usage: + * ``` + * suspend fun processRequest() { + * if (traceCurrentRequestThreadLocal.isPresent()) { // Probabilistic tracing + * // Do some heavy-weight tracing + * } + * // Process request regularly + * } + * ``` + */ +public suspend fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] != null + +/** + * Checks whether current thread local is present in the coroutine context and throws [IllegalStateException] if it is not. + * It is a good practice to validate that thread local is present in the context, especially in large code-bases, + * to avoid stale thread-local values and to have a strict invariants. + * + * E.g. one may use the following method to enforce proper use of the thread locals with coroutines: + * ``` + * public suspend inline fun ThreadLocal.getSafely(): T { + * ensurePresent() + * return get() + * } + * + * // Usage + * withContext(...) { + * val value = threadLocal.getSafely() // Fail-fast in case of improper context + * } + * ``` + */ +public suspend fun ThreadLocal<*>.ensurePresent(): Unit = check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" } diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 7dafb4711f..3965e5e125 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -98,7 +98,7 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { } // top-level data class for a nicer out-of-the-box toString representation and class name -private data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key> +internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key> internal class ThreadLocalElement( private val value: T, diff --git a/kotlinx-coroutines-core/jvm/test/TestBase.kt b/kotlinx-coroutines-core/jvm/test/TestBase.kt index db5c53ae80..6fef760af1 100644 --- a/kotlinx-coroutines-core/jvm/test/TestBase.kt +++ b/kotlinx-coroutines-core/jvm/test/TestBase.kt @@ -201,4 +201,10 @@ public actual open class TestBase actual constructor() { if (exCount < unhandled.size) error("Too few unhandled exceptions $exCount, expected ${unhandled.size}") } + + protected inline fun assertFailsWith(block: () -> Unit): T { + val result = runCatching(block) + assertTrue(result.exceptionOrNull() is T, "Expected ${T::class}, but had $result") + return result.exceptionOrNull()!! as T + } } diff --git a/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt b/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt index 62a340eef4..5d8c3d5c6d 100644 --- a/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt +++ b/kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt @@ -6,6 +6,7 @@ package kotlinx.coroutines import org.junit.* import org.junit.Test +import java.lang.IllegalStateException import kotlin.test.* @Suppress("RedundantAsync") @@ -22,25 +23,33 @@ class ThreadLocalTest : TestBase() { @Test fun testThreadLocal() = runTest { assertNull(stringThreadLocal.get()) + assertFalse(stringThreadLocal.isPresent()) val deferred = async(Dispatchers.Default + stringThreadLocal.asContextElement("value")) { assertEquals("value", stringThreadLocal.get()) + assertTrue(stringThreadLocal.isPresent()) withContext(executor) { + assertTrue(stringThreadLocal.isPresent()) + assertFailsWith { intThreadLocal.ensurePresent() } assertEquals("value", stringThreadLocal.get()) } + assertTrue(stringThreadLocal.isPresent()) assertEquals("value", stringThreadLocal.get()) } assertNull(stringThreadLocal.get()) deferred.await() assertNull(stringThreadLocal.get()) + assertFalse(stringThreadLocal.isPresent()) } @Test fun testThreadLocalInitialValue() = runTest { intThreadLocal.set(42) + assertFalse(intThreadLocal.isPresent()) val deferred = async(Dispatchers.Default + intThreadLocal.asContextElement(239)) { assertEquals(239, intThreadLocal.get()) withContext(executor) { + intThreadLocal.ensurePresent() assertEquals(239, intThreadLocal.get()) } assertEquals(239, intThreadLocal.get()) @@ -63,6 +72,8 @@ class ThreadLocalTest : TestBase() { withContext(executor) { assertEquals(239, intThreadLocal.get()) assertEquals("pew", stringThreadLocal.get()) + intThreadLocal.ensurePresent() + stringThreadLocal.ensurePresent() } assertEquals(239, intThreadLocal.get()) @@ -129,6 +140,7 @@ class ThreadLocalTest : TestBase() { } deferred.await() + assertFalse(stringThreadLocal.isPresent()) assertEquals("main", stringThreadLocal.get()) } @@ -212,4 +224,10 @@ class ThreadLocalTest : TestBase() { assertNotSame(mainThread, Thread.currentThread()) }.await() } + + @Test + fun testMissingThreadLocal() = runTest { + assertFailsWith { stringThreadLocal.ensurePresent() } + assertFailsWith { intThreadLocal.ensurePresent() } + } } diff --git a/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt b/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt index 8de958e473..1945495cb8 100644 --- a/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt +++ b/kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt @@ -14,7 +14,7 @@ fun main() = runBlocking { threadLocal.set("main") println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) { - println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") + println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") yield() println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'") } From a33df02bed04d379f7481b882b20d9318f258d5e Mon Sep 17 00:00:00 2001 From: Vsevolod Tolstopyatov Date: Fri, 22 Mar 2019 11:25:08 +0300 Subject: [PATCH 2/2] Make ThreadLocal methods ensurePresent and isPresent inline to avoid generating state machine where it is not necessary --- kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt | 5 +++-- kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt index 0d173a5386..4e8b6cc42e 100644 --- a/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt +++ b/kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt @@ -150,7 +150,7 @@ public fun ThreadLocal.asContextElement(value: T = get()): ThreadContextE * } * ``` */ -public suspend fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] != null +public suspend inline fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] !== null /** * Checks whether current thread local is present in the coroutine context and throws [IllegalStateException] if it is not. @@ -170,4 +170,5 @@ public suspend fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[Thread * } * ``` */ -public suspend fun ThreadLocal<*>.ensurePresent(): Unit = check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" } +public suspend inline fun ThreadLocal<*>.ensurePresent(): Unit = + check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" } diff --git a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt index 3965e5e125..375dc60b66 100644 --- a/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt +++ b/kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt @@ -98,6 +98,7 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) { } // top-level data class for a nicer out-of-the-box toString representation and class name +@PublishedApi internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key> internal class ThreadLocalElement(