Skip to content

Commit e946cd7

Browse files
authored
Don't allocate threads on every dispatch in Native's thread pools (#3595)
Related to #3576
1 parent 32af157 commit e946cd7

File tree

4 files changed

+168
-18
lines changed

4 files changed

+168
-18
lines changed

kotlinx-coroutines-core/common/src/CloseableCoroutineDispatcher.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ public expect abstract class CloseableCoroutineDispatcher() : CoroutineDispatche
1919

2020
/**
2121
* Initiate the closing sequence of the coroutine dispatcher.
22-
* After a successful call to [close], no new tasks will
23-
* be accepted to be [dispatched][dispatch], but the previously dispatched tasks will be run.
22+
* After a successful call to [close], no new tasks will be accepted to be [dispatched][dispatch].
23+
* The previously-submitted tasks will still be run, but [close] is not guaranteed to wait for them to finish.
2424
*
2525
* Invocations of `close` are idempotent and thread-safe.
2626
*/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/*
2+
* Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import kotlinx.atomicfu.*
8+
import kotlin.coroutines.*
9+
import kotlin.test.*
10+
11+
class MultithreadedDispatcherStressTest {
12+
val shared = atomic(0)
13+
14+
/**
15+
* Tests that [newFixedThreadPoolContext] will not drop tasks when closed.
16+
*/
17+
@Test
18+
fun testClosingNotDroppingTasks() {
19+
repeat(7) {
20+
shared.value = 0
21+
val nThreads = it + 1
22+
val dispatcher = newFixedThreadPoolContext(nThreads, "testMultiThreadedContext")
23+
repeat(1_000) {
24+
dispatcher.dispatch(EmptyCoroutineContext, Runnable {
25+
shared.incrementAndGet()
26+
})
27+
}
28+
dispatcher.close()
29+
while (shared.value < 1_000) {
30+
// spin.
31+
// the test will hang here if the dispatcher drops tasks.
32+
}
33+
}
34+
}
35+
}

kotlinx-coroutines-core/native/src/MultithreadedDispatchers.kt

+65-16
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package kotlinx.coroutines
66

7+
import kotlinx.atomicfu.*
78
import kotlinx.coroutines.channels.*
89
import kotlinx.coroutines.internal.*
910
import kotlin.coroutines.*
@@ -73,43 +74,91 @@ private class MultiWorkerDispatcher(
7374
workersCount: Int
7475
) : CloseableCoroutineDispatcher() {
7576
private val tasksQueue = Channel<Runnable>(Channel.UNLIMITED)
77+
private val availableWorkers = Channel<CancellableContinuation<Runnable>>(Channel.UNLIMITED)
7678
private val workerPool = OnDemandAllocatingPool(workersCount) {
7779
Worker.start(name = "$name-$it").apply {
7880
executeAfter { workerRunLoop() }
7981
}
8082
}
8183

84+
/**
85+
* (number of tasks - number of workers) * 2 + (1 if closed)
86+
*/
87+
private val tasksAndWorkersCounter = atomic(0L)
88+
89+
private inline fun Long.isClosed() = this and 1L == 1L
90+
private inline fun Long.hasTasks() = this >= 2
91+
private inline fun Long.hasWorkers() = this < 0
92+
8293
private fun workerRunLoop() = runBlocking {
83-
// NB: we leverage tail-call optimization in this loop, do not replace it with
84-
// .receive() without proper evaluation
85-
for (task in tasksQueue) {
86-
/**
87-
* Any unhandled exception here will pass through worker's boundary and will be properly reported.
88-
*/
89-
task.run()
94+
while (true) {
95+
val state = tasksAndWorkersCounter.getAndUpdate {
96+
if (it.isClosed() && !it.hasTasks()) return@runBlocking
97+
it - 2
98+
}
99+
if (state.hasTasks()) {
100+
// we promised to process a task, and there are some
101+
tasksQueue.receive().run()
102+
} else {
103+
try {
104+
suspendCancellableCoroutine {
105+
val result = availableWorkers.trySend(it)
106+
checkChannelResult(result)
107+
}.run()
108+
} catch (e: CancellationException) {
109+
/** we are cancelled from [close] and thus will never get back to this branch of code,
110+
but there may still be pending work, so we can't just exit here. */
111+
}
112+
}
90113
}
91114
}
92115

116+
// a worker that promised to be here and should actually arrive, so we wait for it in a blocking manner.
117+
private fun obtainWorker(): CancellableContinuation<Runnable> =
118+
availableWorkers.tryReceive().getOrNull() ?: runBlocking { availableWorkers.receive() }
119+
93120
override fun dispatch(context: CoroutineContext, block: Runnable) {
94-
fun throwClosed(block: Runnable) {
95-
throw IllegalStateException("Dispatcher $name was closed, attempted to schedule: $block")
121+
val state = tasksAndWorkersCounter.getAndUpdate {
122+
if (it.isClosed())
123+
throw IllegalStateException("Dispatcher $name was closed, attempted to schedule: $block")
124+
it + 2
96125
}
97-
98-
if (!workerPool.allocate()) throwClosed(block) // Do not even try to send to avoid race
99-
100-
tasksQueue.trySend(block).onClosed {
101-
throwClosed(block)
126+
if (state.hasWorkers()) {
127+
// there are workers that have nothing to do, let's grab one of them
128+
obtainWorker().resume(block)
129+
} else {
130+
workerPool.allocate()
131+
// no workers are available, we must queue the task
132+
val result = tasksQueue.trySend(block)
133+
checkChannelResult(result)
102134
}
103135
}
104136

105137
override fun close() {
106-
val workers = workerPool.close()
107-
tasksQueue.close()
138+
tasksAndWorkersCounter.getAndUpdate { if (it.isClosed()) it else it or 1L }
139+
val workers = workerPool.close() // no new workers will be created
140+
while (true) {
141+
// check if there are workers that await tasks in their personal channels, we need to wake them up
142+
val state = tasksAndWorkersCounter.getAndUpdate {
143+
if (it.hasWorkers()) it + 2 else it
144+
}
145+
if (!state.hasWorkers())
146+
break
147+
obtainWorker().cancel()
148+
}
108149
/*
109150
* Here we cannot avoid waiting on `.result`, otherwise it will lead
110151
* to a native memory leak, including a pthread handle.
111152
*/
112153
val requests = workers.map { it.requestTermination() }
113154
requests.map { it.result }
114155
}
156+
157+
private fun checkChannelResult(result: ChannelResult<*>) {
158+
if (!result.isSuccess)
159+
throw IllegalStateException(
160+
"Internal invariants of $this were violated, please file a bug to kotlinx.coroutines",
161+
result.exceptionOrNull()
162+
)
163+
}
115164
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import kotlinx.atomicfu.*
8+
import kotlinx.coroutines.channels.*
9+
import kotlinx.coroutines.internal.*
10+
import kotlin.native.concurrent.*
11+
import kotlin.test.*
12+
13+
private class BlockingBarrier(val n: Int) {
14+
val counter = atomic(0)
15+
val wakeUp = Channel<Unit>(n - 1)
16+
fun await() {
17+
val count = counter.addAndGet(1)
18+
if (count == n) {
19+
repeat(n - 1) {
20+
runBlocking {
21+
wakeUp.send(Unit)
22+
}
23+
}
24+
} else if (count < n) {
25+
runBlocking {
26+
wakeUp.receive()
27+
}
28+
}
29+
}
30+
}
31+
32+
class MultithreadedDispatchersTest {
33+
/**
34+
* Test that [newFixedThreadPoolContext] does not allocate more dispatchers than it needs to.
35+
* Incidentally also tests that it will allocate enough workers for its needs. Otherwise, the test will hang.
36+
*/
37+
@Test
38+
fun testNotAllocatingExtraDispatchers() {
39+
val barrier = BlockingBarrier(2)
40+
val lock = SynchronizedObject()
41+
suspend fun spin(set: MutableSet<Worker>) {
42+
repeat(100) {
43+
synchronized(lock) { set.add(Worker.current) }
44+
delay(1)
45+
}
46+
}
47+
val dispatcher = newFixedThreadPoolContext(64, "test")
48+
try {
49+
runBlocking {
50+
val encounteredWorkers = mutableSetOf<Worker>()
51+
val coroutine1 = launch(dispatcher) {
52+
barrier.await()
53+
spin(encounteredWorkers)
54+
}
55+
val coroutine2 = launch(dispatcher) {
56+
barrier.await()
57+
spin(encounteredWorkers)
58+
}
59+
listOf(coroutine1, coroutine2).joinAll()
60+
assertEquals(2, encounteredWorkers.size)
61+
}
62+
} finally {
63+
dispatcher.close()
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)