Skip to content

Commit a5dd74b

Browse files
CopyableThreadContextElement implementation (#3227)
New approach eagerly copies corresponding elements to avoid accidental top-level reuse and also provides merge capability in case when an element is being overwritten. Merge capability is crucial in tracing scenarios to properly preserve the state of linked thread locals Co-authored-by: dkhalanskyjb <[email protected]>
1 parent 8133c97 commit a5dd74b

9 files changed

+259
-35
lines changed

kotlinx-coroutines-core/api/kotlinx-coroutines-core.api

+3-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,8 @@ public final class kotlinx/coroutines/CompletionHandlerException : java/lang/Run
141141
}
142142

143143
public abstract interface class kotlinx/coroutines/CopyableThreadContextElement : kotlinx/coroutines/ThreadContextElement {
144-
public abstract fun copyForChildCoroutine ()Lkotlinx/coroutines/CopyableThreadContextElement;
144+
public abstract fun copyForChild ()Lkotlinx/coroutines/CopyableThreadContextElement;
145+
public abstract fun mergeForChild (Lkotlin/coroutines/CoroutineContext$Element;)Lkotlin/coroutines/CoroutineContext;
145146
}
146147

147148
public final class kotlinx/coroutines/CopyableThreadContextElement$DefaultImpls {
@@ -156,6 +157,7 @@ public abstract interface class kotlinx/coroutines/CopyableThrowable {
156157
}
157158

158159
public final class kotlinx/coroutines/CoroutineContextKt {
160+
public static final fun newCoroutineContext (Lkotlin/coroutines/CoroutineContext;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
159161
public static final fun newCoroutineContext (Lkotlinx/coroutines/CoroutineScope;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
160162
}
161163

kotlinx-coroutines-core/common/src/Builders.common.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ public suspend fun <T> withContext(
148148
return suspendCoroutineUninterceptedOrReturn sc@ { uCont ->
149149
// compute new context
150150
val oldContext = uCont.context
151-
val newContext = oldContext + context
151+
// Copy CopyableThreadContextElement if necessary
152+
val newContext = oldContext.newCoroutineContext(context)
152153
// always check for cancellation of new context
153154
newContext.ensureActive()
154155
// FAST PATH #1 -- new context is the same as the old one

kotlinx-coroutines-core/common/src/CoroutineContext.common.kt

+10-2
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,19 @@ package kotlinx.coroutines
77
import kotlin.coroutines.*
88

99
/**
10-
* Creates a context for the new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
11-
* [ContinuationInterceptor] is specified, and adds optional support for debugging facilities (when turned on).
10+
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
11+
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
12+
* and copyable-thread-local facilities on JVM.
1213
*/
1314
public expect fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext
1415

16+
/**
17+
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
18+
* @suppress
19+
*/
20+
@InternalCoroutinesApi
21+
public expect fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext
22+
1523
@PublishedApi
1624
@Suppress("PropertyName")
1725
internal expect val DefaultDelay: Delay

kotlinx-coroutines-core/js/src/CoroutineContext.kt

+4
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@ public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext):
4242
combined + Dispatchers.Default else combined
4343
}
4444

45+
public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
46+
return this + addedContext
47+
}
48+
4549
// No debugging facilities on JS
4650
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, countOrElement: Any?, block: () -> T): T = block()
4751
internal actual inline fun <T> withContinuationContext(continuation: Continuation<*>, countOrElement: Any?, block: () -> T): T = block()

kotlinx-coroutines-core/jvm/src/CoroutineContext.kt

+67-20
Original file line numberDiff line numberDiff line change
@@ -9,36 +9,83 @@ import kotlin.coroutines.*
99
import kotlin.coroutines.jvm.internal.CoroutineStackFrame
1010

1111
/**
12-
* Creates context for the new coroutine. It installs [Dispatchers.Default] when no other dispatcher nor
13-
* [ContinuationInterceptor] is specified, and adds optional support for debugging facilities (when turned on).
14-
*
12+
* Creates a context for a new coroutine. It installs [Dispatchers.Default] when no other dispatcher or
13+
* [ContinuationInterceptor] is specified and adds optional support for debugging facilities (when turned on)
14+
* and copyable-thread-local facilities on JVM.
1515
* See [DEBUG_PROPERTY_NAME] for description of debugging facilities on JVM.
1616
*/
1717
@ExperimentalCoroutinesApi
1818
public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
19-
val combined = coroutineContext.foldCopiesForChildCoroutine() + context
19+
val combined = foldCopies(coroutineContext, context, true)
2020
val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
2121
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
2222
debug + Dispatchers.Default else debug
2323
}
2424

2525
/**
26-
* Returns the [CoroutineContext] for a child coroutine to inherit.
27-
*
28-
* If any [CopyableThreadContextElement] is in the [this], calls
29-
* [CopyableThreadContextElement.copyForChildCoroutine] on each, returning a new [CoroutineContext]
30-
* by folding the returned copied elements into [this].
31-
*
32-
* Returns [this] if `this` has zero [CopyableThreadContextElement] in it.
26+
* Creates a context for coroutine builder functions that do not launch a new coroutine, e.g. [withContext].
27+
* @suppress
3328
*/
34-
private fun CoroutineContext.foldCopiesForChildCoroutine(): CoroutineContext {
35-
val hasToCopy = fold(false) { result, it ->
36-
result || it is CopyableThreadContextElement<*>
29+
@InternalCoroutinesApi
30+
public actual fun CoroutineContext.newCoroutineContext(addedContext: CoroutineContext): CoroutineContext {
31+
/*
32+
* Fast-path: we only have to copy/merge if 'addedContext' (which typically has one or two elements)
33+
* contains copyable elements.
34+
*/
35+
if (!addedContext.hasCopyableElements()) return this + addedContext
36+
return foldCopies(this, addedContext, false)
37+
}
38+
39+
private fun CoroutineContext.hasCopyableElements(): Boolean =
40+
fold(false) { result, it -> result || it is CopyableThreadContextElement<*> }
41+
42+
/**
43+
* Folds two contexts properly applying [CopyableThreadContextElement] rules when necessary.
44+
* The rules are the following:
45+
* * If neither context has CTCE, the sum of two contexts is returned
46+
* * Every CTCE from the left-hand side context that does not have a matching (by key) element from right-hand side context
47+
* is [copied][CopyableThreadContextElement.copyForChild] if [isNewCoroutine] is `true`.
48+
* * Every CTCE from the left-hand side context that has a matching element in the right-hand side context is [merged][CopyableThreadContextElement.mergeForChild]
49+
* * Every CTCE from the right-hand side context that hasn't been merged is copied
50+
* * Everything else is added to the resulting context as is.
51+
*/
52+
private fun foldCopies(originalContext: CoroutineContext, appendContext: CoroutineContext, isNewCoroutine: Boolean): CoroutineContext {
53+
// Do we have something to copy left-hand side?
54+
val hasElementsLeft = originalContext.hasCopyableElements()
55+
val hasElementsRight = appendContext.hasCopyableElements()
56+
57+
// Nothing to fold, so just return the sum of contexts
58+
if (!hasElementsLeft && !hasElementsRight) {
59+
return originalContext + appendContext
60+
}
61+
62+
var leftoverContext = appendContext
63+
val folded = originalContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
64+
if (element !is CopyableThreadContextElement<*>) return@fold result + element
65+
// Will this element be overwritten?
66+
val newElement = leftoverContext[element.key]
67+
// No, just copy it
68+
if (newElement == null) {
69+
// For 'withContext'-like builders we do not copy as the element is not shared
70+
return@fold result + if (isNewCoroutine) element.copyForChild() else element
71+
}
72+
// Yes, then first remove the element from append context
73+
leftoverContext = leftoverContext.minusKey(element.key)
74+
// Return the sum
75+
@Suppress("UNCHECKED_CAST")
76+
return@fold result + (element as CopyableThreadContextElement<Any?>).mergeForChild(newElement)
3777
}
38-
if (!hasToCopy) return this
39-
return fold<CoroutineContext>(EmptyCoroutineContext) { combined, it ->
40-
combined + if (it is CopyableThreadContextElement<*>) it.copyForChildCoroutine() else it
78+
79+
if (hasElementsRight) {
80+
leftoverContext = leftoverContext.fold<CoroutineContext>(EmptyCoroutineContext) { result, element ->
81+
// We're appending new context element -- we have to copy it, otherwise it may be shared with others
82+
if (element is CopyableThreadContextElement<*>) {
83+
return@fold result + element.copyForChild()
84+
}
85+
return@fold result + element
86+
}
4187
}
88+
return folded + leftoverContext
4289
}
4390

4491
/**
@@ -77,7 +124,7 @@ internal actual inline fun <T> withContinuationContext(continuation: Continuatio
77124
internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineContext, oldValue: Any?): UndispatchedCoroutine<*>? {
78125
if (this !is CoroutineStackFrame) return null
79126
/*
80-
* Fast-path to detect whether we have unispatched coroutine at all in our stack.
127+
* Fast-path to detect whether we have undispatched coroutine at all in our stack.
81128
*
82129
* Implementation note.
83130
* If we ever find that stackwalking for thread-locals is way too slow, here is another idea:
@@ -88,8 +135,8 @@ internal fun Continuation<*>.updateUndispatchedCompletion(context: CoroutineCont
88135
* Both options should work, but it requires more careful studying of the performance
89136
* and, mostly, maintainability impact.
90137
*/
91-
val potentiallyHasUndispatchedCorotuine = context[UndispatchedMarker] !== null
92-
if (!potentiallyHasUndispatchedCorotuine) return null
138+
val potentiallyHasUndispatchedCoroutine = context[UndispatchedMarker] !== null
139+
if (!potentiallyHasUndispatchedCoroutine) return null
93140
val completion = undispatchedCompletion()
94141
completion?.saveThreadContext(context, oldValue)
95142
return completion

kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt

+27-6
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
8080
/**
8181
* A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it.
8282
*
83-
* When an API uses a _mutable_ `ThreadLocal` for consistency, a [CopyableThreadContextElement]
83+
* When an API uses a _mutable_ [ThreadLocal] for consistency, a [CopyableThreadContextElement]
8484
* can give coroutines "coroutine-safe" write access to that `ThreadLocal`.
8585
*
8686
* A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine
@@ -99,6 +99,7 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
9999
* ```
100100
* class TraceContextElement(private val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
101101
* companion object Key : CoroutineContext.Key<TraceContextElement>
102+
*
102103
* override val key: CoroutineContext.Key<TraceContextElement> = Key
103104
*
104105
* override fun updateThreadContext(context: CoroutineContext): TraceData? {
@@ -111,32 +112,52 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
111112
* traceThreadLocal.set(oldState)
112113
* }
113114
*
114-
* override fun copyForChildCoroutine(): CopyableThreadContextElement<TraceData?> {
115+
* override fun copyForChild(): TraceContextElement {
115116
* // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
116117
* // ThreadLocal writes between resumption of the parent coroutine and the launch of the
117118
* // child coroutine visible to the child.
118119
* return TraceContextElement(traceThreadLocal.get()?.copy())
119120
* }
121+
*
122+
* override fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext {
123+
* // Merge operation defines how to handle situations when both
124+
* // the parent coroutine has an element in the context and
125+
* // an element with the same key was also
126+
* // explicitly passed to the child coroutine.
127+
* // If merging does not require special behavior,
128+
* // the copy of the element can be returned.
129+
* return TraceContextElement(traceThreadLocal.get()?.copy())
130+
* }
120131
* }
121132
* ```
122133
*
123-
* A coroutine using this mechanism can safely call Java code that assumes it's called using a
124-
* `Thread`.
134+
* A coroutine using this mechanism can safely call Java code that assumes the corresponding thread local element's
135+
* value is installed into the target thread local.
125136
*/
137+
@DelicateCoroutinesApi
126138
@ExperimentalCoroutinesApi
127139
public interface CopyableThreadContextElement<S> : ThreadContextElement<S> {
128140

129141
/**
130142
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
131-
* coroutine's context that is under construction.
143+
* coroutine's context that is under construction if the added context does not contain an element with the same [key].
132144
*
133145
* This function is called on the element each time a new coroutine inherits a context containing it,
134146
* and the returned value is folded into the context given to the child.
135147
*
136148
* Since this method is called whenever a new coroutine is launched in a context containing this
137149
* [CopyableThreadContextElement], implementations are performance-sensitive.
138150
*/
139-
public fun copyForChildCoroutine(): CopyableThreadContextElement<S>
151+
public fun copyForChild(): CopyableThreadContextElement<S>
152+
153+
/**
154+
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
155+
* coroutine's context that is under construction if the added context does contain an element with the same [key].
156+
*
157+
* This method is invoked on the original element, accepting as the parameter
158+
* the element that is supposed to overwrite it.
159+
*/
160+
public fun mergeForChild(overwritingElement: CoroutineContext.Element): CoroutineContext
140161
}
141162

142163
/**

kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt

+8-5
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,7 @@ class ThreadContextElementTest : TestBase() {
126126
@Test
127127
fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest {
128128
newFixedThreadPoolContext(nThreads = 4, name = "withContext").use {
129-
val startData = MyData()
130-
withContext(it + CopyForChildCoroutineElement(startData)) {
129+
withContext(it + CopyForChildCoroutineElement(MyData())) {
131130
val forBlockData = MyData()
132131
myThreadLocal.setForBlock(forBlockData) {
133132
assertSame(myThreadLocal.get(), forBlockData)
@@ -153,7 +152,7 @@ class ThreadContextElementTest : TestBase() {
153152
assertSame(myThreadLocal.get(), forBlockData)
154153
}
155154
}
156-
assertSame(myThreadLocal.get(), startData) // Asserts value was restored.
155+
assertNull(myThreadLocal.get()) // Asserts value was restored to its origin
157156
}
158157
}
159158
}
@@ -187,7 +186,7 @@ class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
187186
}
188187

189188
/**
190-
* A [ThreadContextElement] that implements copy semantics in [copyForChildCoroutine].
189+
* A [ThreadContextElement] that implements copy semantics in [copyForChild].
191190
*/
192191
class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement<MyData?> {
193192
companion object Key : CoroutineContext.Key<CopyForChildCoroutineElement>
@@ -201,6 +200,10 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle
201200
return oldState
202201
}
203202

203+
override fun mergeForChild(overwritingElement: CoroutineContext.Element): CopyForChildCoroutineElement {
204+
TODO("Not used in tests")
205+
}
206+
204207
override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
205208
myThreadLocal.set(oldState)
206209
}
@@ -216,7 +219,7 @@ class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextEle
216219
* will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the
217220
* thread and calls [restoreThreadContext].
218221
*/
219-
override fun copyForChildCoroutine(): CopyableThreadContextElement<MyData?> {
222+
override fun copyForChild(): CopyForChildCoroutineElement {
220223
return CopyForChildCoroutineElement(myThreadLocal.get())
221224
}
222225
}

0 commit comments

Comments
 (0)