-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Copy pathThreadContextElement.kt
237 lines (228 loc) · 10 KB
/
ThreadContextElement.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
/*
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.coroutines
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
/**
* Defines elements in [CoroutineContext] that are installed into thread context
* every time the coroutine with this element in the context is resumed on a thread.
*
* Implementations of this interface define a type [S] of the thread-local state that they need to store on
* resume of a coroutine and restore later on suspend. The infrastructure provides the corresponding storage.
*
* Example usage looks like this:
*
* ```
* // Appends "name" of a coroutine to a current thread name when coroutine is executed
* class CoroutineName(val name: String) : ThreadContextElement<String> {
* // declare companion object for a key of this element in coroutine context
* companion object Key : CoroutineContext.Key<CoroutineName>
*
* // provide the key of the corresponding context element
* override val key: CoroutineContext.Key<CoroutineName>
* get() = Key
*
* // this is invoked before coroutine is resumed on current thread
* override fun updateThreadContext(context: CoroutineContext): String {
* val previousName = Thread.currentThread().name
* Thread.currentThread().name = "$previousName # $name"
* return previousName
* }
*
* // this is invoked after coroutine has suspended on current thread
* override fun restoreThreadContext(context: CoroutineContext, oldState: String) {
* Thread.currentThread().name = oldState
* }
* }
*
* // Usage
* launch(Dispatchers.Main + CoroutineName("Progress bar coroutine")) { ... }
* ```
*
* Every time this coroutine is resumed on a thread, UI thread name is updated to
* "UI thread original name # Progress bar coroutine" and the thread name is restored to the original one when
* this coroutine suspends.
*
* To use [ThreadLocal] variable within the coroutine use [ThreadLocal.asContextElement][asContextElement] function.
*/
public interface ThreadContextElement<S> : CoroutineContext.Element {
/**
* Updates context of the current thread.
* This function is invoked before the coroutine in the specified [context] is resumed in the current thread
* when the context of the coroutine this element.
* The result of this function is the old value of the thread-local state that will be passed to [restoreThreadContext].
* This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
* context is updated in an undefined state and may crash an application.
*
* @param context the coroutine context.
*/
public fun updateThreadContext(context: CoroutineContext): S
/**
* Restores context of the current thread.
* This function is invoked after the coroutine in the specified [context] is suspended in the current thread
* if [updateThreadContext] was previously invoked on resume of this coroutine.
* The value of [oldState] is the result of the previous invocation of [updateThreadContext] and it should
* be restored in the thread-local state by this function.
* This method should handle its own exceptions and do not rethrow it. Thrown exceptions will leave coroutine which
* context is updated in an undefined state and may crash an application.
*
* @param context the coroutine context.
* @param oldState the value returned by the previous invocation of [updateThreadContext].
*/
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
}
/**
* A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it.
*
* When an API uses a _mutable_ `ThreadLocal` for consistency, a [CopyableThreadContextElement]
* can give coroutines "coroutine-safe" write access to that `ThreadLocal`.
*
* A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine
* will be visible to _itself_ and any child coroutine launched _after_ that write.
*
* Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen
* to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_
* launching a child coroutine will not be visible to that child coroutine.
*
* This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and
* correctly, regardless of the coroutine's structured concurrency.
*
* This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace
* is in a coroutine:
*
* ```
* class TraceContextElement(val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
* companion object Key : CoroutineContext.Key<ThreadTraceContextElement>
* override val key: CoroutineContext.Key<ThreadTraceContextElement>
* get() = Key
*
* override fun updateThreadContext(context: CoroutineContext): TraceData? {
* val oldState = traceThreadLocal.get()
* traceThreadLocal.set(data)
* return oldState
* }
*
* override fun restoreThreadContext(context: CoroutineContext, oldData: TraceData?) {
* traceThreadLocal.set(oldState)
* }
*
* override fun copyForChildCoroutine(): CopyableThreadContextElement<MyData?> {
* // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
* // ThreadLocal writes between resumption of the parent coroutine and the launch of the
* // child coroutine visible to the child.
* return CopyForChildCoroutineElement(traceThreadLocal.get())
* }
* }
* ```
*
* A coroutine using this mechanism can safely call Java code that assumes it's called using a
* `Thread`.
*/
@ExperimentalCoroutinesApi
public interface CopyableThreadContextElement<S> : ThreadContextElement<S> {
/**
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
* coroutine's context that is under construction.
*
* This function is called on the element each time a new coroutine inherits a context containing it,
* and the returned value is folded into the context given to the child.
*
* Since this method is called whenever a new coroutine is launched in a context containing this
* [CopyableThreadContextElement], implementations are performance-sensitive.
*/
public fun copyForChildCoroutine(): CopyableThreadContextElement<S>
}
/**
* Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement]
* maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on.
* By default [ThreadLocal.get] is used as a value for the thread-local variable, but it can be overridden with [value] parameter.
* Beware that context element **does not track** modifications of the thread-local and accessing thread-local from coroutine
* without the corresponding context element returns **undefined** value. See the examples for a detailed description.
*
*
* Example usage:
* ```
* val myThreadLocal = ThreadLocal<String?>()
* ...
* println(myThreadLocal.get()) // Prints "null"
* launch(Dispatchers.Default + myThreadLocal.asContextElement(value = "foo")) {
* println(myThreadLocal.get()) // Prints "foo"
* withContext(Dispatchers.Main) {
* println(myThreadLocal.get()) // Prints "foo", but it's on UI thread
* }
* }
* println(myThreadLocal.get()) // Prints "null"
* ```
*
* The context element does not track modifications of the thread-local variable, for example:
*
* ```
* myThreadLocal.set("main")
* withContext(Dispatchers.Main) {
* println(myThreadLocal.get()) // Prints "main"
* myThreadLocal.set("UI")
* }
* println(myThreadLocal.get()) // Prints "main", not "UI"
* ```
*
* Use `withContext` to update the corresponding thread-local variable to a different value, for example:
* ```
* withContext(myThreadLocal.asContextElement("foo")) {
* println(myThreadLocal.get()) // Prints "foo"
* }
* ```
*
* Accessing the thread-local without corresponding context element leads to undefined value:
* ```
* val tl = ThreadLocal.withInitial { "initial" }
*
* runBlocking {
* println(tl.get()) // Will print "initial"
* // Change context
* withContext(tl.asContextElement("modified")) {
* println(tl.get()) // Will print "modified"
* }
* // Context is changed again
* println(tl.get()) // <- WARN: can print either "modified" or "initial"
* }
* ```
* to fix this behaviour use `runBlocking(tl.asContextElement())`
*/
public fun <T> ThreadLocal<T>.asContextElement(value: T = get()): ThreadContextElement<T> =
ThreadLocalElement(value, this)
/**
* Return `true` when current thread local is present in the coroutine context, `false` otherwise.
* Thread local can be present in the context only if it was added via [asContextElement] to the context.
*
* Example of usage:
* ```
* suspend fun processRequest() {
* if (traceCurrentRequestThreadLocal.isPresent()) { // Probabilistic tracing
* // Do some heavy-weight tracing
* }
* // Process request regularly
* }
* ```
*/
public suspend inline fun ThreadLocal<*>.isPresent(): Boolean = coroutineContext[ThreadLocalKey(this)] !== null
/**
* Checks whether current thread local is present in the coroutine context and throws [IllegalStateException] if it is not.
* It is a good practice to validate that thread local is present in the context, especially in large code-bases,
* to avoid stale thread-local values and to have a strict invariants.
*
* E.g. one may use the following method to enforce proper use of the thread locals with coroutines:
* ```
* public suspend inline fun <T> ThreadLocal<T>.getSafely(): T {
* ensurePresent()
* return get()
* }
*
* // Usage
* withContext(...) {
* val value = threadLocal.getSafely() // Fail-fast in case of improper context
* }
* ```
*/
public suspend inline fun ThreadLocal<*>.ensurePresent(): Unit =
check(isPresent()) { "ThreadLocal $this is missing from context $coroutineContext" }