@@ -17,31 +17,34 @@ import kotlin.coroutines.experimental.*
17
17
* Example usage looks like this:
18
18
*
19
19
* ```
20
- * // declare thread local variable holding MyData
21
- * private val myThreadLocal = ThreadLocal<MyData?>()
22
- *
23
- * // declare context element holding MyData
24
- * class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
20
+ * // Appends "name" of a coroutine to a current thread name when coroutine is executed
21
+ * class CoroutineName(val name: String) : ThreadContextElement<String> {
25
22
* // declare companion object for a key of this element in coroutine context
26
- * companion object Key : CoroutineContext.Key<MyElement >
23
+ * companion object Key : CoroutineContext.Key<CoroutineName >
27
24
*
28
25
* // provide the key of the corresponding context element
29
- * override val key: CoroutineContext.Key<MyElement >
26
+ * override val key: CoroutineContext.Key<CoroutineName >
30
27
* get() = Key
31
28
*
32
29
* // this is invoked before coroutine is resumed on current thread
33
- * override fun updateThreadContext(context: CoroutineContext): MyData? {
34
- * val oldState = myThreadLocal.get()
35
- * myThreadLocal.set(data)
36
- * return oldState
30
+ * override fun updateThreadContext(context: CoroutineContext): String {
31
+ * val previousName = Thread.currentThread().name
32
+ * Thread.currentThread().name = "$previousName # $name"
33
+ * return previousName
37
34
* }
38
35
*
39
36
* // this is invoked after coroutine has suspended on current thread
40
- * override fun restoreThreadContext(context: CoroutineContext, oldState: MyData? ) {
41
- * myThreadLocal.set(oldState)
37
+ * override fun restoreThreadContext(context: CoroutineContext, oldState: String ) {
38
+ * Thread.currentThread().name = oldState
42
39
* }
43
40
* }
41
+ *
42
+ * // Usage
43
+ * launch(UI + CoroutineName("Progress bar coroutine")) { ... }
44
44
* ```
45
+ * Every time launched coroutine is executed, UI thread name will be updated to "UI thread original name # Progress bar coroutine"
46
+ *
47
+ * Note that for raw [ThreadLocal]s [asContextElement] factory should be used without any intermediate [ThreadContextElement] implementations
45
48
*/
46
49
public interface ThreadContextElement <S > : CoroutineContext .Element {
47
50
/* *
@@ -67,87 +70,40 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
67
70
public fun restoreThreadContext (context : CoroutineContext , oldState : S )
68
71
}
69
72
70
- private val ZERO = Symbol (" ZERO" )
71
-
72
- // Used when there are >= 2 active elements in the context
73
- private class ThreadState (val context : CoroutineContext , n : Int ) {
74
- private var a = arrayOfNulls<Any >(n)
75
- private var i = 0
76
-
77
- fun append (value : Any? ) { a[i++ ] = value }
78
- fun take () = a[i++ ]
79
- fun start () { i = 0 }
80
- }
81
-
82
- // Counts ThreadContextElements in the context
83
- // Any? here is Int | ThreadContextElement (when count is one)
84
- private val countAll =
85
- fun (countOrElement : Any? , element : CoroutineContext .Element ): Any? {
86
- if (element is ThreadContextElement <* >) {
87
- val inCount = countOrElement as ? Int ? : 1
88
- return if (inCount == 0 ) element else inCount + 1
89
- }
90
- return countOrElement
91
- }
92
-
93
- // Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
94
- private val findOne =
95
- fun (found : ThreadContextElement <* >? , element : CoroutineContext .Element ): ThreadContextElement <* >? {
96
- if (found != null ) return found
97
- return element as ? ThreadContextElement <* >
98
- }
99
-
100
- // Updates state for ThreadContextElements in the context using the given ThreadState
101
- private val updateState =
102
- fun (state : ThreadState , element : CoroutineContext .Element ): ThreadState {
103
- if (element is ThreadContextElement <* >) {
104
- state.append(element.updateThreadContext(state.context))
105
- }
106
- return state
107
- }
108
-
109
- // Restores state for all ThreadContextElements in the context from the given ThreadState
110
- private val restoreState =
111
- fun (state : ThreadState , element : CoroutineContext .Element ): ThreadState {
112
- @Suppress(" UNCHECKED_CAST" )
113
- if (element is ThreadContextElement <* >) {
114
- (element as ThreadContextElement <Any ?>).restoreThreadContext(state.context, state.take())
115
- }
116
- return state
117
- }
118
-
119
- internal fun updateThreadContext (context : CoroutineContext ): Any? {
120
- val count = context.fold(0 , countAll)
121
- @Suppress(" IMPLICIT_BOXING_IN_IDENTITY_EQUALS" )
122
- return when {
123
- count == = 0 -> ZERO // very fast path when there are no active ThreadContextElements
124
- // ^^^ identity comparison for speed, we know zero always has the same identity
125
- count is Int -> {
126
- // slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
127
- context.fold(ThreadState (context, count), updateState)
128
- }
129
- else -> {
130
- // fast path for one ThreadContextElement (no allocations, no additional context scan)
131
- @Suppress(" UNCHECKED_CAST" )
132
- val element = count as ThreadContextElement <Any ?>
133
- element.updateThreadContext(context)
134
- }
135
- }
136
- }
137
-
138
- internal fun restoreThreadContext (context : CoroutineContext , oldState : Any? ) {
139
- when {
140
- oldState == = ZERO -> return // very fast path when there are no ThreadContextElements
141
- oldState is ThreadState -> {
142
- // slow path with multiple stored ThreadContextElements
143
- oldState.start()
144
- context.fold(oldState, restoreState)
145
- }
146
- else -> {
147
- // fast path for one ThreadContextElement, but need to find it
148
- @Suppress(" UNCHECKED_CAST" )
149
- val element = context.fold(null , findOne) as ThreadContextElement <Any ?>
150
- element.restoreThreadContext(context, oldState)
151
- }
152
- }
73
+ /* *
74
+ * Wraps [ThreadLocal] into [ThreadContextElement]. Resulting [ThreadContextElement] will
75
+ * maintain given [ThreadLocal] value for coroutine not depending on actual thread it's run on.
76
+ * By default [ThreadLocal.get] is used as a initial value for the element, but it can be overridden with [initialValue] parameter.
77
+ *
78
+ * Example usage looks like this:
79
+ * ```
80
+ * val myThreadLocal = ThreadLocal<String?>()
81
+ * ...
82
+ * println(myThreadLocal.get()) // Will print "null"
83
+ * launch(CommonPool + myThreadLocal.asContextElement(initialValue = "foo")) {
84
+ * println(myThreadLocal.get()) // Will print "foo"
85
+ * withContext(UI) {
86
+ * println(myThreadLocal.get()) // Will print "foo", but it's UI thread
87
+ * }
88
+ * }
89
+ *
90
+ * println(myThreadLocal.get()) // Will print "null"
91
+ * ```
92
+ *
93
+ * Note that context element doesn't track modifications of thread local, for example
94
+ *
95
+ * ```
96
+ * myThreadLocal.set("main")
97
+ * withContext(UI) {
98
+ * println(myThreadLocal.get()) // will print "main"
99
+ * myThreadLocal.set("UI")
100
+ * }
101
+ *
102
+ * println(myThreadLocal.get()) // will print "main", not "UI"
103
+ * ```
104
+ *
105
+ * For modifications mutable boxes should be used instead
106
+ */
107
+ public fun <T > ThreadLocal<T>.asContextElement (initialValue : T = get()): ThreadContextElement <T > {
108
+ return ThreadLocalElement (initialValue, this )
153
109
}
0 commit comments