Skip to content

Commit e342597

Browse files
qwwdfsadelizarov
authored andcommitted
Introduce ThreadLocal.asContextElement()
* Move implementation to internal package * Add guide section
1 parent 7587eba commit e342597

File tree

9 files changed

+514
-99
lines changed

9 files changed

+514
-99
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+5
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,11 @@ public final class kotlinx/coroutines/experimental/ThreadContextElement$DefaultI
446446
public static fun plus (Lkotlinx/coroutines/experimental/ThreadContextElement;Lkotlin/coroutines/experimental/CoroutineContext;)Lkotlin/coroutines/experimental/CoroutineContext;
447447
}
448448

449+
public final class kotlinx/coroutines/experimental/ThreadContextElementKt {
450+
public static final fun asContextElement (Ljava/lang/ThreadLocal;Ljava/lang/Object;)Lkotlinx/coroutines/experimental/ThreadContextElement;
451+
public static synthetic fun asContextElement$default (Ljava/lang/ThreadLocal;Ljava/lang/Object;ILjava/lang/Object;)Lkotlinx/coroutines/experimental/ThreadContextElement;
452+
}
453+
449454
public final class kotlinx/coroutines/experimental/ThreadPoolDispatcher : kotlinx/coroutines/experimental/ExecutorCoroutineDispatcherBase {
450455
public fun close ()V
451456
public fun getExecutor ()Ljava/util/concurrent/Executor;

core/kotlinx-coroutines-core/src/CoroutineContext.kt

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
package kotlinx.coroutines.experimental
66

7-
import java.util.*
87
import kotlinx.coroutines.experimental.internal.*
98
import kotlinx.coroutines.experimental.scheduling.*
109
import java.util.concurrent.atomic.*

core/kotlinx-coroutines-core/src/ThreadContextElement.kt

+61-98
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,42 @@ import kotlin.coroutines.experimental.*
1212
* every time the coroutine with this element in the context is resumed on a thread.
1313
*
1414
* Implementations of this interface define a type [S] of the thread-local state that they need to store on
15-
* resume of a coroutine and restore later on suspend and the infrastructure provides the corresponding storage.
15+
* resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage.
1616
*
1717
* Example usage looks like this:
1818
*
1919
* ```
20-
* // declare thread local variable holding MyData
21-
* private val myThreadLocal = ThreadLocal<MyData?>()
22-
*
23-
* // declare context element holding MyData
24-
* class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
20+
* // Appends "name" of a coroutine to a current thread name when coroutine is executed
21+
* class CoroutineName(val name: String) : ThreadContextElement<String> {
2522
* // declare companion object for a key of this element in coroutine context
26-
* companion object Key : CoroutineContext.Key<MyElement>
23+
* companion object Key : CoroutineContext.Key<CoroutineName>
2724
*
2825
* // provide the key of the corresponding context element
29-
* override val key: CoroutineContext.Key<MyElement>
26+
* override val key: CoroutineContext.Key<CoroutineName>
3027
* get() = Key
3128
*
3229
* // this is invoked before coroutine is resumed on current thread
33-
* override fun updateThreadContext(context: CoroutineContext): MyData? {
34-
* val oldState = myThreadLocal.get()
35-
* myThreadLocal.set(data)
36-
* return oldState
30+
* override fun updateThreadContext(context: CoroutineContext): String {
31+
* val previousName = Thread.currentThread().name
32+
* Thread.currentThread().name = "$previousName # $name"
33+
* return previousName
3734
* }
3835
*
3936
* // this is invoked after coroutine has suspended on current thread
40-
* override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
41-
* myThreadLocal.set(oldState)
37+
* override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
38+
* Thread.currentThread().name = oldState
4239
* }
4340
* }
41+
*
42+
* // Usage
43+
* launch(UI + CoroutineName("Progress bar coroutine")) { ... }
4444
* ```
45+
*
46+
* Every time this coroutine is resumed on a thread, UI thread name is updated to
47+
* "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when
48+
* this coroutine suspends.
49+
*
50+
* To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function.
4551
*/
4652
public interface ThreadContextElement<S> : CoroutineContext.Element {
4753
/**
@@ -67,87 +73,44 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
6773
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
6874
}
6975

70-
private val ZERO = Symbol("ZERO")
71-
72-
// Used when there are >= 2 active elements in the context
73-
private class ThreadState(val context: CoroutineContext, n: Int) {
74-
private var a = arrayOfNulls<Any>(n)
75-
private var i = 0
76-
77-
fun append(value: Any?) { a[i++] = value }
78-
fun take() = a[i++]
79-
fun start() { i = 0 }
80-
}
81-
82-
// Counts ThreadContextElements in the context
83-
// Any? here is Int | ThreadContextElement (when count is one)
84-
private val countAll =
85-
fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
86-
if (element is ThreadContextElement<*>) {
87-
val inCount = countOrElement as? Int ?: 1
88-
return if (inCount == 0) element else inCount + 1
89-
}
90-
return countOrElement
91-
}
92-
93-
// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
94-
private val findOne =
95-
fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
96-
if (found != null) return found
97-
return element as? ThreadContextElement<*>
98-
}
99-
100-
// Updates state for ThreadContextElements in the context using the given ThreadState
101-
private val updateState =
102-
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
103-
if (element is ThreadContextElement<*>) {
104-
state.append(element.updateThreadContext(state.context))
105-
}
106-
return state
107-
}
108-
109-
// Restores state for all ThreadContextElements in the context from the given ThreadState
110-
private val restoreState =
111-
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
112-
@Suppress("UNCHECKED_CAST")
113-
if (element is ThreadContextElement<*>) {
114-
(element as ThreadContextElement<Any?>).restoreThreadContext(state.context, state.take())
115-
}
116-
return state
117-
}
118-
119-
internal fun updateThreadContext(context: CoroutineContext): Any? {
120-
val count = context.fold(0, countAll)
121-
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
122-
return when {
123-
count === 0 -> ZERO // very fast path when there are no active ThreadContextElements
124-
// ^^^ identity comparison for speed, we know zero always has the same identity
125-
count is Int -> {
126-
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
127-
context.fold(ThreadState(context, count), updateState)
128-
}
129-
else -> {
130-
// fast path for one ThreadContextElement (no allocations, no additional context scan)
131-
@Suppress("UNCHECKED_CAST")
132-
val element = count as ThreadContextElement<Any?>
133-
element.updateThreadContext(context)
134-
}
135-
}
136-
}
137-
138-
internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
139-
when {
140-
oldState === ZERO -> return // very fast path when there are no ThreadContextElements
141-
oldState is ThreadState -> {
142-
// slow path with multiple stored ThreadContextElements
143-
oldState.start()
144-
context.fold(oldState, restoreState)
145-
}
146-
else -> {
147-
// fast path for one ThreadContextElement, but need to find it
148-
@Suppress("UNCHECKED_CAST")
149-
val element = context.fold(null, findOne) as ThreadContextElement<Any?>
150-
element.restoreThreadContext(context, oldState)
151-
}
152-
}
153-
}
76+
/**
77+
* Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement]
78+
* maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on.
79+
* By default [ThreadLocal.get] is used as a value for the thread-local variable, but it can be overridden with [value] parameter.
80+
*
81+
* Example usage looks like this:
82+
*
83+
* ```
84+
* val myThreadLocal = ThreadLocal<String?>()
85+
* ...
86+
* println(myThreadLocal.get()) // Prints "null"
87+
* launch(CommonPool + myThreadLocal.asContextElement(initialValue = "foo")) {
88+
* println(myThreadLocal.get()) // Prints "foo"
89+
* withContext(UI) {
90+
* println(myThreadLocal.get()) // Prints "foo", but it's on UI thread
91+
* }
92+
* }
93+
* println(myThreadLocal.get()) // Prints "null"
94+
* ```
95+
*
96+
* Note that the context element does not track modifications of the thread-local variable, for example:
97+
*
98+
* ```
99+
* myThreadLocal.set("main")
100+
* withContext(UI) {
101+
* println(myThreadLocal.get()) // Prints "main"
102+
* myThreadLocal.set("UI")
103+
* }
104+
* println(myThreadLocal.get()) // Prints "main", not "UI"
105+
* ```
106+
*
107+
* Use `withContext` to update the corresponding thread-local variable to a different value, for example:
108+
*
109+
* ```
110+
* withContext(myThreadLocal.asContextElement("foo")) {
111+
* println(myThreadLocal.get()) // Prints "foo"
112+
* }
113+
* ```
114+
*/
115+
public fun <T> ThreadLocal<T>.asContextElement(value: T = get()): ThreadContextElement<T> =
116+
ThreadLocalElement(value, this)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
package kotlinx.coroutines.experimental.internal
2+
3+
import kotlinx.coroutines.experimental.*
4+
import kotlin.coroutines.experimental.*
5+
6+
7+
private val ZERO = Symbol("ZERO")
8+
9+
// Used when there are >= 2 active elements in the context
10+
private class ThreadState(val context: CoroutineContext, n: Int) {
11+
private var a = arrayOfNulls<Any>(n)
12+
private var i = 0
13+
14+
fun append(value: Any?) { a[i++] = value }
15+
fun take() = a[i++]
16+
fun start() { i = 0 }
17+
}
18+
19+
// Counts ThreadContextElements in the context
20+
// Any? here is Int | ThreadContextElement (when count is one)
21+
private val countAll =
22+
fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
23+
if (element is ThreadContextElement<*>) {
24+
val inCount = countOrElement as? Int ?: 1
25+
return if (inCount == 0) element else inCount + 1
26+
}
27+
return countOrElement
28+
}
29+
30+
// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
31+
private val findOne =
32+
fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
33+
if (found != null) return found
34+
return element as? ThreadContextElement<*>
35+
}
36+
37+
// Updates state for ThreadContextElements in the context using the given ThreadState
38+
private val updateState =
39+
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
40+
if (element is ThreadContextElement<*>) {
41+
state.append(element.updateThreadContext(state.context))
42+
}
43+
return state
44+
}
45+
46+
// Restores state for all ThreadContextElements in the context from the given ThreadState
47+
private val restoreState =
48+
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
49+
@Suppress("UNCHECKED_CAST")
50+
if (element is ThreadContextElement<*>) {
51+
(element as ThreadContextElement<Any?>).restoreThreadContext(state.context, state.take())
52+
}
53+
return state
54+
}
55+
56+
internal fun updateThreadContext(context: CoroutineContext): Any? {
57+
val count = context.fold(0, countAll)
58+
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
59+
return when {
60+
count === 0 -> ZERO // very fast path when there are no active ThreadContextElements
61+
// ^^^ identity comparison for speed, we know zero always has the same identity
62+
count is Int -> {
63+
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
64+
context.fold(ThreadState(context, count), updateState)
65+
}
66+
else -> {
67+
// fast path for one ThreadContextElement (no allocations, no additional context scan)
68+
@Suppress("UNCHECKED_CAST")
69+
val element = count as ThreadContextElement<Any?>
70+
element.updateThreadContext(context)
71+
}
72+
}
73+
}
74+
75+
internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
76+
when {
77+
oldState === ZERO -> return // very fast path when there are no ThreadContextElements
78+
oldState is ThreadState -> {
79+
// slow path with multiple stored ThreadContextElements
80+
oldState.start()
81+
context.fold(oldState, restoreState)
82+
}
83+
else -> {
84+
// fast path for one ThreadContextElement, but need to find it
85+
@Suppress("UNCHECKED_CAST")
86+
val element = context.fold(null, findOne) as ThreadContextElement<Any?>
87+
element.restoreThreadContext(context, oldState)
88+
}
89+
}
90+
}
91+
92+
// top-level data class for a nicer out-of-the-box toString representation and class name
93+
private data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
94+
95+
internal class ThreadLocalElement<T>(
96+
private val value: T,
97+
private val threadLocal: ThreadLocal<T>
98+
) : ThreadContextElement<T> {
99+
override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal)
100+
101+
override fun updateThreadContext(context: CoroutineContext): T {
102+
val oldState = threadLocal.get()
103+
threadLocal.set(value)
104+
return oldState
105+
}
106+
107+
override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
108+
threadLocal.set(oldState)
109+
}
110+
111+
// this method is overridden to perform value comparison (==) on key
112+
override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext {
113+
return if (this.key == key) EmptyCoroutineContext else this
114+
}
115+
116+
// this method is overridden to perform value comparison (==) on key
117+
public override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
118+
@Suppress("UNCHECKED_CAST")
119+
if (this.key == key) this as E else null
120+
121+
override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)"
122+
}

core/kotlinx-coroutines-core/test/ThreadContextElementTest.kt

+33
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,39 @@ class ThreadContextElementTest : TestBase() {
5252
job.join()
5353
assertNull(myThreadLocal.get())
5454
}
55+
56+
57+
@Test
58+
fun testWithContext() = runTest {
59+
expect(1)
60+
newSingleThreadContext("withContext").use {
61+
val data = MyData()
62+
async(CommonPool + MyElement(data)) {
63+
assertSame(data, myThreadLocal.get())
64+
expect(2)
65+
66+
val newData = MyData()
67+
async(it + MyElement(newData)) {
68+
assertSame(newData, myThreadLocal.get())
69+
expect(3)
70+
}.await()
71+
72+
withContext(it + MyElement(newData)) {
73+
assertSame(newData, myThreadLocal.get())
74+
expect(4)
75+
}
76+
77+
async(it) {
78+
assertNull(myThreadLocal.get())
79+
expect(5)
80+
}.await()
81+
82+
expect(6)
83+
}.await()
84+
}
85+
86+
finish(7)
87+
}
5588
}
5689

5790
class MyData

0 commit comments

Comments
 (0)