-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
Copy pathThreadContext.kt
131 lines (111 loc) · 5.07 KB
/
ThreadContext.kt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
/*
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/
package kotlinx.coroutines.internal
import kotlinx.coroutines.*
import kotlin.coroutines.*
private val ZERO = Symbol("ZERO")
// Used when there are >= 2 active elements in the context
private class ThreadState(val context: CoroutineContext, n: Int) {
private var a = arrayOfNulls<Any>(n)
private var i = 0
fun append(value: Any?) { a[i++] = value }
fun take() = a[i++]
fun start() { i = 0 }
}
// Counts ThreadContextElements in the context
// Any? here is Int | ThreadContextElement (when count is one)
private val countAll =
fun (countOrElement: Any?, element: CoroutineContext.Element): Any? {
if (element is ThreadContextElement<*>) {
val inCount = countOrElement as? Int ?: 1
return if (inCount == 0) element else inCount + 1
}
return countOrElement
}
// Find one (first) ThreadContextElement in the context, it is used when we know there is exactly one
private val findOne =
fun (found: ThreadContextElement<*>?, element: CoroutineContext.Element): ThreadContextElement<*>? {
if (found != null) return found
return element as? ThreadContextElement<*>
}
// Updates state for ThreadContextElements in the context using the given ThreadState
private val updateState =
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
if (element is ThreadContextElement<*>) {
state.append(element.updateThreadContext(state.context))
}
return state
}
// Restores state for all ThreadContextElements in the context from the given ThreadState
private val restoreState =
fun (state: ThreadState, element: CoroutineContext.Element): ThreadState {
@Suppress("UNCHECKED_CAST")
if (element is ThreadContextElement<*>) {
(element as ThreadContextElement<Any?>).restoreThreadContext(state.context, state.take())
}
return state
}
internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
// countOrElement is pre-cached in dispatched continuation
internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
@Suppress("NAME_SHADOWING")
val countOrElement = countOrElement ?: threadContextElements(context)
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
return when {
countOrElement === 0 -> ZERO // very fast path when there are no active ThreadContextElements
// ^^^ identity comparison for speed, we know zero always has the same identity
countOrElement is Int -> {
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
context.fold(ThreadState(context, countOrElement), updateState)
}
else -> {
// fast path for one ThreadContextElement (no allocations, no additional context scan)
@Suppress("UNCHECKED_CAST")
val element = countOrElement as ThreadContextElement<Any?>
element.updateThreadContext(context)
}
}
}
internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
when {
oldState === ZERO -> return // very fast path when there are no ThreadContextElements
oldState is ThreadState -> {
// slow path with multiple stored ThreadContextElements
oldState.start()
context.fold(oldState, restoreState)
}
else -> {
// fast path for one ThreadContextElement, but need to find it
@Suppress("UNCHECKED_CAST")
val element = context.fold(null, findOne) as ThreadContextElement<Any?>
element.restoreThreadContext(context, oldState)
}
}
}
// top-level data class for a nicer out-of-the-box toString representation and class name
@PublishedApi
internal data class ThreadLocalKey(private val threadLocal: ThreadLocal<*>) : CoroutineContext.Key<ThreadLocalElement<*>>
internal class ThreadLocalElement<T>(
private val value: T,
private val threadLocal: ThreadLocal<T>
) : ThreadContextElement<T> {
override val key: CoroutineContext.Key<*> = ThreadLocalKey(threadLocal)
override fun updateThreadContext(context: CoroutineContext): T {
val oldState = threadLocal.get()
threadLocal.set(value)
return oldState
}
override fun restoreThreadContext(context: CoroutineContext, oldState: T) {
threadLocal.set(oldState)
}
// this method is overridden to perform value comparison (==) on key
override fun minusKey(key: CoroutineContext.Key<*>): CoroutineContext {
return if (this.key == key) EmptyCoroutineContext else this
}
// this method is overridden to perform value comparison (==) on key
public override operator fun <E : CoroutineContext.Element> get(key: CoroutineContext.Key<E>): E? =
@Suppress("UNCHECKED_CAST")
if (this.key == key) this as E else null
override fun toString(): String = "ThreadLocal(value=$value, threadLocal = $threadLocal)"
}