Skip to content

Commit 5eb01e9

Browse files
committed
Introduce CoroutineContextThreadLocal API to integrate with thread-local sensitive code
Fixes #119
1 parent 9d31ffc commit 5eb01e9

File tree

8 files changed

+225
-25
lines changed

8 files changed

+225
-25
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+5-2
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,11 @@ public final class kotlinx/coroutines/experimental/CoroutineContextKt {
140140
public static final fun newCoroutineContext (Lkotlin/coroutines/experimental/CoroutineContext;)Lkotlin/coroutines/experimental/CoroutineContext;
141141
public static final fun newCoroutineContext (Lkotlin/coroutines/experimental/CoroutineContext;Lkotlinx/coroutines/experimental/Job;)Lkotlin/coroutines/experimental/CoroutineContext;
142142
public static synthetic fun newCoroutineContext$default (Lkotlin/coroutines/experimental/CoroutineContext;Lkotlinx/coroutines/experimental/Job;ILjava/lang/Object;)Lkotlin/coroutines/experimental/CoroutineContext;
143-
public static final fun restoreThreadContext (Ljava/lang/String;)V
144-
public static final fun updateThreadContext (Lkotlin/coroutines/experimental/CoroutineContext;)Ljava/lang/String;
143+
}
144+
145+
public abstract interface class kotlinx/coroutines/experimental/CoroutineContextThreadLocal {
146+
public abstract fun restoreThreadContext (Lkotlin/coroutines/experimental/CoroutineContext;Ljava/lang/Object;)V
147+
public abstract fun updateThreadContext (Lkotlin/coroutines/experimental/CoroutineContext;)Ljava/lang/Object;
145148
}
146149

147150
public abstract class kotlinx/coroutines/experimental/CoroutineDispatcher : kotlin/coroutines/experimental/AbstractCoroutineContextElement, kotlin/coroutines/experimental/ContinuationInterceptor {

build.gradle

+1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ configure(subprojects.findAll { !it.name.contains(sourceless) && it.name != "ben
105105
main.kotlin.srcDirs = ['src']
106106
test.kotlin.srcDirs = ['test']
107107
main.resources.srcDirs = ['resources']
108+
test.resources.srcDirs = ['test-resources']
108109
}
109110
}
110111

core/kotlinx-coroutines-core/src/CoroutineContext.kt

+33-22
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package kotlinx.coroutines.experimental
66

7+
import java.util.*
78
import java.util.concurrent.atomic.AtomicLong
89
import kotlin.coroutines.experimental.AbstractCoroutineContextElement
910
import kotlin.coroutines.experimental.ContinuationInterceptor
@@ -40,6 +41,17 @@ internal val DEBUG = run {
4041
}
4142
}
4243

44+
@Suppress("UNCHECKED_CAST")
45+
internal val coroutineContextThreadLocal: CoroutineContextThreadLocal<Any?>? = run {
46+
val services = ServiceLoader.load(CoroutineContextThreadLocal::class.java).toMutableList()
47+
if (DEBUG) services.add(0, DebugThreadName)
48+
when (services.size) {
49+
0 -> null
50+
1 -> services.single() as CoroutineContextThreadLocal<Any?>
51+
else -> CoroutineContextThreadLocalList((services as List<CoroutineContextThreadLocal<Any?>>).toTypedArray())
52+
}
53+
}
54+
4355
private val COROUTINE_ID = AtomicLong()
4456

4557
// for tests only
@@ -89,29 +101,33 @@ public actual fun newCoroutineContext(context: CoroutineContext, parent: Job? =
89101
* Executes a block using a given coroutine context.
90102
*/
91103
internal actual inline fun <T> withCoroutineContext(context: CoroutineContext, block: () -> T): T {
92-
val oldName = context.updateThreadContext()
104+
val oldValue = coroutineContextThreadLocal?.updateThreadContext(context)
93105
try {
94106
return block()
95107
} finally {
96-
restoreThreadContext(oldName)
108+
coroutineContextThreadLocal?.restoreThreadContext(context, oldValue)
97109
}
98110
}
99111

100-
@PublishedApi
101-
internal fun CoroutineContext.updateThreadContext(): String? {
102-
if (!DEBUG) return null
103-
val coroutineId = this[CoroutineId] ?: return null
104-
val coroutineName = this[CoroutineName]?.name ?: "coroutine"
105-
val currentThread = Thread.currentThread()
106-
val oldName = currentThread.name
107-
currentThread.name = buildString(oldName.length + coroutineName.length + 10) {
108-
append(oldName)
109-
append(" @")
110-
append(coroutineName)
111-
append('#')
112-
append(coroutineId.id)
112+
private object DebugThreadName : CoroutineContextThreadLocal<String?> {
113+
override fun updateThreadContext(context: CoroutineContext): String? {
114+
val coroutineId = context[CoroutineId] ?: return null
115+
val coroutineName = context[CoroutineName]?.name ?: "coroutine"
116+
val currentThread = Thread.currentThread()
117+
val oldName = currentThread.name
118+
currentThread.name = buildString(oldName.length + coroutineName.length + 10) {
119+
append(oldName)
120+
append(" @")
121+
append(coroutineName)
122+
append('#')
123+
append(coroutineId.id)
124+
}
125+
return oldName
126+
}
127+
128+
override fun restoreThreadContext(context: CoroutineContext, oldValue: String?) {
129+
if (oldValue != null) Thread.currentThread().name = oldValue
113130
}
114-
return oldName
115131
}
116132

117133
internal actual val CoroutineContext.coroutineName: String? get() {
@@ -121,12 +137,7 @@ internal actual val CoroutineContext.coroutineName: String? get() {
121137
return "$coroutineName#${coroutineId.id}"
122138
}
123139

124-
@PublishedApi
125-
internal fun restoreThreadContext(oldName: String?) {
126-
if (oldName != null) Thread.currentThread().name = oldName
127-
}
128-
129-
private class CoroutineId(val id: Long) : AbstractCoroutineContextElement(CoroutineId) {
140+
internal data class CoroutineId(val id: Long) : AbstractCoroutineContextElement(CoroutineId) {
130141
companion object Key : CoroutineContext.Key<CoroutineId>
131142
override fun toString(): String = "CoroutineId($id)"
132143
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.experimental
6+
7+
import kotlin.coroutines.experimental.*
8+
9+
/**
10+
* An extension point to define elements in [CoroutineContext] that are installed into thread local
11+
* variables every time the coroutine from the specified context in resumed on a thread.
12+
*
13+
* Implementations on this interface are looked up via [java.util.ServiceLoader].
14+
*
15+
* Example usage looks like this:
16+
*
17+
* ```
18+
* // declare custom coroutine context element
19+
* class MyElement : AbstractCoroutineContextElement(Key) {
20+
* companion object Key : CoroutineContext.Key<MyElement>
21+
* // some state is kept here
22+
* }
23+
*
24+
* // declare thread local variable
25+
* private val myThreadLocal = ThreadLocal<MyElement?>()
26+
*
27+
* // declare extension point implementation
28+
* class MyCoroutineContextThreadLocal : CoroutineContextThreadLocal<MyElement?> {
29+
* // this is invoked before coroutine is resumed on current thread
30+
* override fun updateThreadContext(context: CoroutineContext): MyElement? {
31+
* val oldValue = myThreadLocal.get()
32+
* myThreadLocal.set(context[MyElement])
33+
* return oldValue
34+
* }
35+
*
36+
* // this is invoked after coroutine has suspended on current thread
37+
* override fun restoreThreadContext(context: CoroutineContext, oldValue: MyElement?) {
38+
* myThreadLocal.set(oldValue)
39+
* }
40+
* }
41+
* ```
42+
*
43+
* Now, `MyCoroutineContextThreadLocal` fully qualified class named shall be registered via
44+
* `META-INF/services/kotlinx.coroutines.experimental.CoroutineContextThreadLocal` file.
45+
*/
46+
public interface CoroutineContextThreadLocal<T> {
47+
/**
48+
* Updates context of the current thread.
49+
* This function is invoked before the coroutine in the specified [context] is resumed in the current thread.
50+
* The result of this function is the old value that will be passed to [restoreThreadContext].
51+
*/
52+
public fun updateThreadContext(context: CoroutineContext): T
53+
54+
/**
55+
* Restores context of the current thread.
56+
* This function is invoked after the coroutine in the specified [context] is suspended in the current thread.
57+
* The value of [oldValue] is the result of the previous invocation of [updateThreadContext].
58+
*/
59+
public fun restoreThreadContext(context: CoroutineContext, oldValue: T)
60+
}
61+
62+
/**
63+
* This class is used when multiple [CoroutineContextThreadLocal] are installed.
64+
*/
65+
internal class CoroutineContextThreadLocalList(
66+
private val impls: Array<CoroutineContextThreadLocal<Any?>>
67+
) : CoroutineContextThreadLocal<Any?> {
68+
init {
69+
require(impls.size > 1)
70+
}
71+
72+
private val threadLocalStack = ThreadLocal<ArrayList<Any?>?>()
73+
74+
override fun updateThreadContext(context: CoroutineContext): Any? {
75+
val stack = threadLocalStack.get() ?: ArrayList<Any?>().also {
76+
threadLocalStack.set(it)
77+
}
78+
val lastIndex = impls.lastIndex
79+
for (i in 0 until lastIndex) {
80+
stack.add(impls[i].updateThreadContext(context))
81+
}
82+
return impls[lastIndex].updateThreadContext(context)
83+
}
84+
85+
override fun restoreThreadContext(context: CoroutineContext, oldValue: Any?) {
86+
val stack = threadLocalStack.get()!! // must be there
87+
val lastIndex = impls.lastIndex
88+
impls[lastIndex].restoreThreadContext(context, oldValue)
89+
for (i in lastIndex - 1 downTo 0) {
90+
impls[i].restoreThreadContext(context, stack.removeAt(stack.lastIndex))
91+
}
92+
}
93+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
kotlinx.coroutines.experimental.MyCoroutineContextThreadLocal
2+
kotlinx.coroutines.experimental.ValidatingCoroutineContextThreadLocal
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.experimental
6+
7+
import org.junit.Test
8+
import kotlin.coroutines.experimental.*
9+
import kotlin.test.*
10+
11+
class CoroutineContextThreadLocalTest : TestBase() {
12+
@Test
13+
fun testExample() = runTest {
14+
val mainDispatcher = coroutineContext[ContinuationInterceptor]!!
15+
val mainThread = Thread.currentThread()
16+
val element = MyElement()
17+
assertEquals(null, myThreadLocal.get())
18+
val job = launch(element) {
19+
assertTrue(mainThread != Thread.currentThread())
20+
assertSame(element, coroutineContext[MyElement])
21+
assertSame(element, myThreadLocal.get())
22+
withContext(mainDispatcher) {
23+
assertSame(mainThread, Thread.currentThread())
24+
assertSame(element, coroutineContext[MyElement])
25+
assertSame(element, myThreadLocal.get())
26+
}
27+
assertTrue(mainThread != Thread.currentThread())
28+
assertSame(element, coroutineContext[MyElement])
29+
assertSame(element, myThreadLocal.get())
30+
}
31+
assertEquals(null, myThreadLocal.get())
32+
job.join()
33+
assertEquals(null, myThreadLocal.get())
34+
}
35+
}
36+
37+
// declare custom coroutine context element
38+
class MyElement : AbstractCoroutineContextElement(Key) {
39+
companion object Key : CoroutineContext.Key<MyElement>
40+
// some state is kept here
41+
}
42+
43+
// declare thread local variable
44+
private val myThreadLocal = ThreadLocal<MyElement?>()
45+
46+
// declare extension point implementation
47+
class MyCoroutineContextThreadLocal : CoroutineContextThreadLocal<MyElement?> {
48+
// this is invoked before coroutine is resumed on current thread
49+
override fun updateThreadContext(context: CoroutineContext): MyElement? {
50+
val oldValue = myThreadLocal.get()
51+
myThreadLocal.set(context[MyElement])
52+
return oldValue
53+
}
54+
55+
// this is invoked after coroutine has suspended on current thread
56+
override fun restoreThreadContext(context: CoroutineContext, oldValue: MyElement?) {
57+
myThreadLocal.set(oldValue)
58+
}
59+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.experimental
6+
7+
import kotlin.coroutines.experimental.*
8+
9+
private val currentCoroutineId = ThreadLocal<CoroutineId?>()
10+
11+
internal class ValidatingCoroutineContextThreadLocal : CoroutineContextThreadLocal<CoroutineId?> {
12+
override fun updateThreadContext(context: CoroutineContext): CoroutineId? {
13+
val id = context[CoroutineId] ?: error("Tests should be run in debug mode (enable assertions?)")
14+
val top = currentCoroutineId.get()
15+
require( top != id) {
16+
"Thread ${Thread.currentThread().name} already has coroutine context for coroutine $context"
17+
}
18+
currentCoroutineId.set(id)
19+
return top
20+
}
21+
22+
override fun restoreThreadContext(context: CoroutineContext, oldValue: CoroutineId?) {
23+
val id = context[CoroutineId]
24+
val top = currentCoroutineId.get()
25+
require(top == id) {
26+
"Thread ${Thread.currentThread().name} does not have coroutine context for coroutine $context, but has for coroutine id $top"
27+
}
28+
currentCoroutineId.set(oldValue)
29+
}
30+
}

integration/kotlinx-coroutines-quasar/src/Quasar.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ fun <T> runFiberBlocking(block: suspend () -> T): T =
4343
private class CoroutineAsync<T>(
4444
private val block: suspend () -> T
4545
) : FiberAsync<T, Throwable>(), Continuation<T> {
46-
override val context: CoroutineContext = Fiber.currentFiber().scheduler.executor.asCoroutineDispatcher()
46+
override val context: CoroutineContext =
47+
newCoroutineContext(Fiber.currentFiber().scheduler.executor.asCoroutineDispatcher())
4748
override fun resume(value: T) { asyncCompleted(value) }
4849
override fun resumeWithException(exception: Throwable) { asyncFailed(exception) }
4950

0 commit comments

Comments
 (0)