@@ -16,6 +16,11 @@ internal const val MASK = BUFFER_CAPACITY - 1 // 128 by default
16
16
internal const val TASK_STOLEN = - 1L
17
17
internal const val NOTHING_TO_STEAL = - 2L
18
18
19
+ internal typealias StealingMode = Int
20
+ internal const val STEAL_ANY : StealingMode = - 1
21
+ internal const val STEAL_CPU_ONLY : StealingMode = 0
22
+ internal const val STEAL_BLOCKING_ONLY : StealingMode = 1
23
+
19
24
/* *
20
25
* Tightly coupled with [CoroutineScheduler] queue of pending tasks, but extracted to separate file for simplicity.
21
26
* At any moment queue is used only by [CoroutineScheduler.Worker] threads, has only one producer (worker owning this queue)
@@ -108,24 +113,34 @@ internal class WorkQueue {
108
113
* Returns [NOTHING_TO_STEAL] if queue has nothing to steal, [TASK_STOLEN] if at least task was stolen
109
114
* or positive value of how many nanoseconds should pass until the head of this queue will be available to steal.
110
115
*/
111
- fun trySteal (stolenTaskRef : ObjectRef <Task ?>): Long {
112
- val task = pollBuffer()
116
+ // TODO move it to tests where appropriate
117
+ fun trySteal (stolenTaskRef : ObjectRef <Task ?>): Long = trySteal(STEAL_ANY , stolenTaskRef)
118
+
119
+ fun trySteal (stealingMode : StealingMode , stolenTaskRef : ObjectRef <Task ?>): Long {
120
+ val task = when (stealingMode) {
121
+ STEAL_ANY -> pollBuffer()
122
+ else -> stealWithExclusiveMode(stealingMode)
123
+ }
124
+
113
125
if (task != null ) {
114
126
stolenTaskRef.element = task
115
127
return TASK_STOLEN
116
128
}
117
- return tryStealLastScheduled(stolenTaskRef, blockingOnly = false )
129
+ return tryStealLastScheduled(stealingMode, stolenTaskRef )
118
130
}
119
131
120
- fun tryStealBlocking (stolenTaskRef : ObjectRef <Task ?>): Long {
132
+ // Steal only tasks of a particular kind, potentially invoking full queue scan
133
+ private fun stealWithExclusiveMode (stealingMode : StealingMode ): Task ? {
121
134
var start = consumerIndex.value
122
135
val end = producerIndex.value
123
-
124
- while (start != end && blockingTasksInBuffer.value > 0 ) {
125
- stolenTaskRef.element = tryExtractBlockingTask(start++ ) ? : continue
126
- return TASK_STOLEN
136
+ val onlyBlocking = stealingMode == STEAL_BLOCKING_ONLY
137
+ // CPU or (BLOCKING & hasBlocking)
138
+ val shouldProceed = ! onlyBlocking || blockingTasksInBuffer.value > 0
139
+ while (start != end && shouldProceed) {
140
+ return tryExtractFromTheMiddle(start++ , onlyBlocking) ? : continue
127
141
}
128
- return tryStealLastScheduled(stolenTaskRef, blockingOnly = true )
142
+
143
+ return null
129
144
}
130
145
131
146
// Polls for blocking task, invoked only by the owner
@@ -138,23 +153,41 @@ internal class WorkQueue {
138
153
} // Failed -> someone else stole it
139
154
}
140
155
156
+ return pollWithMode(onlyBlocking = true /* only blocking */ )
157
+ }
158
+
159
+ fun pollCpu (): Task ? {
160
+ while (true ) { // Poll the slot
161
+ val lastScheduled = lastScheduledTask.value ? : break
162
+ if (lastScheduled.isBlocking) break
163
+ if (lastScheduledTask.compareAndSet(lastScheduled, null )) {
164
+ return lastScheduled
165
+ } // Failed -> someone else stole it
166
+ }
167
+
168
+ return pollWithMode(onlyBlocking = false /* only cpu */ )
169
+ }
170
+
171
+ private fun pollWithMode (/* Only blocking OR only CPU */ onlyBlocking : Boolean ): Task ? {
141
172
val start = consumerIndex.value
142
173
var end = producerIndex.value
143
-
144
- while (start != end && blockingTasksInBuffer.value > 0 ) {
145
- val task = tryExtractBlockingTask(-- end)
174
+ // CPU or (BLOCKING & hasBlocking)
175
+ val shouldProceed = ! onlyBlocking || blockingTasksInBuffer.value > 0
176
+ while (start != end && shouldProceed) {
177
+ val task = tryExtractFromTheMiddle(-- end, onlyBlocking)
146
178
if (task != null ) {
147
179
return task
148
180
}
149
181
}
150
182
return null
151
183
}
152
184
153
- private fun tryExtractBlockingTask (index : Int ): Task ? {
185
+ private fun tryExtractFromTheMiddle (index : Int , onlyBlocking : Boolean ): Task ? {
186
+ if (onlyBlocking && blockingTasksInBuffer.value == 0 ) return null
154
187
val arrayIndex = index and MASK
155
188
val value = buffer[arrayIndex]
156
- if (value != null && value.isBlocking && buffer.compareAndSet(arrayIndex, value, null )) {
157
- blockingTasksInBuffer.decrementAndGet()
189
+ if (value != null && value.isBlocking == onlyBlocking && buffer.compareAndSet(arrayIndex, value, null )) {
190
+ if (onlyBlocking) blockingTasksInBuffer.decrementAndGet()
158
191
return value
159
192
}
160
193
return null
@@ -170,10 +203,16 @@ internal class WorkQueue {
170
203
/* *
171
204
* Contract on return value is the same as for [trySteal]
172
205
*/
173
- private fun tryStealLastScheduled (stolenTaskRef : ObjectRef <Task ?>, blockingOnly : Boolean ): Long {
206
+ private fun tryStealLastScheduled (stealingMode : StealingMode , stolenTaskRef : ObjectRef <Task ?>): Long {
174
207
while (true ) {
175
208
val lastScheduled = lastScheduledTask.value ? : return NOTHING_TO_STEAL
176
- if (blockingOnly && ! lastScheduled.isBlocking) return NOTHING_TO_STEAL
209
+ if (lastScheduled.isBlocking) {
210
+ if (stealingMode == STEAL_CPU_ONLY ) {
211
+ return NOTHING_TO_STEAL
212
+ }
213
+ } else if (stealingMode == STEAL_BLOCKING_ONLY ) {
214
+ return NOTHING_TO_STEAL
215
+ }
177
216
178
217
// TODO time wraparound ?
179
218
val time = schedulerTimeSource.nanoTime()
0 commit comments