@@ -7,6 +7,7 @@ package kotlinx.coroutines
7
7
import kotlinx.coroutines.internal.*
8
8
import kotlinx.coroutines.scheduling.*
9
9
import kotlin.coroutines.*
10
+ import kotlin.coroutines.jvm.internal.CoroutineStackFrame
10
11
11
12
internal const val COROUTINES_SCHEDULER_PROPERTY_NAME = " kotlinx.coroutines.scheduler"
12
13
@@ -47,6 +48,102 @@ internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, c
47
48
}
48
49
}
49
50
51
+ /* *
52
+ * Executes a block using a context of a given continuation.
53
+ */
54
+ internal actual inline fun <T > withContinuationContext (continuation : Continuation <* >, countOrElement : Any? , block : () -> T ): T {
55
+ val context = continuation.context
56
+ val oldValue = updateThreadContext(context, countOrElement)
57
+ val undispatchedCompletion = if (oldValue != = NO_THREAD_ELEMENTS ) {
58
+ // Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them
59
+ continuation.updateUndispatchedCompletion(context, oldValue)
60
+ } else {
61
+ null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context
62
+ }
63
+ try {
64
+ return block()
65
+ } finally {
66
+ if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) {
67
+ restoreThreadContext(context, oldValue)
68
+ }
69
+ }
70
+ }
71
+
72
+ internal fun Continuation <* >.updateUndispatchedCompletion (context : CoroutineContext , oldValue : Any? ): UndispatchedCoroutine <* >? {
73
+ if (this !is CoroutineStackFrame ) return null
74
+ /*
75
+ * Fast-path to detect whether we have unispatched coroutine at all in our stack.
76
+ *
77
+ * Implementation note.
78
+ * If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
79
+ * 1) Store undispatched coroutine right in the `UndispatchedMarker` instance
80
+ * 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker`
81
+ * from the context when creating dispatched coroutine in `withContext`.
82
+ * Another option is to "unmark it" instead of removing to save an allocation.
83
+ * Both options should work, but it requires more careful studying of the performance
84
+ * and, mostly, maintainability impact.
85
+ */
86
+ val potentiallyHasUndispatchedCorotuine = context[UndispatchedMarker ] != = null
87
+ if (! potentiallyHasUndispatchedCorotuine) return null
88
+ val completion = undispatchedCompletion()
89
+ completion?.saveThreadContext(context, oldValue)
90
+ return completion
91
+ }
92
+
93
+ internal tailrec fun CoroutineStackFrame.undispatchedCompletion (): UndispatchedCoroutine <* >? {
94
+ // Find direct completion of this continuation
95
+ val completion: CoroutineStackFrame = when (this ) {
96
+ is DispatchedCoroutine <* > -> return null
97
+ else -> callerFrame ? : return null // something else -- not supported
98
+ }
99
+ if (completion is UndispatchedCoroutine <* >) return completion // found UndispatchedCoroutine!
100
+ return completion.undispatchedCompletion() // walk up the call stack with tail call
101
+ }
102
+
103
+ /* *
104
+ * Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
105
+ * Used as a performance optimization to avoid stack walking where it is not nesessary.
106
+ */
107
+ private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
108
+ override val key: CoroutineContext .Key <* >
109
+ get() = this
110
+ }
111
+
112
+ // Used by withContext when context changes, but dispatcher stays the same
113
+ internal actual class UndispatchedCoroutine <in T >actual constructor (
114
+ context : CoroutineContext ,
115
+ uCont : Continuation <T >
116
+ ) : ScopeCoroutine<T>(context + UndispatchedMarker , uCont) {
117
+
118
+ private var savedContext: CoroutineContext ? = null
119
+ private var savedOldValue: Any? = null
120
+
121
+ fun saveThreadContext (context : CoroutineContext , oldValue : Any? ) {
122
+ savedContext = context
123
+ savedOldValue = oldValue
124
+ }
125
+
126
+ fun clearThreadContext (): Boolean {
127
+ if (savedContext == null ) return false
128
+ savedContext = null
129
+ savedOldValue = null
130
+ return true
131
+ }
132
+
133
+ override fun afterResume (state : Any? ) {
134
+ savedContext?.let { context ->
135
+ restoreThreadContext(context, savedOldValue)
136
+ savedContext = null
137
+ savedOldValue = null
138
+ }
139
+ // resume undispatched -- update context but stay on the same dispatcher
140
+ val result = recoverResult(state, uCont)
141
+ withContinuationContext(uCont, null ) {
142
+ uCont.resumeWith(result)
143
+ }
144
+ }
145
+ }
146
+
50
147
internal actual val CoroutineContext .coroutineName: String? get() {
51
148
if (! DEBUG ) return null
52
149
val coroutineId = this [CoroutineId ] ? : return null
0 commit comments