Skip to content

Commit ecd36dd

Browse files
committed
LimitedDispatcher fixes
* Support dispatchYield * Fix doc * Short-circuit limitedParallelism(x).limitedParallelism(y) for y >= x
1 parent 00122c5 commit ecd36dd

File tree

3 files changed

+61
-16
lines changed

3 files changed

+61
-16
lines changed

kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public abstract class CoroutineDispatcher :
8787
* private val fileWriterDispatcher = backgroundDispatcher.limitedParallelism(1)
8888
* ```
8989
* Note how in this example, the application have the executor with 4 threads, but the total sum of all limits
90-
* is 5. Yet at most 4 coroutines can be executed simultaneously as each view limits only its own parallelism.
90+
* is 6. Yet at most 4 coroutines can be executed simultaneously as each view limits only its own parallelism.
9191
*/
9292
@ExperimentalCoroutinesApi
9393
public open fun limitedParallelism(parallelism: Int): CoroutineDispatcher {

kotlinx-coroutines-core/common/src/internal/LimitedDispatcher.kt

+38-14
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ internal class LimitedDispatcher(
2323

2424
private val queue = LockFreeTaskQueue<Runnable>(singleConsumer = false)
2525

26-
@InternalCoroutinesApi
27-
override fun dispatchYield(context: CoroutineContext, block: Runnable) {
28-
dispatcher.dispatchYield(context, block)
26+
@ExperimentalCoroutinesApi
27+
override fun limitedParallelism(parallelism: Int): CoroutineDispatcher {
28+
parallelism.checkParallelism()
29+
if (parallelism >= this.parallelism) return this
30+
return super.limitedParallelism(parallelism)
2931
}
3032

3133
override fun run() {
@@ -59,25 +61,47 @@ internal class LimitedDispatcher(
5961
}
6062

6163
override fun dispatch(context: CoroutineContext, block: Runnable) {
62-
// Add task to queue so running workers will be able to see that
63-
queue.addLast(block)
64-
if (runningWorkers >= parallelism) {
65-
return
64+
dispatchInternal(block) {
65+
if (dispatcher.isDispatchNeeded(EmptyCoroutineContext)) {
66+
dispatcher.dispatch(EmptyCoroutineContext, this)
67+
} else {
68+
run()
69+
}
70+
}
71+
}
72+
73+
@InternalCoroutinesApi
74+
override fun dispatchYield(context: CoroutineContext, block: Runnable) {
75+
dispatchInternal(block) {
76+
dispatcher.dispatchYield(context, this)
6677
}
78+
}
6779

80+
private inline fun dispatchInternal(block: Runnable, dispatch: () -> Unit) {
81+
// Add task to queue so running workers will be able to see that
82+
if (tryAdd(block)) return
6883
/*
69-
* Protect against race when the worker is finished right after our check.
84+
* Protect against the race when the number of workers is enough,
85+
* but one (because of synchronized serialization) attempts to complete,
86+
* and we just observed the number of running workers smaller than the actual
87+
* number (hit right between `--runningWorkers` and `++runningWorkers` in `run()`)
7088
*/
89+
if (enoughWorkers()) return
90+
dispatch()
91+
}
92+
93+
private fun enoughWorkers(): Boolean {
7194
@Suppress("CAST_NEVER_SUCCEEDS")
7295
synchronized(this as SynchronizedObject) {
73-
if (runningWorkers >= parallelism) return
96+
if (runningWorkers >= parallelism) return true
7497
++runningWorkers
98+
return false
7599
}
76-
if (dispatcher.isDispatchNeeded(EmptyCoroutineContext)) {
77-
dispatcher.dispatch(EmptyCoroutineContext, this)
78-
} else {
79-
run()
80-
}
100+
}
101+
102+
private fun tryAdd(block: Runnable): Boolean {
103+
queue.addLast(block)
104+
return runningWorkers >= parallelism
81105
}
82106
}
83107

kotlinx-coroutines-core/jvm/test/LimitedParallelismStressTest.kt

+22-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class LimitedParallelismStressTest(private val targetParallelism: Int) : TestBas
3333
}
3434

3535
@Test
36-
fun testLimited() = runTest {
36+
fun testLimitedExecutor() = runTest {
3737
val view = executor.limitedParallelism(targetParallelism)
3838
repeat(iterations) {
3939
launch(view) {
@@ -42,6 +42,27 @@ class LimitedParallelismStressTest(private val targetParallelism: Int) : TestBas
4242
}
4343
}
4444

45+
@Test
46+
fun testLimitedDispatchersIo() = runTest {
47+
val view = Dispatchers.IO.limitedParallelism(targetParallelism)
48+
repeat(iterations) {
49+
launch(view) {
50+
checkParallelism()
51+
}
52+
}
53+
}
54+
55+
@Test
56+
fun testLimitedDispatchersIoDispatchYield() = runTest {
57+
val view = Dispatchers.IO.limitedParallelism(targetParallelism)
58+
repeat(iterations) {
59+
launch(view) {
60+
yield()
61+
checkParallelism()
62+
}
63+
}
64+
}
65+
4566
@Test
4667
fun testUnconfined() = runTest {
4768
val view = Dispatchers.Unconfined.limitedParallelism(targetParallelism)

0 commit comments

Comments
 (0)