Skip to content

Pr/2982 #3025

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 17, 2021
Merged

Pr/2982 #3025

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,17 @@ public final class kotlinx/coroutines/CompletionHandlerException : java/lang/Run
public fun <init> (Ljava/lang/String;Ljava/lang/Throwable;)V
}

public abstract interface class kotlinx/coroutines/CopyableThreadContextElement : kotlinx/coroutines/ThreadContextElement {
public abstract fun copyForChildCoroutine ()Lkotlinx/coroutines/CopyableThreadContextElement;
}

public final class kotlinx/coroutines/CopyableThreadContextElement$DefaultImpls {
public static fun fold (Lkotlinx/coroutines/CopyableThreadContextElement;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)Ljava/lang/Object;
public static fun get (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext$Element;
public static fun minusKey (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext;
public static fun plus (Lkotlinx/coroutines/CopyableThreadContextElement;Lkotlin/coroutines/CoroutineContext;)Lkotlin/coroutines/CoroutineContext;
}

public abstract interface class kotlinx/coroutines/CopyableThrowable {
public abstract fun createCopy ()Ljava/lang/Throwable;
}
Expand Down
21 changes: 20 additions & 1 deletion kotlinx-coroutines-core/jvm/src/CoroutineContext.kt
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,31 @@ import kotlin.coroutines.jvm.internal.CoroutineStackFrame
*/
@ExperimentalCoroutinesApi
public actual fun CoroutineScope.newCoroutineContext(context: CoroutineContext): CoroutineContext {
val combined = coroutineContext + context
val combined = coroutineContext.foldCopiesForChildCoroutine() + context
val debug = if (DEBUG) combined + CoroutineId(COROUTINE_ID.incrementAndGet()) else combined
return if (combined !== Dispatchers.Default && combined[ContinuationInterceptor] == null)
debug + Dispatchers.Default else debug
}

/**
* Returns the [CoroutineContext] for a child coroutine to inherit.
*
* If any [CopyableThreadContextElement] is in the [this], calls
* [CopyableThreadContextElement.copyForChildCoroutine] on each, returning a new [CoroutineContext]
* by folding the returned copied elements into [this].
*
* Returns [this] if `this` has zero [CopyableThreadContextElement] in it.
*/
private fun CoroutineContext.foldCopiesForChildCoroutine(): CoroutineContext {
val hasToCopy = fold(false) { result, it ->
result || it is CopyableThreadContextElement<*>
}
if (!hasToCopy) return this
return fold<CoroutineContext>(EmptyCoroutineContext) { combined, it ->
combined + if (it is CopyableThreadContextElement<*>) it.copyForChildCoroutine() else it
}
}

/**
* Executes a block using a given coroutine context.
*/
Expand Down
63 changes: 63 additions & 0 deletions kotlinx-coroutines-core/jvm/src/ThreadContextElement.kt
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,69 @@ public interface ThreadContextElement<S> : CoroutineContext.Element {
public fun restoreThreadContext(context: CoroutineContext, oldState: S)
}

/**
* A [ThreadContextElement] copied whenever a child coroutine inherits a context containing it.
*
* When an API uses a _mutable_ `ThreadLocal` for consistency, a [CopyableThreadContextElement]
* can give coroutines "coroutine-safe" write access to that `ThreadLocal`.
*
* A write made to a `ThreadLocal` with a matching [CopyableThreadContextElement] by a coroutine
* will be visible to _itself_ and any child coroutine launched _after_ that write.
*
* Writes will not be visible to the parent coroutine, peer coroutines, or coroutines that happen
* to use the same thread. Writes made to the `ThreadLocal` by the parent coroutine _after_
* launching a child coroutine will not be visible to that child coroutine.
*
* This can be used to allow a coroutine to use a mutable ThreadLocal API transparently and
* correctly, regardless of the coroutine's structured concurrency.
*
* This example adapts a `ThreadLocal` method trace to be "coroutine local" while the method trace
* is in a coroutine:
*
* ```
* class TraceContextElement(val traceData: TraceData?) : CopyableThreadContextElement<TraceData?> {
* companion object Key : CoroutineContext.Key<ThreadTraceContextElement>
* override val key: CoroutineContext.Key<ThreadTraceContextElement>
* get() = Key
*
* override fun updateThreadContext(context: CoroutineContext): TraceData? {
* val oldState = traceThreadLocal.get()
* traceThreadLocal.set(data)
* return oldState
* }
*
* override fun restoreThreadContext(context: CoroutineContext, oldData: TraceData?) {
* traceThreadLocal.set(oldState)
* }
*
* override fun copyForChildCoroutine(): CopyableThreadContextElement<MyData?> {
* // Copy from the ThreadLocal source of truth at child coroutine launch time. This makes
* // ThreadLocal writes between resumption of the parent coroutine and the launch of the
* // child coroutine visible to the child.
* return CopyForChildCoroutineElement(traceThreadLocal.get())
* }
* }
* ```
*
* A coroutine using this mechanism can safely call Java code that assumes it's called using a
* `Thread`.
*/
@ExperimentalCoroutinesApi
public interface CopyableThreadContextElement<S> : ThreadContextElement<S> {

/**
* Returns a [CopyableThreadContextElement] to replace `this` `CopyableThreadContextElement` in the child
* coroutine's context that is under construction.
*
* This function is called on the element each time a new coroutine inherits a context containing it,
* and the returned value is folded into the context given to the child.
*
* Since this method is called whenever a new coroutine is launched in a context containing this
* [CopyableThreadContextElement], implementations are performance-sensitive.
*/
public fun copyForChildCoroutine(): CopyableThreadContextElement<S>
}

/**
* Wraps [ThreadLocal] into [ThreadContextElement]. The resulting [ThreadContextElement]
* maintains the given [value] of the given [ThreadLocal] for coroutine regardless of the actual thread its is resumed on.
Expand Down
130 changes: 129 additions & 1 deletion kotlinx-coroutines-core/jvm/test/ThreadContextElementTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class ThreadContextElementTest : TestBase() {
assertNull(myThreadLocal.get())
}


@Test
fun testWithContext() = runTest {
expect(1)
Expand Down Expand Up @@ -86,6 +85,78 @@ class ThreadContextElementTest : TestBase() {

finish(7)
}

@Test
fun testNonCopyableElementReferenceInheritedOnLaunch() = runTest {
var parentElement: MyElement? = null
var inheritedElement: MyElement? = null

newSingleThreadContext("withContext").use {
withContext(it + MyElement(MyData())) {
parentElement = coroutineContext[MyElement.Key]
launch {
inheritedElement = coroutineContext[MyElement.Key]
}
}
}

assertSame(inheritedElement, parentElement,
"Inner and outer coroutines did not have the same object reference to a" +
" ThreadContextElement that did not override `copyForChildCoroutine()`")
}

@Test
fun testCopyableElementCopiedOnLaunch() = runTest {
var parentElement: CopyForChildCoroutineElement? = null
var inheritedElement: CopyForChildCoroutineElement? = null

newSingleThreadContext("withContext").use {
withContext(it + CopyForChildCoroutineElement(MyData())) {
parentElement = coroutineContext[CopyForChildCoroutineElement.Key]
launch {
inheritedElement = coroutineContext[CopyForChildCoroutineElement.Key]
}
}
}

assertNotSame(inheritedElement, parentElement,
"Inner coroutine did not copy its copyable ThreadContextElement.")
}

@Test
fun testCopyableThreadContextElementImplementsWriteVisibility() = runTest {
newFixedThreadPoolContext(nThreads = 4, name = "withContext").use {
val startData = MyData()
withContext(it + CopyForChildCoroutineElement(startData)) {
val forBlockData = MyData()
myThreadLocal.setForBlock(forBlockData) {
assertSame(myThreadLocal.get(), forBlockData)
launch {
assertSame(myThreadLocal.get(), forBlockData)
}
launch {
assertSame(myThreadLocal.get(), forBlockData)
// Modify value in child coroutine. Writes to the ThreadLocal and
// the (copied) ThreadLocalElement's memory are not visible to peer or
// ancestor coroutines, so this write is both threadsafe and coroutinesafe.
val innerCoroutineData = MyData()
myThreadLocal.setForBlock(innerCoroutineData) {
assertSame(myThreadLocal.get(), innerCoroutineData)
}
assertSame(myThreadLocal.get(), forBlockData) // Asserts value was restored.
}
launch {
val innerCoroutineData = MyData()
myThreadLocal.setForBlock(innerCoroutineData) {
assertSame(myThreadLocal.get(), innerCoroutineData)
}
assertSame(myThreadLocal.get(), forBlockData)
}
}
assertSame(myThreadLocal.get(), startData) // Asserts value was restored.
}
}
}
}

class MyData
Expand Down Expand Up @@ -114,3 +185,60 @@ class MyElement(val data: MyData) : ThreadContextElement<MyData?> {
myThreadLocal.set(oldState)
}
}

/**
* A [ThreadContextElement] that implements copy semantics in [copyForChildCoroutine].
*/
class CopyForChildCoroutineElement(val data: MyData?) : CopyableThreadContextElement<MyData?> {
companion object Key : CoroutineContext.Key<CopyForChildCoroutineElement>

override val key: CoroutineContext.Key<CopyForChildCoroutineElement>
get() = Key

override fun updateThreadContext(context: CoroutineContext): MyData? {
val oldState = myThreadLocal.get()
myThreadLocal.set(data)
return oldState
}

override fun restoreThreadContext(context: CoroutineContext, oldState: MyData?) {
myThreadLocal.set(oldState)
}

/**
* At coroutine launch time, the _current value of the ThreadLocal_ is inherited by the new
* child coroutine, and that value is copied to a new, unique, ThreadContextElement memory
* reference for the child coroutine to use uniquely.
*
* n.b. the value copied to the child must be the __current value of the ThreadLocal__ and not
* the value initially passed to the ThreadContextElement in order to reflect writes made to the
* ThreadLocal between coroutine resumption and the child coroutine launch point. Those writes
* will be reflected in the parent coroutine's [CopyForChildCoroutineElement] when it yields the
* thread and calls [restoreThreadContext].
*/
override fun copyForChildCoroutine(): CopyableThreadContextElement<MyData?> {
return CopyForChildCoroutineElement(myThreadLocal.get())
}
}

/**
* Calls [block], setting the value of [this] [ThreadLocal] for the duration of [block].
*
* When a [CopyForChildCoroutineElement] for `this` [ThreadLocal] is used within a
* [CoroutineContext], a ThreadLocal set this way will have the "correct" value expected lexically
* at every statement reached, whether that statement is reached immediately, across suspend and
* redispatch within one coroutine, or within a child coroutine. Writes made to the `ThreadLocal`
* by child coroutines will not be visible to the parent coroutine. Writes made to the `ThreadLocal`
* by the parent coroutine _after_ launching a child coroutine will not be visible to that child
* coroutine.
*/
private inline fun <ThreadLocalT, OutputT> ThreadLocal<ThreadLocalT>.setForBlock(
value: ThreadLocalT,
crossinline block: () -> OutputT
) {
val priorValue = get()
set(value)
block()
set(priorValue)
}