diff --git a/benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkNoAllocationsBenchmark.kt b/benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkNoAllocationsBenchmark.kt new file mode 100644 index 0000000000..dcba8383ad --- /dev/null +++ b/benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkNoAllocationsBenchmark.kt @@ -0,0 +1,37 @@ +/* + * Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package benchmarks + +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import org.openjdk.jmh.annotations.* +import java.util.concurrent.* +import kotlin.coroutines.* + +@Warmup(iterations = 3, time = 1) +@Measurement(iterations = 5, time = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +@Fork(1) +open class ChannelSinkNoAllocationsBenchmark { + private val unconfined = Dispatchers.Unconfined + + @Benchmark + fun channelPipeline(): Int = runBlocking { + run(unconfined) + } + + private suspend inline fun run(context: CoroutineContext): Int { + var size = 0 + Channel.range(context).consumeEach { size++ } + return size + } + + private fun Channel.Factory.range(context: CoroutineContext) = GlobalScope.produce(context) { + for (i in 0 until 100_000) + send(Unit) // no allocations + } +} diff --git a/benchmarks/src/jmh/kotlin/benchmarks/SequentialSemaphoreBenchmark.kt b/benchmarks/src/jmh/kotlin/benchmarks/SequentialSemaphoreBenchmark.kt new file mode 100644 index 0000000000..6926db783a --- /dev/null +++ b/benchmarks/src/jmh/kotlin/benchmarks/SequentialSemaphoreBenchmark.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package benchmarks + +import kotlinx.coroutines.* +import kotlinx.coroutines.sync.* +import org.openjdk.jmh.annotations.* +import java.util.concurrent.TimeUnit +import kotlin.test.* + +@Warmup(iterations = 5, time = 1) +@Measurement(iterations = 10, time = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +@Fork(1) +open class SequentialSemaphoreAsMutexBenchmark { + val s = Semaphore(1) + + @Benchmark + fun benchmark() : Unit = runBlocking { + val s = Semaphore(permits = 1, acquiredPermits = 1) + var step = 0 + launch(Dispatchers.Unconfined) { + repeat(N) { + assertEquals(it * 2, step) + step++ + s.acquire() + } + } + repeat(N) { + assertEquals(it * 2 + 1, step) + step++ + s.release() + } + } +} + +fun main() = SequentialSemaphoreAsMutexBenchmark().benchmark() + +private val N = 1_000_000 \ No newline at end of file diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index e880916cc6..b4f2928d03 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -51,7 +51,7 @@ public final class kotlinx/coroutines/CancellableContinuation$DefaultImpls { public static synthetic fun tryResume$default (Lkotlinx/coroutines/CancellableContinuation;Ljava/lang/Object;Ljava/lang/Object;ILjava/lang/Object;)Ljava/lang/Object; } -public class kotlinx/coroutines/CancellableContinuationImpl : kotlin/coroutines/jvm/internal/CoroutineStackFrame, kotlinx/coroutines/CancellableContinuation, kotlinx/coroutines/channels/Waiter { +public class kotlinx/coroutines/CancellableContinuationImpl : kotlin/coroutines/jvm/internal/CoroutineStackFrame, kotlinx/coroutines/CancellableContinuation, kotlinx/coroutines/Waiter { public fun (Lkotlin/coroutines/Continuation;I)V public final fun callCancelHandler (Lkotlinx/coroutines/CancelHandler;Ljava/lang/Throwable;)V public final fun callOnCancellation (Lkotlin/jvm/functions/Function1;Ljava/lang/Throwable;)V @@ -64,6 +64,7 @@ public class kotlinx/coroutines/CancellableContinuationImpl : kotlin/coroutines/ public fun getStackTraceElement ()Ljava/lang/StackTraceElement; public fun initCancellability ()V public fun invokeOnCancellation (Lkotlin/jvm/functions/Function1;)V + public fun invokeOnCancellation (Lkotlinx/coroutines/internal/Segment;I)V public fun isActive ()Z public fun isCancelled ()Z public fun isCompleted ()Z @@ -1258,6 +1259,7 @@ public class kotlinx/coroutines/selects/SelectImplementation : kotlinx/coroutine public fun invoke (Lkotlinx/coroutines/selects/SelectClause1;Lkotlin/jvm/functions/Function2;)V public fun invoke (Lkotlinx/coroutines/selects/SelectClause2;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)V public fun invoke (Lkotlinx/coroutines/selects/SelectClause2;Lkotlin/jvm/functions/Function2;)V + public fun invokeOnCancellation (Lkotlinx/coroutines/internal/Segment;I)V public fun onTimeout (JLkotlin/jvm/functions/Function1;)V public fun selectInRegistrationPhase (Ljava/lang/Object;)V public fun trySelect (Ljava/lang/Object;Ljava/lang/Object;)Z diff --git a/kotlinx-coroutines-core/common/src/CancellableContinuation.kt b/kotlinx-coroutines-core/common/src/CancellableContinuation.kt index 5e8d7f9102..478ff72044 100644 --- a/kotlinx-coroutines-core/common/src/CancellableContinuation.kt +++ b/kotlinx-coroutines-core/common/src/CancellableContinuation.kt @@ -328,7 +328,7 @@ public suspend inline fun suspendCancellableCoroutine( * [CancellableContinuationImpl] is reused. */ internal suspend inline fun suspendCancellableCoroutineReusable( - crossinline block: (CancellableContinuation) -> Unit + crossinline block: (CancellableContinuationImpl) -> Unit ): T = suspendCoroutineUninterceptedOrReturn { uCont -> val cancellable = getOrCreateCancellableContinuation(uCont.intercepted()) block(cancellable) diff --git a/kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt b/kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt index 423cb05d18..8006f3bf8e 100644 --- a/kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt +++ b/kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt @@ -5,7 +5,6 @@ package kotlinx.coroutines import kotlinx.atomicfu.* -import kotlinx.coroutines.channels.Waiter import kotlinx.coroutines.internal.* import kotlin.coroutines.* import kotlin.coroutines.intrinsics.* @@ -15,6 +14,15 @@ private const val UNDECIDED = 0 private const val SUSPENDED = 1 private const val RESUMED = 2 +private const val DECISION_SHIFT = 29 +private const val INDEX_MASK = (1 shl DECISION_SHIFT) - 1 +private const val NO_INDEX = INDEX_MASK + +private inline val Int.decision get() = this shr DECISION_SHIFT +private inline val Int.index get() = this and INDEX_MASK +@Suppress("NOTHING_TO_INLINE") +private inline fun decisionAndIndex(decision: Int, index: Int) = (decision shl DECISION_SHIFT) + index + @JvmField internal val RESUME_TOKEN = Symbol("RESUME_TOKEN") @@ -44,7 +52,7 @@ internal open class CancellableContinuationImpl( * less dependencies. */ - /* decision state machine + /** decision state machine +-----------+ trySuspend +-----------+ | UNDECIDED | -------------> | SUSPENDED | @@ -56,9 +64,12 @@ internal open class CancellableContinuationImpl( | RESUMED | +-----------+ - Note: both tryResume and trySuspend can be invoked at most once, first invocation wins + Note: both tryResume and trySuspend can be invoked at most once, first invocation wins. + If the cancellation handler is specified via a [Segment] instance and the index in it + (so [Segment.onCancellation] should be called), the [_decisionAndIndex] field may store + this index additionally to the "decision" value. */ - private val _decision = atomic(UNDECIDED) + private val _decisionAndIndex = atomic(decisionAndIndex(UNDECIDED, NO_INDEX)) /* === Internal states === @@ -144,7 +155,7 @@ internal open class CancellableContinuationImpl( detachChild() return false } - _decision.value = UNDECIDED + _decisionAndIndex.value = decisionAndIndex(UNDECIDED, NO_INDEX) _state.value = Active return true } @@ -194,10 +205,13 @@ internal open class CancellableContinuationImpl( _state.loop { state -> if (state !is NotCompleted) return false // false if already complete or cancelling // Active -- update to final state - val update = CancelledContinuation(this, cause, handled = state is CancelHandler) + val update = CancelledContinuation(this, cause, handled = state is CancelHandler || state is Segment<*>) if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure // Invoke cancel handler if it was present - (state as? CancelHandler)?.let { callCancelHandler(it, cause) } + when (state) { + is CancelHandler -> callCancelHandler(state, cause) + is Segment<*> -> callSegmentOnCancellation(state, cause) + } // Complete state update detachChildIfNonResuable() dispatchResume(resumeMode) // no need for additional cancellation checks @@ -234,6 +248,12 @@ internal open class CancellableContinuationImpl( fun callCancelHandler(handler: CancelHandler, cause: Throwable?) = callCancelHandlerSafely { handler.invoke(cause) } + private fun callSegmentOnCancellation(segment: Segment<*>, cause: Throwable?) { + val index = _decisionAndIndex.value.index + check(index != NO_INDEX) { "The index for Segment.onCancellation(..) is broken" } + callCancelHandlerSafely { segment.onCancellation(index, cause) } + } + fun callOnCancellation(onCancellation: (cause: Throwable) -> Unit, cause: Throwable) { try { onCancellation.invoke(cause) @@ -253,9 +273,9 @@ internal open class CancellableContinuationImpl( parent.getCancellationException() private fun trySuspend(): Boolean { - _decision.loop { decision -> - when (decision) { - UNDECIDED -> if (this._decision.compareAndSet(UNDECIDED, SUSPENDED)) return true + _decisionAndIndex.loop { cur -> + when (cur.decision) { + UNDECIDED -> if (this._decisionAndIndex.compareAndSet(cur, decisionAndIndex(SUSPENDED, cur.index))) return true RESUMED -> return false else -> error("Already suspended") } @@ -263,9 +283,9 @@ internal open class CancellableContinuationImpl( } private fun tryResume(): Boolean { - _decision.loop { decision -> - when (decision) { - UNDECIDED -> if (this._decision.compareAndSet(UNDECIDED, RESUMED)) return true + _decisionAndIndex.loop { cur -> + when (cur.decision) { + UNDECIDED -> if (this._decisionAndIndex.compareAndSet(cur, decisionAndIndex(RESUMED, cur.index))) return true SUSPENDED -> return false else -> error("Already resumed") } @@ -275,7 +295,7 @@ internal open class CancellableContinuationImpl( @PublishedApi internal fun getResult(): Any? { val isReusable = isReusable() - // trySuspend may fail either if 'block' has resumed/cancelled a continuation + // trySuspend may fail either if 'block' has resumed/cancelled a continuation, // or we got async cancellation from parent. if (trySuspend()) { /* @@ -350,14 +370,44 @@ internal open class CancellableContinuationImpl( override fun resume(value: T, onCancellation: ((cause: Throwable) -> Unit)?) = resumeImpl(value, resumeMode, onCancellation) + /** + * An optimized version for the code below that does not allocate + * a cancellation handler object and efficiently stores the specified + * [segment] and [index] in this [CancellableContinuationImpl]. + * + * The only difference is that `segment.onCancellation(..)` is never + * called if this continuation is already completed; thus, + * the semantics is similar to [BeforeResumeCancelHandler]. + * + * ``` + * invokeOnCancellation { cause -> + * segment.onCancellation(index, cause) + * } + * ``` + */ + override fun invokeOnCancellation(segment: Segment<*>, index: Int) { + _decisionAndIndex.update { + check(it.index == NO_INDEX) { + "invokeOnCancellation should be called at most once" + } + decisionAndIndex(it.decision, index) + } + invokeOnCancellationImpl(segment) + } + public override fun invokeOnCancellation(handler: CompletionHandler) { val cancelHandler = makeCancelHandler(handler) + invokeOnCancellationImpl(cancelHandler) + } + + private fun invokeOnCancellationImpl(handler: Any) { + assert { handler is CancelHandler || handler is Segment<*> } _state.loop { state -> when (state) { is Active -> { - if (_state.compareAndSet(state, cancelHandler)) return // quit on cas success + if (_state.compareAndSet(state, handler)) return // quit on cas success } - is CancelHandler -> multipleHandlersError(handler, state) + is CancelHandler, is Segment<*> -> multipleHandlersError(handler, state) is CompletedExceptionally -> { /* * Continuation was already cancelled or completed exceptionally. @@ -371,7 +421,13 @@ internal open class CancellableContinuationImpl( * because we play type tricks on Kotlin/JS and handler is not necessarily a function there */ if (state is CancelledContinuation) { - callCancelHandler(handler, (state as? CompletedExceptionally)?.cause) + val cause: Throwable? = (state as? CompletedExceptionally)?.cause + if (handler is CancelHandler) { + callCancelHandler(handler, cause) + } else { + val segment = handler as Segment<*> + callSegmentOnCancellation(segment, cause) + } } return } @@ -380,14 +436,16 @@ internal open class CancellableContinuationImpl( * Continuation was already completed, and might already have cancel handler. */ if (state.cancelHandler != null) multipleHandlersError(handler, state) - // BeforeResumeCancelHandler does not need to be called on a completed continuation - if (cancelHandler is BeforeResumeCancelHandler) return + // BeforeResumeCancelHandler and Segment.invokeOnCancellation(..) + // do NOT need to be called on completed continuation. + if (handler is BeforeResumeCancelHandler || handler is Segment<*>) return + handler as CancelHandler if (state.cancelled) { // Was already cancelled while being dispatched -- invoke the handler directly callCancelHandler(handler, state.cancelCause) return } - val update = state.copy(cancelHandler = cancelHandler) + val update = state.copy(cancelHandler = handler) if (_state.compareAndSet(state, update)) return // quit on cas success } else -> { @@ -396,15 +454,16 @@ internal open class CancellableContinuationImpl( * Change its state to CompletedContinuation, unless we have BeforeResumeCancelHandler which * does not need to be called in this case. */ - if (cancelHandler is BeforeResumeCancelHandler) return - val update = CompletedContinuation(state, cancelHandler = cancelHandler) + if (handler is BeforeResumeCancelHandler || handler is Segment<*>) return + handler as CancelHandler + val update = CompletedContinuation(state, cancelHandler = handler) if (_state.compareAndSet(state, update)) return // quit on cas success } } } } - private fun multipleHandlersError(handler: CompletionHandler, state: Any?) { + private fun multipleHandlersError(handler: Any, state: Any?) { error("It's prohibited to register multiple handlers, tried to register $handler, already has $state") } diff --git a/kotlinx-coroutines-core/common/src/Waiter.kt b/kotlinx-coroutines-core/common/src/Waiter.kt new file mode 100644 index 0000000000..79d3dbf564 --- /dev/null +++ b/kotlinx-coroutines-core/common/src/Waiter.kt @@ -0,0 +1,21 @@ +/* + * Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlinx.coroutines.internal.Segment +import kotlinx.coroutines.selects.* + +/** + * All waiters (such as [CancellableContinuationImpl] and [SelectInstance]) in synchronization and + * communication primitives, should implement this interface to make the code faster and easier to read. + */ +internal interface Waiter { + /** + * When this waiter is cancelled, [Segment.onCancellation] with + * the specified [segment] and [index] should be called. + * This function installs the corresponding cancellation handler. + */ + fun invokeOnCancellation(segment: Segment<*>, index: Int) +} diff --git a/kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt b/kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt index e30486fa8c..d5773df544 100644 --- a/kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt +++ b/kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt @@ -160,38 +160,34 @@ internal open class BufferedChannel( // cancellation, it is guaranteed that the element // has been already buffered or passed to receiver. onRendezvousOrBuffered = { cont.resume(Unit) }, - // Clean the cell on suspension and invoke - // `onUndeliveredElement(..)` if needed. - onSuspend = { segm, i -> cont.prepareSenderForSuspension(segm, i) }, // If the channel is closed, call `onUndeliveredElement(..)` and complete the // continuation with the corresponding exception. onClosed = { onClosedSendOnNoWaiterSuspend(element, cont) }, ) } - private fun CancellableContinuation<*>.prepareSenderForSuspension( + private fun Waiter.prepareSenderForSuspension( /* The working cell is specified by the segment and the index in it. */ segment: ChannelSegment, index: Int ) { if (onUndeliveredElement == null) { - invokeOnCancellation(SenderOrReceiverCancellationHandler(segment, index).asHandler) + invokeOnCancellation(segment, index) } else { - invokeOnCancellation(SenderWithOnUndeliveredElementCancellationHandler(segment, index, context).asHandler) - } - } - - // TODO: Replace with a more efficient cancellation mechanism for segments when #3084 is finished. - private inner class SenderOrReceiverCancellationHandler( - private val segment: ChannelSegment, - private val index: Int - ) : BeforeResumeCancelHandler(), DisposableHandle { - override fun dispose() { - segment.onCancellation(index) + when (this) { + is CancellableContinuation<*> -> { + invokeOnCancellation(SenderWithOnUndeliveredElementCancellationHandler(segment, index, context).asHandler) + } + is SelectInstance<*> -> { + disposeOnCompletion(SenderWithOnUndeliveredElementCancellationHandler(segment, index, context)) + } + is SendBroadcast -> { + cont.invokeOnCancellation(SenderWithOnUndeliveredElementCancellationHandler(segment, index, cont.context).asHandler) + } + else -> error("unexpected sender: $this") + } } - - override fun invoke(cause: Throwable?) = dispose() } private inner class SenderWithOnUndeliveredElementCancellationHandler( @@ -227,8 +223,8 @@ internal open class BufferedChannel( // or the element has been buffered. onRendezvousOrBuffered = { success(Unit) }, // On suspension, the `INTERRUPTED_SEND` token has been installed, - // and this `trySend(e)` fails. According to the contract, - // we do not need to call [onUndeliveredElement] handler. + // and this `trySend(e)` must fail. According to the contract, + // we do not need to call the [onUndeliveredElement] handler. onSuspend = { segm, _ -> segm.onSlotCleaned() failure() @@ -256,7 +252,7 @@ internal open class BufferedChannel( element = element, waiter = SendBroadcast(cont), onRendezvousOrBuffered = { cont.resume(true) }, - onSuspend = { segm, i -> cont.prepareSenderForSuspension(segm, i) }, + onSuspend = { _, _ -> }, onClosed = { cont.resume(false) } ) } @@ -264,7 +260,9 @@ internal open class BufferedChannel( /** * Specifies waiting [sendBroadcast] operation. */ - private class SendBroadcast(val cont: CancellableContinuation) : Waiter + private class SendBroadcast( + val cont: CancellableContinuation + ) : Waiter by cont as CancellableContinuationImpl /** * Abstract send implementation. @@ -351,6 +349,7 @@ internal open class BufferedChannel( segment.onSlotCleaned() return onClosed() } + (waiter as? Waiter)?.prepareSenderForSuspension(segment, i) return onSuspend(segment, i) } RESULT_CLOSED -> { @@ -377,7 +376,7 @@ internal open class BufferedChannel( } } - private inline fun sendImplOnNoWaiter( + private inline fun sendImplOnNoWaiter( /* The working cell is specified by the segment and the index in it. */ segment: ChannelSegment, @@ -387,21 +386,18 @@ internal open class BufferedChannel( /* The global index of the cell. */ s: Long, /* The waiter to be stored in case of suspension. */ - waiter: Any, + waiter: Waiter, /* This lambda is invoked when the element has been buffered or a rendezvous with a receiver happens.*/ - onRendezvousOrBuffered: () -> R, - /* This lambda is called when the operation suspends in the - cell specified by the segment and the index in it. */ - onSuspend: (segm: ChannelSegment, i: Int) -> R, + onRendezvousOrBuffered: () -> Unit, /* This lambda is called when the channel is observed in the closed state. */ - onClosed: () -> R, - ): R = + onClosed: () -> Unit, + ) { // Update the cell again, now with the non-null waiter, // restarting the operation from the beginning on failure. // Check the `sendImpl(..)` function for the comments. - when(updateCellSend(segment, index, element, s, waiter, false)) { + when (updateCellSend(segment, index, element, s, waiter, false)) { RESULT_RENDEZVOUS -> { segment.cleanPrev() onRendezvousOrBuffered() @@ -410,7 +406,7 @@ internal open class BufferedChannel( onRendezvousOrBuffered() } RESULT_SUSPEND -> { - onSuspend(segment, index) + waiter.prepareSenderForSuspension(segment, index) } RESULT_CLOSED -> { if (s < receiversCounter) segment.cleanPrev() @@ -422,12 +418,13 @@ internal open class BufferedChannel( element = element, waiter = waiter, onRendezvousOrBuffered = onRendezvousOrBuffered, - onSuspend = onSuspend, + onSuspend = { _, _ -> }, onClosed = onClosed, ) } else -> error("unexpected") } + } private fun updateCellSend( /* The working cell is specified by @@ -737,14 +734,13 @@ internal open class BufferedChannel( val onCancellation = onUndeliveredElement?.bindCancellationFun(element, cont.context) cont.resume(element, onCancellation) }, - onSuspend = { segm, i, _ -> cont.prepareReceiverForSuspension(segm, i) }, onClosed = { onClosedReceiveOnNoWaiterSuspend(cont) }, ) } - private fun CancellableContinuation<*>.prepareReceiverForSuspension(segment: ChannelSegment, index: Int) { + private fun Waiter.prepareReceiverForSuspension(segment: ChannelSegment, index: Int) { onReceiveEnqueued() - invokeOnCancellation(SenderOrReceiverCancellationHandler(segment, index).asHandler) + invokeOnCancellation(segment, index) } private fun onClosedReceiveOnNoWaiterSuspend(cont: CancellableContinuation) { @@ -773,15 +769,14 @@ internal open class BufferedChannel( segment: ChannelSegment, index: Int, r: Long - ) = suspendCancellableCoroutineReusable> { cont -> - val waiter = ReceiveCatching(cont) + ) = suspendCancellableCoroutineReusable { cont -> + val waiter = ReceiveCatching(cont as CancellableContinuationImpl>) receiveImplOnNoWaiter( segment, index, r, waiter = waiter, onElementRetrieved = { element -> cont.resume(success(element), onUndeliveredElement?.bindCancellationFun(element, cont.context)) }, - onSuspend = { segm, i, _ -> cont.prepareReceiverForSuspension(segm, i) }, onClosed = { onClosedReceiveCatchingOnNoWaiterSuspend(cont) } ) } @@ -815,7 +810,7 @@ internal open class BufferedChannel( // Finish when an element is successfully retrieved. onElementRetrieved = { element -> success(element) }, // On suspension, the `INTERRUPTED_RCV` token has been - // installed, and this `tryReceive()` fails. + // installed, and this `tryReceive()` must fail. onSuspend = { segm, _, globalIndex -> // Emulate "cancelled" receive, thus invoking 'waitExpandBufferCompletion' manually, // because effectively there were no cancellation @@ -941,6 +936,7 @@ internal open class BufferedChannel( updCellResult === SUSPEND -> { // The operation has decided to suspend and // stored the specified waiter in the cell. + (waiter as? Waiter)?.prepareReceiverForSuspension(segment, i) onSuspend(segment, i, r) } updCellResult === FAILED -> { @@ -970,7 +966,7 @@ internal open class BufferedChannel( } } - private inline fun receiveImplOnNoWaiter( + private inline fun receiveImplOnNoWaiter( /* The working cell is specified by the segment and the index in it. */ segment: ChannelSegment, @@ -978,40 +974,37 @@ internal open class BufferedChannel( /* The global index of the cell. */ r: Long, /* The waiter to be stored in case of suspension. */ - waiter: W, + waiter: Waiter, /* This lambda is invoked when an element has been successfully retrieved, either from the buffer or by making a rendezvous with a suspended sender. */ - onElementRetrieved: (element: E) -> R, - /* This lambda is called when the operation suspends in the cell - specified by the segment and its global and in-segment indices. */ - onSuspend: (segm: ChannelSegment, i: Int, r: Long) -> R, + onElementRetrieved: (element: E) -> Unit, /* This lambda is called when the channel is observed in the closed state and no waiting senders is found, which means that it is closed for receiving. */ - onClosed: () -> R - ): R { + onClosed: () -> Unit + ) { // Update the cell with the non-null waiter, // restarting from the beginning on failure. // Check the `receiveImpl(..)` function for the comments. val updCellResult = updateCellReceive(segment, index, r, waiter) when { updCellResult === SUSPEND -> { - return onSuspend(segment, index, r) + waiter.prepareReceiverForSuspension(segment, index) } updCellResult === FAILED -> { if (r < sendersCounter) segment.cleanPrev() - return receiveImpl( + receiveImpl( waiter = waiter, onElementRetrieved = onElementRetrieved, - onSuspend = onSuspend, + onSuspend = { _, _, _ -> }, onClosed = onClosed ) } else -> { segment.cleanPrev() @Suppress("UNCHECKED_CAST") - return onElementRetrieved(updCellResult as E) + onElementRetrieved(updCellResult as E) } } } @@ -1498,22 +1491,10 @@ internal open class BufferedChannel( element = element as E, waiter = select, onRendezvousOrBuffered = { select.selectInRegistrationPhase(Unit) }, - onSuspend = { segm, i -> select.prepareSenderForSuspension(segm, i) }, + onSuspend = { _, _ -> }, onClosed = { onClosedSelectOnSend(element, select) } ) - private fun SelectInstance<*>.prepareSenderForSuspension( - // The working cell is specified by - // the segment and the index in it. - segment: ChannelSegment, - index: Int - ) { - if (onUndeliveredElement == null) { - disposeOnCompletion(SenderOrReceiverCancellationHandler(segment, index)) - } else { - disposeOnCompletion(SenderWithOnUndeliveredElementCancellationHandler(segment, index, context)) - } - } private fun onClosedSelectOnSend(element: E, select: SelectInstance<*>) { onUndeliveredElement?.callUndeliveredElement(element, select.context) @@ -1557,20 +1538,10 @@ internal open class BufferedChannel( receiveImpl( // <-- this is an inline function waiter = select, onElementRetrieved = { elem -> select.selectInRegistrationPhase(elem) }, - onSuspend = { segm, i, _ -> select.prepareReceiverForSuspension(segm, i) }, + onSuspend = { _, _, _ -> }, onClosed = { onClosedSelectOnReceive(select) } ) - private fun SelectInstance<*>.prepareReceiverForSuspension( - /* The working cell is specified by - the segment and the index in it. */ - segment: ChannelSegment, - index: Int - ) { - onReceiveEnqueued() - disposeOnCompletion(SenderOrReceiverCancellationHandler(segment, index)) - } - private fun onClosedSelectOnReceive(select: SelectInstance<*>) { select.selectInRegistrationPhase(CHANNEL_CLOSED) } @@ -1643,14 +1614,14 @@ internal open class BufferedChannel( // When `hasNext()` suspends, the location where the continuation // is stored is specified via the segment and the index in it. // We need this information in the cancellation handler below. - private var segment: ChannelSegment? = null + private var segment: Segment<*>? = null private var index = -1 /** * Invoked on cancellation, [BeforeResumeCancelHandler] implementation. */ override fun invoke(cause: Throwable?) { - segment?.onCancellation(index) + segment?.onCancellation(index, null) } // `hasNext()` is just a special receive operation. @@ -1706,18 +1677,16 @@ internal open class BufferedChannel( this.continuation = null cont.resume(true, onUndeliveredElement?.bindCancellationFun(element, cont.context)) }, - onSuspend = { segm, i, _ -> prepareForSuspension(segm, i) }, onClosed = { onClosedHasNextNoWaiterSuspend() } ) } - private fun prepareForSuspension(segment: ChannelSegment, index: Int) { + override fun invokeOnCancellation(segment: Segment<*>, index: Int) { this.segment = segment this.index = index // It is possible that this `hasNext()` invocation is already // resumed, and the `continuation` field is already updated to `null`. this.continuation?.invokeOnCancellation(this.asHandler) - onReceiveEnqueued() } private fun onClosedHasNextNoWaiterSuspend() { @@ -2859,6 +2828,10 @@ internal class ChannelSegment(id: Long, prev: ChannelSegment?, channel: Bu // # Cancellation Support # // ######################## + override fun onCancellation(index: Int, cause: Throwable?) { + onCancellation(index) + } + fun onSenderCancellationWithOnUndeliveredElement(index: Int, context: CoroutineContext) { // Read the element first. If the operation has not been successfully resumed // (this cancellation may be caused by prompt cancellation during dispatching), @@ -3062,8 +3035,8 @@ private class WaiterEB(@JvmField val waiter: Waiter) { * uses this wrapper for its continuation. */ private class ReceiveCatching( - @JvmField val cont: CancellableContinuation> -) : Waiter + @JvmField val cont: CancellableContinuationImpl> +) : Waiter by cont /* Internal results for [BufferedChannel.updateCellReceive]. @@ -3148,10 +3121,3 @@ private inline val Long.ebCompletedCounter get() = this and EB_COMPLETED_COUNTER private inline val Long.ebPauseExpandBuffers: Boolean get() = (this and EB_COMPLETED_PAUSE_EXPAND_BUFFERS_BIT) != 0L private fun constructEBCompletedAndPauseFlag(counter: Long, pauseEB: Boolean): Long = (if (pauseEB) EB_COMPLETED_PAUSE_EXPAND_BUFFERS_BIT else 0) + counter - -/** - * All waiters, such as [CancellableContinuationImpl], [SelectInstance], and - * [BufferedChannel.BufferedChannelIterator], should be marked with this interface - * to make the code faster and easier to read. - */ -internal interface Waiter diff --git a/kotlinx-coroutines-core/common/src/channels/ConflatedBufferedChannel.kt b/kotlinx-coroutines-core/common/src/channels/ConflatedBufferedChannel.kt index 6a9f23e958..699030725b 100644 --- a/kotlinx-coroutines-core/common/src/channels/ConflatedBufferedChannel.kt +++ b/kotlinx-coroutines-core/common/src/channels/ConflatedBufferedChannel.kt @@ -85,16 +85,16 @@ internal open class ConflatedBufferedChannel( waiter = BUFFERED, // Finish successfully when a rendezvous has happened // or the element has been buffered. - onRendezvousOrBuffered = { success(Unit) }, + onRendezvousOrBuffered = { return success(Unit) }, // In case the algorithm decided to suspend, the element // was added to the buffer. However, as the buffer is now // overflowed, the first (oldest) element has to be extracted. onSuspend = { segm, i -> dropFirstElementUntilTheSpecifiedCellIsInTheBuffer(segm.id * SEGMENT_SIZE + i) - success(Unit) + return success(Unit) }, // If the channel is closed, return the corresponding result. - onClosed = { closed(sendException) } + onClosed = { return closed(sendException) } ) @Suppress("UNCHECKED_CAST") diff --git a/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt b/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt index 4b27d8491a..f848e37881 100644 --- a/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt +++ b/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt @@ -188,8 +188,21 @@ internal abstract class ConcurrentLinkedListNode * Each segment in the list has a unique id and is created by the provided to [findSegmentAndMoveForward] method. * Essentially, this is a node in the Michael-Scott queue algorithm, * but with maintaining [prev] pointer for efficient [remove] implementation. + * + * NB: this class cannot be public or leak into user's code as public type as [CancellableContinuationImpl] + * instance-check it and uses a separate code-path for that. */ -internal abstract class Segment>(val id: Long, prev: S?, pointers: Int): ConcurrentLinkedListNode(prev) { +internal abstract class Segment>( + @JvmField val id: Long, prev: S?, pointers: Int +) : ConcurrentLinkedListNode(prev), + // Segments typically store waiting continuations. Thus, on cancellation, the corresponding + // slot should be cleaned and the segment should be removed if it becomes full of cancelled cells. + // To install such a handler efficiently, without creating an extra object, we allow storing + // segments as cancellation handlers in [CancellableContinuationImpl] state, putting the slot + // index in another field. The details are here: https://github.com/Kotlin/kotlinx.coroutines/pull/3084. + // For that, we need segments to implement this internal marker interface. + NotCompleted +{ /** * This property should return the number of slots in this segment, * it is used to define whether the segment is logically removed. @@ -213,6 +226,13 @@ internal abstract class Segment>(val id: Long, prev: S?, pointers // returns `true` if this segment is logically removed after the decrement. internal fun decPointers() = cleanedAndPointers.addAndGet(-(1 shl POINTERS_SHIFT)) == numberOfSlots && !isTail + /** + * This function is invoked on continuation cancellation when this segment + * with the specified [index] are installed as cancellation handler via + * `SegmentDisposable.disposeOnCancellation(Segment, Int)`. + */ + abstract fun onCancellation(index: Int, cause: Throwable?) + /** * Invoked on each slot clean-up; should not be invoked twice for the same slot. */ diff --git a/kotlinx-coroutines-core/common/src/selects/Select.kt b/kotlinx-coroutines-core/common/src/selects/Select.kt index b9d128b7f8..ce48de34a1 100644 --- a/kotlinx-coroutines-core/common/src/selects/Select.kt +++ b/kotlinx-coroutines-core/common/src/selects/Select.kt @@ -376,11 +376,24 @@ internal open class SelectImplementation constructor( private var clauses: MutableList>? = ArrayList(2) /** - * Stores the completion action provided through [disposeOnCompletion] during clause registration. - * After that, if the clause is successfully registered (so, it has not completed immediately), - * this [DisposableHandle] is stored into the corresponding [ClauseData] instance. + * Stores the completion action provided through [disposeOnCompletion] or [invokeOnCancellation] + * during clause registration. After that, if the clause is successfully registered + * (so, it has not completed immediately), this handler is stored into + * the corresponding [ClauseData] instance. + * + * Note that either [DisposableHandle] is provided, or a [Segment] instance with + * the index in it, which specify the location of storing this `select`. + * In the latter case, [Segment.onCancellation] should be called on completion/cancellation. + */ + private var disposableHandleOrSegment: Any? = null + + /** + * In case the disposable handle is specified via [Segment] + * and index in it, implying calling [Segment.onCancellation], + * the corresponding index is stored in this field. + * The segment is stored in [disposableHandleOrSegment]. */ - private var disposableHandle: DisposableHandle? = null + private var indexInSegment: Int = -1 /** * Stores the result passed via [selectInRegistrationPhase] during clause registration @@ -469,8 +482,10 @@ internal open class SelectImplementation constructor( // This also guarantees that the list of clauses cannot be cleared // in the registration phase, so it is safe to read it with "!!". if (!reregister) clauses!! += this - disposableHandle = this@SelectImplementation.disposableHandle - this@SelectImplementation.disposableHandle = null + disposableHandleOrSegment = this@SelectImplementation.disposableHandleOrSegment + indexInSegment = this@SelectImplementation.indexInSegment + this@SelectImplementation.disposableHandleOrSegment = null + this@SelectImplementation.indexInSegment = -1 } else { // This clause has been selected! // Update the state correspondingly. @@ -493,7 +508,23 @@ internal open class SelectImplementation constructor( } override fun disposeOnCompletion(disposableHandle: DisposableHandle) { - this.disposableHandle = disposableHandle + this.disposableHandleOrSegment = disposableHandle + } + + /** + * An optimized version for the code below that does not allocate + * a cancellation handler object and efficiently stores the specified + * [segment] and [index]. + * + * ``` + * disposeOnCompletion { + * segment.onCancellation(index, null) + * } + * ``` + */ + override fun invokeOnCancellation(segment: Segment<*>, index: Int) { + this.disposableHandleOrSegment = segment + this.indexInSegment = index } override fun selectInRegistrationPhase(internalResult: Any?) { @@ -556,7 +587,8 @@ internal open class SelectImplementation constructor( */ private fun reregisterClause(clauseObject: Any) { val clause = findClause(clauseObject)!! // it is guaranteed that the corresponding clause is presented - clause.disposableHandle = null + clause.disposableHandleOrSegment = null + clause.indexInSegment = -1 clause.register(reregister = true) } @@ -692,7 +724,7 @@ internal open class SelectImplementation constructor( // Invoke all cancellation handlers except for the // one related to the selected clause, if specified. clauses.forEach { clause -> - if (clause !== selectedClause) clause.disposableHandle?.dispose() + if (clause !== selectedClause) clause.dispose() } // We do need to clean all the data to avoid memory leaks. this.state.value = STATE_COMPLETED @@ -716,7 +748,7 @@ internal open class SelectImplementation constructor( // a concurrent clean-up procedure has already completed, and it is safe to finish. val clauses = this.clauses ?: return // Remove this `select` instance from all the clause object (channels, mutexes, etc.). - clauses.forEach { it.disposableHandle?.dispose() } + clauses.forEach { it.dispose() } // We do need to clean all the data to avoid memory leaks. this.internalResult = NO_RESULT this.clauses = null @@ -731,9 +763,11 @@ internal open class SelectImplementation constructor( private val processResFunc: ProcessResultFunction, private val param: Any?, // the user-specified param private val block: Any, // the user-specified block, which should be called if this clause becomes selected - @JvmField val onCancellationConstructor: OnCancellationConstructor?, - @JvmField var disposableHandle: DisposableHandle? = null + @JvmField val onCancellationConstructor: OnCancellationConstructor? ) { + @JvmField var disposableHandleOrSegment: Any? = null + @JvmField var indexInSegment: Int = -1 + /** * Tries to register the specified [select] instance in [clauseObject] and check * whether the registration succeeded or a rendezvous has happened during the registration. @@ -788,6 +822,16 @@ internal open class SelectImplementation constructor( } } + fun dispose() { + with(disposableHandleOrSegment) { + if (this is Segment<*>) { + this.onCancellation(indexInSegment, null) + } else { + (this as? DisposableHandle)?.dispose() + } + } + } + fun createOnCancellationAction(select: SelectInstance<*>, internalResult: Any?) = onCancellationConstructor?.invoke(select, param, internalResult) } diff --git a/kotlinx-coroutines-core/common/src/sync/Mutex.kt b/kotlinx-coroutines-core/common/src/sync/Mutex.kt index b32c67c0ec..5a75013c64 100644 --- a/kotlinx-coroutines-core/common/src/sync/Mutex.kt +++ b/kotlinx-coroutines-core/common/src/sync/Mutex.kt @@ -165,7 +165,7 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 lockSuspend(owner) } - private suspend fun lockSuspend(owner: Any?) = suspendCancellableCoroutineReusable { cont -> + private suspend fun lockSuspend(owner: Any?) = suspendCancellableCoroutineReusable { cont -> val contWithOwner = CancellableContinuationWithOwner(cont, owner) acquire(contWithOwner) } @@ -230,7 +230,7 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 if (owner != null && holdsLock(owner)) { select.selectInRegistrationPhase(ON_LOCK_ALREADY_LOCKED_BY_OWNER) } else { - onAcquireRegFunction(SelectInstanceWithOwner(select, owner), owner) + onAcquireRegFunction(SelectInstanceWithOwner(select as SelectInstanceInternal<*>, owner), owner) } } @@ -243,10 +243,10 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 private inner class CancellableContinuationWithOwner( @JvmField - val cont: CancellableContinuation, + val cont: CancellableContinuationImpl, @JvmField val owner: Any? - ) : CancellableContinuation by cont { + ) : CancellableContinuation by cont, Waiter by cont { override fun tryResume(value: Unit, idempotent: Any?, onCancellation: ((cause: Throwable) -> Unit)?): Any? { assert { this@MutexImpl.owner.value === NO_OWNER } val token = cont.tryResume(value, idempotent) { @@ -270,10 +270,10 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 private inner class SelectInstanceWithOwner( @JvmField - val select: SelectInstance, + val select: SelectInstanceInternal, @JvmField val owner: Any? - ) : SelectInstanceInternal by select as SelectInstanceInternal { + ) : SelectInstanceInternal by select { override fun trySelect(clauseObject: Any, result: Any?): Boolean { assert { this@MutexImpl.owner.value === NO_OWNER } return select.trySelect(clauseObject, result).also { success -> diff --git a/kotlinx-coroutines-core/common/src/sync/Semaphore.kt b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt index 82c1ed63f6..8ef888d801 100644 --- a/kotlinx-coroutines-core/common/src/sync/Semaphore.kt +++ b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt @@ -195,7 +195,7 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int @JsName("acquireCont") protected fun acquire(waiter: CancellableContinuation) = acquire( waiter = waiter, - suspend = { cont -> addAcquireToQueue(cont) }, + suspend = { cont -> addAcquireToQueue(cont as Waiter) }, onAcquired = { cont -> cont.resume(Unit, onCancellationRelease) } ) @@ -219,7 +219,7 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int protected fun onAcquireRegFunction(select: SelectInstance<*>, ignoredParam: Any?) = acquire( waiter = select, - suspend = { s -> addAcquireToQueue(s) }, + suspend = { s -> addAcquireToQueue(s as Waiter) }, onAcquired = { s -> s.selectInRegistrationPhase(Unit) } ) @@ -281,7 +281,7 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int /** * Returns `false` if the received permit cannot be used and the calling operation should restart. */ - private fun addAcquireToQueue(waiter: Any): Boolean { + private fun addAcquireToQueue(waiter: Waiter): Boolean { val curTail = this.tail.value val enqIdx = enqIdx.getAndIncrement() val createNewSegment = ::createSegment @@ -290,15 +290,7 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int val i = (enqIdx % SEGMENT_SIZE).toInt() // the regular (fast) path -- if the cell is empty, try to install continuation if (segment.cas(i, null, waiter)) { // installed continuation successfully - when (waiter) { - is CancellableContinuation<*> -> { - waiter.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(segment, i).asHandler) - } - is SelectInstance<*> -> { - waiter.disposeOnCompletion(CancelSemaphoreAcquisitionHandler(segment, i)) - } - else -> error("unexpected: $waiter") - } + waiter.invokeOnCancellation(segment, i) return true } // On CAS failure -- the cell must be either PERMIT or BROKEN @@ -364,20 +356,7 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int } } -private class CancelSemaphoreAcquisitionHandler( - private val segment: SemaphoreSegment, - private val index: Int -) : CancelHandler(), DisposableHandle { - override fun invoke(cause: Throwable?) = dispose() - - override fun dispose() { - segment.cancel(index) - } - - override fun toString() = "CancelSemaphoreAcquisitionHandler[$segment, $index]" -} - -private fun createSegment(id: Long, prev: SemaphoreSegment) = SemaphoreSegment(id, prev, 0) +private fun createSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev, 0) private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int) : Segment(id, prev, pointers) { val acquirers = atomicArrayOfNulls(SEGMENT_SIZE) @@ -399,7 +378,7 @@ private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int) // Cleans the acquirer slot located by the specified index // and removes this segment physically if all slots are cleaned. - fun cancel(index: Int) { + override fun onCancellation(index: Int, cause: Throwable?) { // Clean the slot set(index, CANCELLED) // Remove this segment if needed diff --git a/kotlinx-coroutines-core/common/test/CancellableContinuationHandlersTest.kt b/kotlinx-coroutines-core/common/test/CancellableContinuationHandlersTest.kt index 3c11182e00..bd6a44fff8 100644 --- a/kotlinx-coroutines-core/common/test/CancellableContinuationHandlersTest.kt +++ b/kotlinx-coroutines-core/common/test/CancellableContinuationHandlersTest.kt @@ -6,6 +6,7 @@ package kotlinx.coroutines +import kotlinx.coroutines.internal.* import kotlin.coroutines.* import kotlin.test.* @@ -159,4 +160,31 @@ class CancellableContinuationHandlersTest : TestBase() { } finish(3) } + + @Test + fun testSegmentAsHandler() = runTest { + class MySegment : Segment(0, null, 0) { + override val numberOfSlots: Int get() = 0 + + var invokeOnCancellationCalled = false + override fun onCancellation(index: Int, cause: Throwable?) { + invokeOnCancellationCalled = true + } + } + val s = MySegment() + expect(1) + try { + suspendCancellableCoroutine { c -> + expect(2) + c as CancellableContinuationImpl<*> + c.invokeOnCancellation(s, 0) + c.cancel() + } + } catch (e: CancellationException) { + expect(3) + } + expect(4) + check(s.invokeOnCancellationCalled) + finish(5) + } }