Skip to content

Commit 5dc55a6

Browse files
qwwdfsadelizarov
andcommitted
Restore thread context elements when directly resuming to parent.
This fix solves the problem of restoring thread-context when returning to another context in undispatched way. It impacts suspend/resume performance of coroutines that use ThreadContextElement and undispatched coroutines. The kotlinx.coroutines code poisons the context with special 'UndispatchedMarker' element and linear lookup is performed only when the marker is present. The code also contains description of an alternative approach in order to save a linear lookup in complex coroutines hierarchies. Fast-path of coroutine resumption is slowed down by a single context lookup. Fixes #985 Co-authored-by: Roman Elizarov <[email protected]>
1 parent 7061cc2 commit 5dc55a6

File tree

12 files changed

+420
-22
lines changed

12 files changed

+420
-22
lines changed

benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkBenchmark.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import org.openjdk.jmh.annotations.*
1010
import java.util.concurrent.*
1111
import kotlin.coroutines.*
1212

13-
@Warmup(iterations = 5, time = 1)
13+
@Warmup(iterations = 7, time = 1)
1414
@Measurement(iterations = 5, time = 1)
1515
@BenchmarkMode(Mode.AverageTime)
1616
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@@ -41,7 +41,7 @@ open class ChannelSinkBenchmark {
4141

4242
private suspend inline fun run(context: CoroutineContext): Int {
4343
return Channel
44-
.range(1, 1_000_000, context)
44+
.range(1, 10_000, context)
4545
.filter(context) { it % 4 == 0 }
4646
.fold(0) { a, b -> a + b }
4747
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package benchmarks
6+
7+
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.channels.*
9+
import org.openjdk.jmh.annotations.*
10+
import java.util.concurrent.*
11+
import kotlin.coroutines.*
12+
13+
@Warmup(iterations = 7, time = 1)
14+
@Measurement(iterations = 5, time = 1)
15+
@BenchmarkMode(Mode.AverageTime)
16+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
17+
@State(Scope.Benchmark)
18+
@Fork(2)
19+
open class ChannelSinkDepthBenchmark {
20+
private val tl = ThreadLocal.withInitial({ 42 })
21+
22+
private val unconfinedOneElement = Dispatchers.Unconfined + tl.asContextElement()
23+
24+
@Benchmark
25+
fun depth1(): Int = runBlocking {
26+
run(1, unconfinedOneElement)
27+
}
28+
29+
@Benchmark
30+
fun depth10(): Int = runBlocking {
31+
run(10, unconfinedOneElement)
32+
}
33+
34+
@Benchmark
35+
fun depth100(): Int = runBlocking {
36+
run(100, unconfinedOneElement)
37+
}
38+
39+
@Benchmark
40+
fun depth1000(): Int = runBlocking {
41+
run(1000, unconfinedOneElement)
42+
}
43+
44+
private suspend inline fun run(callTraceDepth: Int, context: CoroutineContext): Int {
45+
return Channel
46+
.range(1, 10_000, context)
47+
.filter(callTraceDepth, context) { it % 4 == 0 }
48+
.fold(0) { a, b -> a + b }
49+
}
50+
51+
private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) =
52+
GlobalScope.produce(context) {
53+
for (i in start until (start + count))
54+
send(i)
55+
}
56+
57+
// Migrated from deprecated operators, are good only for stressing channels
58+
59+
private fun ReceiveChannel<Int>.filter(
60+
callTraceDepth: Int,
61+
context: CoroutineContext = Dispatchers.Unconfined,
62+
predicate: suspend (Int) -> Boolean
63+
): ReceiveChannel<Int> =
64+
GlobalScope.produce(context, onCompletion = { cancel() }) {
65+
deeplyNestedFilter(this, callTraceDepth, predicate)
66+
}
67+
68+
private suspend fun ReceiveChannel<Int>.deeplyNestedFilter(
69+
sink: ProducerScope<Int>,
70+
depth: Int,
71+
predicate: suspend (Int) -> Boolean
72+
) {
73+
if (depth <= 1) {
74+
for (e in this) {
75+
if (predicate(e)) sink.send(e)
76+
}
77+
} else {
78+
deeplyNestedFilter(sink, depth - 1, predicate)
79+
require(true) // tail-call
80+
}
81+
}
82+
83+
private suspend inline fun <E, R> ReceiveChannel<E>.fold(initial: R, operation: (acc: R, E) -> R): R {
84+
var accumulator = initial
85+
consumeEach {
86+
accumulator = operation(accumulator, it)
87+
}
88+
return accumulator
89+
}
90+
}
91+

kotlinx-coroutines-core/common/src/Builders.common.kt

+3-11
Original file line numberDiff line numberDiff line change
@@ -207,25 +207,17 @@ private class LazyStandaloneCoroutine(
207207
}
208208

209209
// Used by withContext when context changes, but dispatcher stays the same
210-
private class UndispatchedCoroutine<in T>(
210+
internal expect class UndispatchedCoroutine<in T>(
211211
context: CoroutineContext,
212212
uCont: Continuation<T>
213-
) : ScopeCoroutine<T>(context, uCont) {
214-
override fun afterResume(state: Any?) {
215-
// resume undispatched -- update context by stay on the same dispatcher
216-
val result = recoverResult(state, uCont)
217-
withCoroutineContext(uCont.context, null) {
218-
uCont.resumeWith(result)
219-
}
220-
}
221-
}
213+
) : ScopeCoroutine<T>
222214

223215
private const val UNDECIDED = 0
224216
private const val SUSPENDED = 1
225217
private const val RESUMED = 2
226218

227219
// Used by withContext when context dispatcher changes
228-
private class DispatchedCoroutine<in T>(
220+
internal class DispatchedCoroutine<in T>(
229221
context: CoroutineContext,
230222
uCont: Continuation<T>
231223
) : ScopeCoroutine<T>(context, uCont) {

kotlinx-coroutines-core/common/src/CoroutineContext.common.kt

+1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@ internal expect val DefaultDelay: Delay
1919

2020
// countOrElement -- pre-cached value for ThreadContext.kt
2121
internal expect inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T
22+
internal expect inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T
2223
internal expect fun Continuation<*>.toDebugString(): String
2324
internal expect val CoroutineContext.coroutineName: String?

kotlinx-coroutines-core/common/src/internal/DispatchedContinuation.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ internal class DispatchedContinuation<in T>(
235235

236236
@Suppress("NOTHING_TO_INLINE") // we need it inline to save us an entry on the stack
237237
inline fun resumeUndispatchedWith(result: Result<T>) {
238-
withCoroutineContext(context, countOrElement) {
238+
withContinuationContext(continuation, countOrElement) {
239239
continuation.resumeWith(result)
240240
}
241241
}

kotlinx-coroutines-core/common/src/internal/DispatchedTask.kt

+3-3
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,9 @@ internal abstract class DispatchedTask<in T>(
8585
try {
8686
val delegate = delegate as DispatchedContinuation<T>
8787
val continuation = delegate.continuation
88-
val context = continuation.context
89-
val state = takeState() // NOTE: Must take state in any case, even if cancelled
90-
withCoroutineContext(context, delegate.countOrElement) {
88+
withContinuationContext(continuation, delegate.countOrElement) {
89+
val context = continuation.context
90+
val state = takeState() // NOTE: Must take state in any case, even if cancelled
9191
val exception = getExceptionalResult(state)
9292
/*
9393
* Check whether continuation was originally resumed with an exception.

kotlinx-coroutines-core/js/src/CoroutineContext.kt

+9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package kotlinx.coroutines
66

7+
import kotlinx.coroutines.internal.*
78
import kotlin.browser.*
89
import kotlin.coroutines.*
910

@@ -49,5 +50,13 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext):
4950

5051
// No debugging facilities on JS
5152
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block()
53+
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block()
5254
internal actual fun Continuation<*>.toDebugString(): String = toString()
5355
internal actual val CoroutineContext.coroutineName: String? get() = null // not supported on JS
56+
57+
internal actual class UndispatchedCoroutine<in T> actual constructor(
58+
context: CoroutineContext,
59+
uCont: Continuation<T>
60+
) : ScopeCoroutine<T>(context, uCont) {
61+
override fun afterResume(state: Any?) = uCont.resumeWith(recoverResult(state, uCont))
62+
}

kotlinx-coroutines-core/jvm/src/Builders.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
2+
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

55
@file:JvmMultifileClass

kotlinx-coroutines-core/jvm/src/CoroutineContext.kt

+97
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package kotlinx.coroutines
77
import kotlinx.coroutines.internal.*
88
import kotlinx.coroutines.scheduling.*
99
import kotlin.coroutines.*
10+
import kotlin.coroutines.jvm.internal.CoroutineStackFrame
1011

1112
internal const val COROUTINES_SCHEDULER_PROPERTY_NAME = "kotlinx.coroutines.scheduler"
1213

@@ -47,6 +48,102 @@ internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, c
4748
}
4849
}
4950

51+
/**
52+
* Executes a block using a context of a given continuation.
53+
*/
54+
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T {
55+
val context = continuation.context
56+
val oldValue = updateThreadContext(context, countOrElement)
57+
val undispatchedCompletion = if (oldValue !== NO_THREAD_ELEMENTS) {
58+
// Only if some values were replaced we'll go to the slow path of figuring out where/how to restore them
59+
continuation.updateUndispatchedCompletion(context, oldValue)
60+
} else {
61+
null // fast path -- don't even try to find undispatchedCompletion as there's nothing to restore in the context
62+
}
63+
try {
64+
return block()
65+
} finally {
66+
if (undispatchedCompletion == null || undispatchedCompletion.clearThreadContext()) {
67+
restoreThreadContext(context, oldValue)
68+
}
69+
}
70+
}
71+
72+
internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
73+
if (this !is CoroutineStackFrame) return null
74+
/*
75+
* Fast-path to detect whether we have unispatched coroutine at all in our stack.
76+
*
77+
* Implementation note.
78+
* If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
79+
* 1) Store undispatched coroutine right in the `UndispatchedMarker` instance
80+
* 2) To avoid issues with cross-dispatch boundary, remove `UndispatchedMarker`
81+
* from the context when creating dispatched coroutine in `withContext`.
82+
* Another option is to "unmark it" instead of removing to save an allocation.
83+
* Both options should work, but it requires more careful studying of the performance
84+
* and, mostly, maintainability impact.
85+
*/
86+
val potentiallyHasUndispatchedCorotuine = context[UndispatchedMarker] !== null
87+
if (!potentiallyHasUndispatchedCorotuine) return null
88+
val completion = undispatchedCompletion()
89+
completion?.saveThreadContext(context, oldValue)
90+
return completion
91+
}
92+
93+
internal tailrec fun CoroutineStackFrame.undispatchedCompletion(): UndispatchedCoroutine<*>? {
94+
// Find direct completion of this continuation
95+
val completion: CoroutineStackFrame = when (this) {
96+
is DispatchedCoroutine<*> -> return null
97+
else -> callerFrame ?: return null // something else -- not supported
98+
}
99+
if (completion is UndispatchedCoroutine<*>) return completion // found UndispatchedCoroutine!
100+
return completion.undispatchedCompletion() // walk up the call stack with tail call
101+
}
102+
103+
/**
104+
* Marker indicating that [UndispatchedCoroutine] exists somewhere up in the stack.
105+
* Used as a performance optimization to avoid stack walking where it is not nesessary.
106+
*/
107+
private object UndispatchedMarker: CoroutineContext.Element, CoroutineContext.Key<UndispatchedMarker> {
108+
override val key: CoroutineContext.Key<*>
109+
get() = this
110+
}
111+
112+
// Used by withContext when context changes, but dispatcher stays the same
113+
internal actual class UndispatchedCoroutine<in T>actual constructor (
114+
context: CoroutineContext,
115+
uCont: Continuation<T>
116+
) : ScopeCoroutine<T>(context + UndispatchedMarker, uCont) {
117+
118+
private var savedContext: CoroutineContext? = null
119+
private var savedOldValue: Any? = null
120+
121+
fun saveThreadContext(context: CoroutineContext, oldValue: Any?) {
122+
savedContext = context
123+
savedOldValue = oldValue
124+
}
125+
126+
fun clearThreadContext(): Boolean {
127+
if (savedContext == null) return false
128+
savedContext = null
129+
savedOldValue = null
130+
return true
131+
}
132+
133+
override fun afterResume(state: Any?) {
134+
savedContext?.let { context ->
135+
restoreThreadContext(context, savedOldValue)
136+
savedContext = null
137+
savedOldValue = null
138+
}
139+
// resume undispatched -- update context but stay on the same dispatcher
140+
val result = recoverResult(state, uCont)
141+
withContinuationContext(uCont, null) {
142+
uCont.resumeWith(result)
143+
}
144+
}
145+
}
146+
50147
internal actual val CoroutineContext.coroutineName: String? get() {
51148
if (!DEBUG) return null
52149
val coroutineId = this[CoroutineId] ?: return null

kotlinx-coroutines-core/jvm/src/internal/ThreadContext.kt

+5-4
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ package kotlinx.coroutines.internal
77
import kotlinx.coroutines.*
88
import kotlin.coroutines.*
99

10-
11-
private val ZERO = Symbol("ZERO")
10+
@JvmField
11+
internal val NO_THREAD_ELEMENTS = Symbol("NO_THREAD_ELEMENTS")
1212

1313
// Used when there are >= 2 active elements in the context
1414
private class ThreadState(val context: CoroutineContext, n: Int) {
@@ -60,12 +60,13 @@ private val restoreState =
6060
internal actual fun threadContextElements(context: CoroutineContext): Any = context.fold(0, countAll)!!
6161

6262
// countOrElement is pre-cached in dispatched continuation
63+
// returns NO_THREAD_ELEMENTS if the contest does not have any ThreadContextElements
6364
internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?): Any? {
6465
@Suppress("NAME_SHADOWING")
6566
val countOrElement = countOrElement ?: threadContextElements(context)
6667
@Suppress("IMPLICIT_BOXING_IN_IDENTITY_EQUALS")
6768
return when {
68-
countOrElement === 0 -> ZERO // very fast path when there are no active ThreadContextElements
69+
countOrElement === 0 -> NO_THREAD_ELEMENTS // very fast path when there are no active ThreadContextElements
6970
// ^^^ identity comparison for speed, we know zero always has the same identity
7071
countOrElement is Int -> {
7172
// slow path for multiple active ThreadContextElements, allocates ThreadState for multiple old values
@@ -82,7 +83,7 @@ internal fun updateThreadContext(context: CoroutineContext, countOrElement: Any?
8283

8384
internal fun restoreThreadContext(context: CoroutineContext, oldState: Any?) {
8485
when {
85-
oldState === ZERO -> return // very fast path when there are no ThreadContextElements
86+
oldState === NO_THREAD_ELEMENTS -> return // very fast path when there are no ThreadContextElements
8687
oldState is ThreadState -> {
8788
// slow path with multiple stored ThreadContextElements
8889
oldState.start()

0 commit comments

Comments
 (0)