Skip to content

Commit 4b0f4de

Browse files
elizarovqwwdfsad
authored andcommitted
Never loose coroutines when CoroutineScheduler is closed
* CoroutineScheduler close/shutdown sequence now guaranteed that a submitted task is either executed by the scheduler that is being concurrently shutdown or RejectedExecutionException is thrown. * RejectedExecutionException is caught by the coroutine dispatcher implementation and coroutines are rescheduled into DefaultExecutor if a (custom) instance of CoroutineScheduler is being shutdown. * Added stress test for coroutine scheduler closing.
1 parent 85725e8 commit 4b0f4de

File tree

6 files changed

+159
-22
lines changed

6 files changed

+159
-22
lines changed

core/kotlinx-coroutines-core/src/internal/LockFreeMPMCQueue.kt

+26-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ import kotlinx.atomicfu.*
99
internal open class LockFreeMPMCQueueNode<T> {
1010
val next = atomic<T?>(null)
1111

12+
// internal declarations for inline functions
1213
@PublishedApi internal val nextValue: T? get() = next.value
14+
@PublishedApi internal fun nextCas(expect: T?, update: T?) = next.compareAndSet(expect, update)
1315
}
1416

1517
/*
@@ -23,9 +25,14 @@ internal open class LockFreeMPMCQueue<T : LockFreeMPMCQueueNode<T>> {
2325
atomic(LockFreeMPMCQueueNode<T>() as T) // sentinel
2426

2527
private val tail = atomic(head.value)
26-
internal val headValue: T get() = head.value
2728

28-
public fun addLast(node: T): Boolean {
29+
// internal declarations for inline functions
30+
@PublishedApi internal val headValue: T get() = head.value
31+
@PublishedApi internal val tailValue: T get() = tail.value
32+
@PublishedApi internal fun headCas(curHead: T, update: T) = head.compareAndSet(curHead, update)
33+
@PublishedApi internal fun tailCas(curTail: T, update: T) = tail.compareAndSet(curTail, update)
34+
35+
public fun addLast(node: T) {
2936
tail.loop { curTail ->
3037
val curNext = curTail.next.value
3138
if (curNext != null) {
@@ -34,6 +41,22 @@ internal open class LockFreeMPMCQueue<T : LockFreeMPMCQueueNode<T>> {
3441
}
3542
if (curTail.next.compareAndSet(null, node)) {
3643
tail.compareAndSet(curTail, node)
44+
return
45+
}
46+
}
47+
}
48+
49+
public inline fun addLastIfPrev(node: T, predicate: (prev: Any) -> Boolean): Boolean {
50+
while(true) {
51+
val curTail = tailValue
52+
val curNext = curTail.nextValue
53+
if (curNext != null) {
54+
tailCas(curTail, curNext)
55+
continue // retry
56+
}
57+
if (!predicate(curTail)) return false
58+
if (curTail.nextCas(null, node)) {
59+
tailCas(curTail, node)
3760
return true
3861
}
3962
}
@@ -48,9 +71,7 @@ internal open class LockFreeMPMCQueue<T : LockFreeMPMCQueueNode<T>> {
4871
}
4972
}
5073

51-
fun headCas(curHead: T, update: T) = head.compareAndSet(curHead, update)
52-
53-
public inline fun removeFirstOrNullIf(predicate: (T) -> Boolean): T? {
74+
public inline fun removeFirstOrNullIf(predicate: (first: T) -> Boolean): T? {
5475
while (true) {
5576
val curHead = headValue
5677
val next = curHead.nextValue ?: return null

core/kotlinx-coroutines-core/src/scheduling/CoroutineScheduler.kt

+27-11
Original file line numberDiff line numberDiff line change
@@ -301,23 +301,29 @@ internal class CoroutineScheduler(
301301
val currentWorker = Thread.currentThread() as? Worker
302302
// Capture # of created workers that cannot change anymore (mind the synchronized block!)
303303
val created = synchronized(workers) { createdWorkers }
304+
// Shutdown all workers with the only exception of the current thread
304305
for (i in 1..created) {
305306
val worker = workers[i]!!
306-
if (worker.isAlive && worker !== currentWorker) {
307-
LockSupport.unpark(worker)
308-
worker.join(timeout)
307+
if (worker !== currentWorker) {
308+
while (worker.isAlive) {
309+
LockSupport.unpark(worker)
310+
worker.join(timeout)
311+
}
312+
val state = worker.state
313+
check(state === WorkerState.TERMINATED) { "Expected TERMINATED state, but found $state"}
309314
worker.localQueue.offloadAllWork(globalQueue)
310315
}
311-
312316
}
317+
// Make sure no more work is added to GlobalQueue from anywhere
318+
check(globalQueue.add(CLOSED_TASK)) { "GlobalQueue could not be closed yet" }
313319
// Finish processing tasks from globalQueue and/or from this worker's local queue
314320
while (true) {
315-
val task = currentWorker?.findTask() ?: globalQueue.removeFirstOrNull() ?: break
321+
val task = currentWorker?.findTask() ?: globalQueue.removeFirstIfNotClosed() ?: break
316322
runSafely(task)
317323
}
318324
// Shutdown current thread
319325
currentWorker?.tryReleaseCpu(WorkerState.TERMINATED)
320-
// cleanup state to make sure that tryUnpark tries to create new threads and fails because isTerminated
326+
// check & cleanup state
321327
assert(cpuPermits.availablePermits() == corePoolSize)
322328
parkedWorkersStack.value = 0L
323329
controlState.value = 0L
@@ -339,8 +345,12 @@ internal class CoroutineScheduler(
339345
when (submitToLocalQueue(task, fair)) {
340346
ADDED -> return
341347
NOT_ADDED -> {
342-
globalQueue.addLast(task) // offload task to local queue
343-
requestCpuWorker() // ask for help
348+
// try to offload task to global queue
349+
if (!globalQueue.add(task)) {
350+
// Global queue is closed in the last step of close/shutdown -- no more tasks should be accepted
351+
throw RejectedExecutionException("$schedulerName was terminated")
352+
}
353+
requestCpuWorker()
344354
}
345355
else -> requestCpuWorker() // ask for help
346356
}
@@ -439,7 +449,7 @@ internal class CoroutineScheduler(
439449
private fun createNewWorker(): Int {
440450
synchronized(workers) {
441451
// Make sure we're not trying to resurrect terminated scheduler
442-
if (isTerminated) throw RejectedExecutionException("$schedulerName was terminated")
452+
if (isTerminated) return -1
443453
val state = controlState.value
444454
val created = createdWorkers(state)
445455
val blocking = blockingWorkers(state)
@@ -464,6 +474,12 @@ internal class CoroutineScheduler(
464474
?: return NOT_ADDED
465475
if (worker.scheduler !== this) return NOT_ADDED // different scheduler's worker (!!!)
466476

477+
/*
478+
* This worker could have been already terminated from this thread by close/shutdown and it should not
479+
* accept any more tasks into its local queue.
480+
*/
481+
if (worker.state === WorkerState.TERMINATED) return NOT_ADDED
482+
467483
var result = ADDED
468484
if (task.mode == TaskMode.NON_BLOCKING) {
469485
/*
@@ -923,9 +939,9 @@ internal class CoroutineScheduler(
923939
* once per two core pool size iterations
924940
*/
925941
val globalFirst = nextInt(2 * corePoolSize) == 0
926-
if (globalFirst) globalQueue.removeFirstOrNull()?.let { return it }
942+
if (globalFirst) globalQueue.removeFirstIfNotClosed()?.let { return it }
927943
localQueue.poll()?.let { return it }
928-
if (!globalFirst) globalQueue.removeFirstOrNull()?.let { return it }
944+
if (!globalFirst) globalQueue.removeFirstIfNotClosed()?.let { return it }
929945
return trySteal()
930946
}
931947

core/kotlinx-coroutines-core/src/scheduling/Dispatcher.kt

+15-3
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,18 @@ open class ExperimentalCoroutineDispatcher(
4747
private var coroutineScheduler = createScheduler()
4848

4949
override fun dispatch(context: CoroutineContext, block: Runnable): Unit =
50-
coroutineScheduler.dispatch(block)
50+
try {
51+
coroutineScheduler.dispatch(block)
52+
} catch (e: RejectedExecutionException) {
53+
DefaultExecutor.dispatch(context, block)
54+
}
5155

5256
override fun dispatchYield(context: CoroutineContext, block: Runnable): Unit =
53-
coroutineScheduler.dispatch(block, fair = true)
57+
try {
58+
coroutineScheduler.dispatch(block, fair = true)
59+
} catch (e: RejectedExecutionException) {
60+
DefaultExecutor.dispatchYield(context, block)
61+
}
5462

5563
override fun close() = coroutineScheduler.close()
5664

@@ -84,7 +92,11 @@ open class ExperimentalCoroutineDispatcher(
8492
}
8593

8694
internal fun dispatchWithContext(block: Runnable, context: TaskContext, fair: Boolean): Unit =
87-
coroutineScheduler.dispatch(block, context, fair)
95+
try {
96+
coroutineScheduler.dispatch(block, context, fair)
97+
} catch (e: RejectedExecutionException) {
98+
DefaultExecutor.execute(block)
99+
}
88100

89101
private fun createScheduler() = CoroutineScheduler(corePoolSize, maxPoolSize, idleWorkerKeepAliveNs)
90102

core/kotlinx-coroutines-core/src/scheduling/Tasks.kt

+11
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,19 @@ internal class Task(
101101
"Task[${block.classSimpleName}@${block.hexAddress}, $submissionTime, $taskContext]"
102102
}
103103

104+
private val EMPTY_RUNNABLE = Runnable {}
105+
internal val CLOSED_TASK = Task(EMPTY_RUNNABLE, 0, NonBlockingContext)
106+
104107
// Open for tests
105108
internal open class GlobalQueue : LockFreeMPMCQueue<Task>() {
109+
// Returns false when GlobalQueue was was already closed
110+
public fun add(task: Task): Boolean =
111+
addLastIfPrev(task) { prev -> prev !== CLOSED_TASK }
112+
113+
// Returns null when GlobalQueue was was already closed
114+
public fun removeFirstIfNotClosed(): Task? =
115+
removeFirstOrNullIf { first -> first !== CLOSED_TASK }
116+
106117
// Open for tests
107118
public open fun removeFirstBlockingModeOrNull(): Task? =
108119
removeFirstOrNullIf { it.mode == TaskMode.PROBABLY_BLOCKING }

core/kotlinx-coroutines-core/src/scheduling/WorkQueue.kt

+12-3
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,23 @@ internal class WorkQueue {
140140
private fun offloadWork(globalQueue: GlobalQueue) {
141141
repeat((bufferSize / 2).coerceAtLeast(1)) {
142142
val task = pollExternal() ?: return
143-
globalQueue.addLast(task)
143+
addToGlobalQueue(globalQueue, task)
144144
}
145145
}
146146

147+
private fun addToGlobalQueue(globalQueue: GlobalQueue, task: Task) {
148+
/*
149+
* globalQueue is closed as the very last step in the shutdown sequence when all worker threads had
150+
* been already shutdown (with the only exception of the last worker thread that might be performing
151+
* shutdown procedure itself). As a consistency check we do a [cheap!] check that it is not closed here yet.
152+
*/
153+
check(globalQueue.add(task)) { "GlobalQueue could not be closed yet" }
154+
}
155+
147156
internal fun offloadAllWork(globalQueue: GlobalQueue) {
157+
lastScheduledTask.getAndSet(null)?.let { addToGlobalQueue(globalQueue, it) }
148158
while (true) {
149-
val task = pollExternal() ?: return
150-
globalQueue.addLast(task)
159+
addToGlobalQueue(globalQueue, pollExternal() ?: return)
151160
}
152161
}
153162

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.experimental.scheduling
6+
7+
import kotlinx.atomicfu.*
8+
import kotlinx.coroutines.experimental.*
9+
import org.junit.Test
10+
import java.util.*
11+
import kotlin.test.*
12+
13+
class CoroutineSchedulerCloseStressTest : TestBase() {
14+
private val N_REPEAT = 2 * stressTestMultiplier
15+
private val MAX_LEVEL = 5
16+
private val N_COROS = (1 shl (MAX_LEVEL + 1)) - 1
17+
private val N_THREADS = 4
18+
private val rnd = Random()
19+
20+
private lateinit var dispatcher: ExecutorCoroutineDispatcher
21+
private var closeIndex = -1
22+
23+
private val started = atomic(0)
24+
private val finished = atomic(0)
25+
26+
@Test
27+
fun testNormalClose() {
28+
try {
29+
launchCoroutines()
30+
} finally {
31+
dispatcher.close()
32+
}
33+
}
34+
35+
@Test
36+
fun testRacingClose() {
37+
repeat(N_REPEAT) {
38+
closeIndex = rnd.nextInt(N_COROS)
39+
launchCoroutines()
40+
}
41+
}
42+
43+
private fun launchCoroutines() = runBlocking {
44+
dispatcher = ExperimentalCoroutineDispatcher(N_THREADS)
45+
started.value = 0
46+
finished.value = 0
47+
withContext(dispatcher) {
48+
launchChild(0, 0)
49+
}
50+
assertEquals(N_COROS, started.value)
51+
assertEquals(N_COROS, finished.value)
52+
}
53+
54+
private fun CoroutineScope.launchChild(index: Int, level: Int): Job = launch(start = CoroutineStart.ATOMIC) {
55+
started.incrementAndGet()
56+
try {
57+
if (index == closeIndex) dispatcher.close()
58+
if (level < MAX_LEVEL) {
59+
launchChild(2 * index + 1, level + 1)
60+
launchChild(2 * index + 2, level + 1)
61+
} else {
62+
delay(1000)
63+
}
64+
} finally {
65+
finished.incrementAndGet()
66+
}
67+
}
68+
}

0 commit comments

Comments
 (0)