Skip to content

Commit 1a3a29c

Browse files
committed
Introduce CoroutineDispatcher.limitedParallelism
1 parent 3f459d5 commit 1a3a29c

9 files changed

+228
-1
lines changed

kotlinx-coroutines-core/api/kotlinx-coroutines-core.api

+2
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ public abstract class kotlinx/coroutines/CoroutineDispatcher : kotlin/coroutines
156156
public fun get (Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext$Element;
157157
public final fun interceptContinuation (Lkotlin/coroutines/Continuation;)Lkotlin/coroutines/Continuation;
158158
public fun isDispatchNeeded (Lkotlin/coroutines/CoroutineContext;)Z
159+
public fun limitedParallelism (I)Lkotlinx/coroutines/CoroutineDispatcher;
159160
public fun minusKey (Lkotlin/coroutines/CoroutineContext$Key;)Lkotlin/coroutines/CoroutineContext;
160161
public final fun plus (Lkotlinx/coroutines/CoroutineDispatcher;)Lkotlinx/coroutines/CoroutineDispatcher;
161162
public final fun releaseInterceptedContinuation (Lkotlin/coroutines/Continuation;)V
@@ -446,6 +447,7 @@ public class kotlinx/coroutines/JobSupport : kotlinx/coroutines/ChildJob, kotlin
446447
public abstract class kotlinx/coroutines/MainCoroutineDispatcher : kotlinx/coroutines/CoroutineDispatcher {
447448
public fun <init> ()V
448449
public abstract fun getImmediate ()Lkotlinx/coroutines/MainCoroutineDispatcher;
450+
public fun limitedParallelism (I)Lkotlinx/coroutines/CoroutineDispatcher;
449451
public fun toString ()Ljava/lang/String;
450452
protected final fun toStringInternalImpl ()Ljava/lang/String;
451453
}

kotlinx-coroutines-core/common/src/CoroutineDispatcher.kt

+34
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,40 @@ public abstract class CoroutineDispatcher :
6161
*/
6262
public open fun isDispatchNeeded(context: CoroutineContext): Boolean = true
6363

64+
/**
65+
* Creates a view of the current dispatcher that limits the parallelism to the given [value][parallelism].
66+
* The resulting view uses the original dispatcher for execution, but with the guarantee that
67+
* no more than [parallelism] coroutines are executed at the same time.
68+
*
69+
* This method does not impose restrictions on the number of views or the total sum of parallelism values,
70+
* each view controls its own parallelism independently with the guarantee that the effective parallelism
71+
* of all views cannot exceed the actual parallelism of the original dispatcher.
72+
*
73+
* ### Limitations
74+
*
75+
* The default implementation of `limitedParallelism` does not support direct dispatchers,
76+
* such as calls to [dispatch] execute the given runnable in place. For direct dispatchers,
77+
* it is recommended to override this method and provide a domain-specific implementation.
78+
*
79+
* ### Example of usage
80+
* ```
81+
* private val backgroundDispatcher = newFixedThreadPoolContext(4, "App Background")
82+
* // At most 2 threads will be processing images as it is really slow and CPU-intensive
83+
* private val imageProcessingDispatcher = backgroundDispatcher.limitedParallelism(2)
84+
* // At most 3 threads will be processing JSON to avoid image processing starvation
85+
* private val imageProcessingDispatcher = backgroundDispatcher.limitedParallelism(3)
86+
* // At most 1 thread will be doing IO
87+
* private val fileWriterDispatcher = backgroundDispatcher.limitedParallelism(1)
88+
* ```
89+
* Note how in this example, the application have the executor with 4 threads, but the total sum of all limits
90+
* is 5. Yet at most 4 coroutines can be executed simultaneously as each view limits only its own parallelism.
91+
*/
92+
@ExperimentalCoroutinesApi
93+
public open fun limitedParallelism(parallelism: Int): CoroutineDispatcher {
94+
parallelism.checkParallelism()
95+
return LimitedDispatcher(this, parallelism)
96+
}
97+
6498
/**
6599
* Dispatches execution of a runnable [block] onto another thread in the given [context].
66100
* This method should guarantee that the given [block] will be eventually invoked,

kotlinx-coroutines-core/common/src/EventLoop.common.kt

+5-1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ internal abstract class EventLoop : CoroutineDispatcher() {
115115
}
116116
}
117117

118+
final override fun limitedParallelism(parallelism: Int): CoroutineDispatcher {
119+
parallelism.checkParallelism()
120+
return this
121+
}
122+
118123
protected open fun shutdown() {}
119124
}
120125

@@ -525,4 +530,3 @@ internal expect fun nanoTime(): Long
525530
internal expect object DefaultExecutor {
526531
public fun enqueue(task: Runnable)
527532
}
528-

kotlinx-coroutines-core/common/src/MainCoroutineDispatcher.kt

+8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
package kotlinx.coroutines
66

7+
import kotlinx.coroutines.internal.*
8+
79
/**
810
* Base class for special [CoroutineDispatcher] which is confined to application "Main" or "UI" thread
911
* and used for any UI-based activities. Instance of `MainDispatcher` can be obtained by [Dispatchers.Main].
@@ -51,6 +53,12 @@ public abstract class MainCoroutineDispatcher : CoroutineDispatcher() {
5153
*/
5254
override fun toString(): String = toStringInternalImpl() ?: "$classSimpleName@$hexAddress"
5355

56+
override fun limitedParallelism(parallelism: Int): CoroutineDispatcher {
57+
parallelism.checkParallelism()
58+
// MainCoroutineDispatcher is single-threaded -- short-circuit any attempts to limit it
59+
return this
60+
}
61+
5462
/**
5563
* Internal method for more specific [toString] implementations. It returns non-null
5664
* string if this dispatcher is set in the platform as the main one.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/*
2+
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.internal
6+
7+
import kotlinx.coroutines.*
8+
import kotlin.coroutines.*
9+
import kotlin.jvm.*
10+
11+
/**
12+
* The result of .limitedParallelism(x) call, dispatcher
13+
* that wraps the given dispatcher, but limits the parallelism level, while
14+
* trying to emulate fairness.
15+
*/
16+
internal class LimitedDispatcher(
17+
private val dispatcher: CoroutineDispatcher,
18+
private val parallelism: Int
19+
) : CoroutineDispatcher(), Runnable, Delay by (dispatcher as? Delay ?: DefaultDelay) {
20+
21+
@Volatile
22+
private var runningWorkers = 0
23+
24+
private val queue = LockFreeTaskQueue<Runnable>(singleConsumer = false)
25+
26+
@InternalCoroutinesApi
27+
override fun dispatchYield(context: CoroutineContext, block: Runnable) {
28+
dispatcher.dispatchYield(context, block)
29+
}
30+
31+
override fun run() {
32+
var fairnessCounter = 0
33+
while (true) {
34+
val task = queue.removeFirstOrNull()
35+
if (task != null) {
36+
task.run()
37+
// 16 is our out-of-thin-air constant to emulate fairness. Used in JS dispatchers as well
38+
if (++fairnessCounter >= 16 && dispatcher.isDispatchNeeded(EmptyCoroutineContext)) {
39+
// Do "yield" to let other views to execute their runnable as well
40+
// Note that we do not decrement 'runningWorkers' as we still committed to do our part of work
41+
dispatcher.dispatch(EmptyCoroutineContext, this)
42+
return
43+
}
44+
continue
45+
}
46+
47+
@Suppress("CAST_NEVER_SUCCEEDS")
48+
synchronized(this as SynchronizedObject) {
49+
--runningWorkers
50+
if (queue.size == 0) return
51+
++runningWorkers
52+
fairnessCounter = 0
53+
}
54+
}
55+
}
56+
57+
override fun dispatch(context: CoroutineContext, block: Runnable) {
58+
// Add task to queue so running workers will be able to see that
59+
queue.addLast(block)
60+
if (runningWorkers >= parallelism) {
61+
return
62+
}
63+
64+
/*
65+
* Protect against race when the worker is finished
66+
* right after our check
67+
*/
68+
@Suppress("CAST_NEVER_SUCCEEDS")
69+
synchronized(this as SynchronizedObject) {
70+
if (runningWorkers >= parallelism) return
71+
++runningWorkers
72+
}
73+
if (dispatcher.isDispatchNeeded(EmptyCoroutineContext)) {
74+
dispatcher.dispatch(EmptyCoroutineContext, this)
75+
} else {
76+
run()
77+
}
78+
}
79+
}
80+
81+
// Save a few bytecode ops
82+
internal fun Int.checkParallelism() = require(this >= 1) { "Expected positive parallelism level, but got $this" }

kotlinx-coroutines-core/js/src/JSDispatcher.kt

+5
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ internal sealed class SetTimeoutBasedDispatcher: CoroutineDispatcher(), Delay {
3131

3232
abstract fun scheduleQueueProcessing()
3333

34+
override fun limitedParallelism(parallelism: Int): CoroutineDispatcher {
35+
parallelism.checkParallelism()
36+
return this
37+
}
38+
3439
override fun dispatch(context: CoroutineContext, block: Runnable) {
3540
messageQueue.enqueue(block)
3641
}

kotlinx-coroutines-core/jvm/src/internal/MainDispatchers.kt

+3
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@ private class MissingMainCoroutineDispatcher(
9393
override fun isDispatchNeeded(context: CoroutineContext): Boolean =
9494
missing()
9595

96+
override fun limitedParallelism(parallelism: Int): CoroutineDispatcher =
97+
missing()
98+
9699
override suspend fun delay(time: Long) =
97100
missing()
98101

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/*
2+
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import org.junit.*
8+
import org.junit.Test
9+
import org.junit.runner.*
10+
import org.junit.runners.*
11+
import java.util.concurrent.atomic.*
12+
import kotlin.test.*
13+
14+
@RunWith(Parameterized::class)
15+
class LimitedParallelismStressTest(private val targetParallelism: Int) : TestBase() {
16+
17+
companion object {
18+
@Parameterized.Parameters(name = "{0}")
19+
@JvmStatic
20+
fun params(): Collection<Array<Any>> = listOf(1, 2, 3, 4).map { arrayOf(it) }
21+
}
22+
23+
@get:Rule
24+
val executor = ExecutorRule(targetParallelism * 2)
25+
private val iterations = 100_000 * stressTestMultiplier
26+
27+
private val parallelism = AtomicInteger(0)
28+
29+
private fun checkParallelism() {
30+
val value = parallelism.incrementAndGet()
31+
assertTrue { value <= targetParallelism }
32+
parallelism.decrementAndGet()
33+
}
34+
35+
@Test
36+
fun testLimited() = runTest {
37+
val view = executor.limitedParallelism(targetParallelism)
38+
repeat(iterations) {
39+
launch(view) {
40+
checkParallelism()
41+
}
42+
}
43+
}
44+
45+
@Test
46+
fun testUnconfined() = runTest {
47+
val view = Dispatchers.Unconfined.limitedParallelism(targetParallelism)
48+
repeat(iterations) {
49+
launch(executor) {
50+
withContext(view) {
51+
checkParallelism()
52+
}
53+
}
54+
}
55+
}
56+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines
6+
7+
import org.junit.*
8+
9+
class LimitedParallelismTest : TestBase() {
10+
11+
@Test
12+
fun testParallelismSpec() {
13+
assertFailsWith<IllegalArgumentException> { Dispatchers.Default.limitedParallelism(0) }
14+
assertFailsWith<IllegalArgumentException> { Dispatchers.Default.limitedParallelism(-1) }
15+
assertFailsWith<IllegalArgumentException> { Dispatchers.Default.limitedParallelism(Int.MIN_VALUE) }
16+
Dispatchers.Default.limitedParallelism(Int.MAX_VALUE)
17+
}
18+
19+
@Test
20+
fun testTaskFairness() = runTest {
21+
val executor = newSingleThreadContext("test")
22+
val view = executor.limitedParallelism(1)
23+
val view2 = executor.limitedParallelism(1)
24+
val j1 = launch(view) {
25+
while (true) {
26+
yield()
27+
}
28+
}
29+
val j2 = launch(view2) { j1.cancel() }
30+
joinAll(j1, j2)
31+
executor.close()
32+
}
33+
}

0 commit comments

Comments
 (0)