Skip to content

Commit 0355b2c

Browse files
authored
Add ThreadLocal.isPresent and ThreadLocal.ensurePresent methods (#1043)
Add ThreadLocal.isPresent and ThreadLocal.ensurePresent methods Fixes #1028
1 parent 5ea9339 commit 0355b2c

File tree

9 files changed

+72
-6
lines changed

9 files changed

+72
-6
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+2
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,8 @@ public final class kotlinx/coroutines/ThreadContextElement$DefaultImpls {
481481
public final class kotlinx/coroutines/ThreadContextElementKt {
482482
public static final fun asContextElement (Ljava/lang/ThreadLocal;Ljava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
483483
public static synthetic fun asContextElement$default (Ljava/lang/ThreadLocal;Ljava/lang/Object;ILjava/lang/Object;)Lkotlinx/coroutines/ThreadContextElement;
484+
public static final fun ensurePresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
485+
public static final fun isPresent (Ljava/lang/ThreadLocal;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
484486
}
485487

486488
public final class kotlinx/coroutines/ThreadPoolDispatcherKt {

docs/coroutine-context-and-dispatchers.md

+6-1
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ fun main() = runBlocking<Unit> {
635635
threadLocal.set("main")
636636
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
637637
val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
638-
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
638+
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
639639
yield()
640640
println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
641641
}
@@ -664,6 +664,10 @@ Post-main, current thread: Thread[main @coroutine#1,5,main], thread local value:
664664
665665
<!--- TEST FLEXIBLE_THREAD -->
666666
667+
Note how easily one may forget the corresponding context element and then still safely access thread local.
668+
To avoid such situations, it is recommended to use [ensurePresent] method
669+
and fail-fast on improper usages.
670+
667671
`ThreadLocal` has first-class support and can be used with any primitive `kotlinx.coroutines` provides.
668672
It has one key limitation: when thread-local is mutated, a new value is not propagated to the coroutine caller
669673
(as context element cannot track all `ThreadLocal` object accesses) and updated value is lost on the next suspension.
@@ -701,5 +705,6 @@ that should be implemented.
701705
[MainScope()]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-main-scope.html
702706
[Dispatchers.Main]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-dispatchers/-main.html
703707
[asContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/as-context-element.html
708+
[ensurePresent]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/java.lang.-thread-local/ensure-present.html
704709
[ThreadContextElement]: https://kotlin.github.io/kotlinx.coroutines/kotlinx-coroutines-core/kotlinx.coroutines/-thread-context-element/index.html
705710
<!--- END -->

integration/kotlinx-coroutines-guava/test/ListenableFutureTest.kt

-2
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ import org.hamcrest.core.*
1111
import org.junit.*
1212
import org.junit.Assert.*
1313
import org.junit.Test
14-
import java.io.*
1514
import java.util.concurrent.*
16-
import kotlin.test.assertFailsWith
1715

1816
class ListenableFutureTest : TestBase() {
1917
@Before

integration/kotlinx-coroutines-jdk8/test/future/FutureTest.kt

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import java.util.function.*
1616
import kotlin.concurrent.*
1717
import kotlin.coroutines.*
1818
import kotlin.reflect.*
19-
import kotlin.test.assertFailsWith
2019

2120
class FutureTest : TestBase() {
2221
@Before

kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt

+37
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,40 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
135135
*/
136136
public fun <T> ThreadLocal<T>.asContextElement(value: T = get()): ThreadContextElement<T> =
137137
ThreadLocalElement(value, this)
138+
139+
/**
140+
* Return `true` when current thread local is present in the coroutine context, `false` otherwise.
141+
* Thread local can be present in the context only if it was added via [asContextElement] to the context.
142+
*
143+
* Example of usage:
144+
* ```
145+
* suspend fun processRequest() {
146+
* if (traceCurrentRequestThreadLocal.isPresent()) { // Probabilistic tracing
147+
* // Do some heavy-weight tracing
148+
* }
149+
* // Process request regularly
150+
* }
151+
* ```
152+
*/
153+
public suspend inline fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] !== null
154+
155+
/**
156+
* Checks whether current thread local is present in the coroutine context and throws [IllegalStateException] if it is not.
157+
* It is a good practice to validate that thread local is present in the context, especially in large code-bases,
158+
* to avoid stale thread-local values and to have a strict invariants.
159+
*
160+
* E.g. one may use the following method to enforce proper use of the thread locals with coroutines:
161+
* ```
162+
* public suspend inline fun <T> ThreadLocal<T>.getSafely(): T {
163+
* ensurePresent()
164+
* return get()
165+
* }
166+
*
167+
* // Usage
168+
* withContext(...) {
169+
* val value = threadLocal.getSafely() // Fail-fast in case of improper context
170+
* }
171+
* ```
172+
*/
173+
public suspend inline fun ThreadLocal<*>.ensurePresent(): Unit =
174+
check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" }

kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
9898
}
9999

100100
// top-level data class for a nicer out-of-the-box toString representation and class name
101-
private data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
101+
@PublishedApi
102+
internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
102103

103104
internal class ThreadLocalElement<T>(
104105
private val value: T,

kotlinx-coroutines-core/jvm/test/TestBase.kt

+6
Original file line numberDiff line numberDiff line change
@@ -201,4 +201,10 @@ public actual open class TestBase actual constructor() {
201201
if (exCount < unhandled.size)
202202
error("Too few unhandled exceptions $exCount, expected ${unhandled.size}")
203203
}
204+
205+
protected inline fun <reified T: Throwable> assertFailsWith(block: () -> Unit): T {
206+
val result = runCatching(block)
207+
assertTrue(result.exceptionOrNull() is T, "Expected ${T::class}, but had $result")
208+
return result.exceptionOrNull()!! as T
209+
}
204210
}

kotlinx-coroutines-core/jvm/test/ThreadLocalTest.kt

+18
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package kotlinx.coroutines
66

77
import org.junit.*
88
import org.junit.Test
9+
import java.lang.IllegalStateException
910
import kotlin.test.*
1011

1112
@Suppress("RedundantAsync")
@@ -22,25 +23,33 @@ class ThreadLocalTest : TestBase() {
2223
@Test
2324
fun testThreadLocal() = runTest {
2425
assertNull(stringThreadLocal.get())
26+
assertFalse(stringThreadLocal.isPresent())
2527
val deferred = async(Dispatchers.Default + stringThreadLocal.asContextElement("value")) {
2628
assertEquals("value", stringThreadLocal.get())
29+
assertTrue(stringThreadLocal.isPresent())
2730
withContext(executor) {
31+
assertTrue(stringThreadLocal.isPresent())
32+
assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
2833
assertEquals("value", stringThreadLocal.get())
2934
}
35+
assertTrue(stringThreadLocal.isPresent())
3036
assertEquals("value", stringThreadLocal.get())
3137
}
3238

3339
assertNull(stringThreadLocal.get())
3440
deferred.await()
3541
assertNull(stringThreadLocal.get())
42+
assertFalse(stringThreadLocal.isPresent())
3643
}
3744

3845
@Test
3946
fun testThreadLocalInitialValue() = runTest {
4047
intThreadLocal.set(42)
48+
assertFalse(intThreadLocal.isPresent())
4149
val deferred = async(Dispatchers.Default + intThreadLocal.asContextElement(239)) {
4250
assertEquals(239, intThreadLocal.get())
4351
withContext(executor) {
52+
intThreadLocal.ensurePresent()
4453
assertEquals(239, intThreadLocal.get())
4554
}
4655
assertEquals(239, intThreadLocal.get())
@@ -63,6 +72,8 @@ class ThreadLocalTest : TestBase() {
6372
withContext(executor) {
6473
assertEquals(239, intThreadLocal.get())
6574
assertEquals("pew", stringThreadLocal.get())
75+
intThreadLocal.ensurePresent()
76+
stringThreadLocal.ensurePresent()
6677
}
6778

6879
assertEquals(239, intThreadLocal.get())
@@ -129,6 +140,7 @@ class ThreadLocalTest : TestBase() {
129140
}
130141

131142
deferred.await()
143+
assertFalse(stringThreadLocal.isPresent())
132144
assertEquals("main", stringThreadLocal.get())
133145
}
134146

@@ -212,4 +224,10 @@ class ThreadLocalTest : TestBase() {
212224
assertNotSame(mainThread, Thread.currentThread())
213225
}.await()
214226
}
227+
228+
@Test
229+
fun testMissingThreadLocal() = runTest {
230+
assertFailsWith<IllegalStateException> { stringThreadLocal.ensurePresent() }
231+
assertFailsWith<IllegalStateException> { intThreadLocal.ensurePresent() }
232+
}
215233
}

kotlinx-coroutines-core/jvm/test/guide/example-context-11.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ fun main() = runBlocking<Unit> {
1414
threadLocal.set("main")
1515
println("Pre-main, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
1616
val job = launch(Dispatchers.Default + threadLocal.asContextElement(value = "launch")) {
17-
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
17+
println("Launch start, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
1818
yield()
1919
println("After yield, current thread: ${Thread.currentThread()}, thread local value: '${threadLocal.get()}'")
2020
}

0 commit comments

Comments
 (0)