Skip to content

Commit 603bd79

Browse files
authored
Implemented CopyableThreadContextElement with a copyForChildCoroutine(). (#3025)
* This is a `ThreadContextElement` that is copy-constructed when a new coroutine is created and inherits the context. Co-authored-by: Tyson Henning <[email protected]> Fixes #2839
1 parent ae0c842 commit 603bd79

File tree

4 files changed

+223
-2
lines changed

4 files changed

+223
-2
lines changed

kotlinx-coroutines-core/api/kotlinx-coroutines-core.api

+11
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,17 @@ public final class kotlinx/coroutines/CompletionHandlerException : java/lang/Run
140140
public fun <init> (Ljava/lang/String;Ljava/lang/Throwable;)V
141141
}
142142

143+
public abstract interface class kotlinx/coroutines/CopyableThreadContextElement : kotlinx/coroutines/ThreadContextElement {
144+
public abstract fun copyForChildCoroutine ()Lkotlinx/coroutines/CopyableThreadContextElement;
145+
}
146+
147+
public final class kotlinx/coroutines/CopyableThreadContextElement$DefaultImpls {
148+
public static fun fold (Lkotlinx/coroutines/CopyableThreadContextElement;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
149+
public static fun get (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext$Element;
150+
public static fun minusKey (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext;
151+
public static fun plus (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
152+
}
153+
143154
public abstract interface class kotlinx/coroutines/CopyableThrowable {
144155
public abstract fun createCopy ()Ljava/lang/Throwable;
145156
}

kotlinx-coroutines-core/jvm/src/CoroutineContext.kt

+20-1
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,31 @@ import kotlin.coroutines.jvm.internal.CoroutineStackFrame
1616
*/
1717
@ExperimentalCoroutinesApi
1818
public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
19-
val combined = coroutineContext + context
19+
val combined = coroutineContext.foldCopiesForChildCoroutine() + context
2020
val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
2121
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
2222
debug + Dispatchers.Default else debug
2323
}
2424

25+
/**
26+
* Returns the [CoroutineContext] for a child coroutine to inherit.
27+
*
28+
* If any [CopyableThreadContextElement] is in the [this], calls
29+
* [CopyableThreadContextElement.copyForChildCoroutine] on each, returning a new [CoroutineContext]
30+
* by folding the returned copied elements into [this].
31+
*
32+
* Returns [this] if `this` has zero [CopyableThreadContextElement] in it.
33+
*/
34+
private fun CoroutineContext.foldCopiesForChildCoroutine(): CoroutineContext {
35+
val hasToCopy = fold(false) { result, it ->
36+
result || it is CopyableThreadContextElement<*>
37+
}
38+
if (!hasToCopy) return this
39+
return fold<CoroutineContext>(EmptyCoroutineContext) { combined, it ->
40+
combined + if (it is CopyableThreadContextElement<*>) it.copyForChildCoroutine() else it
41+
}
42+
}
43+
2544
/**
2645
* Executes a block using a given coroutine context.
2746
*/

kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt

+63
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,69 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
7777
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
7878
}
7979

80+
/**
81+
* A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it.
82+
*
83+
* When an API uses a _mutable_ `ThreadLocal` for consistency, a [CopyableThreadContextElement]
84+
* can give coroutines "coroutine-safe" write access to that `ThreadLocal`.
85+
*
86+
* A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine
87+
* will be visible to _itself_ and any child coroutine launched _after_ that write.
88+
*
89+
* Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen
90+
* to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_
91+
* launching a child coroutine will not be visible to that child coroutine.
92+
*
93+
* This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and
94+
* correctly, regardless of the coroutine's structured concurrency.
95+
*
96+
* This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace
97+
* is in a coroutine:
98+
*
99+
* ```
100+
* class TraceContextElement(val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
101+
* companion object Key : CoroutineContext.Key<ThreadTraceContextElement>
102+
* override val key: CoroutineContext.Key<ThreadTraceContextElement>
103+
* get() = Key
104+
*
105+
* override fun updateThreadContext(context: CoroutineContext): TraceData? {
106+
* val oldState = traceThreadLocal.get()
107+
* traceThreadLocal.set(data)
108+
* return oldState
109+
* }
110+
*
111+
* override fun restoreThreadContext(context: CoroutineContext, oldData: TraceData?) {
112+
* traceThreadLocal.set(oldState)
113+
* }
114+
*
115+
* override fun copyForChildCoroutine(): CopyableThreadContextElement<MyData?> {
116+
* // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
117+
* // ThreadLocal writes between resumption of the parent coroutine and the launch of the
118+
* // child coroutine visible to the child.
119+
* return CopyForChildCoroutineElement(traceThreadLocal.get())
120+
* }
121+
* }
122+
* ```
123+
*
124+
* A coroutine using this mechanism can safely call Java code that assumes it's called using a
125+
* `Thread`.
126+
*/
127+
@ExperimentalCoroutinesApi
128+
public interface CopyableThreadContextElement<S> : ThreadContextElement<S> {
129+
130+
/**
131+
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
132+
* coroutine's context that is under construction.
133+
*
134+
* This function is called on the element each time a new coroutine inherits a context containing it,
135+
* and the returned value is folded into the context given to the child.
136+
*
137+
* Since this method is called whenever a new coroutine is launched in a context containing this
138+
* [CopyableThreadContextElement], implementations are performance-sensitive.
139+
*/
140+
public fun copyForChildCoroutine(): CopyableThreadContextElement<S>
141+
}
142+
80143
/**
81144
* Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement]
82145
* maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on.

kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt

+129-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ class ThreadContextElementTest : TestBase() {
5454
assertNull(myThreadLocal.get())
5555
}
5656

57-
5857
@Test
5958
fun testWithContext() = runTest {
6059
expect(1)
@@ -86,6 +85,78 @@ class ThreadContextElementTest : TestBase() {
8685

8786
finish(7)
8887
}
88+
89+
@Test
90+
fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest {
91+
var parentElement: MyElement? = null
92+
var inheritedElement: MyElement? = null
93+
94+
newSingleThreadContext("withContext").use {
95+
withContext(it + MyElement(MyData())) {
96+
parentElement = coroutineContext[MyElement.Key]
97+
launch {
98+
inheritedElement = coroutineContext[MyElement.Key]
99+
}
100+
}
101+
}
102+
103+
assertSame(inheritedElement, parentElement,
104+
"Inner and outer coroutines did not have the same object reference to a" +
105+
" ThreadContextElement that did not override `copyForChildCoroutine()`")
106+
}
107+
108+
@Test
109+
fun testCopyableElementCopiedOnLaunch() = runTest {
110+
var parentElement: CopyForChildCoroutineElement? = null
111+
var inheritedElement: CopyForChildCoroutineElement? = null
112+
113+
newSingleThreadContext("withContext").use {
114+
withContext(it + CopyForChildCoroutineElement(MyData())) {
115+
parentElement = coroutineContext[CopyForChildCoroutineElement.Key]
116+
launch {
117+
inheritedElement = coroutineContext[CopyForChildCoroutineElement.Key]
118+
}
119+
}
120+
}
121+
122+
assertNotSame(inheritedElement, parentElement,
123+
"Inner coroutine did not copy its copyable ThreadContextElement.")
124+
}
125+
126+
@Test
127+
fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest {
128+
newFixedThreadPoolContext(nThreads = 4, name = "withContext").use {
129+
val startData = MyData()
130+
withContext(it + CopyForChildCoroutineElement(startData)) {
131+
val forBlockData = MyData()
132+
myThreadLocal.setForBlock(forBlockData) {
133+
assertSame(myThreadLocal.get(), forBlockData)
134+
launch {
135+
assertSame(myThreadLocal.get(), forBlockData)
136+
}
137+
launch {
138+
assertSame(myThreadLocal.get(), forBlockData)
139+
// Modify value in child coroutine. Writes to the ThreadLocal and
140+
// the (copied) ThreadLocalElement's memory are not visible to peer or
141+
// ancestor coroutines, so this write is both threadsafe and coroutinesafe.
142+
val innerCoroutineData = MyData()
143+
myThreadLocal.setForBlock(innerCoroutineData) {
144+
assertSame(myThreadLocal.get(), innerCoroutineData)
145+
}
146+
assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored.
147+
}
148+
launch {
149+
val innerCoroutineData = MyData()
150+
myThreadLocal.setForBlock(innerCoroutineData) {
151+
assertSame(myThreadLocal.get(), innerCoroutineData)
152+
}
153+
assertSame(myThreadLocal.get(), forBlockData)
154+
}
155+
}
156+
assertSame(myThreadLocal.get(), startData) // Asserts value was restored.
157+
}
158+
}
159+
}
89160
}
90161

91162
class MyData
@@ -114,3 +185,60 @@ class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
114185
myThreadLocal.set(oldState)
115186
}
116187
}
188+
189+
/**
190+
* A [ThreadContextElement] that implements copy semantics in [copyForChildCoroutine].
191+
*/
192+
class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement<MyData?> {
193+
companion object Key : CoroutineContext.Key<CopyForChildCoroutineElement>
194+
195+
override val key: CoroutineContext.Key<CopyForChildCoroutineElement>
196+
get() = Key
197+
198+
override fun updateThreadContext(context: CoroutineContext): MyData? {
199+
val oldState = myThreadLocal.get()
200+
myThreadLocal.set(data)
201+
return oldState
202+
}
203+
204+
override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
205+
myThreadLocal.set(oldState)
206+
}
207+
208+
/**
209+
* At coroutine launch time, the _current value of the ThreadLocal_ is inherited by the new
210+
* child coroutine, and that value is copied to a new, unique, ThreadContextElement memory
211+
* reference for the child coroutine to use uniquely.
212+
*
213+
* n.b. the value copied to the child must be the __current value of the ThreadLocal__ and not
214+
* the value initially passed to the ThreadContextElement in order to reflect writes made to the
215+
* ThreadLocal between coroutine resumption and the child coroutine launch point. Those writes
216+
* will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the
217+
* thread and calls [restoreThreadContext].
218+
*/
219+
override fun copyForChildCoroutine(): CopyableThreadContextElement<MyData?> {
220+
return CopyForChildCoroutineElement(myThreadLocal.get())
221+
}
222+
}
223+
224+
/**
225+
* Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block].
226+
*
227+
* When a [CopyForChildCoroutineElement] for `this` [ThreadLocal] is used within a
228+
* [CoroutineContext], a ThreadLocal set this way will have the "correct" value expected lexically
229+
* at every statement reached, whether that statement is reached immediately, across suspend and
230+
* redispatch within one coroutine, or within a child coroutine. Writes made to the `ThreadLocal`
231+
* by child coroutines will not be visible to the parent coroutine. Writes made to the `ThreadLocal`
232+
* by the parent coroutine _after_ launching a child coroutine will not be visible to that child
233+
* coroutine.
234+
*/
235+
private inline fun <ThreadLocalT, OutputT> ThreadLocal<ThreadLocalT>.setForBlock(
236+
value: ThreadLocalT,
237+
crossinline block: () -> OutputT
238+
) {
239+
val priorValue = get()
240+
set(value)
241+
block()
242+
set(priorValue)
243+
}
244+

0 commit comments

Comments
 (0)