Skip to content

Commit 34b7b98

Browse files
committed
Optimize CancellableContinuationImpl.invokeOnCancellation(..) for Segments
1 parent b8cfac1 commit 34b7b98

File tree

5 files changed

+119
-38
lines changed

5 files changed

+119
-38
lines changed

benchmarks/src/jmh/kotlin/benchmarks/SequentialSemaphoreBenchmark.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ import org.openjdk.jmh.annotations.*
1010
import java.util.concurrent.TimeUnit
1111
import kotlin.test.*
1212

13-
@Warmup(iterations = 5, time = 100)
14-
@Measurement(iterations = 10, time = 100)
13+
@Warmup(iterations = 5, time = 1)
14+
@Measurement(iterations = 10, time = 1)
1515
@BenchmarkMode(Mode.AverageTime)
1616
@OutputTimeUnit(TimeUnit.MILLISECONDS)
1717
@State(Scope.Benchmark)

kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt

+69-19
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,18 @@ import kotlin.coroutines.*
1111
import kotlin.coroutines.intrinsics.*
1212
import kotlin.jvm.*
1313

14+
private const val DECISION_SHIFT = 29
15+
private const val INDEX_MASK = (1 shl DECISION_SHIFT) - 1
16+
private const val NO_INDEX = INDEX_MASK
1417
private const val UNDECIDED = 0
1518
private const val SUSPENDED = 1
1619
private const val RESUMED = 2
1720

21+
private inline val Int.decision get() = this shr DECISION_SHIFT
22+
private inline val Int.index get() = this and INDEX_MASK
23+
@Suppress("NOTHING_TO_INLINE")
24+
private inline fun construct(decision: Int, index: Int) = (decision shl DECISION_SHIFT) + index
25+
1826
@JvmField
1927
internal val RESUME_TOKEN = Symbol("RESUME_TOKEN")
2028

@@ -56,9 +64,9 @@ internal open class CancellableContinuationImpl<in T>(
5664
| RESUMED |
5765
+-----------+
5866
59-
Note: both tryResume and trySuspend can be invoked at most once, first invocation wins
67+
Note: both tryResume and trySuspend can be invoked at most once, first invocation wins.
6068
*/
61-
private val _decision = atomic(UNDECIDED)
69+
private val _decisionAndIndex = atomic(construct(UNDECIDED, NO_INDEX))
6270

6371
/*
6472
=== Internal states ===
@@ -144,7 +152,7 @@ internal open class CancellableContinuationImpl<in T>(
144152
detachChild()
145153
return false
146154
}
147-
_decision.value = UNDECIDED
155+
_decisionAndIndex.value = construct(UNDECIDED, NO_INDEX)
148156
_state.value = Active
149157
return true
150158
}
@@ -194,10 +202,11 @@ internal open class CancellableContinuationImpl<in T>(
194202
_state.loop { state ->
195203
if (state !is NotCompleted) return false // false if already complete or cancelling
196204
// Active -- update to final state
197-
val update = CancelledContinuation(this, cause, handled = state is CancelHandler)
205+
val update = CancelledContinuation(this, cause, handled = state is CancelHandler || state is Segment<*>)
198206
if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure
199207
// Invoke cancel handler if it was present
200208
(state as? CancelHandler)?.let { callCancelHandler(it, cause) }
209+
(state as? Segment<*>)?.let { callSegmentOnCancellation(it, cause) }
201210
// Complete state update
202211
detachChildIfNonResuable()
203212
dispatchResume(resumeMode) // no need for additional cancellation checks
@@ -234,6 +243,13 @@ internal open class CancellableContinuationImpl<in T>(
234243
fun callCancelHandler(handler: CancelHandler, cause: Throwable?) =
235244
callCancelHandlerSafely { handler.invoke(cause) }
236245

246+
private fun callSegmentOnCancellation(segment: Segment<*>, cause: Throwable?) {
247+
val index = _decisionAndIndex.value.index
248+
check(index != NO_INDEX) { "The index for segment.invokeOnCancellation(..) is broken" }
249+
callCancelHandlerSafely { segment.invokeOnCancellation(index, cause) }
250+
}
251+
252+
237253
fun callOnCancellation(onCancellation: (cause: Throwable) -> Unit, cause: Throwable) {
238254
try {
239255
onCancellation.invoke(cause)
@@ -253,19 +269,19 @@ internal open class CancellableContinuationImpl<in T>(
253269
parent.getCancellationException()
254270

255271
private fun trySuspend(): Boolean {
256-
_decision.loop { decision ->
257-
when (decision) {
258-
UNDECIDED -> if (this._decision.compareAndSet(UNDECIDED, SUSPENDED)) return true
272+
_decisionAndIndex.loop { cur ->
273+
when (cur.decision) {
274+
UNDECIDED -> if (this._decisionAndIndex.compareAndSet(cur, construct(SUSPENDED, cur.index))) return true
259275
RESUMED -> return false
260276
else -> error("Already suspended")
261277
}
262278
}
263279
}
264280

265281
private fun tryResume(): Boolean {
266-
_decision.loop { decision ->
267-
when (decision) {
268-
UNDECIDED -> if (this._decision.compareAndSet(UNDECIDED, RESUMED)) return true
282+
_decisionAndIndex.loop { cur ->
283+
when (cur.decision) {
284+
UNDECIDED -> if (this._decisionAndIndex.compareAndSet(cur, construct(RESUMED, cur.index))) return true
269285
SUSPENDED -> return false
270286
else -> error("Already resumed")
271287
}
@@ -350,14 +366,39 @@ internal open class CancellableContinuationImpl<in T>(
350366
override fun resume(value: T, onCancellation: ((cause: Throwable) -> Unit)?) =
351367
resumeImpl(value, resumeMode, onCancellation)
352368

369+
/**
370+
* An optimized version for the code below that does not allocate
371+
* a cancellation handler object and efficiently stores the specified
372+
* [segment] and [index] in this [CancellableContinuationImpl].
373+
* ```
374+
* invokeOnCancellation { cause ->
375+
* segment.invokeOnCancellation(index, cause)
376+
* }
377+
* ```
378+
*/
379+
internal fun invokeOnCancellation(segment: Segment<*>, index: Int) {
380+
_decisionAndIndex.update {
381+
check(it.index == NO_INDEX) {
382+
"invokeOnCancellation should be invoked at most once"
383+
}
384+
construct(it.decision, index)
385+
}
386+
invokeOnCancellationImpl(segment)
387+
}
388+
353389
public override fun invokeOnCancellation(handler: CompletionHandler) {
354390
val cancelHandler = makeCancelHandler(handler)
391+
invokeOnCancellationImpl(cancelHandler)
392+
}
393+
394+
private fun invokeOnCancellationImpl(handler: Any) {
395+
assert { handler is CancelHandler || handler is Segment<*> }
355396
_state.loop { state ->
356397
when (state) {
357398
is Active -> {
358-
if (_state.compareAndSet(state, cancelHandler)) return // quit on cas success
399+
if (_state.compareAndSet(state, handler)) return // quit on cas success
359400
}
360-
is CancelHandler -> multipleHandlersError(handler, state)
401+
is CancelHandler, is Segment<*> -> multipleHandlersError(handler, state)
361402
is CompletedExceptionally -> {
362403
/*
363404
* Continuation was already cancelled or completed exceptionally.
@@ -371,7 +412,13 @@ internal open class CancellableContinuationImpl<in T>(
371412
* because we play type tricks on Kotlin/JS and handler is not necessarily a function there
372413
*/
373414
if (state is CancelledContinuation) {
374-
callCancelHandler(handler, (state as? CompletedExceptionally)?.cause)
415+
val cause: Throwable? = (state as? CompletedExceptionally)?.cause
416+
if (handler is CancelHandler) {
417+
callCancelHandler(handler, cause)
418+
} else {
419+
val segment = handler as Segment<*>
420+
callSegmentOnCancellation(segment, cause)
421+
}
375422
}
376423
return
377424
}
@@ -380,14 +427,16 @@ internal open class CancellableContinuationImpl<in T>(
380427
* Continuation was already completed, and might already have cancel handler.
381428
*/
382429
if (state.cancelHandler != null) multipleHandlersError(handler, state)
383-
// BeforeResumeCancelHandler does not need to be called on a completed continuation
384-
if (cancelHandler is BeforeResumeCancelHandler) return
430+
// BeforeResumeCancelHandler and Segment.invokeOnCancellation(..)
431+
// do NOT need to be called on completed continuation.
432+
if (handler is BeforeResumeCancelHandler || handler is Segment<*>) return
433+
handler as CancelHandler
385434
if (state.cancelled) {
386435
// Was already cancelled while being dispatched -- invoke the handler directly
387436
callCancelHandler(handler, state.cancelCause)
388437
return
389438
}
390-
val update = state.copy(cancelHandler = cancelHandler)
439+
val update = state.copy(cancelHandler = handler)
391440
if (_state.compareAndSet(state, update)) return // quit on cas success
392441
}
393442
else -> {
@@ -396,15 +445,16 @@ internal open class CancellableContinuationImpl<in T>(
396445
* Change its state to CompletedContinuation, unless we have BeforeResumeCancelHandler which
397446
* does not need to be called in this case.
398447
*/
399-
if (cancelHandler is BeforeResumeCancelHandler) return
400-
val update = CompletedContinuation(state, cancelHandler = cancelHandler)
448+
if (handler is BeforeResumeCancelHandler || handler is Segment<*>) return
449+
handler as CancelHandler
450+
val update = CompletedContinuation(state, cancelHandler = handler)
401451
if (_state.compareAndSet(state, update)) return // quit on cas success
402452
}
403453
}
404454
}
405455
}
406456

407-
private fun multipleHandlersError(handler: CompletionHandler, state: Any?) {
457+
private fun multipleHandlersError(handler: Any, state: Any?) {
408458
error("It's prohibited to register multiple handlers, tried to register $handler, already has $state")
409459
}
410460

kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt

+16-1
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,15 @@ internal abstract class ConcurrentLinkedListNode<N : ConcurrentLinkedListNode<N>
188188
* Essentially, this is a node in the Michael-Scott queue algorithm,
189189
* but with maintaining [prev] pointer for efficient [remove] implementation.
190190
*/
191-
internal abstract class Segment<S : Segment<S>>(val id: Long, prev: S?, pointers: Int): ConcurrentLinkedListNode<S>(prev) {
191+
internal abstract class Segment<S : Segment<S>>(val id: Long, prev: S?, pointers: Int) :
192+
ConcurrentLinkedListNode<S>(prev),
193+
// Segments typically store waiting continuations. Thus, on cancellation, the corresponding
194+
// slot should be cleaned and the segment should be removed if it becomes full of cancelled cells.
195+
// To install such a handler efficiently, without creating an extra object, we allow storing
196+
// segments as cancellation handlers in [CancellableContinuationImpl] state, putting the slot
197+
// index in another field. The details are here: https://github.com/Kotlin/kotlinx.coroutines/pull/3084.
198+
NotCompleted
199+
{
192200
/**
193201
* This property should return the number of slots in this segment,
194202
* it is used to define whether the segment is logically removed.
@@ -212,6 +220,13 @@ internal abstract class Segment<S : Segment<S>>(val id: Long, prev: S?, pointers
212220
// returns `true` if this segment is logically removed after the decrement.
213221
internal fun decPointers() = cleanedAndPointers.addAndGet(-(1 shl POINTERS_SHIFT)) == numberOfSlots && !isTail
214222

223+
/**
224+
* This function is invoked on continuation cancellation when this segment
225+
* with the specified [index] are installed as cancellation handler via
226+
* `CancellableContinuationImpl.invokeOnCancellation(Segment, Int)`.
227+
*/
228+
internal open fun invokeOnCancellation(index: Int, cause: Throwable?) {}
229+
215230
/**
216231
* Invoked on each slot clean-up; should not be invoked twice for the same slot.
217232
*/

kotlinx-coroutines-core/common/src/sync/Semaphore.kt

+4-16
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,8 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int
291291
if (segment.cas(i, null, waiter)) { // installed continuation successfully
292292
when (waiter) {
293293
is CancellableContinuation<*> -> {
294-
waiter.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(segment, i).asHandler)
294+
waiter as CancellableContinuationImpl<*>
295+
waiter.invokeOnCancellation(segment, i)
295296
}
296297
is SelectInstance<*> -> {
297298
waiter.disposeOnCompletion(CancelSemaphoreAcquisitionHandler(segment, i))
@@ -362,20 +363,7 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int
362363
}
363364
}
364365

365-
private class CancelSemaphoreAcquisitionHandler(
366-
private val segment: SemaphoreSegment,
367-
private val index: Int
368-
) : CancelHandler(), DisposableHandle {
369-
override fun invoke(cause: Throwable?) = dispose()
370-
371-
override fun dispose() {
372-
segment.cancel(index)
373-
}
374-
375-
override fun toString() = "CancelSemaphoreAcquisitionHandler[$segment, $index]"
376-
}
377-
378-
private fun createSegment(id: Long, prev: SemaphoreSegment) = SemaphoreSegment(id, prev, 0)
366+
private fun createSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev, 0)
379367

380368
private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int) : Segment<SemaphoreSegment>(id, prev, pointers) {
381369
val acquirers = atomicArrayOfNulls<Any?>(SEGMENT_SIZE)
@@ -397,7 +385,7 @@ private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int)
397385

398386
// Cleans the acquirer slot located by the specified index
399387
// and removes this segment physically if all slots are cleaned.
400-
fun cancel(index: Int) {
388+
override fun invokeOnCancellation(index: Int, cause: Throwable?) {
401389
// Clean the slot
402390
set(index, CANCELLED)
403391
// Remove this segment if needed

kotlinx-coroutines-core/common/test/CancellableContinuationHandlersTest.kt

+28
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
package kotlinx.coroutines
88

9+
import kotlinx.coroutines.internal.*
910
import kotlin.coroutines.*
1011
import kotlin.test.*
1112

@@ -159,4 +160,31 @@ class CancellableContinuationHandlersTest : TestBase() {
159160
}
160161
finish(3)
161162
}
163+
164+
@Test
165+
fun testSegmentAsHandler() = runTest {
166+
class MySegment : Segment<MySegment>(0, null, 0) {
167+
override val maxSlots: Int get() = 0
168+
169+
var invokeOnCancellationCalled = false
170+
override fun invokeOnCancellation(index: Int, cause: Throwable?) {
171+
invokeOnCancellationCalled = true
172+
}
173+
}
174+
val s = MySegment()
175+
expect(1)
176+
try {
177+
suspendCancellableCoroutine<Unit> { c ->
178+
expect(2)
179+
c as CancellableContinuationImpl<*>
180+
c.invokeOnCancellation(s, 0)
181+
c.cancel()
182+
}
183+
} catch (e: CancellationException) {
184+
expect(3)
185+
}
186+
expect(4)
187+
check(s.invokeOnCancellationCalled)
188+
finish(5)
189+
}
162190
}

0 commit comments

Comments
 (0)