|
4 | 4 |
|
5 | 5 | package kotlinx.coroutines
|
6 | 6 |
|
| 7 | +import kotlinx.atomicfu.* |
7 | 8 | import kotlinx.coroutines.channels.*
|
| 9 | +import kotlinx.coroutines.internal.* |
8 | 10 | import kotlin.coroutines.*
|
9 | 11 | import kotlin.native.concurrent.*
|
10 | 12 |
|
@@ -63,28 +65,62 @@ internal class WorkerDispatcher(name: String) : CloseableCoroutineDispatcher(),
|
63 | 65 | }
|
64 | 66 | }
|
65 | 67 |
|
66 |
| -private class MultiWorkerDispatcher(name: String, workersCount: Int) : CloseableCoroutineDispatcher() { |
67 |
| - private val tasksQueue = Channel<Runnable>(Channel.UNLIMITED) |
68 |
| - private val workers = Array(workersCount) { Worker.start(name = "$name-$it") } |
| 68 | +public class MultiWorkerDispatcher( |
| 69 | + private val name: String, private val workersCount: Int |
| 70 | +) : CloseableCoroutineDispatcher() { |
| 71 | + private val runningWorkers = atomic(0) |
| 72 | + private val queue = DispatcherQueue() |
| 73 | + private val workers = atomicArrayOfNulls<Worker>(workersCount) |
| 74 | + private val isTerminated = atomic(false) |
69 | 75 |
|
70 |
| - init { |
71 |
| - workers.forEach { w -> w.executeAfter(0L) { workerRunLoop() } } |
| 76 | + override fun dispatch(context: CoroutineContext, block: Runnable) { |
| 77 | + if (runningWorkers.value != workersCount) { |
| 78 | + tryAddWorker() |
| 79 | + } |
| 80 | + queue.put(block) |
72 | 81 | }
|
73 | 82 |
|
74 |
| - private fun workerRunLoop() = runBlocking { |
75 |
| - for (task in tasksQueue) { |
76 |
| - // TODO error handling |
77 |
| - task.run() |
| 83 | + private fun tryAddWorker() { |
| 84 | + runningWorkers.loop { |
| 85 | + if (it == workersCount) return |
| 86 | + if (runningWorkers.compareAndSet(it, it + 1)) { |
| 87 | + addWorker(it) |
| 88 | + return |
| 89 | + } |
78 | 90 | }
|
79 | 91 | }
|
80 | 92 |
|
81 |
| - override fun dispatch(context: CoroutineContext, block: Runnable) { |
82 |
| - // TODO handle rejections |
83 |
| - tasksQueue.trySend(block) |
| 93 | + private fun addWorker(sequenceNumber: Int) { |
| 94 | + val worker = Worker.start(name = "$name-#$sequenceNumber") |
| 95 | + workers[sequenceNumber].value = worker |
| 96 | + worker.executeAfter(0L) { |
| 97 | + workerLoop() |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + private fun workerLoop() { |
| 102 | + while (!isTerminated.value) { |
| 103 | + val runnable = queue.take() |
| 104 | + runnable.run() |
| 105 | + } |
84 | 106 | }
|
85 | 107 |
|
86 | 108 | override fun close() {
|
87 |
| - tasksQueue.close() |
88 |
| - workers.forEach { it.requestTermination().result } |
| 109 | + // TODO it races with worker creation |
| 110 | + if (!isTerminated.compareAndSet(false, true)) return |
| 111 | + repeat(workersCount) { |
| 112 | + queue.put(Runnable {}) // Empty poison pill to wakeup workers and make them check isTerminated |
| 113 | + } |
| 114 | + |
| 115 | + val requests = ArrayList<Future<Unit>>() |
| 116 | + for (i in 0 until workers.size) { |
| 117 | + val worker = workers[i].value ?: continue |
| 118 | + requests += worker.requestTermination(false) |
| 119 | + } |
| 120 | + for (request in requests) { |
| 121 | + request.result // Wait for workers termination |
| 122 | + } |
| 123 | + |
| 124 | +// queue.close() |
89 | 125 | }
|
90 | 126 | }
|
0 commit comments