Skip to content

Don't allocate threads on every dispatch in Native's thread pools #3595

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ public expect abstract class CloseableCoroutineDispatcher() : CoroutineDispatche

/**
* Initiate the closing sequence of the coroutine dispatcher.
* After a successful call to [close], no new tasks will
* be accepted to be [dispatched][dispatch], but the previously dispatched tasks will be run.
* After a successful call to [close], no new tasks will be accepted to be [dispatched][dispatch].
* The previously-submitted tasks will still be run before the call to [close] is finished.
*
* Invocations of `close` are idempotent and thread-safe.
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines

import kotlinx.atomicfu.*
import kotlin.coroutines.*
import kotlin.test.*

class MultithreadedDispatcherStressTest {
val shared = atomic(0)

/**
* Tests that [newFixedThreadPoolContext] will not drop tasks when closed.
*/
@Test
fun testClosingNotDroppingTasks() {
repeat(7) {
shared.value = 0
val nThreads = it + 1
val dispatcher = newFixedThreadPoolContext(nThreads, "testMultiThreadedContext")
repeat(1_000) {
dispatcher.dispatch(EmptyCoroutineContext, Runnable {
shared.incrementAndGet()
})
}
dispatcher.close()
val m = shared.value
assertEquals(1_000, m, "$nThreads threads")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ private class ClosedAfterGuideTestDispatcher(
}

override fun close() {
(executor as ExecutorService).shutdown()
(executor as ExecutorService).apply {
shutdown()
awaitTermination(1, TimeUnit.MINUTES)
}
}

override fun toString(): String = "ThreadPoolDispatcher[$nThreads, $name]"
Expand Down
70 changes: 54 additions & 16 deletions kotlinx-coroutines-core/native/src/MultithreadedDispatchers.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package kotlinx.coroutines

import kotlinx.atomicfu.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
Expand Down Expand Up @@ -73,38 +74,75 @@ private class MultiWorkerDispatcher(
workersCount: Int
) : CloseableCoroutineDispatcher() {
private val tasksQueue = Channel<Runnable>(Channel.UNLIMITED)
private val availableWorkers = Channel<Channel<Runnable>>(Channel.UNLIMITED)
private val workerPool = OnDemandAllocatingPool(workersCount) {
Worker.start(name = "$name-$it").apply {
executeAfter { workerRunLoop() }
}
}

/**
* (number of tasks - number of workers) * 2 + (1 if closed)
*/
private val tasksAndWorkersCounter = atomic(0L)

private inline fun Long.isClosed() = this and 1L == 1L
private inline fun Long.hasTasks() = this >= 2
private inline fun Long.hasWorkers() = this < 0

private fun workerRunLoop() = runBlocking {
// NB: we leverage tail-call optimization in this loop, do not replace it with
// .receive() without proper evaluation
for (task in tasksQueue) {
/**
* Any unhandled exception here will pass through worker's boundary and will be properly reported.
*/
task.run()
val privateChannel = Channel<Runnable>(1)
while (true) {
val state = tasksAndWorkersCounter.getAndUpdate {
if (it.isClosed() && !it.hasTasks()) return@runBlocking
it - 2
}
if (state.hasTasks()) {
// we promised to process a task, and there are some
tasksQueue.receive().run()
} else {
availableWorkers.send(privateChannel)
privateChannel.receiveCatching().getOrNull()?.run()
}
}
}

override fun dispatch(context: CoroutineContext, block: Runnable) {
fun throwClosed(block: Runnable) {
throw IllegalStateException("Dispatcher $name was closed, attempted to schedule: $block")
private fun obtainWorker(): Channel<Runnable> {
// spin loop until a worker that promised to be here actually arrives.
while (true) {
val result = availableWorkers.tryReceive()
return result.getOrNull() ?: continue
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit reluctant about having potentially infinite spinloops in case of un unforeseen bug or failure.
It would be nice to replace it with (non-existing) receiveBlocking which could've been a direct counterpart of sendBlocking

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

}
}

if (!workerPool.allocate()) throwClosed(block) // Do not even try to send to avoid race

tasksQueue.trySend(block).onClosed {
throwClosed(block)
override fun dispatch(context: CoroutineContext, block: Runnable) {
val state = tasksAndWorkersCounter.getAndUpdate {
if (it.isClosed())
throw IllegalStateException("Dispatcher $name was closed, attempted to schedule: $block")
it + 2
}
if (state.hasWorkers()) {
// there are workers that have nothing to do, let's grab one of them
obtainWorker().trySend(block)
} else {
workerPool.allocate()
// no workers are available, we must queue the task
tasksQueue.trySend(block)
}
}

override fun close() {
val workers = workerPool.close()
tasksQueue.close()
tasksAndWorkersCounter.getAndUpdate { if (it.isClosed()) it else it or 1L }
val workers = workerPool.close() // no new workers will be created
while (true) {
// check if there are workers that await tasks in their personal channels, we need to wake them up
val state = tasksAndWorkersCounter.getAndUpdate {
if (it.hasWorkers()) it + 2 else it
}
if (!state.hasWorkers())
break
obtainWorker().close()
}
/*
* Here we cannot avoid waiting on `.result`, otherwise it will lead
* to a native memory leak, including a pthread handle.
Expand Down
35 changes: 35 additions & 0 deletions kotlinx-coroutines-core/native/test/WorkerTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,39 @@ class WorkerTest : TestBase() {
finished.receive()
}
}

/**
* Test that [newFixedThreadPoolContext] does not allocate more dispatchers than it needs to.
* Incidentally also tests that it will allocate enough workers for its needs. Otherwise, the test will hang.
*/
@Test
fun testNotAllocatingExtraDispatchers() {
suspend fun spin(set: MutableSet<Worker>) {
repeat(100) {
set.add(Worker.current)
delay(1)
}
}
val dispatcher = newFixedThreadPoolContext(64, "test")
try {
runBlocking {
val encounteredWorkers = mutableSetOf<Worker>()
var canStart = false
val coroutine1 = launch(dispatcher) {
while (!canStart) {
// intentionally empty
}
spin(encounteredWorkers)
}
val coroutine2 = launch(dispatcher) {
canStart = true
spin(encounteredWorkers)
}
listOf(coroutine1, coroutine2).joinAll()
assertEquals(2, encounteredWorkers.size)
}
} finally {
dispatcher.close()
}
}
}