diff --git a/kotlinx-coroutines-core/common/src/CloseableCoroutineDispatcher.kt b/kotlinx-coroutines-core/common/src/CloseableCoroutineDispatcher.kt index 9c6703291a..541b3082e2 100644 --- a/kotlinx-coroutines-core/common/src/CloseableCoroutineDispatcher.kt +++ b/kotlinx-coroutines-core/common/src/CloseableCoroutineDispatcher.kt @@ -19,8 +19,8 @@ public expect abstract class CloseableCoroutineDispatcher() : CoroutineDispatche /** * Initiate the closing sequence of the coroutine dispatcher. - * After a successful call to [close], no new tasks will - * be accepted to be [dispatched][dispatch], but the previously dispatched tasks will be run. + * After a successful call to [close], no new tasks will be accepted to be [dispatched][dispatch]. + * The previously-submitted tasks will still be run, but [close] is not guaranteed to wait for them to finish. * * Invocations of `close` are idempotent and thread-safe. */ diff --git a/kotlinx-coroutines-core/concurrent/test/MultithreadedDispatcherStressTest.kt b/kotlinx-coroutines-core/concurrent/test/MultithreadedDispatcherStressTest.kt new file mode 100644 index 0000000000..4e4583f20a --- /dev/null +++ b/kotlinx-coroutines-core/concurrent/test/MultithreadedDispatcherStressTest.kt @@ -0,0 +1,35 @@ +/* + * Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlinx.atomicfu.* +import kotlin.coroutines.* +import kotlin.test.* + +class MultithreadedDispatcherStressTest { + val shared = atomic(0) + + /** + * Tests that [newFixedThreadPoolContext] will not drop tasks when closed. + */ + @Test + fun testClosingNotDroppingTasks() { + repeat(7) { + shared.value = 0 + val nThreads = it + 1 + val dispatcher = newFixedThreadPoolContext(nThreads, "testMultiThreadedContext") + repeat(1_000) { + dispatcher.dispatch(EmptyCoroutineContext, Runnable { + shared.incrementAndGet() + }) + } + dispatcher.close() + while (shared.value < 1_000) { + // spin. + // the test will hang here if the dispatcher drops tasks. + } + } + } +} diff --git a/kotlinx-coroutines-core/native/src/MultithreadedDispatchers.kt b/kotlinx-coroutines-core/native/src/MultithreadedDispatchers.kt index bf91e7003b..0012ff65db 100644 --- a/kotlinx-coroutines-core/native/src/MultithreadedDispatchers.kt +++ b/kotlinx-coroutines-core/native/src/MultithreadedDispatchers.kt @@ -4,6 +4,7 @@ package kotlinx.coroutines +import kotlinx.atomicfu.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.internal.* import kotlin.coroutines.* @@ -73,38 +74,78 @@ private class MultiWorkerDispatcher( workersCount: Int ) : CloseableCoroutineDispatcher() { private val tasksQueue = Channel(Channel.UNLIMITED) + private val availableWorkers = Channel>(Channel.UNLIMITED) private val workerPool = OnDemandAllocatingPool(workersCount) { Worker.start(name = "$name-$it").apply { executeAfter { workerRunLoop() } } } + /** + * (number of tasks - number of workers) * 2 + (1 if closed) + */ + private val tasksAndWorkersCounter = atomic(0L) + + private inline fun Long.isClosed() = this and 1L == 1L + private inline fun Long.hasTasks() = this >= 2 + private inline fun Long.hasWorkers() = this < 0 + private fun workerRunLoop() = runBlocking { - // NB: we leverage tail-call optimization in this loop, do not replace it with - // .receive() without proper evaluation - for (task in tasksQueue) { - /** - * Any unhandled exception here will pass through worker's boundary and will be properly reported. - */ - task.run() + while (true) { + val state = tasksAndWorkersCounter.getAndUpdate { + if (it.isClosed() && !it.hasTasks()) return@runBlocking + it - 2 + } + if (state.hasTasks()) { + // we promised to process a task, and there are some + tasksQueue.receive().run() + } else { + try { + suspendCancellableCoroutine { + val result = availableWorkers.trySend(it) + checkChannelResult(result) + }.run() + } catch (e: CancellationException) { + /** we are cancelled from [close] and thus will never get back to this branch of code, + but there may still be pending work, so we can't just exit here. */ + } + } } } + // a worker that promised to be here and should actually arrive, so we wait for it in a blocking manner. + private fun obtainWorker(): CancellableContinuation = + availableWorkers.tryReceive().getOrNull() ?: runBlocking { availableWorkers.receive() } + override fun dispatch(context: CoroutineContext, block: Runnable) { - fun throwClosed(block: Runnable) { - throw IllegalStateException("Dispatcher $name was closed, attempted to schedule: $block") + val state = tasksAndWorkersCounter.getAndUpdate { + if (it.isClosed()) + throw IllegalStateException("Dispatcher $name was closed, attempted to schedule: $block") + it + 2 } - - if (!workerPool.allocate()) throwClosed(block) // Do not even try to send to avoid race - - tasksQueue.trySend(block).onClosed { - throwClosed(block) + if (state.hasWorkers()) { + // there are workers that have nothing to do, let's grab one of them + obtainWorker().resume(block) + } else { + workerPool.allocate() + // no workers are available, we must queue the task + val result = tasksQueue.trySend(block) + checkChannelResult(result) } } override fun close() { - val workers = workerPool.close() - tasksQueue.close() + tasksAndWorkersCounter.getAndUpdate { if (it.isClosed()) it else it or 1L } + val workers = workerPool.close() // no new workers will be created + while (true) { + // check if there are workers that await tasks in their personal channels, we need to wake them up + val state = tasksAndWorkersCounter.getAndUpdate { + if (it.hasWorkers()) it + 2 else it + } + if (!state.hasWorkers()) + break + obtainWorker().cancel() + } /* * Here we cannot avoid waiting on `.result`, otherwise it will lead * to a native memory leak, including a pthread handle. @@ -112,4 +153,12 @@ private class MultiWorkerDispatcher( val requests = workers.map { it.requestTermination() } requests.map { it.result } } + + private fun checkChannelResult(result: ChannelResult<*>) { + if (!result.isSuccess) + throw IllegalStateException( + "Internal invariants of $this were violated, please file a bug to kotlinx.coroutines", + result.exceptionOrNull() + ) + } } diff --git a/kotlinx-coroutines-core/native/test/MultithreadedDispatchersTest.kt b/kotlinx-coroutines-core/native/test/MultithreadedDispatchersTest.kt new file mode 100644 index 0000000000..ce433cc3e3 --- /dev/null +++ b/kotlinx-coroutines-core/native/test/MultithreadedDispatchersTest.kt @@ -0,0 +1,66 @@ +/* + * Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlinx.atomicfu.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.internal.* +import kotlin.native.concurrent.* +import kotlin.test.* + +private class BlockingBarrier(val n: Int) { + val counter = atomic(0) + val wakeUp = Channel(n - 1) + fun await() { + val count = counter.addAndGet(1) + if (count == n) { + repeat(n - 1) { + runBlocking { + wakeUp.send(Unit) + } + } + } else if (count < n) { + runBlocking { + wakeUp.receive() + } + } + } +} + +class MultithreadedDispatchersTest { + /** + * Test that [newFixedThreadPoolContext] does not allocate more dispatchers than it needs to. + * Incidentally also tests that it will allocate enough workers for its needs. Otherwise, the test will hang. + */ + @Test + fun testNotAllocatingExtraDispatchers() { + val barrier = BlockingBarrier(2) + val lock = SynchronizedObject() + suspend fun spin(set: MutableSet) { + repeat(100) { + synchronized(lock) { set.add(Worker.current) } + delay(1) + } + } + val dispatcher = newFixedThreadPoolContext(64, "test") + try { + runBlocking { + val encounteredWorkers = mutableSetOf() + val coroutine1 = launch(dispatcher) { + barrier.await() + spin(encounteredWorkers) + } + val coroutine2 = launch(dispatcher) { + barrier.await() + spin(encounteredWorkers) + } + listOf(coroutine1, coroutine2).joinAll() + assertEquals(2, encounteredWorkers.size) + } + } finally { + dispatcher.close() + } + } +}