@@ -76,40 +76,78 @@ class ThreadContextElementTest : TestBase() {
76
76
assertNull(threadContextElementThreadLocal.get())
77
77
}
78
78
79
+ class JobCaptor (val capturees : MutableList <String > = mutableListOf()) : ThreadContextElement<Unit> {
80
+
81
+ companion object Key : CoroutineContext.Key<MyElement>
82
+
83
+ override val key: CoroutineContext .Key <* > get() = Key
84
+
85
+ override fun updateThreadContext (context : CoroutineContext ) {
86
+ capturees.add(" Update: ${context.job} " )
87
+ }
88
+
89
+ override fun restoreThreadContext (context : CoroutineContext , oldState : Unit ) {
90
+ capturees.add(" Restore: ${context.job} " )
91
+ }
92
+ }
93
+
94
+ /* *
95
+ * For stability of the test, it is important to make sure that
96
+ * the parent job actually suspends when calling
97
+ * `withContext(dispatcher2 + CoroutineName("dispatched"))`.
98
+ *
99
+ * Here this requirement is fulfilled by forcing execution on a single thread.
100
+ * However, dispatching is performed with two non-equal dispatchers to force dispatching.
101
+ *
102
+ * Suspend of the parent coroutine [kotlinx.coroutines.DispatchedCoroutine.trySuspend] is out of the control of the test,
103
+ * while being executed concurrently with resume of the child coroutine [kotlinx.coroutines.DispatchedCoroutine.tryResume].
104
+ */
79
105
@Test
80
106
fun testWithContextJobAccess () = runTest {
107
+ // Emulate non-equal dispatchers
108
+ val dispatcher = Dispatchers .Default .limitedParallelism(1 )
109
+ val dispatcher1 = dispatcher.limitedParallelism(1 , " dispatcher1" )
110
+ val dispatcher2 = dispatcher.limitedParallelism(1 , " dispatcher2" )
81
111
val captor = JobCaptor ()
82
- val manuallyCaptured = ArrayList <Job >()
83
- withContext(captor) {
84
- manuallyCaptured + = coroutineContext.job
112
+ val manuallyCaptured = mutableListOf<String >()
113
+
114
+ fun registerUpdate (job : Job ? ) = manuallyCaptured.add(" Update: $job " )
115
+ fun registerRestore (job : Job ? ) = manuallyCaptured.add(" Restore: $job " )
116
+
117
+ var rootJob: Job ? = null
118
+ withContext(captor + dispatcher1) {
119
+ rootJob = coroutineContext.job
120
+ registerUpdate(rootJob)
121
+ var undispatchedJob: Job ? = null
85
122
withContext(CoroutineName (" undispatched" )) {
86
- manuallyCaptured + = coroutineContext.job
87
- withContext(Dispatchers .Default ) {
88
- manuallyCaptured + = coroutineContext.job
123
+ undispatchedJob = coroutineContext.job
124
+ registerUpdate(undispatchedJob)
125
+ // These 2 restores and the corresponding next 2 updates happen only if the following `withContext`
126
+ // call actually suspends.
127
+ registerRestore(undispatchedJob)
128
+ registerRestore(rootJob)
129
+ // Without forcing of single backing thread the code inside `withContext`
130
+ // may already complete at the moment when the parent coroutine decides
131
+ // whether it needs to suspend or not.
132
+ var dispatchedJob: Job ? = null
133
+ withContext(dispatcher2 + CoroutineName (" dispatched" )) {
134
+ dispatchedJob = coroutineContext.job
135
+ registerUpdate(dispatchedJob)
89
136
}
137
+ registerRestore(dispatchedJob)
90
138
// Context restored, captured again
91
- manuallyCaptured + = coroutineContext.job
139
+ registerUpdate(undispatchedJob)
92
140
}
141
+ registerRestore(undispatchedJob)
93
142
// Context restored, captured again
94
- manuallyCaptured + = coroutineContext.job
143
+ registerUpdate(rootJob)
95
144
}
96
- assertEquals(manuallyCaptured, captor.capturees)
97
- }
98
- }
99
-
100
- private class JobCaptor () : ThreadContextElement<Unit> {
101
-
102
- val capturees: MutableList <Job > = mutableListOf ()
103
-
104
- companion object Key : CoroutineContext.Key<MyElement>
105
-
106
- override val key: CoroutineContext .Key <* > get() = Key
107
-
108
- override fun updateThreadContext (context : CoroutineContext ) {
109
- capturees.add(context.job)
110
- }
145
+ registerRestore(rootJob)
111
146
112
- override fun restoreThreadContext (context : CoroutineContext , oldState : Unit ) {
147
+ // Restores may be called concurrently to the update calls in other threads, so their order is not checked.
148
+ val expected = manuallyCaptured.filter { it.startsWith(" Update: " ) }.joinToString(separator = " \n " )
149
+ val actual = captor.capturees.filter { it.startsWith(" Update: " ) }.joinToString(separator = " \n " )
150
+ assertEquals(expected, actual)
113
151
}
114
152
}
115
153
0 commit comments