-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Copy pathThreadContext.kt
130 lines (111 loc) · 5.12 KB
/
ThreadContext.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
/*
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.coroutines.internal
import kotlinx.coroutines.*
import kotlin.coroutines.*
@JvmField
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")
// Used when there are >= 2 active elements in the context
@Suppress("UNCHECKED_CAST")
private class ThreadState(@JvmField val context: CoroutineContext, n: Int) {
private val values = arrayOfNulls<Any>(n)
private val elements = arrayOfNulls<ThreadContextElement<Any?>>(n)
private var i = 0
fun append(element: ThreadContextElement<*>, value: Any?) {
values[i] = value
elements[i++] = element as ThreadContextElement<Any?>
}
fun restore(context: CoroutineContext) {
for (i in elements.indices.reversed()) {
elements[i]?.restoreThreadContext(context, values[i])
}
}
}
// Counts ThreadContextElements in the context
// Any? here is Int | ThreadContextElement (when count is one)
private val countAll =
fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
if (element is ThreadContextElement<*>) {
val inCount = countOrElement as? Int ?: 1
return if (inCount == 0) element else inCount + 1
}
return countOrElement
}
// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
private val findOne =
fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
if (found != null) return found
return element as? ThreadContextElement<*>
}
// Updates state for ThreadContextElements in the context using the given ThreadState
private val updateState =
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
if (element is ThreadContextElement<*>) {
state.append(element, element.updateThreadContext(state.context))
}
return state
}
internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
// countOrElement is pre-cached in dispatched continuation
// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
@Suppress("NAME_SHADOWING")
val countOrElement = countOrElement ?: threadContextElements(context)
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
return when {
countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
// ^^^ identity comparison for speed, we know zero always has the same identity
countOrElement is Int -> {
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
context.fold(ThreadState(context, countOrElement), updateState)
}
else -> {
// fast path for one ThreadContextElement (no allocations, no additional context scan)
@Suppress("UNCHECKED_CAST")
val element = countOrElement as ThreadContextElement<Any?>
element.updateThreadContext(context)
}
}
}
internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
when {
oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
oldState is ThreadState -> {
// slow path with multiple stored ThreadContextElements
oldState.restore(context)
}
else -> {
// fast path for one ThreadContextElement, but need to find it
@Suppress("UNCHECKED_CAST")
val element = context.fold(null, findOne) as ThreadContextElement<Any?>
element.restoreThreadContext(context, oldState)
}
}
}
// top-level data class for a nicer out-of-the-box toString representation and class name
@PublishedApi
internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
internal class ThreadLocalElement<T>(
private val value: T,
private val threadLocal: ThreadLocal<T>
) : ThreadContextElement<T> {
override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal)
override fun updateThreadContext(context: CoroutineContext): T {
val oldState = threadLocal.get()
threadLocal.set(value)
return oldState
}
override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
threadLocal.set(oldState)
}
// this method is overridden to perform value comparison (==) on key
override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext {
return if (this.key == key) EmptyCoroutineContext else this
}
// this method is overridden to perform value comparison (==) on key
public override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
@Suppress("UNCHECKED_CAST")
if (this.key == key) this as E else null
override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)"
}