Skip to content

Commit 20daaa7

Browse files
committed
Introduce a separate slot for stealing tasks into in CoroutineScheduler
It solves two problems: * Stealing into exclusively owned local queue does no longer require and CAS'es or atomic operations where they were previously not needed. It should save a few cycles on the stealing code path * The overall timing perturbations should be slightly better now: previously it was possible for the stolen task to be immediately got stolen again from the stealer thread because it was actually published to owner's queue, but its submission time was never updated Fixes #3416
1 parent 287d038 commit 20daaa7

File tree

4 files changed

+49
-31
lines changed

4 files changed

+49
-31
lines changed

kotlinx-coroutines-core/jvm/src/scheduling/CoroutineScheduler.kt

+13-4
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import kotlinx.coroutines.internal.*
1010
import java.io.*
1111
import java.util.concurrent.*
1212
import java.util.concurrent.locks.*
13+
import kotlin.jvm.internal.Ref.ObjectRef
1314
import kotlin.math.*
1415
import kotlin.random.*
1516

@@ -598,6 +599,12 @@ internal class CoroutineScheduler(
598599
@JvmField
599600
val localQueue: WorkQueue = WorkQueue()
600601

602+
/**
603+
* Slot that is used to steal tasks into to avoid re-adding them
604+
* to the local queue. See [trySteal]
605+
*/
606+
private val stolenTask: ObjectRef<Task?> = ObjectRef()
607+
601608
/**
602609
* Worker state. **Updated only by this worker thread**.
603610
* By default, worker is in DORMANT state in the case when it was created, but all CPU tokens or tasks were taken.
@@ -617,7 +624,7 @@ internal class CoroutineScheduler(
617624

618625
/**
619626
* It is set to the termination deadline when started doing [park] and it reset
620-
* when there is a task. It servers as protection against spurious wakeups of parkNanos.
627+
* when there is a task. It serves as protection against spurious wakeups of parkNanos.
621628
*/
622629
private var terminationDeadline = 0L
623630

@@ -920,12 +927,14 @@ internal class CoroutineScheduler(
920927
if (worker !== null && worker !== this) {
921928
assert { localQueue.size == 0 }
922929
val stealResult = if (blockingOnly) {
923-
localQueue.tryStealBlockingFrom(victim = worker.localQueue)
930+
localQueue.tryStealBlockingFrom(victim = worker.localQueue, stolenTask)
924931
} else {
925-
localQueue.tryStealFrom(victim = worker.localQueue)
932+
localQueue.tryStealFrom(victim = worker.localQueue, stolenTask)
926933
}
927934
if (stealResult == TASK_STOLEN) {
928-
return localQueue.poll()
935+
val result = stolenTask.element
936+
stolenTask.element = null
937+
return result
929938
} else if (stealResult > 0) {
930939
minDelay = min(minDelay, stealResult)
931940
}

kotlinx-coroutines-core/jvm/src/scheduling/WorkQueue.kt

+11-11
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package kotlinx.coroutines.scheduling
77
import kotlinx.atomicfu.*
88
import kotlinx.coroutines.*
99
import java.util.concurrent.atomic.*
10+
import kotlin.jvm.internal.Ref.ObjectRef
1011

1112
internal const val BUFFER_CAPACITY_BASE = 7
1213
internal const val BUFFER_CAPACITY = 1 shl BUFFER_CAPACITY_BASE
@@ -31,7 +32,7 @@ internal const val NOTHING_TO_STEAL = -2L
3132
* (scheduler workers without a CPU permit steal blocking tasks via this mechanism). Such property enforces us to use CAS in
3233
* order to properly claim value from the buffer.
3334
* Moreover, [Task] objects are reusable, so it may seem that this queue is prone to ABA problem.
34-
* Indeed it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless.
35+
* Indeed, it formally has ABA-problem, but the whole processing logic is written in the way that such ABA is harmless.
3536
* I have discovered a truly marvelous proof of this, which this KDoc is too narrow to contain.
3637
*/
3738
internal class WorkQueue {
@@ -100,23 +101,22 @@ internal class WorkQueue {
100101
}
101102

102103
/**
103-
* Tries stealing from [victim] queue into this queue.
104+
* Tries stealing from [victim] queue into the [stolenTaskRef] argument.
104105
*
105106
* Returns [NOTHING_TO_STEAL] if queue has nothing to steal, [TASK_STOLEN] if at least task was stolen
106107
* or positive value of how many nanoseconds should pass until the head of this queue will be available to steal.
107108
*/
108-
fun tryStealFrom(victim: WorkQueue): Long {
109+
fun tryStealFrom(victim: WorkQueue, stolenTaskRef: ObjectRef<Task?>): Long {
109110
assert { bufferSize == 0 }
110111
val task = victim.pollBuffer()
111112
if (task != null) {
112-
val notAdded = add(task)
113-
assert { notAdded == null }
113+
stolenTaskRef.element = task
114114
return TASK_STOLEN
115115
}
116-
return tryStealLastScheduled(victim, blockingOnly = false)
116+
return tryStealLastScheduled(victim, stolenTaskRef, blockingOnly = false)
117117
}
118118

119-
fun tryStealBlockingFrom(victim: WorkQueue): Long {
119+
fun tryStealBlockingFrom(victim: WorkQueue, stolenTaskRef: ObjectRef<Task?>): Long {
120120
assert { bufferSize == 0 }
121121
var start = victim.consumerIndex.value
122122
val end = victim.producerIndex.value
@@ -128,13 +128,13 @@ internal class WorkQueue {
128128
val value = buffer[index]
129129
if (value != null && value.isBlocking && buffer.compareAndSet(index, value, null)) {
130130
victim.blockingTasksInBuffer.decrementAndGet()
131-
add(value)
131+
stolenTaskRef.element = value
132132
return TASK_STOLEN
133133
} else {
134134
++start
135135
}
136136
}
137-
return tryStealLastScheduled(victim, blockingOnly = true)
137+
return tryStealLastScheduled(victim, stolenTaskRef, blockingOnly = true)
138138
}
139139

140140
fun offloadAllWorkTo(globalQueue: GlobalQueue) {
@@ -147,7 +147,7 @@ internal class WorkQueue {
147147
/**
148148
* Contract on return value is the same as for [tryStealFrom]
149149
*/
150-
private fun tryStealLastScheduled(victim: WorkQueue, blockingOnly: Boolean): Long {
150+
private fun tryStealLastScheduled(victim: WorkQueue, stolenTaskRef: ObjectRef<Task?>, blockingOnly: Boolean): Long {
151151
while (true) {
152152
val lastScheduled = victim.lastScheduledTask.value ?: return NOTHING_TO_STEAL
153153
if (blockingOnly && !lastScheduled.isBlocking) return NOTHING_TO_STEAL
@@ -164,7 +164,7 @@ internal class WorkQueue {
164164
* and dispatched another one. In the latter case we should retry to avoid missing task.
165165
*/
166166
if (victim.lastScheduledTask.compareAndSet(lastScheduled, null)) {
167-
add(lastScheduled)
167+
stolenTaskRef.element = lastScheduled
168168
return TASK_STOLEN
169169
}
170170
continue

kotlinx-coroutines-core/jvm/test/scheduling/WorkQueueStressTest.kt

+11-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import org.junit.*
99
import org.junit.Test
1010
import java.util.concurrent.*
1111
import kotlin.concurrent.*
12+
import kotlin.jvm.internal.*
1213
import kotlin.test.*
1314

1415
class WorkQueueStressTest : TestBase() {
@@ -52,17 +53,18 @@ class WorkQueueStressTest : TestBase() {
5253

5354
for (i in 0 until stealersCount) {
5455
threads += thread(name = "stealer $i") {
56+
val ref = Ref.ObjectRef<Task?>()
5557
val myQueue = WorkQueue()
5658
startLatch.await()
5759
while (!producerFinished || producerQueue.size != 0) {
58-
stolenTasks[i].addAll(myQueue.drain().map { task(it) })
59-
myQueue.tryStealFrom(victim = producerQueue)
60+
stolenTasks[i].addAll(myQueue.drain(ref).map { task(it) })
61+
myQueue.tryStealFrom(victim = producerQueue, ref)
6062
}
6163

6264
// Drain last element which is not counted in buffer
63-
stolenTasks[i].addAll(myQueue.drain().map { task(it) })
64-
myQueue.tryStealFrom(producerQueue)
65-
stolenTasks[i].addAll(myQueue.drain().map { task(it) })
65+
stolenTasks[i].addAll(myQueue.drain(ref).map { task(it) })
66+
myQueue.tryStealFrom(producerQueue, ref)
67+
stolenTasks[i].addAll(myQueue.drain(ref).map { task(it) })
6668
}
6769
}
6870

@@ -89,13 +91,14 @@ class WorkQueueStressTest : TestBase() {
8991
val stolen = GlobalQueue()
9092
threads += thread(name = "stealer") {
9193
val myQueue = WorkQueue()
94+
val ref = Ref.ObjectRef<Task?>()
9295
startLatch.await()
9396
while (stolen.size != offerIterations) {
94-
if (myQueue.tryStealFrom(producerQueue) != NOTHING_TO_STEAL) {
95-
stolen.addAll(myQueue.drain().map { task(it) })
97+
if (myQueue.tryStealFrom(producerQueue, ref) != NOTHING_TO_STEAL) {
98+
stolen.addAll(myQueue.drain(ref).map { task(it) })
9699
}
97100
}
98-
stolen.addAll(myQueue.drain().map { task(it) })
101+
stolen.addAll(myQueue.drain(ref).map { task(it) })
99102
}
100103

101104
startLatch.countDown()

kotlinx-coroutines-core/jvm/test/scheduling/WorkQueueTest.kt

+14-8
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ package kotlinx.coroutines.scheduling
77
import kotlinx.coroutines.*
88
import org.junit.*
99
import org.junit.Test
10+
import kotlin.jvm.internal.Ref.ObjectRef
1011
import kotlin.test.*
1112

1213
class WorkQueueTest : TestBase() {
@@ -27,7 +28,7 @@ class WorkQueueTest : TestBase() {
2728
fun testLastScheduledComesFirst() {
2829
val queue = WorkQueue()
2930
(1L..4L).forEach { queue.add(task(it)) }
30-
assertEquals(listOf(4L, 1L, 2L, 3L), queue.drain())
31+
assertEquals(listOf(4L, 1L, 2L, 3L), queue.drain(ObjectRef()))
3132
}
3233

3334
@Test
@@ -38,9 +39,9 @@ class WorkQueueTest : TestBase() {
3839
(0 until size).forEach { queue.add(task(it))?.let { t -> offload.addLast(t) } }
3940

4041
val expectedResult = listOf(129L) + (0L..126L).toList()
41-
val actualResult = queue.drain()
42+
val actualResult = queue.drain(ObjectRef())
4243
assertEquals(expectedResult, actualResult)
43-
assertEquals((0L until size).toSet().minus(expectedResult), offload.drain().toSet())
44+
assertEquals((0L until size).toSet().minus(expectedResult.toSet()), offload.drain().toSet())
4445
}
4546

4647
@Test
@@ -61,23 +62,28 @@ class WorkQueueTest : TestBase() {
6162
timeSource.step(3)
6263

6364
val stealer = WorkQueue()
64-
assertEquals(TASK_STOLEN, stealer.tryStealFrom(victim))
65-
assertEquals(arrayListOf(1L), stealer.drain())
65+
val ref = ObjectRef<Task?>()
66+
assertEquals(TASK_STOLEN, stealer.tryStealFrom(victim, ref))
67+
assertEquals(arrayListOf(1L), stealer.drain(ref))
6668

67-
assertEquals(TASK_STOLEN, stealer.tryStealFrom(victim))
68-
assertEquals(arrayListOf(2L), stealer.drain())
69+
assertEquals(TASK_STOLEN, stealer.tryStealFrom(victim, ref))
70+
assertEquals(arrayListOf(2L), stealer.drain(ref))
6971
}
7072
}
7173

7274
internal fun task(n: Long) = TaskImpl(Runnable {}, n, NonBlockingContext)
7375

74-
internal fun WorkQueue.drain(): List<Long> {
76+
internal fun WorkQueue.drain(ref: ObjectRef<Task?>): List<Long> {
7577
var task: Task? = poll()
7678
val result = arrayListOf<Long>()
7779
while (task != null) {
7880
result += task.submissionTime
7981
task = poll()
8082
}
83+
if (ref.element != null) {
84+
result += ref.element!!.submissionTime
85+
ref.element = null
86+
}
8187
return result
8288
}
8389

0 commit comments

Comments
 (0)