@@ -12,36 +12,42 @@ import kotlin.coroutines.experimental.*
12
12
* every time the coroutine with this element in the context is resumed on a thread.
13
13
*
14
14
* Implementations of this interface define a type [S] of the thread-local state that they need to store on
15
- * resume of a coroutine and restore later on suspend and the infrastructure provides the corresponding storage.
15
+ * resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage.
16
16
*
17
17
* Example usage looks like this:
18
18
*
19
19
* ```
20
- * // declare thread local variable holding MyData
21
- * private val myThreadLocal = ThreadLocal<MyData?>()
22
- *
23
- * // declare context element holding MyData
24
- * class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
20
+ * // Appends "name" of a coroutine to a current thread name when coroutine is executed
21
+ * class CoroutineName(val name: String) : ThreadContextElement<String> {
25
22
* // declare companion object for a key of this element in coroutine context
26
- * companion object Key : CoroutineContext.Key<MyElement >
23
+ * companion object Key : CoroutineContext.Key<CoroutineName >
27
24
*
28
25
* // provide the key of the corresponding context element
29
- * override val key: CoroutineContext.Key<MyElement >
26
+ * override val key: CoroutineContext.Key<CoroutineName >
30
27
* get() = Key
31
28
*
32
29
* // this is invoked before coroutine is resumed on current thread
33
- * override fun updateThreadContext(context: CoroutineContext): MyData? {
34
- * val oldState = myThreadLocal.get()
35
- * myThreadLocal.set(data)
36
- * return oldState
30
+ * override fun updateThreadContext(context: CoroutineContext): String {
31
+ * val previousName = Thread.currentThread().name
32
+ * Thread.currentThread().name = "$previousName # $name"
33
+ * return previousName
37
34
* }
38
35
*
39
36
* // this is invoked after coroutine has suspended on current thread
40
- * override fun restoreThreadContext(context: CoroutineContext, oldState: MyData? ) {
41
- * myThreadLocal.set(oldState)
37
+ * override fun restoreThreadContext(context: CoroutineContext, oldState: String ) {
38
+ * Thread.currentThread().name = oldState
42
39
* }
43
40
* }
41
+ *
42
+ * // Usage
43
+ * launch(UI + CoroutineName("Progress bar coroutine")) { ... }
44
44
* ```
45
+ *
46
+ * Every time this coroutine is resumed on a thread, UI thread name is updated to
47
+ * "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when
48
+ * this coroutine suspends.
49
+ *
50
+ * To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function.
45
51
*/
46
52
public interface ThreadContextElement <S > : CoroutineContext .Element {
47
53
/* *
@@ -67,87 +73,44 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
67
73
public fun restoreThreadContext (context : CoroutineContext , oldState : S )
68
74
}
69
75
70
- private val ZERO = Symbol (" ZERO" )
71
-
72
- // Used when there are >= 2 active elements in the context
73
- private class ThreadState (val context : CoroutineContext , n : Int ) {
74
- private var a = arrayOfNulls<Any >(n)
75
- private var i = 0
76
-
77
- fun append (value : Any? ) { a[i++ ] = value }
78
- fun take () = a[i++ ]
79
- fun start () { i = 0 }
80
- }
81
-
82
- // Counts ThreadContextElements in the context
83
- // Any? here is Int | ThreadContextElement (when count is one)
84
- private val countAll =
85
- fun (countOrElement : Any? , element : CoroutineContext .Element ): Any? {
86
- if (element is ThreadContextElement <* >) {
87
- val inCount = countOrElement as ? Int ? : 1
88
- return if (inCount == 0 ) element else inCount + 1
89
- }
90
- return countOrElement
91
- }
92
-
93
- // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
94
- private val findOne =
95
- fun (found : ThreadContextElement <* >? , element : CoroutineContext .Element ): ThreadContextElement <* >? {
96
- if (found != null ) return found
97
- return element as ? ThreadContextElement <* >
98
- }
99
-
100
- // Updates state for ThreadContextElements in the context using the given ThreadState
101
- private val updateState =
102
- fun (state : ThreadState , element : CoroutineContext .Element ): ThreadState {
103
- if (element is ThreadContextElement <* >) {
104
- state.append(element.updateThreadContext(state.context))
105
- }
106
- return state
107
- }
108
-
109
- // Restores state for all ThreadContextElements in the context from the given ThreadState
110
- private val restoreState =
111
- fun (state : ThreadState , element : CoroutineContext .Element ): ThreadState {
112
- @Suppress(" UNCHECKED_CAST" )
113
- if (element is ThreadContextElement <* >) {
114
- (element as ThreadContextElement <Any ?>).restoreThreadContext(state.context, state.take())
115
- }
116
- return state
117
- }
118
-
119
- internal fun updateThreadContext (context : CoroutineContext ): Any? {
120
- val count = context.fold(0 , countAll)
121
- @Suppress(" IMPLICIT_BOXING_IN_IDENTITY_EQUALS" )
122
- return when {
123
- count == = 0 -> ZERO // very fast path when there are no active ThreadContextElements
124
- // ^^^ identity comparison for speed, we know zero always has the same identity
125
- count is Int -> {
126
- // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
127
- context.fold(ThreadState (context, count), updateState)
128
- }
129
- else -> {
130
- // fast path for one ThreadContextElement (no allocations, no additional context scan)
131
- @Suppress(" UNCHECKED_CAST" )
132
- val element = count as ThreadContextElement <Any ?>
133
- element.updateThreadContext(context)
134
- }
135
- }
136
- }
137
-
138
- internal fun restoreThreadContext (context : CoroutineContext , oldState : Any? ) {
139
- when {
140
- oldState == = ZERO -> return // very fast path when there are no ThreadContextElements
141
- oldState is ThreadState -> {
142
- // slow path with multiple stored ThreadContextElements
143
- oldState.start()
144
- context.fold(oldState, restoreState)
145
- }
146
- else -> {
147
- // fast path for one ThreadContextElement, but need to find it
148
- @Suppress(" UNCHECKED_CAST" )
149
- val element = context.fold(null , findOne) as ThreadContextElement <Any ?>
150
- element.restoreThreadContext(context, oldState)
151
- }
152
- }
153
- }
76
+ /* *
77
+ * Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement]
78
+ * maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on.
79
+ * By default [ThreadLocal.get] is used as a value for the thread-local variable, but it can be overridden with [value] parameter.
80
+ *
81
+ * Example usage looks like this:
82
+ *
83
+ * ```
84
+ * val myThreadLocal = ThreadLocal<String?>()
85
+ * ...
86
+ * println(myThreadLocal.get()) // Prints "null"
87
+ * launch(CommonPool + myThreadLocal.asContextElement(initialValue = "foo")) {
88
+ * println(myThreadLocal.get()) // Prints "foo"
89
+ * withContext(UI) {
90
+ * println(myThreadLocal.get()) // Prints "foo", but it's on UI thread
91
+ * }
92
+ * }
93
+ * println(myThreadLocal.get()) // Prints "null"
94
+ * ```
95
+ *
96
+ * Note that the context element does not track modifications of the thread-local variable, for example:
97
+ *
98
+ * ```
99
+ * myThreadLocal.set("main")
100
+ * withContext(UI) {
101
+ * println(myThreadLocal.get()) // Prints "main"
102
+ * myThreadLocal.set("UI")
103
+ * }
104
+ * println(myThreadLocal.get()) // Prints "main", not "UI"
105
+ * ```
106
+ *
107
+ * Use `withContext` to update the corresponding thread-local variable to a different value, for example:
108
+ *
109
+ * ```
110
+ * withContext(myThreadLocal.asContextElement("foo")) {
111
+ * println(myThreadLocal.get()) // Prints "foo"
112
+ * }
113
+ * ```
114
+ */
115
+ public fun <T > ThreadLocal<T>.asContextElement (value : T = get()): ThreadContextElement <T > =
116
+ ThreadLocalElement (value, this )
0 commit comments