Skip to content

Add ThreadLocal.isPresent and ThreadLocal.ensurePresent methods #1043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,8 @@ public final class kotlinx/coroutines/ThreadContextElement$DefaultImpls {
public final class kotlinx/coroutines/ThreadContextElementKt {
public static final fun asContextElement (Ljava/lang/ThreadLocal;Ljava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
public static synthetic fun asContextElement$default (Ljava/lang/ThreadLocal;Ljava/lang/Object;ILjava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
public static final fun ensurePresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun isPresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class kotlinx/coroutines/ThreadPoolDispatcherKt {
Expand Down
7 changes: 6 additions & 1 deletion docs/coroutine-context-and-dispatchers.md
Original file line number Diff line number Diff line change
Expand Up @@ -635,7 +635,7 @@ fun main() = runBlocking<Unit> {
threadLocal.set("main")
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
yield()
println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
}
Expand Down Expand Up @@ -664,6 +664,10 @@ Post-main, current thread: Thread[main @coroutine#1,5,main], thread local value:

<!--- TEST FLEXIBLE_THREAD -->

Note how easily one may forget the corresponding context element and then still safely access thread local.
To avoid such situations, it is recommended to use [ensurePresent] method
and fail-fast on improper usages.

`ThreadLocal` has first-class support and can be used with any primitive `kotlinx.coroutines` provides.
It has one key limitation: when thread-local is mutated, a new value is not propagated to the coroutine caller
(as context element cannot track all `ThreadLocal` object accesses) and updated value is lost on the next suspension.
Expand Down Expand Up @@ -701,5 +705,6 @@ that should be implemented.
[MainScope()]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-main-scope.html
[Dispatchers.Main]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-dispatchers/-main.html
[asContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/as-context-element.html
[ensurePresent]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/ensure-present.html
[ThreadContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-thread-context-element/index.html
<!--- END -->
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ import org.hamcrest.core.*
import org.junit.*
import org.junit.Assert.*
import org.junit.Test
import java.io.*
import java.util.concurrent.*
import kotlin.test.assertFailsWith

class ListenableFutureTest : TestBase() {
@Before
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import java.util.function.*
import kotlin.concurrent.*
import kotlin.coroutines.*
import kotlin.reflect.*
import kotlin.test.assertFailsWith

class FutureTest : TestBase() {
@Before
Expand Down
37 changes: 37 additions & 0 deletions kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,40 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
*/
public fun <T> ThreadLocal<T>.asContextElement(value: T = get()): ThreadContextElement<T> =
ThreadLocalElement(value, this)

/**
* Return `true` when current thread local is present in the coroutine context, `false` otherwise.
* Thread local can be present in the context only if it was added via [asContextElement] to the context.
*
* Example of usage:
* ```
* suspend fun processRequest() {
* if (traceCurrentRequestThreadLocal.isPresent()) { // Probabilistic tracing
* // Do some heavy-weight tracing
* }
* // Process request regularly
* }
* ```
*/
public suspend inline fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] !== null

/**
* Checks whether current thread local is present in the coroutine context and throws [IllegalStateException] if it is not.
* It is a good practice to validate that thread local is present in the context, especially in large code-bases,
* to avoid stale thread-local values and to have a strict invariants.
*
* E.g. one may use the following method to enforce proper use of the thread locals with coroutines:
* ```
* public suspend inline fun <T> ThreadLocal<T>.getSafely(): T {
* ensurePresent()
* return get()
* }
*
* // Usage
* withContext(...) {
* val value = threadLocal.getSafely() // Fail-fast in case of improper context
* }
* ```
*/
public suspend inline fun ThreadLocal<*>.ensurePresent(): Unit =
check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" }
3 changes: 2 additions & 1 deletion kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
}

// top-level data class for a nicer out-of-the-box toString representation and class name
private data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
@PublishedApi
internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>

internal class ThreadLocalElement<T>(
private val value: T,
Expand Down
6 changes: 6 additions & 0 deletions kotlinx-coroutines-core/jvm/test/TestBase.kt
Original file line number Diff line number Diff line change
Expand Up @@ -201,4 +201,10 @@ public actual open class TestBase actual constructor() {
if (exCount < unhandled.size)
error("Too few unhandled exceptions $exCount, expected ${unhandled.size}")
}

protected inline fun <reified T: Throwable> assertFailsWith(block: () -> Unit): T {
val result = runCatching(block)
assertTrue(result.exceptionOrNull() is T, "Expected ${T::class}, but had $result")
return result.exceptionOrNull()!! as T
}
}
18 changes: 18 additions & 0 deletions kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ package kotlinx.coroutines

import org.junit.*
import org.junit.Test
import java.lang.IllegalStateException
import kotlin.test.*

@Suppress("RedundantAsync")
Expand All @@ -22,25 +23,33 @@ class ThreadLocalTest : TestBase() {
@Test
fun testThreadLocal() = runTest {
assertNull(stringThreadLocal.get())
assertFalse(stringThreadLocal.isPresent())
val deferred = async(Dispatchers.Default + stringThreadLocal.asContextElement("value")) {
assertEquals("value", stringThreadLocal.get())
assertTrue(stringThreadLocal.isPresent())
withContext(executor) {
assertTrue(stringThreadLocal.isPresent())
assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
assertEquals("value", stringThreadLocal.get())
}
assertTrue(stringThreadLocal.isPresent())
assertEquals("value", stringThreadLocal.get())
}

assertNull(stringThreadLocal.get())
deferred.await()
assertNull(stringThreadLocal.get())
assertFalse(stringThreadLocal.isPresent())
}

@Test
fun testThreadLocalInitialValue() = runTest {
intThreadLocal.set(42)
assertFalse(intThreadLocal.isPresent())
val deferred = async(Dispatchers.Default + intThreadLocal.asContextElement(239)) {
assertEquals(239, intThreadLocal.get())
withContext(executor) {
intThreadLocal.ensurePresent()
assertEquals(239, intThreadLocal.get())
}
assertEquals(239, intThreadLocal.get())
Expand All @@ -63,6 +72,8 @@ class ThreadLocalTest : TestBase() {
withContext(executor) {
assertEquals(239, intThreadLocal.get())
assertEquals("pew", stringThreadLocal.get())
intThreadLocal.ensurePresent()
stringThreadLocal.ensurePresent()
}

assertEquals(239, intThreadLocal.get())
Expand Down Expand Up @@ -129,6 +140,7 @@ class ThreadLocalTest : TestBase() {
}

deferred.await()
assertFalse(stringThreadLocal.isPresent())
assertEquals("main", stringThreadLocal.get())
}

Expand Down Expand Up @@ -212,4 +224,10 @@ class ThreadLocalTest : TestBase() {
assertNotSame(mainThread, Thread.currentThread())
}.await()
}

@Test
fun testMissingThreadLocal() = runTest {
assertFailsWith<IllegalStateException> { stringThreadLocal.ensurePresent() }
assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fun main() = runBlocking<Unit> {
threadLocal.set("main")
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
yield()
println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
}
Expand Down