Skip to content

Commit 87d1af9

Browse files
authored
Introduce a separate slot for stealing tasks into in CoroutineScheduler (#3537)
* Introduce a separate slot for stealing tasks into in CoroutineScheduler It solves two problems: * Stealing into exclusively owned local queue does no longer require CAS'es or atomic operations where they were previously not needed. It should save a few cycles on the stealing code path * The overall timing perturbations should be slightly better now: previously it was possible for the stolen task to be immediately got stolen again from the stealer thread because it was actually published to the owner's queue, but its submission time was never updated (#3416) * Move victim argument in WorkQueue into the receiver position to simplify the overall code structure * Fix oversubscription in CoroutineScheduler (-> Dispatchers.Default) (#3418) Previously, a worker thread unconditionally processed tasks from its own local queue, even if tasks were CPU-intensive, but CPU token was not acquired. Fixes #3416 Fixes #3418
1 parent ebff885 commit 87d1af9

File tree

5 files changed

+209
-70
lines changed

5 files changed

+209
-70
lines changed

kotlinx-coroutines-core/jvm/src/scheduling/CoroutineScheduler.kt

+26-19
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import kotlinx.coroutines.internal.*
1010
import java.io.*
1111
import java.util.concurrent.*
1212
import java.util.concurrent.locks.*
13+
import kotlin.jvm.internal.Ref.ObjectRef
1314
import kotlin.math.*
1415
import kotlin.random.*
1516

@@ -263,16 +264,16 @@ internal class CoroutineScheduler(
263264
val workers = ResizableAtomicArray<Worker>(corePoolSize + 1)
264265

265266
/**
266-
* Long describing state of workers in this pool.
267-
* Currently includes created, CPU-acquired and blocking workers each occupying [BLOCKING_SHIFT] bits.
267+
* The `Long` value describing the state of workers in this pool.
268+
* Currently includes created, CPU-acquired, and blocking workers, each occupying [BLOCKING_SHIFT] bits.
268269
*/
269270
private val controlState = atomic(corePoolSize.toLong() shl CPU_PERMITS_SHIFT)
270271
private val createdWorkers: Int inline get() = (controlState.value and CREATED_MASK).toInt()
271272
private val availableCpuPermits: Int inline get() = availableCpuPermits(controlState.value)
272273

273274
private inline fun createdWorkers(state: Long): Int = (state and CREATED_MASK).toInt()
274275
private inline fun blockingTasks(state: Long): Int = (state and BLOCKING_MASK shr BLOCKING_SHIFT).toInt()
275-
public inline fun availableCpuPermits(state: Long): Int = (state and CPU_PERMITS_MASK shr CPU_PERMITS_SHIFT).toInt()
276+
inline fun availableCpuPermits(state: Long): Int = (state and CPU_PERMITS_MASK shr CPU_PERMITS_SHIFT).toInt()
276277

277278
// Guarded by synchronization
278279
private inline fun incrementCreatedWorkers(): Int = createdWorkers(controlState.incrementAndGet())
@@ -598,6 +599,12 @@ internal class CoroutineScheduler(
598599
@JvmField
599600
val localQueue: WorkQueue = WorkQueue()
600601

602+
/**
603+
* Slot that is used to steal tasks into to avoid re-adding them
604+
* to the local queue. See [trySteal]
605+
*/
606+
private val stolenTask: ObjectRef<Task?> = ObjectRef()
607+
601608
/**
602609
* Worker state. **Updated only by this worker thread**.
603610
* By default, worker is in DORMANT state in the case when it was created, but all CPU tokens or tasks were taken.
@@ -617,7 +624,7 @@ internal class CoroutineScheduler(
617624

618625
/**
619626
* It is set to the termination deadline when started doing [park] and it reset
620-
* when there is a task. It servers as protection against spurious wakeups of parkNanos.
627+
* when there is a task. It serves as protection against spurious wakeups of parkNanos.
621628
*/
622629
private var terminationDeadline = 0L
623630

@@ -719,7 +726,6 @@ internal class CoroutineScheduler(
719726
parkedWorkersStackPush(this)
720727
return
721728
}
722-
assert { localQueue.size == 0 }
723729
workerCtl.value = PARKED // Update value once
724730
/*
725731
* inStack() prevents spurious wakeups, while workerCtl.value == PARKED
@@ -866,15 +872,16 @@ internal class CoroutineScheduler(
866872
}
867873
}
868874

869-
fun findTask(scanLocalQueue: Boolean): Task? {
870-
if (tryAcquireCpuPermit()) return findAnyTask(scanLocalQueue)
871-
// If we can't acquire a CPU permit -- attempt to find blocking task
872-
val task = if (scanLocalQueue) {
873-
localQueue.poll() ?: globalBlockingQueue.removeFirstOrNull()
874-
} else {
875-
globalBlockingQueue.removeFirstOrNull()
876-
}
877-
return task ?: trySteal(blockingOnly = true)
875+
fun findTask(mayHaveLocalTasks: Boolean): Task? {
876+
if (tryAcquireCpuPermit()) return findAnyTask(mayHaveLocalTasks)
877+
/*
878+
* If we can't acquire a CPU permit, attempt to find blocking task:
879+
* * Check if our queue has one (maybe mixed in with CPU tasks)
880+
* * Poll global and try steal
881+
*/
882+
return localQueue.pollBlocking()
883+
?: globalBlockingQueue.removeFirstOrNull()
884+
?: trySteal(blockingOnly = true)
878885
}
879886

880887
private fun findAnyTask(scanLocalQueue: Boolean): Task? {
@@ -904,7 +911,6 @@ internal class CoroutineScheduler(
904911
}
905912

906913
private fun trySteal(blockingOnly: Boolean): Task? {
907-
assert { localQueue.size == 0 }
908914
val created = createdWorkers
909915
// 0 to await an initialization and 1 to avoid excess stealing on single-core machines
910916
if (created < 2) {
@@ -918,14 +924,15 @@ internal class CoroutineScheduler(
918924
if (currentIndex > created) currentIndex = 1
919925
val worker = workers[currentIndex]
920926
if (worker !== null && worker !== this) {
921-
assert { localQueue.size == 0 }
922927
val stealResult = if (blockingOnly) {
923-
localQueue.tryStealBlockingFrom(victim = worker.localQueue)
928+
worker.localQueue.tryStealBlocking(stolenTask)
924929
} else {
925-
localQueue.tryStealFrom(victim = worker.localQueue)
930+
worker.localQueue.trySteal(stolenTask)
926931
}
927932
if (stealResult == TASK_STOLEN) {
928-
return localQueue.poll()
933+
val result = stolenTask.element
934+
stolenTask.element = null
935+
return result
929936
} else if (stealResult > 0) {
930937
minDelay = min(minDelay, stealResult)
931938
}

kotlinx-coroutines-core/jvm/src/scheduling/WorkQueue.kt

+56-33
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package kotlinx.coroutines.scheduling
77
import kotlinx.atomicfu.*
88
import kotlinx.coroutines.*
99
import java.util.concurrent.atomic.*
10+
import kotlin.jvm.internal.Ref.ObjectRef
1011

1112
internal const val BUFFER_CAPACITY_BASE = 7
1213
internal const val BUFFER_CAPACITY = 1 shl BUFFER_CAPACITY_BASE
@@ -31,7 +32,7 @@ internal const val NOTHING_TO_STEAL = -2L
3132
* (scheduler workers without a CPU permit steal blocking tasks via this mechanism). Such property enforces us to use CAS in
3233
* order to properly claim value from the buffer.
3334
* Moreover, [Task] objects are reusable, so it may seem that this queue is prone to ABA problem.
34-
* Indeed it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless.
35+
* Indeed, it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless.
3536
* I have discovered a truly marvelous proof of this, which this KDoc is too narrow to contain.
3637
*/
3738
internal class WorkQueue {
@@ -46,10 +47,12 @@ internal class WorkQueue {
4647
* [T2] changeProducerIndex (3)
4748
* [T3] changeConsumerIndex (4)
4849
*
49-
* Which can lead to resulting size bigger than actual size at any moment of time.
50-
* This is in general harmless because steal will be blocked by timer
50+
* Which can lead to resulting size being negative or bigger than actual size at any moment of time.
51+
* This is in general harmless because steal will be blocked by timer.
52+
* Negative sizes can be observed only when non-owner reads the size, which happens only
53+
* for diagnostic toString().
5154
*/
52-
internal val bufferSize: Int get() = producerIndex.value - consumerIndex.value
55+
private val bufferSize: Int get() = producerIndex.value - consumerIndex.value
5356
internal val size: Int get() = if (lastScheduledTask.value != null) bufferSize + 1 else bufferSize
5457
private val buffer: AtomicReferenceArray<Task?> = AtomicReferenceArray(BUFFER_CAPACITY)
5558
private val lastScheduledTask = atomic<Task?>(null)
@@ -100,41 +103,61 @@ internal class WorkQueue {
100103
}
101104

102105
/**
103-
* Tries stealing from [victim] queue into this queue.
106+
* Tries stealing from this queue into the [stolenTaskRef] argument.
104107
*
105108
* Returns [NOTHING_TO_STEAL] if queue has nothing to steal, [TASK_STOLEN] if at least task was stolen
106109
* or positive value of how many nanoseconds should pass until the head of this queue will be available to steal.
107110
*/
108-
fun tryStealFrom(victim: WorkQueue): Long {
109-
assert { bufferSize == 0 }
110-
val task = victim.pollBuffer()
111+
fun trySteal(stolenTaskRef: ObjectRef<Task?>): Long {
112+
val task = pollBuffer()
111113
if (task != null) {
112-
val notAdded = add(task)
113-
assert { notAdded == null }
114+
stolenTaskRef.element = task
114115
return TASK_STOLEN
115116
}
116-
return tryStealLastScheduled(victim, blockingOnly = false)
117+
return tryStealLastScheduled(stolenTaskRef, blockingOnly = false)
117118
}
118119

119-
fun tryStealBlockingFrom(victim: WorkQueue): Long {
120-
assert { bufferSize == 0 }
121-
var start = victim.consumerIndex.value
122-
val end = victim.producerIndex.value
123-
val buffer = victim.buffer
124-
125-
while (start != end) {
126-
val index = start and MASK
127-
if (victim.blockingTasksInBuffer.value == 0) break
128-
val value = buffer[index]
129-
if (value != null && value.isBlocking && buffer.compareAndSet(index, value, null)) {
130-
victim.blockingTasksInBuffer.decrementAndGet()
131-
add(value)
132-
return TASK_STOLEN
133-
} else {
134-
++start
120+
fun tryStealBlocking(stolenTaskRef: ObjectRef<Task?>): Long {
121+
var start = consumerIndex.value
122+
val end = producerIndex.value
123+
124+
while (start != end && blockingTasksInBuffer.value > 0) {
125+
stolenTaskRef.element = tryExtractBlockingTask(start++) ?: continue
126+
return TASK_STOLEN
127+
}
128+
return tryStealLastScheduled(stolenTaskRef, blockingOnly = true)
129+
}
130+
131+
// Polls for blocking task, invoked only by the owner
132+
fun pollBlocking(): Task? {
133+
while (true) { // Poll the slot
134+
val lastScheduled = lastScheduledTask.value ?: break
135+
if (!lastScheduled.isBlocking) break
136+
if (lastScheduledTask.compareAndSet(lastScheduled, null)) {
137+
return lastScheduled
138+
} // Failed -> someone else stole it
139+
}
140+
141+
val start = consumerIndex.value
142+
var end = producerIndex.value
143+
144+
while (start != end && blockingTasksInBuffer.value > 0) {
145+
val task = tryExtractBlockingTask(--end)
146+
if (task != null) {
147+
return task
135148
}
136149
}
137-
return tryStealLastScheduled(victim, blockingOnly = true)
150+
return null
151+
}
152+
153+
private fun tryExtractBlockingTask(index: Int): Task? {
154+
val arrayIndex = index and MASK
155+
val value = buffer[arrayIndex]
156+
if (value != null && value.isBlocking && buffer.compareAndSet(arrayIndex, value, null)) {
157+
blockingTasksInBuffer.decrementAndGet()
158+
return value
159+
}
160+
return null
138161
}
139162

140163
fun offloadAllWorkTo(globalQueue: GlobalQueue) {
@@ -145,11 +168,11 @@ internal class WorkQueue {
145168
}
146169

147170
/**
148-
* Contract on return value is the same as for [tryStealFrom]
171+
* Contract on return value is the same as for [trySteal]
149172
*/
150-
private fun tryStealLastScheduled(victim: WorkQueue, blockingOnly: Boolean): Long {
173+
private fun tryStealLastScheduled(stolenTaskRef: ObjectRef<Task?>, blockingOnly: Boolean): Long {
151174
while (true) {
152-
val lastScheduled = victim.lastScheduledTask.value ?: return NOTHING_TO_STEAL
175+
val lastScheduled = lastScheduledTask.value ?: return NOTHING_TO_STEAL
153176
if (blockingOnly && !lastScheduled.isBlocking) return NOTHING_TO_STEAL
154177

155178
// TODO time wraparound ?
@@ -163,8 +186,8 @@ internal class WorkQueue {
163186
* If CAS has failed, either someone else had stolen this task or the owner executed this task
164187
* and dispatched another one. In the latter case we should retry to avoid missing task.
165188
*/
166-
if (victim.lastScheduledTask.compareAndSet(lastScheduled, null)) {
167-
add(lastScheduled)
189+
if (lastScheduledTask.compareAndSet(lastScheduled, null)) {
190+
stolenTaskRef.element = lastScheduled
168191
return TASK_STOLEN
169192
}
170193
continue
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright 2016-2022 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.scheduling
6+
7+
import kotlinx.coroutines.*
8+
import org.junit.Test
9+
import java.util.concurrent.*
10+
import java.util.concurrent.atomic.AtomicInteger
11+
12+
class CoroutineSchedulerOversubscriptionTest : TestBase() {
13+
14+
private val inDefault = AtomicInteger(0)
15+
16+
private fun CountDownLatch.runAndCheck() {
17+
if (inDefault.incrementAndGet() > CORE_POOL_SIZE) {
18+
error("Oversubscription detected")
19+
}
20+
21+
await()
22+
inDefault.decrementAndGet()
23+
}
24+
25+
@Test
26+
fun testOverSubscriptionDeterministic() = runTest {
27+
val barrier = CountDownLatch(1)
28+
val threadsOccupiedBarrier = CyclicBarrier(CORE_POOL_SIZE)
29+
// All threads but one
30+
repeat(CORE_POOL_SIZE - 1) {
31+
launch(Dispatchers.Default) {
32+
threadsOccupiedBarrier.await()
33+
barrier.runAndCheck()
34+
}
35+
}
36+
threadsOccupiedBarrier.await()
37+
withContext(Dispatchers.Default) {
38+
// Put a task in a local queue, it will be stolen
39+
launch(Dispatchers.Default) {
40+
barrier.runAndCheck()
41+
}
42+
// Put one more task to trick the local queue check
43+
launch(Dispatchers.Default) {
44+
barrier.runAndCheck()
45+
}
46+
47+
withContext(Dispatchers.IO) {
48+
try {
49+
// Release the thread
50+
delay(100)
51+
} finally {
52+
barrier.countDown()
53+
}
54+
}
55+
}
56+
}
57+
58+
@Test
59+
fun testOverSubscriptionStress() = repeat(1000 * stressTestMultiplierSqrt) {
60+
inDefault.set(0)
61+
runTest {
62+
val barrier = CountDownLatch(1)
63+
val threadsOccupiedBarrier = CyclicBarrier(CORE_POOL_SIZE)
64+
// All threads but one
65+
repeat(CORE_POOL_SIZE - 1) {
66+
launch(Dispatchers.Default) {
67+
threadsOccupiedBarrier.await()
68+
barrier.runAndCheck()
69+
}
70+
}
71+
threadsOccupiedBarrier.await()
72+
withContext(Dispatchers.Default) {
73+
// Put a task in a local queue
74+
launch(Dispatchers.Default) {
75+
barrier.runAndCheck()
76+
}
77+
// Put one more task to trick the local queue check
78+
launch(Dispatchers.Default) {
79+
barrier.runAndCheck()
80+
}
81+
82+
withContext(Dispatchers.IO) {
83+
yield()
84+
barrier.countDown()
85+
}
86+
}
87+
}
88+
}
89+
}

0 commit comments

Comments
 (0)