diff --git a/benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkBenchmark.kt b/benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkBenchmark.kt index f706d3aa03..b4498a0ab6 100644 --- a/benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkBenchmark.kt +++ b/benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkBenchmark.kt @@ -24,6 +24,11 @@ open class ChannelSinkBenchmark { private val unconfinedOneElement = Dispatchers.Unconfined + tl.asContextElement() private val unconfinedTwoElements = Dispatchers.Unconfined + tl.asContextElement() + tl2.asContextElement() + private val elements = (0 until N).toList() + + @Param("0", "1", "8", "32") + var channelCapacity = 0 + @Benchmark fun channelPipeline(): Int = runBlocking { run(unconfined) @@ -41,14 +46,14 @@ open class ChannelSinkBenchmark { private suspend inline fun run(context: CoroutineContext): Int { return Channel - .range(1, 10_000, context) - .filter(context) { it % 4 == 0 } - .fold(0) { a, b -> a + b } + .range(context) // should not allocate `Int`s! + .filter(context) { it % 4 == 0 } // should not allocate `Int`s! + .fold(0) { a, b -> if (a % 8 == 0) a else b } // should not allocate `Int`s! } - private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) = GlobalScope.produce(context) { - for (i in start until (start + count)) - send(i) + private fun Channel.Factory.range(context: CoroutineContext) = GlobalScope.produce(context, capacity = channelCapacity) { + for (i in 0 until N) + send(elements[i]) // should not allocate `Int`s! } // Migrated from deprecated operators, are good only for stressing channels @@ -69,3 +74,4 @@ open class ChannelSinkBenchmark { } } +private const val N = 10_000 diff --git a/benchmarks/src/jmh/kotlin/benchmarks/SemaphoreBenchmark.kt b/benchmarks/src/jmh/kotlin/benchmarks/SemaphoreBenchmark.kt index 6826b7a1a3..97b62a0581 100644 --- a/benchmarks/src/jmh/kotlin/benchmarks/SemaphoreBenchmark.kt +++ b/benchmarks/src/jmh/kotlin/benchmarks/SemaphoreBenchmark.kt @@ -48,7 +48,7 @@ open class SemaphoreBenchmark { val semaphore = Semaphore(_3_maxPermits) val jobs = ArrayList(coroutines) repeat(coroutines) { - jobs += GlobalScope.launch { + jobs += GlobalScope.launch(dispatcher) { repeat(n) { semaphore.withPermit { doGeomDistrWork(WORK_INSIDE) @@ -66,7 +66,7 @@ open class SemaphoreBenchmark { val semaphore = Channel(_3_maxPermits) val jobs = ArrayList(coroutines) repeat(coroutines) { - jobs += GlobalScope.launch { + jobs += GlobalScope.launch(dispatcher) { repeat(n) { semaphore.send(Unit) // acquire doGeomDistrWork(WORK_INSIDE) @@ -87,4 +87,4 @@ enum class SemaphoreBenchDispatcherCreator(val create: (parallelism: Int) -> Cor private const val WORK_INSIDE = 50 private const val WORK_OUTSIDE = 50 -private const val BATCH_SIZE = 100000 +private const val BATCH_SIZE = 1000000 diff --git a/benchmarks/src/jmh/kotlin/benchmarks/SequentialSemaphoreBenchmark.kt b/benchmarks/src/jmh/kotlin/benchmarks/SequentialSemaphoreBenchmark.kt new file mode 100644 index 0000000000..6926db783a --- /dev/null +++ b/benchmarks/src/jmh/kotlin/benchmarks/SequentialSemaphoreBenchmark.kt @@ -0,0 +1,43 @@ +/* + * Copyright 2016-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package benchmarks + +import kotlinx.coroutines.* +import kotlinx.coroutines.sync.* +import org.openjdk.jmh.annotations.* +import java.util.concurrent.TimeUnit +import kotlin.test.* + +@Warmup(iterations = 5, time = 1) +@Measurement(iterations = 10, time = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +@Fork(1) +open class SequentialSemaphoreAsMutexBenchmark { + val s = Semaphore(1) + + @Benchmark + fun benchmark() : Unit = runBlocking { + val s = Semaphore(permits = 1, acquiredPermits = 1) + var step = 0 + launch(Dispatchers.Unconfined) { + repeat(N) { + assertEquals(it * 2, step) + step++ + s.acquire() + } + } + repeat(N) { + assertEquals(it * 2 + 1, step) + step++ + s.release() + } + } +} + +fun main() = SequentialSemaphoreAsMutexBenchmark().benchmark() + +private val N = 1_000_000 \ No newline at end of file diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index def21f8130..a003d6fb78 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -51,7 +51,7 @@ public final class kotlinx/coroutines/CancellableContinuation$DefaultImpls { public static synthetic fun tryResume$default (Lkotlinx/coroutines/CancellableContinuation;Ljava/lang/Object;Ljava/lang/Object;ILjava/lang/Object;)Ljava/lang/Object; } -public class kotlinx/coroutines/CancellableContinuationImpl : kotlin/coroutines/jvm/internal/CoroutineStackFrame, kotlinx/coroutines/CancellableContinuation, kotlinx/coroutines/channels/Waiter { +public class kotlinx/coroutines/CancellableContinuationImpl : kotlin/coroutines/jvm/internal/CoroutineStackFrame, kotlinx/coroutines/CancellableContinuation, kotlinx/coroutines/Waiter { public fun (Lkotlin/coroutines/Continuation;I)V public final fun callCancelHandler (Lkotlinx/coroutines/CancelHandler;Ljava/lang/Throwable;)V public final fun callOnCancellation (Lkotlin/jvm/functions/Function1;Ljava/lang/Throwable;)V @@ -64,6 +64,7 @@ public class kotlinx/coroutines/CancellableContinuationImpl : kotlin/coroutines/ public fun getStackTraceElement ()Ljava/lang/StackTraceElement; public fun initCancellability ()V public fun invokeOnCancellation (Lkotlin/jvm/functions/Function1;)V + public fun invokeOnCancellation (Lkotlinx/coroutines/internal/Segment;I)V public fun isActive ()Z public fun isCancelled ()Z public fun isCompleted ()Z @@ -1257,6 +1258,7 @@ public class kotlinx/coroutines/selects/SelectImplementation : kotlinx/coroutine public fun invoke (Lkotlinx/coroutines/selects/SelectClause1;Lkotlin/jvm/functions/Function2;)V public fun invoke (Lkotlinx/coroutines/selects/SelectClause2;Ljava/lang/Object;Lkotlin/jvm/functions/Function2;)V public fun invoke (Lkotlinx/coroutines/selects/SelectClause2;Lkotlin/jvm/functions/Function2;)V + public fun invokeOnCancellation (Lkotlinx/coroutines/internal/Segment;I)V public fun onTimeout (JLkotlin/jvm/functions/Function1;)V public fun selectInRegistrationPhase (Ljava/lang/Object;)V public fun trySelect (Ljava/lang/Object;Ljava/lang/Object;)Z @@ -1327,6 +1329,18 @@ public final class kotlinx/coroutines/sync/MutexKt { public static synthetic fun withLock$default (Lkotlinx/coroutines/sync/Mutex;Ljava/lang/Object;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; } +public abstract interface class kotlinx/coroutines/sync/ReadWriteMutex { + public abstract fun getWrite ()Lkotlinx/coroutines/sync/Mutex; + public abstract fun readLock (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun readUnlock ()V +} + +public final class kotlinx/coroutines/sync/ReadWriteMutexKt { + public static final fun ReadWriteMutex ()Lkotlinx/coroutines/sync/ReadWriteMutex; + public static final fun read (Lkotlinx/coroutines/sync/ReadWriteMutex;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun write (Lkotlinx/coroutines/sync/ReadWriteMutex;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public abstract interface class kotlinx/coroutines/sync/Semaphore { public abstract fun acquire (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public abstract fun getAvailablePermits ()I diff --git a/kotlinx-coroutines-core/build.gradle b/kotlinx-coroutines-core/build.gradle index 84d9b0485d..77e7056509 100644 --- a/kotlinx-coroutines-core/build.gradle +++ b/kotlinx-coroutines-core/build.gradle @@ -266,8 +266,8 @@ task jvmStressTest(type: Test, dependsOn: compileTestKotlinJvm) { testLogging.showStandardStreams = true systemProperty 'kotlinx.coroutines.scheduler.keep.alive.sec', '100000' // any unpark problem hangs test // Adjust internal algorithmic parameters to increase the testing quality instead of performance. - systemProperty 'kotlinx.coroutines.semaphore.segmentSize', '1' - systemProperty 'kotlinx.coroutines.semaphore.maxSpinCycles', '10' + systemProperty 'kotlinx.coroutines.cqs.segmentSize', '1' + systemProperty 'kotlinx.coroutines.cqs.maxSpinCycles', '10' systemProperty 'kotlinx.coroutines.bufferedChannel.segmentSize', '2' systemProperty 'kotlinx.coroutines.bufferedChannel.expandBufferCompletionWaitIterations', '1' } @@ -302,8 +302,8 @@ static void configureJvmForLincheck(task, additional = false) { '--add-exports', 'java.base/jdk.internal.util=ALL-UNNAMED'] // in the model checking mode // Adjust internal algorithmic parameters to increase the testing quality instead of performance. var segmentSize = additional ? '2' : '1' - task.systemProperty 'kotlinx.coroutines.semaphore.segmentSize', segmentSize - task.systemProperty 'kotlinx.coroutines.semaphore.maxSpinCycles', '1' // better for the model checking mode + task.systemProperty 'kotlinx.coroutines.cqs.segmentSize', segmentSize + task.systemProperty 'kotlinx.coroutines.cqs.maxSpinCycles', '1' // better for the model checking mode task.systemProperty 'kotlinx.coroutines.bufferedChannel.segmentSize', segmentSize task.systemProperty 'kotlinx.coroutines.bufferedChannel.expandBufferCompletionWaitIterations', '1' } diff --git a/kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt b/kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt index 423cb05d18..8006f3bf8e 100644 --- a/kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt +++ b/kotlinx-coroutines-core/common/src/CancellableContinuationImpl.kt @@ -5,7 +5,6 @@ package kotlinx.coroutines import kotlinx.atomicfu.* -import kotlinx.coroutines.channels.Waiter import kotlinx.coroutines.internal.* import kotlin.coroutines.* import kotlin.coroutines.intrinsics.* @@ -15,6 +14,15 @@ private const val UNDECIDED = 0 private const val SUSPENDED = 1 private const val RESUMED = 2 +private const val DECISION_SHIFT = 29 +private const val INDEX_MASK = (1 shl DECISION_SHIFT) - 1 +private const val NO_INDEX = INDEX_MASK + +private inline val Int.decision get() = this shr DECISION_SHIFT +private inline val Int.index get() = this and INDEX_MASK +@Suppress("NOTHING_TO_INLINE") +private inline fun decisionAndIndex(decision: Int, index: Int) = (decision shl DECISION_SHIFT) + index + @JvmField internal val RESUME_TOKEN = Symbol("RESUME_TOKEN") @@ -44,7 +52,7 @@ internal open class CancellableContinuationImpl( * less dependencies. */ - /* decision state machine + /** decision state machine +-----------+ trySuspend +-----------+ | UNDECIDED | -------------> | SUSPENDED | @@ -56,9 +64,12 @@ internal open class CancellableContinuationImpl( | RESUMED | +-----------+ - Note: both tryResume and trySuspend can be invoked at most once, first invocation wins + Note: both tryResume and trySuspend can be invoked at most once, first invocation wins. + If the cancellation handler is specified via a [Segment] instance and the index in it + (so [Segment.onCancellation] should be called), the [_decisionAndIndex] field may store + this index additionally to the "decision" value. */ - private val _decision = atomic(UNDECIDED) + private val _decisionAndIndex = atomic(decisionAndIndex(UNDECIDED, NO_INDEX)) /* === Internal states === @@ -144,7 +155,7 @@ internal open class CancellableContinuationImpl( detachChild() return false } - _decision.value = UNDECIDED + _decisionAndIndex.value = decisionAndIndex(UNDECIDED, NO_INDEX) _state.value = Active return true } @@ -194,10 +205,13 @@ internal open class CancellableContinuationImpl( _state.loop { state -> if (state !is NotCompleted) return false // false if already complete or cancelling // Active -- update to final state - val update = CancelledContinuation(this, cause, handled = state is CancelHandler) + val update = CancelledContinuation(this, cause, handled = state is CancelHandler || state is Segment<*>) if (!_state.compareAndSet(state, update)) return@loop // retry on cas failure // Invoke cancel handler if it was present - (state as? CancelHandler)?.let { callCancelHandler(it, cause) } + when (state) { + is CancelHandler -> callCancelHandler(state, cause) + is Segment<*> -> callSegmentOnCancellation(state, cause) + } // Complete state update detachChildIfNonResuable() dispatchResume(resumeMode) // no need for additional cancellation checks @@ -234,6 +248,12 @@ internal open class CancellableContinuationImpl( fun callCancelHandler(handler: CancelHandler, cause: Throwable?) = callCancelHandlerSafely { handler.invoke(cause) } + private fun callSegmentOnCancellation(segment: Segment<*>, cause: Throwable?) { + val index = _decisionAndIndex.value.index + check(index != NO_INDEX) { "The index for Segment.onCancellation(..) is broken" } + callCancelHandlerSafely { segment.onCancellation(index, cause) } + } + fun callOnCancellation(onCancellation: (cause: Throwable) -> Unit, cause: Throwable) { try { onCancellation.invoke(cause) @@ -253,9 +273,9 @@ internal open class CancellableContinuationImpl( parent.getCancellationException() private fun trySuspend(): Boolean { - _decision.loop { decision -> - when (decision) { - UNDECIDED -> if (this._decision.compareAndSet(UNDECIDED, SUSPENDED)) return true + _decisionAndIndex.loop { cur -> + when (cur.decision) { + UNDECIDED -> if (this._decisionAndIndex.compareAndSet(cur, decisionAndIndex(SUSPENDED, cur.index))) return true RESUMED -> return false else -> error("Already suspended") } @@ -263,9 +283,9 @@ internal open class CancellableContinuationImpl( } private fun tryResume(): Boolean { - _decision.loop { decision -> - when (decision) { - UNDECIDED -> if (this._decision.compareAndSet(UNDECIDED, RESUMED)) return true + _decisionAndIndex.loop { cur -> + when (cur.decision) { + UNDECIDED -> if (this._decisionAndIndex.compareAndSet(cur, decisionAndIndex(RESUMED, cur.index))) return true SUSPENDED -> return false else -> error("Already resumed") } @@ -275,7 +295,7 @@ internal open class CancellableContinuationImpl( @PublishedApi internal fun getResult(): Any? { val isReusable = isReusable() - // trySuspend may fail either if 'block' has resumed/cancelled a continuation + // trySuspend may fail either if 'block' has resumed/cancelled a continuation, // or we got async cancellation from parent. if (trySuspend()) { /* @@ -350,14 +370,44 @@ internal open class CancellableContinuationImpl( override fun resume(value: T, onCancellation: ((cause: Throwable) -> Unit)?) = resumeImpl(value, resumeMode, onCancellation) + /** + * An optimized version for the code below that does not allocate + * a cancellation handler object and efficiently stores the specified + * [segment] and [index] in this [CancellableContinuationImpl]. + * + * The only difference is that `segment.onCancellation(..)` is never + * called if this continuation is already completed; thus, + * the semantics is similar to [BeforeResumeCancelHandler]. + * + * ``` + * invokeOnCancellation { cause -> + * segment.onCancellation(index, cause) + * } + * ``` + */ + override fun invokeOnCancellation(segment: Segment<*>, index: Int) { + _decisionAndIndex.update { + check(it.index == NO_INDEX) { + "invokeOnCancellation should be called at most once" + } + decisionAndIndex(it.decision, index) + } + invokeOnCancellationImpl(segment) + } + public override fun invokeOnCancellation(handler: CompletionHandler) { val cancelHandler = makeCancelHandler(handler) + invokeOnCancellationImpl(cancelHandler) + } + + private fun invokeOnCancellationImpl(handler: Any) { + assert { handler is CancelHandler || handler is Segment<*> } _state.loop { state -> when (state) { is Active -> { - if (_state.compareAndSet(state, cancelHandler)) return // quit on cas success + if (_state.compareAndSet(state, handler)) return // quit on cas success } - is CancelHandler -> multipleHandlersError(handler, state) + is CancelHandler, is Segment<*> -> multipleHandlersError(handler, state) is CompletedExceptionally -> { /* * Continuation was already cancelled or completed exceptionally. @@ -371,7 +421,13 @@ internal open class CancellableContinuationImpl( * because we play type tricks on Kotlin/JS and handler is not necessarily a function there */ if (state is CancelledContinuation) { - callCancelHandler(handler, (state as? CompletedExceptionally)?.cause) + val cause: Throwable? = (state as? CompletedExceptionally)?.cause + if (handler is CancelHandler) { + callCancelHandler(handler, cause) + } else { + val segment = handler as Segment<*> + callSegmentOnCancellation(segment, cause) + } } return } @@ -380,14 +436,16 @@ internal open class CancellableContinuationImpl( * Continuation was already completed, and might already have cancel handler. */ if (state.cancelHandler != null) multipleHandlersError(handler, state) - // BeforeResumeCancelHandler does not need to be called on a completed continuation - if (cancelHandler is BeforeResumeCancelHandler) return + // BeforeResumeCancelHandler and Segment.invokeOnCancellation(..) + // do NOT need to be called on completed continuation. + if (handler is BeforeResumeCancelHandler || handler is Segment<*>) return + handler as CancelHandler if (state.cancelled) { // Was already cancelled while being dispatched -- invoke the handler directly callCancelHandler(handler, state.cancelCause) return } - val update = state.copy(cancelHandler = cancelHandler) + val update = state.copy(cancelHandler = handler) if (_state.compareAndSet(state, update)) return // quit on cas success } else -> { @@ -396,15 +454,16 @@ internal open class CancellableContinuationImpl( * Change its state to CompletedContinuation, unless we have BeforeResumeCancelHandler which * does not need to be called in this case. */ - if (cancelHandler is BeforeResumeCancelHandler) return - val update = CompletedContinuation(state, cancelHandler = cancelHandler) + if (handler is BeforeResumeCancelHandler || handler is Segment<*>) return + handler as CancelHandler + val update = CompletedContinuation(state, cancelHandler = handler) if (_state.compareAndSet(state, update)) return // quit on cas success } } } } - private fun multipleHandlersError(handler: CompletionHandler, state: Any?) { + private fun multipleHandlersError(handler: Any, state: Any?) { error("It's prohibited to register multiple handlers, tried to register $handler, already has $state") } diff --git a/kotlinx-coroutines-core/common/src/Debug.common.kt b/kotlinx-coroutines-core/common/src/Debug.common.kt index 185ad295d8..94a086f8b4 100644 --- a/kotlinx-coroutines-core/common/src/Debug.common.kt +++ b/kotlinx-coroutines-core/common/src/Debug.common.kt @@ -8,6 +8,7 @@ internal expect val DEBUG: Boolean internal expect val Any.hexAddress: String internal expect val Any.classSimpleName: String internal expect fun assert(value: () -> Boolean) +internal inline fun assertNot(crossinline value: () -> Boolean) = assert { !value() } /** * Throwable which can be cloned during stacktrace recovery in a class-specific way. diff --git a/kotlinx-coroutines-core/common/src/Waiter.kt b/kotlinx-coroutines-core/common/src/Waiter.kt new file mode 100644 index 0000000000..2b3caa9980 --- /dev/null +++ b/kotlinx-coroutines-core/common/src/Waiter.kt @@ -0,0 +1,21 @@ +/* + * Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines + +import kotlinx.coroutines.internal.Segment +import kotlinx.coroutines.selects.* + +/** + * All waiters (such as [CancellableContinuationImpl] and [SelectInstance]) in synchronization and + * communication primitives, should implement this interface to make the code faster and easier to read. + */ +internal interface Waiter { + /** + * When this waiter is cancelled, [Segment.onCancellation] with + * the specified [segment] and [index] should be called. + * This function installs the corresponding cancellation handler. + */ + fun invokeOnCancellation(segment: Segment<*>, index: Int) +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt b/kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt index 01b5a16b9c..8ee3e34f85 100644 --- a/kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt +++ b/kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt @@ -153,44 +153,40 @@ internal open class BufferedChannel( sendImplOnNoWaiter( // <-- this is an inline function segment = segment, index = index, element = element, s = s, // Store the created continuation as a waiter. - waiter = cont, + waiter = cont as Waiter, // If a rendezvous happens or the element has been buffered, // resume the continuation and finish. In case of prompt // cancellation, it is guaranteed that the element // has been already buffered or passed to receiver. onRendezvousOrBuffered = { cont.resume(Unit) }, - // Clean the cell on suspension and invoke - // `onUndeliveredElement(..)` if needed. - onSuspend = { segm, i -> cont.prepareSenderForSuspension(segm, i) }, // If the channel is closed, call `onUndeliveredElement(..)` and complete the // continuation with the corresponding exception. onClosed = { onClosedSendOnNoWaiterSuspend(element, cont) }, ) } - private fun CancellableContinuation<*>.prepareSenderForSuspension( + private fun Waiter.prepareSenderForSuspension( /* The working cell is specified by the segment and the index in it. */ segment: ChannelSegment, index: Int ) { if (onUndeliveredElement == null) { - invokeOnCancellation(SenderOrReceiverCancellationHandler(segment, index).asHandler) + invokeOnCancellation(segment, index) } else { - invokeOnCancellation(SenderWithOnUndeliveredElementCancellationHandler(segment, index, context).asHandler) - } - } - - // TODO: Replace with a more efficient cancellation mechanism for segments when #3084 is finished. - private inner class SenderOrReceiverCancellationHandler( - private val segment: ChannelSegment, - private val index: Int - ) : BeforeResumeCancelHandler(), DisposableHandle { - override fun dispose() { - segment.onCancellation(index) + when (this) { + is CancellableContinuation<*> -> { + invokeOnCancellation(SenderWithOnUndeliveredElementCancellationHandler(segment, index, context).asHandler) + } + is SelectInstance<*> -> { + disposeOnCompletion(SenderWithOnUndeliveredElementCancellationHandler(segment, index, context)) + } + is SendBroadcast -> { + cont.invokeOnCancellation(SenderWithOnUndeliveredElementCancellationHandler(segment, index, cont.context).asHandler) + } + else -> error("unexpected sender: $this") + } } - - override fun invoke(cause: Throwable?) = dispose() } private inner class SenderWithOnUndeliveredElementCancellationHandler( @@ -226,8 +222,8 @@ internal open class BufferedChannel( // or the element has been buffered. onRendezvousOrBuffered = { success(Unit) }, // On suspension, the `INTERRUPTED_SEND` token has been installed, - // and this `trySend(e)` fails. According to the contract, - // we do not need to call [onUndeliveredElement] handler. + // and this `trySend(e)` must fail. According to the contract, + // we do not need to call the [onUndeliveredElement] handler. onSuspend = { segm, _ -> segm.onSlotCleaned() failure() @@ -255,7 +251,7 @@ internal open class BufferedChannel( element = element, waiter = SendBroadcast(cont), onRendezvousOrBuffered = { cont.resume(true) }, - onSuspend = { segm, i -> cont.prepareSenderForSuspension(segm, i) }, + onSuspend = { _, _ -> }, onClosed = { cont.resume(false) } ) } @@ -263,7 +259,8 @@ internal open class BufferedChannel( /** * Specifies waiting [sendBroadcast] operation. */ - private class SendBroadcast(val cont: CancellableContinuation) : Waiter + private class SendBroadcast(val cont: CancellableContinuation) : + Waiter by cont as CancellableContinuationImpl /** * Abstract send implementation. @@ -350,6 +347,7 @@ internal open class BufferedChannel( segment.onSlotCleaned() return onClosed() } + (waiter as? Waiter)?.prepareSenderForSuspension(segment, i) return onSuspend(segment, i) } RESULT_CLOSED -> { @@ -376,7 +374,7 @@ internal open class BufferedChannel( } } - private inline fun sendImplOnNoWaiter( + private inline fun sendImplOnNoWaiter( /* The working cell is specified by the segment and the index in it. */ segment: ChannelSegment, @@ -386,21 +384,18 @@ internal open class BufferedChannel( /* The global index of the cell. */ s: Long, /* The waiter to be stored in case of suspension. */ - waiter: Any, + waiter: Waiter, /* This lambda is invoked when the element has been buffered or a rendezvous with a receiver happens.*/ - onRendezvousOrBuffered: () -> R, - /* This lambda is called when the operation suspends in the - cell specified by the segment and the index in it. */ - onSuspend: (segm: ChannelSegment, i: Int) -> R, + onRendezvousOrBuffered: () -> Unit, /* This lambda is called when the channel is observed in the closed state. */ - onClosed: () -> R, - ): R = + onClosed: () -> Unit, + ) { // Update the cell again, now with the non-null waiter, // restarting the operation from the beginning on failure. // Check the `sendImpl(..)` function for the comments. - when(updateCellSend(segment, index, element, s, waiter, false)) { + when (updateCellSend(segment, index, element, s, waiter, false)) { RESULT_RENDEZVOUS -> { segment.cleanPrev() onRendezvousOrBuffered() @@ -409,7 +404,7 @@ internal open class BufferedChannel( onRendezvousOrBuffered() } RESULT_SUSPEND -> { - onSuspend(segment, index) + waiter.prepareSenderForSuspension(segment, index) } RESULT_CLOSED -> { if (s < receiversCounter) segment.cleanPrev() @@ -421,12 +416,13 @@ internal open class BufferedChannel( element = element, waiter = waiter, onRendezvousOrBuffered = onRendezvousOrBuffered, - onSuspend = onSuspend, + onSuspend = { _, _ -> }, onClosed = onClosed, ) } else -> error("unexpected") } + } private fun updateCellSend( /* The working cell is specified by @@ -724,7 +720,7 @@ internal open class BufferedChannel( receiveImplOnNoWaiter( // <-- this is an inline function segment = segment, index = index, r = r, // Store the created continuation as a waiter. - waiter = cont, + waiter = cont as Waiter, // In case of successful element retrieval, resume // the continuation with the element and inform the // `BufferedChannel` extensions that the synchronization @@ -736,14 +732,13 @@ internal open class BufferedChannel( val onCancellation = onUndeliveredElement?.bindCancellationFun(element, cont.context) cont.resume(element, onCancellation) }, - onSuspend = { segm, i, _ -> cont.prepareReceiverForSuspension(segm, i) }, onClosed = { onClosedReceiveOnNoWaiterSuspend(cont) }, ) } - private fun CancellableContinuation<*>.prepareReceiverForSuspension(segment: ChannelSegment, index: Int) { + private fun Waiter.prepareReceiverForSuspension(segment: ChannelSegment, index: Int) { onReceiveEnqueued() - invokeOnCancellation(SenderOrReceiverCancellationHandler(segment, index).asHandler) + invokeOnCancellation(segment, index) } private fun onClosedReceiveOnNoWaiterSuspend(cont: CancellableContinuation) { @@ -772,15 +767,14 @@ internal open class BufferedChannel( segment: ChannelSegment, index: Int, r: Long - ) = suspendCancellableCoroutineReusable> { cont -> - val waiter = ReceiveCatching(cont) + ) = suspendCancellableCoroutineReusable { cont -> + val waiter = ReceiveCatching(cont as CancellableContinuationImpl>) receiveImplOnNoWaiter( segment, index, r, waiter = waiter, onElementRetrieved = { element -> cont.resume(success(element), onUndeliveredElement?.bindCancellationFun(element, cont.context)) }, - onSuspend = { segm, i, _ -> cont.prepareReceiverForSuspension(segm, i) }, onClosed = { onClosedReceiveCatchingOnNoWaiterSuspend(cont) } ) } @@ -814,7 +808,7 @@ internal open class BufferedChannel( // Finish when an element is successfully retrieved. onElementRetrieved = { element -> success(element) }, // On suspension, the `INTERRUPTED_RCV` token has been - // installed, and this `tryReceive()` fails. + // installed, and this `tryReceive()` must fail. onSuspend = { segm, _, globalIndex -> // Emulate "cancelled" receive, thus invoking 'waitExpandBufferCompletion' manually, // because effectively there were no cancellation @@ -940,6 +934,7 @@ internal open class BufferedChannel( updCellResult === SUSPEND -> { // The operation has decided to suspend and // stored the specified waiter in the cell. + (waiter as? Waiter)?.prepareReceiverForSuspension(segment, i) onSuspend(segment, i, r) } updCellResult === FAILED -> { @@ -969,7 +964,7 @@ internal open class BufferedChannel( } } - private inline fun receiveImplOnNoWaiter( + private inline fun receiveImplOnNoWaiter( /* The working cell is specified by the segment and the index in it. */ segment: ChannelSegment, @@ -977,40 +972,37 @@ internal open class BufferedChannel( /* The global index of the cell. */ r: Long, /* The waiter to be stored in case of suspension. */ - waiter: W, + waiter: Waiter, /* This lambda is invoked when an element has been successfully retrieved, either from the buffer or by making a rendezvous with a suspended sender. */ - onElementRetrieved: (element: E) -> R, - /* This lambda is called when the operation suspends in the cell - specified by the segment and its global and in-segment indices. */ - onSuspend: (segm: ChannelSegment, i: Int, r: Long) -> R, + onElementRetrieved: (element: E) -> Unit, /* This lambda is called when the channel is observed in the closed state and no waiting senders is found, which means that it is closed for receiving. */ - onClosed: () -> R - ): R { + onClosed: () -> Unit + ) { // Update the cell with the non-null waiter, // restarting from the beginning on failure. // Check the `receiveImpl(..)` function for the comments. val updCellResult = updateCellReceive(segment, index, r, waiter) when { updCellResult === SUSPEND -> { - return onSuspend(segment, index, r) + waiter.prepareReceiverForSuspension(segment, index) } updCellResult === FAILED -> { if (r < sendersCounter) segment.cleanPrev() - return receiveImpl( + receiveImpl( waiter = waiter, onElementRetrieved = onElementRetrieved, - onSuspend = onSuspend, + onSuspend = { _, _, _ -> }, onClosed = onClosed ) } else -> { segment.cleanPrev() @Suppress("UNCHECKED_CAST") - return onElementRetrieved(updCellResult as E) + onElementRetrieved(updCellResult as E) } } } @@ -1497,22 +1489,10 @@ internal open class BufferedChannel( element = element as E, waiter = select, onRendezvousOrBuffered = { select.selectInRegistrationPhase(Unit) }, - onSuspend = { segm, i -> select.prepareSenderForSuspension(segm, i) }, + onSuspend = { _, _ -> }, onClosed = { onClosedSelectOnSend(element, select) } ) - private fun SelectInstance<*>.prepareSenderForSuspension( - // The working cell is specified by - // the segment and the index in it. - segment: ChannelSegment, - index: Int - ) { - if (onUndeliveredElement == null) { - disposeOnCompletion(SenderOrReceiverCancellationHandler(segment, index)) - } else { - disposeOnCompletion(SenderWithOnUndeliveredElementCancellationHandler(segment, index, context)) - } - } private fun onClosedSelectOnSend(element: E, select: SelectInstance<*>) { onUndeliveredElement?.callUndeliveredElement(element, select.context) @@ -1556,20 +1536,10 @@ internal open class BufferedChannel( receiveImpl( // <-- this is an inline function waiter = select, onElementRetrieved = { elem -> select.selectInRegistrationPhase(elem) }, - onSuspend = { segm, i, _ -> select.prepareReceiverForSuspension(segm, i) }, + onSuspend = { _, _, _ -> }, onClosed = { onClosedSelectOnReceive(select) } ) - private fun SelectInstance<*>.prepareReceiverForSuspension( - /* The working cell is specified by - the segment and the index in it. */ - segment: ChannelSegment, - index: Int - ) { - onReceiveEnqueued() - disposeOnCompletion(SenderOrReceiverCancellationHandler(segment, index)) - } - private fun onClosedSelectOnReceive(select: SelectInstance<*>) { select.selectInRegistrationPhase(CHANNEL_CLOSED) } @@ -1642,14 +1612,14 @@ internal open class BufferedChannel( // When `hasNext()` suspends, the location where the continuation // is stored is specified via the segment and the index in it. // We need this information in the cancellation handler below. - private var segment: ChannelSegment? = null + private var segment: Segment<*>? = null private var index = -1 /** * Invoked on cancellation, [BeforeResumeCancelHandler] implementation. */ override fun invoke(cause: Throwable?) { - segment?.onCancellation(index) + segment?.onCancellation(index, null) } // `hasNext()` is just a special receive operation. @@ -1705,18 +1675,16 @@ internal open class BufferedChannel( this.continuation = null cont.resume(true, onUndeliveredElement?.bindCancellationFun(element, cont.context)) }, - onSuspend = { segm, i, _ -> prepareForSuspension(segm, i) }, onClosed = { onClosedHasNextNoWaiterSuspend() } ) } - private fun prepareForSuspension(segment: ChannelSegment, index: Int) { + override fun invokeOnCancellation(segment: Segment<*>, index: Int) { this.segment = segment this.index = index // It is possible that this `hasNext()` invocation is already // resumed, and the `continuation` field is already updated to `null`. this.continuation?.invokeOnCancellation(this.asHandler) - onReceiveEnqueued() } private fun onClosedHasNextNoWaiterSuspend() { @@ -2858,6 +2826,10 @@ internal class ChannelSegment(id: Long, prev: ChannelSegment?, channel: Bu // # Cancellation Support # // ######################## + override fun onCancellation(index: Int, cause: Throwable?) { + onCancellation(index) + } + fun onSenderCancellationWithOnUndeliveredElement(index: Int, context: CoroutineContext) { // Read the element first. If the operation has not been successfully resumed // (this cancellation may be caused by prompt cancellation during dispatching), @@ -3057,8 +3029,8 @@ private class WaiterEB(@JvmField val waiter: Waiter) { * uses this wrapper for its continuation. */ private class ReceiveCatching( - @JvmField val cont: CancellableContinuation> -) : Waiter + @JvmField val cont: CancellableContinuationImpl> +) : Waiter by cont /* Internal results for [BufferedChannel.updateCellReceive]. @@ -3143,10 +3115,3 @@ private inline val Long.ebCompletedCounter get() = this and EB_COMPLETED_COUNTER private inline val Long.ebPauseExpandBuffers: Boolean get() = (this and EB_COMPLETED_PAUSE_EXPAND_BUFFERS_BIT) != 0L private fun constructEBCompletedAndPauseFlag(counter: Long, pauseEB: Boolean): Long = (if (pauseEB) EB_COMPLETED_PAUSE_EXPAND_BUFFERS_BIT else 0) + counter - -/** - * All waiters, such as [CancellableContinuationImpl], [SelectInstance], and - * [BufferedChannel.BufferedChannelIterator], should be marked with this interface - * to make the code faster and easier to read. - */ -internal interface Waiter diff --git a/kotlinx-coroutines-core/common/src/channels/ConflatedBufferedChannel.kt b/kotlinx-coroutines-core/common/src/channels/ConflatedBufferedChannel.kt index 6a9f23e958..699030725b 100644 --- a/kotlinx-coroutines-core/common/src/channels/ConflatedBufferedChannel.kt +++ b/kotlinx-coroutines-core/common/src/channels/ConflatedBufferedChannel.kt @@ -85,16 +85,16 @@ internal open class ConflatedBufferedChannel( waiter = BUFFERED, // Finish successfully when a rendezvous has happened // or the element has been buffered. - onRendezvousOrBuffered = { success(Unit) }, + onRendezvousOrBuffered = { return success(Unit) }, // In case the algorithm decided to suspend, the element // was added to the buffer. However, as the buffer is now // overflowed, the first (oldest) element has to be extracted. onSuspend = { segm, i -> dropFirstElementUntilTheSpecifiedCellIsInTheBuffer(segm.id * SEGMENT_SIZE + i) - success(Unit) + return success(Unit) }, // If the channel is closed, return the corresponding result. - onClosed = { closed(sendException) } + onClosed = { return closed(sendException) } ) @Suppress("UNCHECKED_CAST") diff --git a/kotlinx-coroutines-core/common/src/internal/CancellableQueueSynchronizer.kt b/kotlinx-coroutines-core/common/src/internal/CancellableQueueSynchronizer.kt new file mode 100644 index 0000000000..b699f0edfb --- /dev/null +++ b/kotlinx-coroutines-core/common/src/internal/CancellableQueueSynchronizer.kt @@ -0,0 +1,672 @@ +/* + * Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.internal + +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.internal.CancellableQueueSynchronizer.* +import kotlinx.coroutines.internal.CancellableQueueSynchronizer.CancellationMode.* +import kotlinx.coroutines.internal.CancellableQueueSynchronizer.ResumeMode.* +import kotlinx.coroutines.selects.* +import kotlinx.coroutines.sync.* +import kotlin.coroutines.* +import kotlin.math.* +import kotlin.native.concurrent.* + +/** +[CancellableQueueSynchronizer] (CQS) is an abstraction for implementing _fair_ synchronization and communication primitives. +Essentially, It maintains a FIFO queue of waiting requests and provides two main functions: +- [suspend] that stores the specified waiter into the queue, and +- [resume] that tries to retrieve and resume the first waiter, passing the specified value to it. +The key advantage of these semantics is that CQS allows to invoke [resume] before [suspend] as long as +it is known that [suspend] will happen eventually. For example, our [Semaphore] implementation actively +uses this property for better performance. + +A useful mental image of [CancellableQueueSynchronizer] is that of an infinite array with two positioning counters: +one references the next cell in which a new waiter is enqueued as a part of the next [suspend] call, +while another references the next cell for [resume]. The intuition is that [suspend] atomically increments +its counter via `Fetch-and-Add` and stores the waiter in the corresponding cell. Likewise, [resume] increments +its counter, visits the corresponding cell, and resumes the stored waiter with the specified value. + +_Synchronous and Asynchronous Resumption Modes_ + +Notably, [resume] may come to the cell before [suspend] and find the cell in the empty state. +To solve this race, we introduce two [resumption modes][ResumeMode]: [synchronous][SYNC] and [asynchronous][ASYNC]. +In both case, [resume] puts the value into the empty cell, and then either finishes immediately +in the [asynchronous][ASYNC] mode, or waits until the value is taken by a concurrent [suspend] +in the [synchronous][SYNC] one. In the latter case, if the value is not taken within a bounded time, [resume] marks +the cell as _broken_. Thus, both this [resume] and the corresponding [suspend] fail. The intuition is that allowing +for broken cells keeps the balance of pairwise operations, such as [acquire()][Semaphore.acquire] +and [release()][Semaphore.release] in [Semaphore], so these operations simply restart in case of breaking the cell. +This way, we can achieve wait-freedom with the [asynchronous][ASYNC] mode, and obstruction-freedom +with the [synchronous][SYNC] mode. + +_Cancellation Support_ + +We support two cancellation policies in [CancellableQueueSynchronizer]. In the [simple cancellation mode][SIMPLE], +[resume] fails and returns `false` if it finds the cell in the `CANCELLED` state or if the waiter resumption +(see [CancellableContinuation.tryResume]) fails. These failures are typically handled by restarting the operation +from the beginning. With the [smart cancellation][SMART], [resume] efficiently skips `CANCELLED` cells +(the cells where waiter resumption failed are also considered as `CANCELLED`). This way, even if a million of +canceled requests are stored in [CancellableQueueSynchronizer], one [resume] invocation is sufficient to pass +the value to the next alive waiter since it skips all these canceled waiters. However, the smart cancellation mode +provides less intuitive contract and requires users to write more complicated code -- the details are discussed further. + +The main issue with skipping `CANCELLED` cells in [resume] is that it can become illegal to put the value into +the next cell. Consider the following execution: [suspend] is called, then [resume] starts, but the suspended +waiter becomes canceled. This way, no one is waiting in [CancellableQueueSynchronizer] anymore. Thus, if [resume] +skips this canceled cell, puts the value into the next empty cell, and completes, the data structure's state becomes +incorrect. Instead, the value provided by this [resume] should be refused and returned to the outer data structure. +Unfortunately, there is no way for [CancellableQueueSynchronizer] to decide whether the value should be refused or not. +Thus, users should implement a custom cancellation handler by overriding the [onCancellation] function, which must +return `true` if the cancellation completes successfully and `false` if the [resume] that will come to this cell +should be refused. In the latter case, the [resume] that comes to this cell invokes [tryReturnRefusedValue] to return +the value back to the outer data structure. However, it is possible for [tryReturnRefusedValue] to fail, and +[returnValue] is called in this case. Typically, this [returnValue] function coincides with the one that resumes waiters +(e.g., with [release][Semaphore.release] in [Semaphore]). There is also an important difference between [synchronous][SYNC] +and [asynchronous][ASYNC] resumption modes. In the [synchronous][SYNC] mode, the [resume] that comes to a cell with +a canceled waiter (but the cell is not in the `CANCELLED` state yet) waits in a spin-loop until the cancellation handler +is processed and the cell is moved to either `CANCELLED` or `REFUSE` state. In contrast, in the [asynchronous][ASYNC] mode, +[resume] replaces the canceled waiter with the value of this resumption and finishes immediately -- the cancellation handler +completes this [resume] eventually. This way, in the [asynchronous][ASYNC] mode, the value passed to [resume] can be out +of the data structure for a while but is guaranteed to be processed eventually. + +To support prompt cancellation, [CancellableQueueSynchronizer] returns the value back to the data structure by calling +[returnValue] if the continuation is cancelled while dispatching. Typically, [returnValue] delegates to the operation +that calls [resume], such as [release][Semaphore.release] in [Semaphore]. + +_Algorithm Details_ + +Please see the ["CQS: A Formally-Verified Framework for Fair and Abortable Synchronization"](TODO) +paper by Nikita Koval, Dmitry Kaplansky, and Dan Alistarh for the detailed algorithm description. + */ +internal abstract class CancellableQueueSynchronizer { + /* + The counters indicate the total numbers of `suspend` and `resume` calls ever performed. + They are incremented in the beginning of the corresponding operation; + thus, acquiring a unique (for the operation type) cell to process. + The segments reference the last working one for each operation type. + */ + private val suspendIdx = atomic(0L) + private val suspendSegment: AtomicRef + private val resumeIdx = atomic(0L) + private val resumeSegment: AtomicRef + + init { + val s = CQSSegment(id = 0, prev = null, pointers = 2) + resumeSegment = atomic(s) + suspendSegment = atomic(s) + } + + /** + * Specifies whether [resume] should work in + * [synchronous][SYNC] or [asynchronous][ASYNC] mode. + */ + protected open val resumeMode: ResumeMode get() = SYNC + + /** + * Specifies whether [resume] should fail on cancelled waiters ([SIMPLE] mode) or + * skip them ([SMART] mode). Remember that in case of [smart][SMART] cancellation mode, + * the [onCancellation] handler should be implemented. + */ + protected open val cancellationMode: CancellationMode get() = SIMPLE + + /** + * This function is called when waiter is cancelled and smart + * cancellation mode is used (so cancelled cells are skipped by + * [resume]). By design, this handler performs the logical cancellation + * and returns `true` if the cancellation succeeds and the cell can be + * moved to the `CANCELLED` state. In this case, [resume] skips the cell and passes + * the value to the next waiter in the waiting queue. However, if the [resume] + * that comes to this cell should be refused, [onCancellation] should return false. + * In this case, [tryReturnRefusedValue] is invoked with the value of this [resume], + * following by [returnValue] if [tryReturnRefusedValue] fails. + */ + protected open fun onCancellation() : Boolean = error("not implemented") + + /** + * This function specifies how the value refused by this [CancellableQueueSynchronizer] + * (when [onCancellation] returns `false`) should be transferred back to the data structure. + * It returns `true` on success and `false` when the attempt fails. In the latter case, + * [returnValue] is used to complete the returning process. + */ + protected open fun tryReturnRefusedValue(value: T): Boolean = true + + /** + * This function specifies how the value from a failed [resume] should be returned back to + * the data structure. Typically, this function delegates to the one that invokes [resume] + * (e.g., [release()][Semaphore.release] in [Semaphore]). + * + * This function is invoked when [onCancellation] returns `false` and the following [tryReturnRefusedValue] + * fails, or when prompt cancellation occurs and the value should be returned back to the data structure. + * TODO: we need to merge the PR that optimizes this code + */ + protected open fun returnValue(value: T) {} + + /** + * This is a shortcut for [tryReturnRefusedValue] and + * the following [returnValue] invocation on failure. + */ + private fun returnRefusedValue(value: T) { + if (tryReturnRefusedValue(value)) return + returnValue(value) + } + + @Suppress("INFERRED_TYPE_VARIABLE_INTO_POSSIBLE_EMPTY_INTERSECTION") + internal fun suspendCancelled(): T? { + // Increment `suspendIdx` and find the segment + // with the corresponding id. It is guaranteed + // that this segment is not removed since at + // least the cell for this `suspend` invocation + // is not in the `CANCELLED` state. + val curSuspendSegm = this.suspendSegment.value + val suspendIdx = suspendIdx.getAndIncrement() + val segment = this.suspendSegment.findSegmentAndMoveForward(id = suspendIdx / SEGMENT_SIZE, startFrom = curSuspendSegm, + createNewSegment = ::createSegment).segment + assert { segment.id == suspendIdx / SEGMENT_SIZE } + // Try to install the waiter into the cell - this is the regular path. + val i = (suspendIdx % SEGMENT_SIZE).toInt() + if (segment.cas(i, null, CANCELLED)) { + // The continuation is successfully installed, and + // `resume` cannot break the cell now, so this + // suspension is successful. + // Add a cancellation handler if required and finish. + return null + } + // The continuation installation has failed. This happened because a concurrent + // `resume` came earlier to this cell and put its value into it. Remember that + // in the `SYNC` resumption mode this concurrent `resume` can mark the cell as broken. + // + // Try to grab the value if the cell is not in the `BROKEN` state. + val value = segment.get(i) + if (value !== BROKEN && segment.cas(i, value, TAKEN)) { + // The elimination is performed successfully, + // complete with the value stored in the cell. + @Suppress("UNCHECKED_CAST") + return value as T + } + // The cell is broken, this can happen only in the `SYNC` resumption mode. + assert { resumeMode == SYNC && segment.get(i) === BROKEN } + return null + } + + @Suppress("UNCHECKED_CAST", "INFERRED_TYPE_VARIABLE_INTO_POSSIBLE_EMPTY_INTERSECTION") + internal fun suspend(waiter: Waiter): Boolean { + // Increment `suspendIdx` and find the segment + // with the corresponding id. It is guaranteed + // that this segment is not removed since at + // least the cell for this `suspend` invocation + // is not in the `CANCELLED` state. + val curSuspendSegm = this.suspendSegment.value + val suspendIdx = suspendIdx.getAndIncrement() + val segment = this.suspendSegment.findSegmentAndMoveForward(id = suspendIdx / SEGMENT_SIZE, startFrom = curSuspendSegm, + createNewSegment = ::createSegment).segment + assert { segment.id == suspendIdx / SEGMENT_SIZE } + // Try to install the waiter into the cell - this is the regular path. + val i = (suspendIdx % SEGMENT_SIZE).toInt() + if (segment.cas(i, null, waiter)) { + // The continuation is successfully installed, and + // `resume` cannot break the cell now, so this + // suspension is successful. + // Add a cancellation handler if required and finish. + waiter.invokeOnCancellation(segment, i) + return true + } + // The continuation installation has failed. This happened because a concurrent + // `resume` came earlier to this cell and put its value into it. Remember that + // in the `SYNC` resumption mode this concurrent `resume` can mark the cell as broken. + // + // Try to grab the value if the cell is not in the `BROKEN` state. + val value = segment.get(i) + if (value !== BROKEN && segment.cas(i, value, TAKEN)) { + // The elimination is performed successfully, + // complete with the value stored in the cell. + value as T + when (waiter) { + is CancellableContinuation<*> -> { + waiter as CancellableContinuation + waiter.resume(value, { returnValue(value) }) // TODO do we really need this? + } + is SelectInstance<*> -> { + waiter as SelectInstance + waiter.selectInRegistrationPhase(value) + } + } + return true + } + // The cell is broken, this can happen only in the `SYNC` resumption mode. + assert { resumeMode == SYNC && segment.get(i) === BROKEN } + return false + } + + /** + * Tries to resume the next waiter and returns `true` if + * the resumption succeeds. However, it can fail due to + * several reasons. First, if the [synchronous][SYNC] resumption + * mode is used, this [resume] invocation may come before [suspend], + * find the cell in the empty state, mark it as [broken][BROKEN], + * and fail returning `false` as a result. Another reason for [resume] + * to fail is waiter cancellation if the [simple cancellation mode][SIMPLE] + * is used. + * + * Note that with the [smart][SMART] cancellation mode [resume] skips + * cancelled waiters and can fail only in case of unsuccessful elimination + * due to [synchronous][SYNC] resumption. + */ + fun resume(value: T): Boolean { + // Should we skip cancelled cells? + val skipCancelled = cancellationMode != SIMPLE + while (true) { + // Try to resume the next waiter, adjust `resumeIdx` if + // cancelled cells will be skipped anyway. + when (tryResumeImpl(value = value, adjustResumeIdx = skipCancelled)) { + TRY_RESUME_SUCCESS -> return true + TRY_RESUME_FAIL_CANCELLED -> if (!skipCancelled) return false + TRY_RESUME_FAIL_BROKEN -> return false + } + } + } + + /** + * Tries to resume the next waiter, and returns [TRY_RESUME_SUCCESS] on + * success, [TRY_RESUME_FAIL_CANCELLED] if the next waiter is cancelled, + * or [TRY_RESUME_FAIL_BROKEN] if the next cell has been marked as broken + * by this [tryResumeImpl] invocation due to a race in the [SYNC] resumption mode. + * + * In the [smart cancellation mode][SMART], all cells marked as + * [cancelled][CANCELLED] should be skipped, so there is no need + * to increment [resumeIdx] one-by-one if there is a removed segment + * (logically full of [cancelled][CANCELLED] cells). Instead, the algorithm + * moves [resumeIdx] to the first possibly non-cancelled cell, i.e., + * to the first segment id multiplied by [SEGMENT_SIZE]. + */ + @Suppress("UNCHECKED_CAST", "INFERRED_TYPE_VARIABLE_INTO_POSSIBLE_EMPTY_INTERSECTION") + private fun tryResumeImpl(value: T, adjustResumeIdx: Boolean): Int { + // Check that `adjustResumeIdx` is `false` in the simple cancellation mode. + assertNot { cancellationMode == SIMPLE && adjustResumeIdx } + // Increment `resumeIdx` and find the first segment with + // the corresponding or higher (if the required segment + // is physically removed) id. + val curResumeSegm = this.resumeSegment.value + val resumeIdx = resumeIdx.getAndIncrement() + val id = resumeIdx / SEGMENT_SIZE + val segment = this.resumeSegment.findSegmentAndMoveForward(id, startFrom = curResumeSegm, + createNewSegment = ::createSegment).segment + // The previous segments can be safely collected by GC, clean the pointer to them. + segment.cleanPrev() + // Is the required segment physically removed? + if (segment.id > id) { + // Adjust `resumeIdx` to the first non-removed segment if needed. + if (adjustResumeIdx) adjustResumeIdx(segment.id * SEGMENT_SIZE) + // The cell #resumeIdx is in the `CANCELLED` state, return the corresponding failure. + return TRY_RESUME_FAIL_CANCELLED + } + // Modify the cell according to the state machine, + // all the transitions are performed atomically. + val i = (resumeIdx % SEGMENT_SIZE).toInt() + modify_cell@while (true) { + val cellState = segment.get(i) + when { + // Is the cell empty? + cellState === null -> { + // Try to perform an elimination by putting the + // value to the empty cell and wait until it is + // taken by a concurrent `suspend` in case of + // using the synchronous resumption mode. + if (!segment.cas(i, null, value)) continue@modify_cell + // Finish immediately in the asynchronous resumption mode. + if (resumeMode == ASYNC) return TRY_RESUME_SUCCESS + // Wait for a concurrent `suspend` (which should mark + // the cell as taken) for a bounded time in a spin-loop. + var iteration = 0 + while (true) { + if (segment.get(i) === TAKEN) return TRY_RESUME_SUCCESS + iteration++ + if (resumeMode == SYNC && iteration > MAX_SPIN_CYCLES) break + } + // The value is still not taken, try to atomically mark the cell as broken. + // A CAS failure indicates that the value is successfully taken. + return if (segment.cas(i, value, BROKEN)) TRY_RESUME_FAIL_BROKEN else TRY_RESUME_SUCCESS + } + // Is the waiter cancelled? + cellState === CANCELLED -> { + // Return the corresponding failure. + return TRY_RESUME_FAIL_CANCELLED + } + // Should the current `resume` be refused by this CQS? + cellState === REFUSE -> { + // This state should not occur + // in the simple cancellation mode. + assert { cancellationMode != SIMPLE } + // Return the refused value back to the + // data structure and finish successfully. + returnRefusedValue(value) + return TRY_RESUME_SUCCESS + } + // Does the cell store a cancellable continuation? + cellState is Waiter -> { + // Change the cell state to `RESUMED`, so + // the cancellation handler cannot be invoked + // even if the continuation becomes cancelled. + if (!segment.cas(i, cellState, RESUMED)) continue@modify_cell + // Try to resume the continuation. + val resumed = when(cellState) { + is CancellableContinuation<*> -> { + (cellState as CancellableContinuation) + val token = cellState.tryResume(value, null, { returnValue(value) }) + if (token != null) { + // Hooray, the continuation is successfully resumed! + cellState.completeResume(token) + true + } else { + false + } + } + is SelectInstance<*> -> { + cellState.trySelect(this@CancellableQueueSynchronizer, value) + } + else -> error("unexpected") + } + if (!resumed) { + // Unfortunately, the continuation resumption has failed. + // Fail the current `resume` if the simple cancellation mode is used. + if (cancellationMode === SIMPLE) + return TRY_RESUME_FAIL_CANCELLED + // In the smart cancellation mode, the cancellation handler should be invoked. + val cancelled = onCancellation() + if (cancelled) { + // We could mark the cell as `CANCELLED` for consistency, + // but there is no need for this since the cell cannot + // be processed by another operation anymore. + // + // Try to resume the next waiter. If the resumption fails due to + // a race in the synchronous mode, the value should be returned + // back to the data structure. + if (!resume(value)) returnValue(value) + } else { + // The value is refused by this CQS, return it back to the data structure. + returnRefusedValue(value) + } + } + // Once the state is changed to `RESUMED`, `resume` is considered as successful. + return TRY_RESUME_SUCCESS + } + // Does the cell store a cancelling waiter, which is already logically + // cancelled but the cancellation handler has not been completed yet? + cellState === CANCELLING -> { + // Fail in the simple cancellation mode. + if (cancellationMode == SIMPLE) return TRY_RESUME_FAIL_CANCELLED + // In the smart cancellation mode, this cell should be either skipped + // (when it becomes `CANCELLED`), or the current `resume` should be refused. + // + // In the synchronous resumption mode, `resume(..)` waits in a an unbounded spin-loop until + // the state of this cell is changed to either `CANCELLED` or `REFUSE`. While this part makes + // the overall algorithm blocking in theory, this cancellation handler and `resume` overlap occurs + // relatively rare in practice and it is guaranteed that one cancellation can block at most one + // `resume`, what makes the algorithm almost non-blocking in any real-world high-contended scenario. + if (resumeMode == SYNC) continue@modify_cell + // In the asynchronous resumption mode, `resume` puts the resumption value into the cell, + // so the concurrent cancellation handler completes this `resume` after it decides whether + // the cell should be marked as `CANCELLED` or `REFUSE`. Thus, this `resume` is delegated to + // the cancellation handler and can be postponed for a while. + // + // To distinguish continuations related to the `suspend` operation with the continuations passed + // as values (this is strange but possible), we wrap the last ones with `WrappedContinuationValue`. + val valueToStore: Any = if (value is Continuation<*>) WrappedContinuationValue(value) else value + if (segment.cas(i, cellState, valueToStore)) return TRY_RESUME_SUCCESS + } + // The cell stores a plane non-cancellable continuation, we can simply resume it. + cellState is Continuation<*> -> { + // Resume the continuation and mark the cell + // as `RESUMED` to avoid memory leaks. + segment.set(i, RESUMED) + (cellState as Continuation).resume(value) + return TRY_RESUME_SUCCESS + } + else -> error("Unexpected cell state: $cellState") + } + } + } + + /** + * Updates [resumeIdx] to [newValue] if the current value is lower. + */ + private fun adjustResumeIdx(newValue: Long): Unit = resumeIdx.loop { cur -> + if (cur >= newValue) return + if (resumeIdx.compareAndSet(cur, newValue)) return + } + + /** + * These modes define the strategy that [resume] should + * use if it comes to the cell before [suspend] and finds it empty. + * In the [asynchronous][ASYNC] mode, [resume] puts the value into the cell, + * so [suspend] grabs it after that and completes without actual suspension. + * In other words, an elimination happens in this case. + * + * However, such a strategy produces an incorrect behavior when used for some + * data structures (e.g., for [tryAcquire][Semaphore.tryAcquire] in [Semaphore]), + * so the [synchronous][SYNC] mode has been introduced in addition. + * Similarly to the asynchronous one, [resume] puts the value into the cell, + * but do not finish immediately. In opposite, it waits in a bounded spin-loop + * (see [MAX_SPIN_CYCLES]) until the value is taken and completes only after that. + * If the value is not taken after this spin-loop ends, [resume] marks the cell as + * [broken][BROKEN] and fails, so the corresponding [suspend] invocation finds the cell + * [broken][BROKEN] and fails as well. + */ + internal enum class ResumeMode { SYNC, SYNC_BLOCKING, ASYNC } + + /** + * These modes define the strategy that should be used when a waiter becomes cancelled. + * + * In the [simple cancellation mode][SIMPLE], [resume] fails when the waiter in the working cell is cancelled. + * In the [smart cancellation mode][SMART], [resume] skips cancelled cells and passes the value to the first + * non-cancelled waiter. However, it is also possible that the cancelled waiter was the last one, so this + * [resume] should be refused (in this case, the corresponding [onCancellation] call returns false), and + * the value is returned back to the data structure via [returnValue]. + */ + internal enum class CancellationMode { SIMPLE, SMART } + + private fun createSegment(id: Long, prev: CQSSegment?) = CQSSegment(id, prev, 0) + + /** + * The queue of waiters in [CancellableQueueSynchronizer] is represented as a linked list of [CQSSegment]. + */ + private inner class CQSSegment(id: Long, prev: CQSSegment?, pointers: Int) : Segment(id, prev, pointers) { + private val waiters = atomicArrayOfNulls(SEGMENT_SIZE) + override val numberOfSlots: Int get() = SEGMENT_SIZE + + override fun onCancellation(index: Int, cause: Throwable?) { + // Invoke the cancellation handler + // only if the state is not `RESUMED`. + // + // After the state is changed to `RESUMED`, the + // resumption is considered as logically successful, + // and the value can be returned back to the data structure + // only via a `returnValue(..)` call. + if (!tryMarkCancelling(index)) return + // Do we use simple or smart cancellation? + if (cancellationMode === SIMPLE) { + // In the simple cancellation mode the logic + // is straightforward -- mark the cell as + // cancelled to avoid memory leaks and complete. + markCancelled(index) + return + } + // We are in the smart cancellation mode. + // Invoke `onCancellation()` and mark the cell as `CANCELLED` + // if the call returns `true`, or as `REFUSE` if it + // returns `false`. Note that it is possible for a + // concurrent `resume(..)` to put its resumption value + // into the cell in the asynchronous mode. In this case, + // the cancellation handler should complete this `resume(..)`. + val cancelled = onCancellation() + if (cancelled) { + // The cell should be considered as cancelled. + // Mark the cell correspondingly and help a concurrent + // `resume(..)` to process its value if needed. + val value = markCancelled(index) ?: return + // Resume the next waiter with the value + // provided by a concurrent `resume(..)`. + // The value could be put only in the asynchronous mode, + // so the `resume(..)` call above must not fail. + @Suppress("UNCHECKED_CAST") + resume(value as T) + } else { + // The `resume(..)` that will come to this cell should be refused. + // Mark the cell correspondingly and help a concurrent + // `resume(..)` to process its value if needed. + val value = markRefuse(index) ?: return + @Suppress("UNCHECKED_CAST") + returnRefusedValue(value as T) + } + } + + @Suppress("NOTHING_TO_INLINE") + inline fun get(index: Int): Any? = waiters[index].value + + @Suppress("NOTHING_TO_INLINE") + inline fun set(index: Int, value: Any?) { + waiters[index].value = value + } + + @Suppress("NOTHING_TO_INLINE") + inline fun cas(index: Int, expected: Any?, value: Any?): Boolean = waiters[index].compareAndSet(expected, value) + + @Suppress("NOTHING_TO_INLINE") + inline fun getAndSet(index: Int, value: Any?): Any? = waiters[index].getAndSet(value) + + /** + * In the CQS algorithm, we use different handlers for normal and prompt cancellations. + * However, the current [CancellableContinuation] API (which is, hopefully, subject to change) + * does not allow to split the handlers -- the one set by [CancellableContinuation.invokeOnCancellation] + * is always invoked, even when prompt cancellation occurs. To guarantee that only the proper handler is used + * (either the one installed by `invokeOnCancellation { ... }` or the one passed to `tryResume(..)`), + * we use a special intermediate `CANCELLING` state for the normal cancellation. Thus, once the waiter becomes + * cancelled, it should be atomically replaced with the `CANCELLING` marker. At the same time, if the state + * is already `RESUMED`, the continuation is considered as logically resumed, [tryMarkCancelling] returns false, + * and any cancellation is considered as prompt one, so the handler passed to `tryResume(..)` is used. + * This handler simply returns the value back to the data structure via `returnValue(..)` invocation. + */ + fun tryMarkCancelling(index: Int): Boolean { + while (true) { + val cellState = get(index) + when { + cellState === RESUMED -> return false + cellState is Waiter -> { + if (cas(index, cellState, CANCELLING)) return true + } + else -> { + if (cellState is Continuation<*>) + error("Only cancellable continuations can be cancelled, ${cellState::class.simpleName} has been detected") + else + error("Unexpected cell state: $cellState") + } + } + } + } + + /** + * Atomically replaces [CANCELLING] with [CANCELLED] and returns `null` on success. + * However, in the asynchronous resumption mode, [resume] may to come to the cell + * while it is in the [CANCELLING] state, replace the [CANCELLING] marker with the + * resumption value, and finish, delegating the rest of the resumption. In this case, + * the function returns this value, and the caller must complete the resumption. + * + * In addition, this function checks whether the segment becomes full of cancelled + * cells, and physically removes the segment from the linked list in this case; thus, + * avoiding possible memory leaks caused by cancellation. + */ + fun markCancelled(index: Int): Any? = mark(index, CANCELLED).also { + onSlotCleaned() + } + + /** + * Atomically replaces [CANCELLING] with [REFUSE] and returns `null` on success. + * However, in the asynchronous resumption mode, [resume] may to come to the cell + * while it is in the [CANCELLING] state, replace the [CANCELLING] marker with the + * resumption value, and finish, delegating the rest of the resumption. In this case, + * the function returns this value, and the caller must complete the resumption. + */ + fun markRefuse(index: Int): Any? = mark(index, REFUSE) + + /** + * Updates the cell state to either [CANCELLED] or [REFUSE] from [CANCELLING], the + * corresponding update [marker] is passed as an argument. However, in the asynchronous + * resumption mode, it is possible for [resume] to come to the cell while it is in + * the [CANCELLING] state. In this case, [resume] replaces the [CANCELLING] marker + * with the resumption value and finishes, delegating the rest of the resumption to + * the concurrent cancellation handler. Therefore, this [mark] function atomically + * checks whether there is a value put into the cell by a concurrent [resume] and + * either returns this value if found, or `null` if the cell was in the [CANCELLING] state. + */ + private fun mark(index: Int, marker: Any?): Any? { + val old = getAndSet(index, marker) + // The cell should be in the `CANCELLING` state or store + // an asynchronously put value at the point of this update. + assert { old !== RESUMED && old !== CANCELLED && old !== REFUSE && old !== TAKEN && old !== BROKEN } + assert { old !is Continuation<*> } + // Return `null` if no value has been passed in meantime. + if (old === CANCELLING) return null + // A concurrent `resume(..)` has put a value into the cell, return it as a result. + return if (old is WrappedContinuationValue) old.cont else old + } + + override fun toString() = "CQSSegment[id=$id, hashCode=${hashCode()}]" + } + + // We use this string representation for traces in Lincheck tests + override fun toString(): String { + val waiters = ArrayList() + var curSegment = resumeSegment.value + var curIdx = resumeIdx.value + while (curIdx < max(suspendIdx.value, resumeIdx.value)) { + val i = (curIdx % SEGMENT_SIZE).toInt() + waiters += when { + curIdx < curSegment.id * SEGMENT_SIZE -> "CANCELLED" + curSegment.get(i) is Continuation<*> -> "" + else -> curSegment.get(i).toString() + } + curIdx++ + if (curIdx == (curSegment.id + 1) * SEGMENT_SIZE) + curSegment = curSegment.next ?: break + } + return "suspendIdx=${suspendIdx.value},resumeIdx=${resumeIdx.value},waiters=$waiters" + } +} + +/** + * In the [smart cancellation mode][CancellableQueueSynchronizer.CancellationMode.SMART] + * it is possible for [resume] to come to a cell with cancelled continuation and + * asynchronously put the resumption value into the cell, so the cancellation handler decides whether + * this value should be used for resuming the next waiter or be refused. When this + * value is a continuation, it is hard to distinguish it with the one related to the cancelled + * waiter. To solve the problem, such values of type [Continuation] are wrapped with + * [WrappedContinuationValue]. Note that the wrapper is required only in [CancellableQueueSynchronizer.CancellationMode.SMART] + * mode and is used in the asynchronous race resolution logic between cancellation and [resume] + * invocation; this way, it is used relatively rare. + */ +private class WrappedContinuationValue(val cont: Continuation<*>) + +@SharedImmutable +private val SEGMENT_SIZE = systemProp("kotlinx.coroutines.cqs.segmentSize", 16) +@SharedImmutable +private val MAX_SPIN_CYCLES = systemProp("kotlinx.coroutines.cqs.maxSpinCycles", 100) +@SharedImmutable +private val TAKEN = Symbol("TAKEN") +@SharedImmutable +private val BROKEN = Symbol("BROKEN") +@SharedImmutable +private val CANCELLING = Symbol("CANCELLING") +@SharedImmutable +private val CANCELLED = Symbol("CANCELLED") +@SharedImmutable +private val REFUSE = Symbol("REFUSE") +@SharedImmutable +private val RESUMED = Symbol("RESUMED") + +private const val TRY_RESUME_SUCCESS = 0 +private const val TRY_RESUME_FAIL_CANCELLED = 1 +private const val TRY_RESUME_FAIL_BROKEN = 2 \ No newline at end of file diff --git a/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt b/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt index 2bcf97b7ad..839ff78f8d 100644 --- a/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt +++ b/kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt @@ -188,7 +188,16 @@ internal abstract class ConcurrentLinkedListNode * Essentially, this is a node in the Michael-Scott queue algorithm, * but with maintaining [prev] pointer for efficient [remove] implementation. */ -internal abstract class Segment>(val id: Long, prev: S?, pointers: Int): ConcurrentLinkedListNode(prev) { +internal abstract class Segment>(val id: Long, prev: S?, pointers: Int) : + ConcurrentLinkedListNode(prev), + // Segments typically store waiting continuations. Thus, on cancellation, the corresponding + // slot should be cleaned and the segment should be removed if it becomes full of cancelled cells. + // To install such a handler efficiently, without creating an extra object, we allow storing + // segments as cancellation handlers in [CancellableContinuationImpl] state, putting the slot + // index in another field. The details are here: https://github.com/Kotlin/kotlinx.coroutines/pull/3084. + // For that, we need segments to implement this internal marker interface. + NotCompleted +{ /** * This property should return the number of slots in this segment, * it is used to define whether the segment is logically removed. @@ -212,6 +221,13 @@ internal abstract class Segment>(val id: Long, prev: S?, pointers // returns `true` if this segment is logically removed after the decrement. internal fun decPointers() = cleanedAndPointers.addAndGet(-(1 shl POINTERS_SHIFT)) == numberOfSlots && !isTail + /** + * This function is invoked on continuation cancellation when this segment + * with the specified [index] are installed as cancellation handler via + * `SegmentDisposable.disposeOnCancellation(Segment, Int)`. + */ + abstract fun onCancellation(index: Int, cause: Throwable?) + /** * Invoked on each slot clean-up; should not be invoked twice for the same slot. */ diff --git a/kotlinx-coroutines-core/common/src/selects/Select.kt b/kotlinx-coroutines-core/common/src/selects/Select.kt index b9d128b7f8..ce48de34a1 100644 --- a/kotlinx-coroutines-core/common/src/selects/Select.kt +++ b/kotlinx-coroutines-core/common/src/selects/Select.kt @@ -376,11 +376,24 @@ internal open class SelectImplementation constructor( private var clauses: MutableList>? = ArrayList(2) /** - * Stores the completion action provided through [disposeOnCompletion] during clause registration. - * After that, if the clause is successfully registered (so, it has not completed immediately), - * this [DisposableHandle] is stored into the corresponding [ClauseData] instance. + * Stores the completion action provided through [disposeOnCompletion] or [invokeOnCancellation] + * during clause registration. After that, if the clause is successfully registered + * (so, it has not completed immediately), this handler is stored into + * the corresponding [ClauseData] instance. + * + * Note that either [DisposableHandle] is provided, or a [Segment] instance with + * the index in it, which specify the location of storing this `select`. + * In the latter case, [Segment.onCancellation] should be called on completion/cancellation. + */ + private var disposableHandleOrSegment: Any? = null + + /** + * In case the disposable handle is specified via [Segment] + * and index in it, implying calling [Segment.onCancellation], + * the corresponding index is stored in this field. + * The segment is stored in [disposableHandleOrSegment]. */ - private var disposableHandle: DisposableHandle? = null + private var indexInSegment: Int = -1 /** * Stores the result passed via [selectInRegistrationPhase] during clause registration @@ -469,8 +482,10 @@ internal open class SelectImplementation constructor( // This also guarantees that the list of clauses cannot be cleared // in the registration phase, so it is safe to read it with "!!". if (!reregister) clauses!! += this - disposableHandle = this@SelectImplementation.disposableHandle - this@SelectImplementation.disposableHandle = null + disposableHandleOrSegment = this@SelectImplementation.disposableHandleOrSegment + indexInSegment = this@SelectImplementation.indexInSegment + this@SelectImplementation.disposableHandleOrSegment = null + this@SelectImplementation.indexInSegment = -1 } else { // This clause has been selected! // Update the state correspondingly. @@ -493,7 +508,23 @@ internal open class SelectImplementation constructor( } override fun disposeOnCompletion(disposableHandle: DisposableHandle) { - this.disposableHandle = disposableHandle + this.disposableHandleOrSegment = disposableHandle + } + + /** + * An optimized version for the code below that does not allocate + * a cancellation handler object and efficiently stores the specified + * [segment] and [index]. + * + * ``` + * disposeOnCompletion { + * segment.onCancellation(index, null) + * } + * ``` + */ + override fun invokeOnCancellation(segment: Segment<*>, index: Int) { + this.disposableHandleOrSegment = segment + this.indexInSegment = index } override fun selectInRegistrationPhase(internalResult: Any?) { @@ -556,7 +587,8 @@ internal open class SelectImplementation constructor( */ private fun reregisterClause(clauseObject: Any) { val clause = findClause(clauseObject)!! // it is guaranteed that the corresponding clause is presented - clause.disposableHandle = null + clause.disposableHandleOrSegment = null + clause.indexInSegment = -1 clause.register(reregister = true) } @@ -692,7 +724,7 @@ internal open class SelectImplementation constructor( // Invoke all cancellation handlers except for the // one related to the selected clause, if specified. clauses.forEach { clause -> - if (clause !== selectedClause) clause.disposableHandle?.dispose() + if (clause !== selectedClause) clause.dispose() } // We do need to clean all the data to avoid memory leaks. this.state.value = STATE_COMPLETED @@ -716,7 +748,7 @@ internal open class SelectImplementation constructor( // a concurrent clean-up procedure has already completed, and it is safe to finish. val clauses = this.clauses ?: return // Remove this `select` instance from all the clause object (channels, mutexes, etc.). - clauses.forEach { it.disposableHandle?.dispose() } + clauses.forEach { it.dispose() } // We do need to clean all the data to avoid memory leaks. this.internalResult = NO_RESULT this.clauses = null @@ -731,9 +763,11 @@ internal open class SelectImplementation constructor( private val processResFunc: ProcessResultFunction, private val param: Any?, // the user-specified param private val block: Any, // the user-specified block, which should be called if this clause becomes selected - @JvmField val onCancellationConstructor: OnCancellationConstructor?, - @JvmField var disposableHandle: DisposableHandle? = null + @JvmField val onCancellationConstructor: OnCancellationConstructor? ) { + @JvmField var disposableHandleOrSegment: Any? = null + @JvmField var indexInSegment: Int = -1 + /** * Tries to register the specified [select] instance in [clauseObject] and check * whether the registration succeeded or a rendezvous has happened during the registration. @@ -788,6 +822,16 @@ internal open class SelectImplementation constructor( } } + fun dispose() { + with(disposableHandleOrSegment) { + if (this is Segment<*>) { + this.onCancellation(indexInSegment, null) + } else { + (this as? DisposableHandle)?.dispose() + } + } + } + fun createOnCancellationAction(select: SelectInstance<*>, internalResult: Any?) = onCancellationConstructor?.invoke(select, param, internalResult) } diff --git a/kotlinx-coroutines-core/common/src/sync/Mutex.kt b/kotlinx-coroutines-core/common/src/sync/Mutex.kt index fbd1fe55f4..a64b4b7204 100644 --- a/kotlinx-coroutines-core/common/src/sync/Mutex.kt +++ b/kotlinx-coroutines-core/common/src/sync/Mutex.kt @@ -7,8 +7,10 @@ package kotlinx.coroutines.sync import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.internal.* +import kotlinx.coroutines.internal.CancellableQueueSynchronizer.ResumeMode.* import kotlinx.coroutines.selects.* import kotlin.contracts.* +import kotlin.coroutines.* import kotlin.jvm.* /** @@ -131,7 +133,9 @@ public suspend inline fun Mutex.withLock(owner: Any? = null, action: () -> T } -internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 else 0), Mutex { +internal open class MutexImpl(locked: Boolean) : CancellableQueueSynchronizer(), Mutex { + override val resumeMode get() = SYNC_BLOCKING + /** * After the lock is acquired, the corresponding owner is stored in this field. * The [unlock] operation checks the owner and either re-sets it to [NO_OWNER], @@ -140,13 +144,22 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 */ private val owner = atomic(if (locked) null else NO_OWNER) + private val availablePermits = atomic(if (locked) 0 else 1) + private val onSelectCancellationUnlockConstructor: OnCancellationConstructor = { _: SelectInstance<*>, owner: Any?, _: Any? -> { unlock(owner) } } - override val isLocked: Boolean get() = - availablePermits == 0 + override val isLocked: Boolean get() { + while (true) { + val p = availablePermits.value + if (p == 1) return false + assert { p <= 0 } + if (owner.value === NO_OWNER) continue + return true + } + } override fun holdsLock(owner: Any): Boolean { while (true) { @@ -161,23 +174,124 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 } override suspend fun lock(owner: Any?) { - if (tryLock(owner)) return +// if (tryLock(owner)) return lockSuspend(owner) } - private suspend fun lockSuspend(owner: Any?) = suspendCancellableCoroutineReusable { cont -> + private suspend fun lockSuspend(owner: Any?) = suspendCancellableCoroutineReusable { cont -> + cont as CancellableContinuationImpl val contWithOwner = CancellableContinuationWithOwner(cont, owner) - acquire(contWithOwner) + lockImpl(contWithOwner, owner) } - override fun tryLock(owner: Any?): Boolean = - if (tryAcquire()) { - assert { this.owner.value === NO_OWNER } - this.owner.value = owner - true - } else { - false + private fun lockImpl(waiter: Waiter, owner: Any?) { + xxx@ while (true) { + // Get the current number of available permits. + val p = availablePermits.getAndDecrement() + // Try to decrement the number of available + // permits if it is greater than zero. + if (p <= 0) { + // The semaphore permit acquisition has failed. + // However, we need to check that this mutex is not + // locked by our owner. + if (owner != null) { + // Is this mutex locked by our owner? + val curOwner = this.owner.value + if (curOwner === owner) { + if (suspendCancelled() != null) { + when (waiter) { + is CancellableContinuation<*> -> { + @Suppress("UNCHECKED_CAST") + waiter as CancellableContinuation + waiter.resume(Unit, null) + } + is SelectInstance<*> -> { + waiter.selectInRegistrationPhase(Unit) + } + } + return + } + when (waiter) { + is CancellableContinuation<*> -> { + waiter.resumeWithException(IllegalStateException("ERROR")) + } + is SelectInstance<*> -> { + waiter.selectInRegistrationPhase(ON_LOCK_ALREADY_LOCKED_BY_OWNER) + } + } + return + } else if (curOwner === NO_OWNER) { + if (suspendCancelled() == null) continue@xxx + when (waiter) { + is CancellableContinuation<*> -> { + @Suppress("UNCHECKED_CAST") + waiter as CancellableContinuation + waiter.resume(Unit, null) + } + is SelectInstance<*> -> { + waiter.selectInRegistrationPhase(Unit) + } + } + return + } + + // This mutex is either locked by another owner or unlocked. + // In the latter case, it is possible that it WAS locked by + // our owner when the semaphore permit acquisition has failed. + // To preserve linearizability, the operation restarts in this case. +// if (!isLocked) continue + } + if (suspend(waiter)) return + } else { + assert { p == 1 } + assert { this.owner.value === NO_OWNER } + when (waiter) { + is CancellableContinuation<*> -> { + @Suppress("UNCHECKED_CAST") + waiter as CancellableContinuation + waiter.resume(Unit, null) + } + is SelectInstance<*> -> { + waiter.selectInRegistrationPhase(Unit) + } + } + return + } + } + } + + override fun tryLock(owner: Any?): Boolean = when (tryLockImpl(owner)) { + TRY_LOCK_SUCCESS -> true + TRY_LOCK_FAILED -> false + TRY_LOCK_ALREADY_LOCKED_BY_OWNER -> error("This mutex is already locked by the specified owner: $owner") + else -> error("unexpected") + } + + private fun tryLockImpl(owner: Any?): Int { + while (true) { + // Get the current number of available permits. + val p = availablePermits.value + // Try to decrement the number of available + // permits if it is greater than zero. + if (p <= 0) { + // The semaphore permit acquisition has failed. + // However, we need to check that this mutex is not + // locked by our owner. + if (owner != null) { + // Is this mutex locked by our owner? + val curOwner = this.owner.value + if (curOwner === NO_OWNER) continue + if (curOwner === owner) return TRY_LOCK_ALREADY_LOCKED_BY_OWNER + } + return TRY_LOCK_FAILED + } + if (availablePermits.compareAndSet(p, p - 1)) { + assert { this.owner.value === NO_OWNER } + this.owner.value = owner + return TRY_LOCK_SUCCESS + } } + } override fun unlock(owner: Any?) { while (true) { @@ -196,6 +310,27 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 } } + fun release() { + while (true) { + // Increment the number of available permits. + val p = availablePermits.value + // Is this `release` call correct and does not + // exceed the maximal number of permits? + if (p >= 1) { + error("This mutex is not locked") + } + if (availablePermits.compareAndSet(p, p + 1)) { + // Is there a waiter that should be resumed? + if (p == 0) return + // Try to resume the first waiter, and + // restart the operation if either this + // first waiter is cancelled or + // due to `SYNC` resumption mode. + if (resume(Unit)) return + } + } + } + @Suppress("UNCHECKED_CAST", "OverridingDeprecatedMember", "OVERRIDE_DEPRECATION") override val onLock: SelectClause2 get() = SelectClause2Impl( clauseObject = this, @@ -205,19 +340,22 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 ) protected open fun onLockRegFunction(select: SelectInstance<*>, owner: Any?) { - onAcquireRegFunction(SelectInstanceWithOwner(select, owner), owner) + lockImpl(SelectInstanceWithOwner(select as SelectInstanceInternal<*>, owner), owner) } protected open fun onLockProcessResult(owner: Any?, result: Any?): Any? { + if (result == ON_LOCK_ALREADY_LOCKED_BY_OWNER) { + error("This mutex is already locked by the specified owner: $owner") + } return this } private inner class CancellableContinuationWithOwner( @JvmField - val cont: CancellableContinuation, + val cont: CancellableContinuationImpl, @JvmField val owner: Any? - ) : CancellableContinuation by cont { + ) : CancellableContinuation by cont, Waiter by cont { override fun tryResume(value: Unit, idempotent: Any?, onCancellation: ((cause: Throwable) -> Unit)?): Any? { assert { this@MutexImpl.owner.value === NO_OWNER } val token = cont.tryResume(value, idempotent) { @@ -241,10 +379,10 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 private inner class SelectInstanceWithOwner( @JvmField - val select: SelectInstance, + val select: SelectInstanceInternal, @JvmField val owner: Any? - ) : SelectInstanceInternal by select as SelectInstanceInternal { + ) : SelectInstanceInternal by select { override fun trySelect(clauseObject: Any, result: Any?): Boolean { assert { this@MutexImpl.owner.value === NO_OWNER } return select.trySelect(clauseObject, result).also { success -> @@ -253,13 +391,22 @@ internal open class MutexImpl(locked: Boolean) : SemaphoreImpl(1, if (locked) 1 } override fun selectInRegistrationPhase(internalResult: Any?) { - assert { this@MutexImpl.owner.value === NO_OWNER } - this@MutexImpl.owner.value = owner + if (internalResult !== ON_LOCK_ALREADY_LOCKED_BY_OWNER) { + assert { this@MutexImpl.owner.value === NO_OWNER } + this@MutexImpl.owner.value = owner + } select.selectInRegistrationPhase(internalResult) } } + internal val debugStateRepresentation: String get() = "p=${availablePermits.value},owner=${owner.value},SQS=${super.toString()}" + override fun toString() = "Mutex@${hexAddress}[isLocked=$isLocked,owner=${owner.value}]" } private val NO_OWNER = Symbol("NO_OWNER") +private val ON_LOCK_ALREADY_LOCKED_BY_OWNER = Symbol("ALREADY_LOCKED_BY_OWNER") + +private const val TRY_LOCK_SUCCESS = 0 +private const val TRY_LOCK_FAILED = 1 +private const val TRY_LOCK_ALREADY_LOCKED_BY_OWNER = 2 diff --git a/kotlinx-coroutines-core/common/src/sync/ReadWriteMutex.kt b/kotlinx-coroutines-core/common/src/sync/ReadWriteMutex.kt new file mode 100644 index 0000000000..0752a8136c --- /dev/null +++ b/kotlinx-coroutines-core/common/src/sync/ReadWriteMutex.kt @@ -0,0 +1,613 @@ +/* + * Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.sync + +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import kotlinx.coroutines.internal.* +import kotlinx.coroutines.internal.CancellableQueueSynchronizer.CancellationMode.* +import kotlinx.coroutines.internal.CancellableQueueSynchronizer.ResumeMode.* +import kotlinx.coroutines.selects.* +import kotlinx.coroutines.sync.ReadWriteMutexImpl.WriteUnlockPolicy.* +import kotlin.contracts.* +import kotlin.js.* + +/** + * This readers-writer mutex maintains a logical pair of locks, one for read-only + * operations that can be processed concurrently (see [readLock()][readLock] and [readUnlock()][readUnlock]), + * and another one for write operations that require an exclusive access (see [write]). + * It is guaranteed that write and read operations never interfere. + * + * The table below shows which locks can be held simultaneously. + * +-------------+-------------+-------------+ + * | | reader lock | writer lock | + * +-------------+-------------+-------------+ + * | reader lock | ALLOWED | FORBIDDEN | + * +-------------+-------------+-------------+ + * | writer lock | FORBIDDEN | FORBIDDEN | + * +-------------+-------------+-------------+ + * + * Similar to [Mutex], this readers-writer mutex is **non-reentrant**, + * so invoking [readLock()][readLock] or [write.lock()][write] even from the coroutine that + * currently holds the corresponding lock may suspend the invoker. Likewise, invoking + * [readLock()][readLock] from the holder of the writer lock also suspends the invoker. + * + * Typical usage of [ReadWriteMutex] is wrapping each read invocation with + * [read { ... }][read] and each write invocation with [write { ... }][write]. + * These wrapper functions guarantee that the readers-writer mutex is used correctly + * and safely. However, one can use `lock()` and `unlock()` operations directly. + * + * The advantage of using [ReadWriteMutex] compared to the plain [Mutex] is the ability + * to parallelize read operations and, therefore, increase the level of concurrency. + * This is extremely useful for the workloads with dominating read operations so they can be + * executed in parallel, improving the performance and scalability. However, depending on the + * updates frequency, the execution cost of read and write operations, and the contention, + * it can be simpler and cheaper to use the plain [Mutex]. Therefore, it is highly recommended + * to measure the performance difference to make the right choice. + */ +@ExperimentalCoroutinesApi +public interface ReadWriteMutex { + /** + * // TODO: how to reference `val write: Mutex` instead of the extension function? + * Acquires a reader lock of this mutex if the [writer lock][write] is not held and there is no writer + * waiting for it. Suspends the caller otherwise until the writer lock is released and this reader is resumed. + * Please note, that in this case the next waiting writer instead of this reader can be resumed after + * the currently active writer releases the lock. + * + * This suspending function is cancellable. If the [Job] of the current coroutine is cancelled or completed while this + * function is suspended, this function immediately resumes with [CancellationException]. + * There is a **prompt cancellation guarantee**. If the job was cancelled while this function was + * suspended, it will not resume successfully. See [suspendCancellableCoroutine] documentation for low-level details. + * This function releases the lock if it was already acquired by this function before the [CancellationException] + * was thrown. + * + * Note that this function does not check for cancellation when it is not suspended. + * Use [yield] or [CoroutineScope.isActive] to periodically check for cancellation in tight loops if needed. + * + * It is recommended to use [read { ... }][read] block for safety reasons, so the acquired reader lock + * is always released at the end of the critical section, and [readUnlock()][readUnlock] is never invoked + * before a successful [readLock()][readLock]. + */ + @ExperimentalCoroutinesApi + public suspend fun readLock() + + /** + * Releases a reader lock of this mutex and resumes the first waiting writer + * if this operation has released the last acquired reader lock. + * + * It is recommended to use [read { ... }][read] block for safety reasons, so the acquired reader lock + * is always released at the end of the critical section, and [readUnlock()][readUnlock] is never invoked + * before a successful [readLock()][readLock]. + */ + @ExperimentalCoroutinesApi + public fun readUnlock() + + /** + * Returns a [mutex][Mutex] which manipulates with the writer lock of this [ReadWriteMutex]. + * + * When acquires the writer lock, the operation completes immediately if neither the writer lock nor + * a reader lock is held. Otherwise, the acquisition suspends the caller until the exclusive access + * is granted by either [readUnlock()][readUnlock] or [write.unlock()][Mutex.unlock]. Please note that + * all suspended writers are processed in first-in-first-out (FIFO) order. + * + * When releasing the writer lock, the operation resumes the first waiting writer or waiting readers. + * Note that different fairness policies can be applied by an implementation, such as + * prioritizing readers or writers and attempting to always resume them at first, + * choosing the prioritization policy by flipping a coin, or providing a truly fair + * strategy where all waiters, both readers and writers, form a single FIFO queue. + * + * This [Mutex] implementation for writers does not support owners in [lock()][Mutex.lock] + * and [withLock { ... }][Mutex.withLock] functions as well as the [onLock][Mutex.onLock] select clause. + * + * It is also recommended to use [write { ... }][write] block for safety reasons, so the acquired writer lock + * is always released at the end of the critical section, and [write.unlock()][Mutex.unlock] is never invoked + * before a successful [write.lock()][Mutex.lock]. + */ + @ExperimentalCoroutinesApi + public val write: Mutex +} + +/** + * Creates a new [ReadWriteMutex] instance, both reader and writer locks are not acquired. + * + * Instead of ensuring the strict fairness, when all waiting readers and writers form + * a single queue, this implementation provides a slightly relaxed but more efficient guarantee. + * In this version, two separate queues for waiting readers and waiting writers are maintained. + * When the last reader lock is released, the first waiting writer is released -- this behaviour + * respects the strict fairness property. However, when the writer lock is released, the implementation + * either releases all the waiting readers or the first waiting writer, choosing the policy by the + * round-robin strategy. Thus, if the choice differs from the strict fairness, it is guaranteed that + * the proper waiter(s) will be resumed on the next step. Simultaneously, we find it more efficient to + * resume all waiting readers even if it violates the strict fairness. + */ +@JsName("_ReadWriteMutex") +public fun ReadWriteMutex(): ReadWriteMutex = ReadWriteMutexImpl() + +/** + * Executes the given [action] under a _reader_ lock of this readers-writer mutex. + * + * @return the return value of the [action]. + */ +@OptIn(ExperimentalContracts::class) +public suspend inline fun ReadWriteMutex.read(action: () -> T): T { + contract { + callsInPlace(action, InvocationKind.EXACTLY_ONCE) + } + + readLock() + try { + return action() + } finally { + readUnlock() + } +} + +/** + * Executes the given [action] under the _writer_ lock of this readers-writer mutex. + * + * @return the return value of the [action]. + */ +public suspend inline fun ReadWriteMutex.write(action: () -> T): T = + write.withLock(null, action) + +/** + * This readers-writer mutex maintains the numbers of active and waiting readers, + * a flag on whether the writer lock is acquired, and the number of writers waiting + * for the lock. This tuple represents the current state of the readers-writer mutex and + * is split into [waitingReaders] and [state] fields -- it is impossible to store everything + * in a single register since its maximal capacity is 64 bit, and this is not sufficient + * for three counters and several flags. Additionally, separate [CancellableQueueSynchronizer]-s + * are used for waiting readers and writers. + * + * To acquire a reader lock, the algorithm checks whether the writer lock is held or there is a writer + * waiting for it, increasing the number of _active_ readers and grabbing a read lock immediately if not. + * Otherwise, it atomically decreases the number of _active_ readers and increases the number of _waiting_ + * readers and suspends. + * As for the writer lock acquisition, the idea is the same -- the algorithm checks whether both reader and + * writer locks are not acquired and takes the lock immediately in this case. Otherwise, if the writer should + * wait for the lock, the algorithm increases the counter of waiting writers and suspends. + * + * When releasing a reader lock, the algorithm decrements the number of active readers. + * If the counter reaches zero, it checks whether a writer is waiting for the lock + * and resumes the first waiting one. + * On the writer lock release, the algorithm resumes either the next waiting writer + * (decrementing the counter of them) or all waiting readers (decrementing the counter of waiting + * readers and incrementing the counter of active ones). + * + * When there are both readers and writers waiting for a lock at the point of the writer lock release, + * the truly fair implementation would form a single queue where all waiters, both readers and writers, + * are stored. Instead of ensuring the strict fairness, this implementation provides a slightly relaxed + * but more efficient guarantee. In short, it maintains two separate queues, for waiting readers and + * waiting writers. When the writer lock is released, the algorithm either releases all the waiting readers + * or the first waiting writer, choosing the policy by the round-robin strategy. Thus, if the choice differs + * from the strict fairness, it is guaranteed that the proper waiter(s) will be resumed on the next step. + * Simultaneously, we find it more efficient to resume all waiting readers even if it violates the strict fairness. + * + * As for cancellation, the main idea is to revert the state update. However, possible logical races + * should be managed carefully, which makes the revert part non-trivial. The details are discussed in the code + * comments and appear almost everywhere. + */ +internal class ReadWriteMutexImpl : ReadWriteMutex, Mutex { + // The number of coroutines waiting for a reader lock in `cqsReaders`. + private val waitingReaders = atomic(0) + // This state field contains several counters and is always updated atomically by `CAS`: + // - `AR` (active readers) is a 30-bit counter which represents the number + // of coroutines holding a read lock; + // - `WLA` (writer lock acquired) is a flag which is `true` when + // the writer lock is acquired; + // - `WW` (waiting writers) is a 30-bit counter which represents the number + // of coroutines waiting for the writer lock in `cqsWriters`; + // - `RWR` (resuming waiting readers) is a flag which is `true` when waiting readers + // resumption is in progress. + private val state = atomic(0L) + + private val cqsReaders = ReadersCQS() // the place where readers should suspend and be resumed + private val cqsWriters = WritersCQS() // the place where writers should suspend and be resumed + + private var curUnlockPolicy = false // false -- prioritize readers on the writer lock release + // true -- prioritize writers on the writer lock release + + @ExperimentalCoroutinesApi + override val write: Mutex get() = this // we do not create an extra object this way. + override val isLocked: Boolean get() = state.value.wla + override fun tryLock(owner: Any?): Boolean = error("ReadWriteMutex.write does not support `tryLock()`") + override suspend fun lock(owner: Any?) { + if (owner != null) error("ReadWriteMutex.write does not support owners") + writeLock() + } + @Suppress("OVERRIDE_DEPRECATION") + override val onLock: SelectClause2 get() = error("ReadWriteMutex.write does not support `onLock`") + override fun holdsLock(owner: Any) = error("ReadWriteMutex.write does not support owners") + override fun unlock(owner: Any?) { + if (owner != null) error("ReadWriteMutex.write does not support owners") + writeUnlock() + } + + override suspend fun readLock() { + // Try to acquire a reader lock without suspension. + if (tryReadLock()) return + // The attempt fails, invoke the slow-path. This slow-path + // part is implemented in a separate function to guarantee + // that the tail call optimization is applied here. + readLockSlowPath() + } + + private fun tryReadLock(): Boolean { + while (true) { + // Read the current state. + val s = state.value + // Is the writer lock acquired or is there a waiting writer? + if (!s.wla && s.ww <= 0) { + // A reader lock is available to acquire, try to do it! + // Note that there can be a concurrent `write.unlock()` which is + // resuming readers now, so the `RWR` flag is set in this case. + if (state.compareAndSet(s, state(s.ar + 1, false, 0, s.rwr))) + return true + // CAS failed => the state has changed. + // Re-read it and try to acquire a reader lock again. + continue + } else return false + } + } + + private suspend fun readLockSlowPath() { + // Increment the number of waiting readers at first. + // If the current invocation should not suspend, + // the counter will be decremented back later. + waitingReaders.incrementAndGet() + // Check whether this operation should suspend. If not, try + // to decrement the counter of waiting readers and restart. + while (true) { + // Read the current state. + val s = state.value + // Is there a writer holding the lock or waiting for it? + if (s.wla || s.ww > 0) { + // The number of waiting readers was incremented + // correctly, wait for a reader lock in `cqsReaders`. + suspendCancellableCoroutineReusable { cont -> + cqsReaders.suspend(cont as Waiter) + } + return + } else { + // A race has been detected! The increment of the counter of + // waiting readers was wrong, try to decrement it back. However, + // it could already become zero due to a concurrent `write.unlock()` + // which reads the number of waiting readers, replaces it with `0`, + // and resumes all these readers. In this case, it is guaranteed + // that a reader lock will be provided via `cqsReaders`. + while (true) { + // Read the current number of waiting readers. + val wr = waitingReaders.value + // Is our invocation already handled by a concurrent + // `write.unlock()` and a reader lock is going to be + // passed via `cqsReaders`? Suspend in this case -- + // it is guaranteed that the lock will be provided + // when this concurrent `write.unlock()` completes. + if (wr == 0) { + suspendCancellableCoroutineReusable { cont -> + cqsReaders.suspend(cont as Waiter) + } + return + } + // Otherwise, try to decrement the number of waiting + // readers and retry the operation from the beginning. + if (waitingReaders.compareAndSet(wr, wr - 1)) { + // Try again starting from the fast path. + readLock() + return + } + } + } + } + } + + override fun readUnlock() { + // When releasing a reader lock, the algorithm checks whether + // this reader lock is the last acquired one and resumes + // the first waiting writer (if applicable) in this case. + while (true) { + // Read the current state. + val s = state.value + check(!s.wla) { "Invalid `readUnlock` invocation: the writer lock is acquired. $INVALID_UNLOCK_INVOCATION_TIP" } + check(s.ar > 0) { "Invalid `readUnlock` invocation: no reader lock is acquired. $INVALID_UNLOCK_INVOCATION_TIP" } + // Is this reader the last one and is the `RWR` flag unset (=> it is valid to resume the next writer)? + if (s.ar == 1 && !s.rwr) { + // Check whether there is a waiting writer and resume it. + // Otherwise, simply change the state and finish. + if (s.ww > 0) { + // Try to decrement the number of waiting writers and set the `WLA` flag. + // Resume the first waiting writer on success. + if (state.compareAndSet(s, state(0, true, s.ww - 1, false))) { + cqsWriters.resume(Unit) + return + } + } else { + // There is no waiting writer according to the state. + // Try to clear the number of active readers and finish. + if (state.compareAndSet(s, state(0, false, 0, false))) + return + } + } else { + // Try to decrement the number of active readers and finish. + // Please note that the `RWR` flag can be set here if there is + // a concurrent unfinished `write.unlock()` operation which + // has resumed the current reader but the corresponding + // `readUnlock()` happened before this `write.unlock()` completion. + if (state.compareAndSet(s, state(s.ar - 1, false, s.ww, s.rwr))) + return + } + } + } + + /** + * This customization of [CancellableQueueSynchronizer] for waiting readers + * use the asynchronous resumption mode and smart cancellation mode, + * so neither [suspend] nor [resume] fail. However, to support + * `tryReadLock()` the synchronous resumption mode should be used. + */ + private inner class ReadersCQS : CancellableQueueSynchronizer() { + override val resumeMode get() = ASYNC + override val cancellationMode get() = SMART + + override fun onCancellation(): Boolean { + // The cancellation logic here is pretty similar to + // the one in `readLock()` when the number of waiting + // readers has been incremented incorrectly. + while (true) { + // First, read the current number of waiting readers. + val wr = waitingReaders.value + // Check whether it has already reached zero -- in this + // case a concurrent `write.unlock()` will invoke `resume()` + // for this cancelled operation eventually, so `onCancellation()` + // should return `false` to refuse the granted lock. + if (wr == 0) return false + // Otherwise, try to decrement the number of waiting readers keeping + // the counter non-negative and successfully finish the cancellation. + if (waitingReaders.compareAndSet(wr, wr - 1)) return true + } + } + + // When `onCancellation()` fails, the state keeps unchanged. Therefore, + // the reader lock should be returned back to the mutex in `returnValue(..)`. + override fun tryReturnRefusedValue(value: Unit) = false + + // Returns the reader lock back to the mutex. + // This function is also used for prompt cancellation. + override fun returnValue(value: Unit) = readUnlock() + } + + internal suspend fun writeLock() { + // The algorithm is straightforward -- it reads the current state, + // checks that there is no reader or writer lock acquired, and + // tries to change the state by atomically setting the `WLA` flag. + // Otherwise, if the writer lock cannot be acquired immediatelly, + // it increments the number of waiting writers and suspends in + // `cqsWriters` waiting for the lock. + while (true) { + // Read the current state. + val s = state.value + // Is there an active writer (the WLA flag is set), a concurrent `writeUnlock` operation, + // which is releasing readers now (the RWR flag is set), or an active reader (AR >= 0)? + if (!s.wla && !s.rwr && s.ar == 0) { + // Try to acquire the writer lock, re-try the operation if this CAS fails. + assert { s.ww == 0 } + if (state.compareAndSet(s, state(0, true, 0, false))) + return + } else { + // The lock cannot be acquired immediately, and this operation has to suspend. + // Try to increment the number of waiting writers and suspend in `cqsWriters`. + if (state.compareAndSet(s, state(s.ar, s.wla, s.ww + 1, s.rwr))) { + suspendCancellableCoroutineReusable { cont -> + cqsWriters.suspend(cont as Waiter) + } + return + } + } + } + } + + internal fun writeUnlock() { + // Since we store waiting readers and writers separately, it is not easy + // to determine whether the next readers or the next writer should be resumed. + // However, it is natural to have the following policies: + // - PRIORITIZE_READERS -- always resume all waiting readers at first; + // the next waiting writer is resumed only if no reader is waiting for a lock. + // - PRIORITIZE_WRITERS -- always resumed the next writer first; + // - ROUND_ROBIN -- switch between the policies above on every invocation. + // + // We find the round-robin strategy fair enough in practice, but the others are used + // in Lincheck tests. However, it could be useful to have `PRIORITIZE_WRITERS` policy + // in the public API for the cases when the writer lock is used for UI updates. + writeUnlock(ROUND_ROBIN) + } + + internal fun writeUnlock(policy: WriteUnlockPolicy) { + // The algorithm for releasing the writer lock is straightforward by design, + // but has a lot of corner cases that should be properly managed. + // If the next writer should be resumed (see `PRIORITIZE_WRITERS` policy), + // the algorithm tries to atomically decrement the number of waiting writers + // and keep the `WLA` flag, resuming the first writer in `cqsWriters` after that. + // Otherwise, if the `PRIORITIZE_READERS` policy is used or there is no waiting writer, + // the algorithm sets the `RWR` (resuming waiting readers) flag and invokes a special + // `completeWaitingReadersResumption()` to resume all the waiting readers. + while (true) { + // Read the current state at first. + val s = state.value + check(s.wla) { "Invalid `writeUnlock` invocation: the writer lock is not acquired. $INVALID_UNLOCK_INVOCATION_TIP" } + assertNot { s.rwr } + assert { s.ar == 0 } + // Should we resume the next writer? + curUnlockPolicy = !curUnlockPolicy // change the unlock policy for the `ROUND_ROBIN` strategy + val resumeWriter = (s.ww > 0) && (policy == PRIORITIZE_WRITERS || policy == ROUND_ROBIN && curUnlockPolicy) + if (resumeWriter) { + // Resume the next writer - try to decrement the number of waiting + // writers and resume the first one in `cqsWriters` on success. + if (state.compareAndSet(s, state(0, true, s.ww - 1, false))) { + cqsWriters.resume(Unit) + return + } + } else { + // Resume waiting readers. Reset the `WLA` flag and set the `RWR` flag atomically, + // completing the resumption via `completeWaitingReadersResumption()` after that. + // Note that this function also checks whether the next waiting writer should be resumed + // on completion and does it if required. It also resets the `RWR` flag at the end. + // While it is possible that no reader is waiting for a lock, so that this CAS can be omitted, + // we do not add the corresponding code for simplicity since it does not improve the performance + // significantly but reduces the code readability. + if (state.compareAndSet(s, state(0, false, s.ww, true))) { + completeWaitingReadersResumption() + return + } + } + } + } + + private fun completeWaitingReadersResumption() { + // This function is called after the `RWR` flag is set + // and completes the readers resumption process. Note that + // it also checks whether the next waiting writer should be + // resumed on completion and performs this resumption if needed. + assert { state.value.rwr } + // At first, atomically replace the number of waiting + // readers (to be resumed) with 0, retrieving the old value. + val wr = waitingReaders.getAndSet(0) + // After that, these waiting readers should be logically resumed + // by incrementing the corresponding counter in the `state` field. + // We also skip this step if the obtained number of waiting readers is zero. + if (wr > 0) { // should we update the state? + state.update { s -> + assertNot { s.wla } // the writer lock cannot be acquired now. + assert { s.rwr } // the `RWR` flag should still be set. + state(s.ar + wr, false, s.ww, true) + } + } + // After the readers are resumed logically, they should be resumed physically in `cqsReaders`. + repeat(wr) { + cqsReaders.resume(Unit) + } + // Once all the waiting readers are resumed, the `RWR` flag should be reset. + // It is possible that all the resumed readers have already completed their + // work and successfully invoked `readUnlock()` at this point, but since + // the `RWR` flag was set, they were unable to resume the next waiting writer. + // Similarly, it is possible that there were no waiting readers at all. + // Therefore, in the end, we check whether the number of active readers is 0 + // and resume the next waiting writer in this case (if there exists one). + var resumeWriter = false + state.getAndUpdate { s -> + resumeWriter = s.ar == 0 && s.ww > 0 + val wwUpd = if (resumeWriter) s.ww - 1 else s.ww + state(s.ar, resumeWriter, wwUpd, false) + } + if (resumeWriter) { + // Resume the next writer physically and finish + cqsWriters.resume(Unit) + return + } + // Meanwhile, it could be possible for a writer to come and suspend due to the `RWR` flag. + // After that, all the following readers suspend since a writer is waiting for the lock. + // However, if the writer becomes canceled, it cannot resume these suspended readers if the `RWR` flag + // is still set, so we have to help him with the resumption process. To detect such a situation, we re-read + // the number of waiting readers and try to start the resumption process again if the writer lock is not acquired. + if (waitingReaders.value > 0) { // Is there a waiting reader? + while (true) { + val s = state.value // Read the current state. + if (s.wla || s.ww > 0 || s.rwr) return // Check whether the readers resumption is valid. + // Try to set the `RWR` flag again and resume the waiting readers. + if (state.compareAndSet(s, state(s.ar, false, 0, true))) { + completeWaitingReadersResumption() + return + } + } + } + } + + /** + * This customization of [CancellableQueueSynchronizer] for waiting writers + * uses the asynchronous resumption mode and smart cancellation mode, + * so neither [suspend] nor [resume] fail. However, in order to support + * `tryWriteLock()` the synchronous resumption mode should be used instead. + */ + private inner class WritersCQS : CancellableQueueSynchronizer() { + override val resumeMode get() = ASYNC + override val cancellationMode get() = SMART + + override fun onCancellation(): Boolean { + // In general, on cancellation, the algorithm tries to decrement the number of waiting writers. + // Similarly to the cancellation logic for readers, if the number of waiting writers has already reached 0, + // the current canceling writer will be resumed in `cqsWriters`. In this case, the function returns + // `false`, and the permit will be returned via `returnValue()`. Otherwise, if the number of waiting + // writers >= 1, the decrement is sufficient. However, if this canceling writer is the last waiting one, + // the algorithm sets the `RWR` flag and resumes waiting readers. This logic is similar to `writeUnlock(..)`. + while (true) { + val s = state.value // Read the current state. + if (s.ww == 0) return false // Is this writer going to be resumed in `cqsWriters`? + // Is this writer the last one and is the readers resumption valid? + if (s.ww == 1 && !s.wla && !s.rwr) { + // Set the `RWR` flag and resume the waiting readers. + // While it is possible that no reader is waiting for a lock, so that this CAS can be omitted, + // we do not add the corresponding code for simplicity since it does not improve the performance + // significantly but reduces the code readability. Note that the same logic appears in `writeUnlock(..)`, + // and the cancellation performance is less critical since the cancellation itself does not come for free. + if (state.compareAndSet(s, state(s.ar, false, 0, true))) { + completeWaitingReadersResumption() + return true + } + } else { + // There are multiple writers waiting for the lock. Try to decrement the number of them. + if (state.compareAndSet(s, state(s.ar, s.wla, s.ww - 1, s.rwr))) + return true + } + } + } + + // Resumes the next waiting writer if the current `writeLock()` operation + // is already cancelled but the next writer is logically resumed + override fun tryReturnRefusedValue(value: Unit): Boolean { + writeUnlock(PRIORITIZE_WRITERS) + return true + } + + // Returns the writer lock back to the mutex. + // This function is also used for prompt cancellation. + override fun returnValue(value: Unit) = writeUnlock() + } + + // This state representation is used in Lincheck tests. + internal val stateRepresentation: String get() = + "" + + internal enum class WriteUnlockPolicy { PRIORITIZE_READERS, PRIORITIZE_WRITERS, ROUND_ROBIN } +} + +/** + * Constructs a value for [ReadWriteMutexImpl.state] field. + * The created state can be parsed via the extension functions below. + */ +private fun state(activeReaders: Int, writeLockAcquired: Boolean, waitingWriters: Int, resumingWaitingReaders: Boolean): Long = + (if (writeLockAcquired) WLA_BIT else 0) + + (if (resumingWaitingReaders) RWR_BIT else 0) + + activeReaders * AR_MULTIPLIER + + waitingWriters * WW_MULTIPLIER + +// Equals `true` if the `WLA` flag is set in this state. +private val Long.wla: Boolean get() = this or WLA_BIT == this +// Equals `true` if the `RWR` flag is set in this state. +private val Long.rwr: Boolean get() = this or RWR_BIT == this +// The number of waiting writers specified in this state. +private val Long.ww: Int get() = ((this % AR_MULTIPLIER) / WW_MULTIPLIER).toInt() +// The number of active readers specified in this state. +private val Long.ar: Int get() = (this / AR_MULTIPLIER).toInt() + +private const val WLA_BIT = 1L +private const val RWR_BIT = 1L shl 1 +private const val WW_MULTIPLIER = 1L shl 2 +private const val AR_MULTIPLIER = 1L shl 33 + +private const val INVALID_UNLOCK_INVOCATION_TIP = "This can be caused by releasing the lock without acquiring it at first, " + + "or incorrectly putting the acquisition inside the \"try\" block of the \"try-finally\" section that safely releases " + + "the lock in the \"finally\" block - the acquisition should be performed right before this \"try\" block." \ No newline at end of file diff --git a/kotlinx-coroutines-core/common/src/sync/Semaphore.kt b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt index 4db8ae3ca6..147c6dbdc7 100644 --- a/kotlinx-coroutines-core/common/src/sync/Semaphore.kt +++ b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt @@ -7,9 +7,7 @@ package kotlinx.coroutines.sync import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.internal.* -import kotlinx.coroutines.selects.* import kotlin.contracts.* -import kotlin.js.* import kotlin.math.* /** @@ -18,7 +16,7 @@ import kotlin.math.* * Each [release] adds a permit, potentially releasing a suspended acquirer. * Semaphore is fair and maintains a FIFO order of acquirers. * - * Semaphores are mostly used to limit the number of coroutines that have access to particular resource. + * Semaphores are mostly used to limit the number of coroutines that have an access to particular resource. * Semaphore with `permits = 1` is essentially a [Mutex]. **/ public interface Semaphore { @@ -33,6 +31,9 @@ public interface Semaphore { * * This suspending function is cancellable. If the [Job] of the current coroutine is cancelled or completed while this * function is suspended, this function immediately resumes with [CancellationException]. + * + * This suspending function is cancellable. If the [Job] of the current coroutine is cancelled or completed while this + * function is suspended, this function immediately resumes with [CancellationException]. * There is a **prompt cancellation guarantee**. If the job was cancelled while this function was * suspended, it will not resume successfully. See [suspendCancellableCoroutine] documentation for low-level details. * This function releases the semaphore if it was already acquired by this function before the [CancellationException] @@ -42,13 +43,21 @@ public interface Semaphore { * Use [CoroutineScope.isActive] or [CoroutineScope.ensureActive] to periodically * check for cancellation in tight loops if needed. * - * Use [tryAcquire] to try to acquire a permit of this semaphore without suspension. + * Use [tryAcquire] to try acquire a permit of this semaphore without suspension. + * + * It is recommended to use [withPermit] for safety reasons, so that the acquired permit is always + * released at the end of your critical section and [release] is never invoked before a successful + * permit acquisition. */ public suspend fun acquire() /** * Tries to acquire a permit from this semaphore without suspension. * + * It is recommended to use [withPermit] for safety reasons, so that the acquired permit is always + * released at the end of your critical section and [release] is never invoked before a successful + * permit acquisition. + * * @return `true` if a permit was acquired, `false` otherwise. */ public fun tryAcquire(): Boolean @@ -57,6 +66,10 @@ public interface Semaphore { * Releases a permit, returning it into this semaphore. Resumes the first * suspending acquirer if there is one at the point of invocation. * Throws [IllegalStateException] if the number of [release] invocations is greater than the number of preceding [acquire]. + * + * It is recommended to use [withPermit] for safety reasons, so that the acquired permit is always + * released at the end of your critical section and [release] is never invoked before a successful + * permit acquisition. */ public fun release() } @@ -90,54 +103,13 @@ public suspend inline fun Semaphore.withPermit(action: () -> T): T { } } -@Suppress("UNCHECKED_CAST") -internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int) : Semaphore { - /* - The queue of waiting acquirers is essentially an infinite array based on the list of segments - (see `SemaphoreSegment`); each segment contains a fixed number of slots. To determine a slot for each enqueue - and dequeue operation, we increment the corresponding counter at the beginning of the operation - and use the value before the increment as a slot number. This way, each enqueue-dequeue pair - works with an individual cell. We use the corresponding segment pointers to find the required ones. - - Here is a state machine for cells. Note that only one `acquire` and at most one `release` operation - can deal with each cell, and that `release` uses `getAndSet(PERMIT)` to perform transitions for performance reasons - so that the state `PERMIT` represents different logical states. - - +------+ `acquire` suspends +------+ `release` tries +--------+ // if `cont.tryResume(..)` succeeds, then - | NULL | -------------------> | cont | -------------------> | PERMIT | (cont RETRIEVED) // the corresponding `acquire` operation gets - +------+ +------+ to resume `cont` +--------+ // a permit and the `release` one completes. - | | - | | `acquire` request is cancelled and the continuation is - | `release` comes | replaced with a special `CANCEL` token to avoid memory leaks - | to the slot before V - | `acquire` and puts +-----------+ `release` has +--------+ - | a permit into the | CANCELLED | -----------------> | PERMIT | (RElEASE FAILED) - | slot, waiting for +-----------+ failed +--------+ - | `acquire` after - | that. - | - | `acquire` gets +-------+ - | +-----------------> | TAKEN | (ELIMINATION HAPPENED) - V | the permit +-------+ - +--------+ | - | PERMIT | -< - +--------+ | - | `release` has waited a bounded time, +--------+ - +---------------------------------------> | BROKEN | (BOTH RELEASE AND ACQUIRE FAILED) - but `acquire` has not come +--------+ - */ - - private val head: AtomicRef - private val deqIdx = atomic(0L) - private val tail: AtomicRef - private val enqIdx = atomic(0L) - +internal open class SemaphoreImpl( + private val permits: Int, + acquiredPermits: Int +) : CancellableQueueSynchronizer(), Semaphore { init { - require(permits > 0) { "Semaphore should have at least 1 permit, but had $permits" } - require(acquiredPermits in 0..permits) { "The number of acquired permits should be in 0..$permits" } - val s = SemaphoreSegment(0, null, 2) - head = atomic(s) - tail = atomic(s) + require(permits > 0) { "Semaphore must have at least 1 permit, but is initialized with $permits" } + require(acquiredPermits in 0..permits) { "The number of acquired permits should be in range [0..$permits]" } } /** @@ -150,8 +122,6 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int private val _availablePermits = atomic(permits - acquiredPermits) override val availablePermits: Int get() = max(_availablePermits.value, 0) - private val onCancellationRelease = { _: Throwable -> release() } - override fun tryAcquire(): Boolean { while (true) { // Get the current number of available permits. @@ -172,57 +142,38 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int } override suspend fun acquire() { - // Decrement the number of available permits. - val p = decPermits() - // Is the permit acquired? - if (p > 0) return // permit acquired - // Try to suspend otherwise. - // While it looks better when the following function is inlined, - // it is important to make `suspend` function invocations in a way - // so that the tail-call optimization can be applied here. - acquireSlowPath() - } - - private suspend fun acquireSlowPath() = suspendCancellableCoroutineReusable sc@ { cont -> - // Try to suspend. - if (addAcquireToQueue(cont)) return@sc - // The suspension has been failed - // due to the synchronous resumption mode. - // Restart the whole `acquire`. - acquire(cont) + while (true) { + // Decrement the number of available permits. + val p = decPermits() + // Is the permit acquired? + if (p > 0) return + // Try to suspend otherwise. + // While it looks better when the following function is inlined, + // it is important to make `suspend` function invocations in a way + // so that the tail-call optimization can be applied here. + return acquireSlowPath() + } } - @JsName("acquireCont") - protected fun acquire(waiter: CancellableContinuation) = acquire( - waiter = waiter, - suspend = { cont -> addAcquireToQueue(cont) }, - onAcquired = { cont -> cont.resume(Unit, onCancellationRelease) } - ) - - @JsName("acquireInternal") - private inline fun acquire(waiter: W, suspend: (waiter: W) -> Boolean, onAcquired: (waiter: W) -> Unit) { + private suspend fun acquireSlowPath() = suspendCancellableCoroutineReusable sc@{ cont -> while (true) { - // Decrement the number of available permits at first. + // Try to suspend. + if (suspend(cont as Waiter)) return@sc + // The suspension has been failed + // due to the synchronous resumption mode. + // Restart the whole `acquire`, and decrement + // the number of available permits at first. val p = decPermits() // Is the permit acquired? if (p > 0) { - onAcquired(waiter) - return + cont.resume(Unit) { release() } + return@sc } - // Permit has not been acquired, try to suspend. - if (suspend(waiter)) return + // Permit has not been acquired, go to + // the beginning of the loop and suspend. } } - // We do not fully support `onAcquire` as it is needed only for `Mutex.onLock`. - @Suppress("UNUSED_PARAMETER") - protected fun onAcquireRegFunction(select: SelectInstance<*>, ignoredParam: Any?) = - acquire( - waiter = select, - suspend = { s -> addAcquireToQueue(s) }, - onAcquired = { s -> s.selectInRegistrationPhase(Unit) } - ) - /** * Decrements the number of available permits * and ensures that it is not greater than [permits] @@ -261,10 +212,20 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int // restart the operation if either this // first waiter is cancelled or // due to `SYNC` resumption mode. - if (tryResumeNextFromQueue()) return + if (resume(Unit)) return } } + override fun returnValue(value: Unit) { + // Return the permit if the current continuation + // is cancelled after the `tryResume` invocation + // because of the prompt cancellation. + // Note that this `release()` call can throw + // exception if there was a successful concurrent + // `release()` invoked without acquiring a permit. + release() + } + /** * Changes the number of available permits to * [permits] if it became greater due to an @@ -277,138 +238,4 @@ internal open class SemaphoreImpl(private val permits: Int, acquiredPermits: Int if (_availablePermits.compareAndSet(cur, permits)) break } } - - /** - * Returns `false` if the received permit cannot be used and the calling operation should restart. - */ - private fun addAcquireToQueue(waiter: Any): Boolean { - val curTail = this.tail.value - val enqIdx = enqIdx.getAndIncrement() - val segment = this.tail.findSegmentAndMoveForward(id = enqIdx / SEGMENT_SIZE, startFrom = curTail, - createNewSegment = ::createSegment).segment // cannot be closed - val i = (enqIdx % SEGMENT_SIZE).toInt() - // the regular (fast) path -- if the cell is empty, try to install continuation - if (segment.cas(i, null, waiter)) { // installed continuation successfully - when (waiter) { - is CancellableContinuation<*> -> { - waiter.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(segment, i).asHandler) - } - is SelectInstance<*> -> { - waiter.disposeOnCompletion(CancelSemaphoreAcquisitionHandler(segment, i)) - } - else -> error("unexpected: $waiter") - } - return true - } - // On CAS failure -- the cell must be either PERMIT or BROKEN - // If the cell already has PERMIT from tryResumeNextFromQueue, try to grab it - if (segment.cas(i, PERMIT, TAKEN)) { // took permit thus eliminating acquire/release pair - /// This continuation is not yet published, but still can be cancelled via outer job - when (waiter) { - is CancellableContinuation<*> -> { - waiter as CancellableContinuation - waiter.resume(Unit, onCancellationRelease) - } - is SelectInstance<*> -> { - waiter.selectInRegistrationPhase(Unit) - } - else -> error("unexpected: $waiter") - } - return true - } - assert { segment.get(i) === BROKEN } // it must be broken in this case, no other way around it - return false // broken cell, need to retry on a different cell - } - - @Suppress("UNCHECKED_CAST") - private fun tryResumeNextFromQueue(): Boolean { - val curHead = this.head.value - val deqIdx = deqIdx.getAndIncrement() - val id = deqIdx / SEGMENT_SIZE - val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead, - createNewSegment = ::createSegment).segment // cannot be closed - segment.cleanPrev() - if (segment.id > id) return false - val i = (deqIdx % SEGMENT_SIZE).toInt() - val cellState = segment.getAndSet(i, PERMIT) // set PERMIT and retrieve the prev cell state - when { - cellState === null -> { - // Acquire has not touched this cell yet, wait until it comes for a bounded time - // The cell state can only transition from PERMIT to TAKEN by addAcquireToQueue - repeat(MAX_SPIN_CYCLES) { - if (segment.get(i) === TAKEN) return true - } - // Try to break the slot in order not to wait - return !segment.cas(i, PERMIT, BROKEN) - } - cellState === CANCELLED -> return false // the acquirer has already been cancelled - else -> return cellState.tryResumeAcquire() - } - } - - private fun Any.tryResumeAcquire(): Boolean = when(this) { - is CancellableContinuation<*> -> { - this as CancellableContinuation - val token = tryResume(Unit, null, onCancellationRelease) - if (token != null) { - completeResume(token) - true - } else false - } - is SelectInstance<*> -> { - trySelect(this@SemaphoreImpl, Unit) - } - else -> error("unexpected: $this") - } -} - -private class CancelSemaphoreAcquisitionHandler( - private val segment: SemaphoreSegment, - private val index: Int -) : CancelHandler(), DisposableHandle { - override fun invoke(cause: Throwable?) = dispose() - - override fun dispose() { - segment.cancel(index) - } - - override fun toString() = "CancelSemaphoreAcquisitionHandler[$segment, $index]" -} - -private fun createSegment(id: Long, prev: SemaphoreSegment) = SemaphoreSegment(id, prev, 0) - -private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int) : Segment(id, prev, pointers) { - val acquirers = atomicArrayOfNulls(SEGMENT_SIZE) - override val numberOfSlots: Int get() = SEGMENT_SIZE - - @Suppress("NOTHING_TO_INLINE") - inline fun get(index: Int): Any? = acquirers[index].value - - @Suppress("NOTHING_TO_INLINE") - inline fun set(index: Int, value: Any?) { - acquirers[index].value = value - } - - @Suppress("NOTHING_TO_INLINE") - inline fun cas(index: Int, expected: Any?, value: Any?): Boolean = acquirers[index].compareAndSet(expected, value) - - @Suppress("NOTHING_TO_INLINE") - inline fun getAndSet(index: Int, value: Any?) = acquirers[index].getAndSet(value) - - // Cleans the acquirer slot located by the specified index - // and removes this segment physically if all slots are cleaned. - fun cancel(index: Int) { - // Clean the slot - set(index, CANCELLED) - // Remove this segment if needed - onSlotCleaned() - } - - override fun toString() = "SemaphoreSegment[id=$id, hashCode=${hashCode()}]" -} -private val MAX_SPIN_CYCLES = systemProp("kotlinx.coroutines.semaphore.maxSpinCycles", 100) -private val PERMIT = Symbol("PERMIT") -private val TAKEN = Symbol("TAKEN") -private val BROKEN = Symbol("BROKEN") -private val CANCELLED = Symbol("CANCELLED") -private val SEGMENT_SIZE = systemProp("kotlinx.coroutines.semaphore.segmentSize", 16) +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/common/test/CancellableContinuationHandlersTest.kt b/kotlinx-coroutines-core/common/test/CancellableContinuationHandlersTest.kt index 3c11182e00..bd6a44fff8 100644 --- a/kotlinx-coroutines-core/common/test/CancellableContinuationHandlersTest.kt +++ b/kotlinx-coroutines-core/common/test/CancellableContinuationHandlersTest.kt @@ -6,6 +6,7 @@ package kotlinx.coroutines +import kotlinx.coroutines.internal.* import kotlin.coroutines.* import kotlin.test.* @@ -159,4 +160,31 @@ class CancellableContinuationHandlersTest : TestBase() { } finish(3) } + + @Test + fun testSegmentAsHandler() = runTest { + class MySegment : Segment(0, null, 0) { + override val numberOfSlots: Int get() = 0 + + var invokeOnCancellationCalled = false + override fun onCancellation(index: Int, cause: Throwable?) { + invokeOnCancellationCalled = true + } + } + val s = MySegment() + expect(1) + try { + suspendCancellableCoroutine { c -> + expect(2) + c as CancellableContinuationImpl<*> + c.invokeOnCancellation(s, 0) + c.cancel() + } + } catch (e: CancellationException) { + expect(3) + } + expect(4) + check(s.invokeOnCancellationCalled) + finish(5) + } } diff --git a/kotlinx-coroutines-core/common/test/sync/MutexTest.kt b/kotlinx-coroutines-core/common/test/sync/MutexTest.kt index 6a60387672..8f11695a37 100644 --- a/kotlinx-coroutines-core/common/test/sync/MutexTest.kt +++ b/kotlinx-coroutines-core/common/test/sync/MutexTest.kt @@ -4,8 +4,8 @@ package kotlinx.coroutines.sync -import kotlinx.atomicfu.* import kotlinx.coroutines.* +import kotlinx.coroutines.selects.* import kotlin.test.* class MutexTest : TestBase() { @@ -138,4 +138,15 @@ class MutexTest : TestBase() { assertTrue(mutex.holdsLock(owner2)) finish(4) } + + @Test + @Suppress("DEPRECATION") + fun testIllegalStateInvariant() = runTest { + val mutex = Mutex() + val owner = Any() + assertTrue(mutex.tryLock(owner)) + assertFailsWith { mutex.tryLock(owner) } + assertFailsWith { mutex.lock(owner) } + assertFailsWith { select { mutex.onLock(owner) {} } } + } } diff --git a/kotlinx-coroutines-core/common/test/sync/ReadWriteMutexTest.kt b/kotlinx-coroutines-core/common/test/sync/ReadWriteMutexTest.kt new file mode 100644 index 0000000000..010e2ef28b --- /dev/null +++ b/kotlinx-coroutines-core/common/test/sync/ReadWriteMutexTest.kt @@ -0,0 +1,70 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.sync + +import kotlinx.coroutines.* +import kotlin.test.* + +class ReadWriteMutexTest : TestBase() { + @Test + fun simpleSingleCoroutineTest() = runTest { + val m = ReadWriteMutex() + m.readLock() + m.readLock() + m.readUnlock() + m.readUnlock() + m.write.lock() + m.write.unlock() + m.readLock() + } + + @Test + fun multipleCoroutinesTest() = runTest { + val m = ReadWriteMutex() + m.readLock() + expect(1) + launch { + expect(2) + m.readLock() + expect(3) + } + yield() + expect(4) + launch { + expect(5) + m.write.lock() + expect(8) + } + yield() + expect(6) + m.readUnlock() + yield() + expect(7) + m.readUnlock() + yield() + finish(9) + } + + @Test + fun acquireReadSucceedsAfterCancelledAcquireWrite() = runTest { + val m = ReadWriteMutex() + m.readLock() + val wJob = launch { + expect(1) + m.write.lock() + expectUnreached() + } + yield() + expect(2) + wJob.cancel() + launch { + expect(3) + m.readLock() + expect(4) + } + yield() + finish(5) + } +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/jvm/test/TestBase.kt b/kotlinx-coroutines-core/jvm/test/TestBase.kt index 6a013fa1da..d29164c220 100644 --- a/kotlinx-coroutines-core/jvm/test/TestBase.kt +++ b/kotlinx-coroutines-core/jvm/test/TestBase.kt @@ -255,6 +255,14 @@ public actual open class TestBase(private var disableOutCheck: Boolean) { protected suspend fun currentDispatcher() = coroutineContext[ContinuationInterceptor]!! } +fun CancellableContinuation.tryResume0(value: T, onCancellation: (Throwable?) -> Unit): Boolean { + tryResume(value, null, onCancellation).let { + if (it == null) return false + completeResume(it) + return true + } +} + /* * We ignore tests that test **real** non-virtualized tests with time on Windows, because * our CI Windows is virtualized itself (oh, the irony) and its clock resolution is dozens of ms, diff --git a/kotlinx-coroutines-core/jvm/test/lincheck/CancellableQueueSynchronizerLincheckTests.kt b/kotlinx-coroutines-core/jvm/test/lincheck/CancellableQueueSynchronizerLincheckTests.kt new file mode 100644 index 0000000000..22862275d0 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/lincheck/CancellableQueueSynchronizerLincheckTests.kt @@ -0,0 +1,822 @@ +/* + * Copyright 2016-2023 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +@file:Suppress("unused", "MemberVisibilityCanBePrivate") + +package kotlinx.coroutines.lincheck + +import kotlinx.atomicfu.* +import kotlinx.coroutines.* +import kotlinx.coroutines.internal.* +import kotlinx.coroutines.internal.CancellableQueueSynchronizer.CancellationMode.* +import kotlinx.coroutines.internal.CancellableQueueSynchronizer.ResumeMode.* +import kotlinx.coroutines.sync.Semaphore +import org.jetbrains.kotlinx.lincheck.* +import org.jetbrains.kotlinx.lincheck.annotations.* +import org.jetbrains.kotlinx.lincheck.annotations.Operation +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.* +import org.jetbrains.kotlinx.lincheck.verifier.* +import kotlin.collections.ArrayList +import kotlin.coroutines.* +import kotlin.reflect.* + +// This test suit serves two purposes. First of all, it tests the `SegmentQueueSynchronizer` +// implementation under different use-cases and workloads. On the other side, this test suite +// provides different well-known synchronization and communication primitive implementations +// via `SegmentQueueSynchronizer`, which can be considered as an API richness check as well as +// a collection of examples on how to use `SegmentQueueSynchronizer` to build new primitives. + +// ############## +// # SEMAPHORES # +// ############## + +/** + * This [Semaphore] implementation is similar to the one in the library, + * but uses the [asynchronous][ASYNC] resumption mode. However, it is non-trivial + * to make [tryAcquire] linearizable in this case, so it is not supported here. + */ +internal class AsyncSemaphore(permits: Int) : CancellableQueueSynchronizer(), Semaphore { + override val resumeMode get() = ASYNC + + private val _availablePermits = atomic(permits) + override val availablePermits get() = _availablePermits.value.coerceAtLeast(0) + + override fun tryAcquire() = error("Not implemented") // Not supported in the ASYNC version + + override suspend fun acquire() { + // Decrement the number of available permits. + val p = _availablePermits.getAndDecrement() + // Is the permit successfully acquired? + if (p > 0) return + // Suspend otherwise. + suspendCancellableCoroutine { cont -> + check(suspend(cont as Waiter)) { "Should not fail in ASYNC mode" } + } + } + + override fun release() { + while (true) { + // Increment the number of available permits. + val p = _availablePermits.getAndIncrement() + // Is there a waiter that should be resumed? + if (p >= 0) return + // Try to resume the first waiter, and + // restart the operation if it is cancelled. + if (resume(Unit)) return + } + } + + // For prompt cancellation. + override fun returnValue(value: Unit) = release() +} + +/** + * This semaphore implementation is correct only if [release] is always + * invoked after a successful [acquire]; in other words, when semaphore + * is used properly, without unexpected [release] invocations. The main + * advantage is using smart cancellation, so [release] always works + * in constant time under no contention, and the cancelled [acquire] + * requests do not play any role. It is worth noting, that it is possible + * to make this implementation correct under the prompt cancellation model + * even with unexpected [release]-s. + */ +internal class AsyncSemaphoreSmart(permits: Int) : CancellableQueueSynchronizer(), Semaphore { + override val resumeMode get() = ASYNC + override val cancellationMode get() = SMART + + private val _availablePermits = atomic(permits) + override val availablePermits get() = _availablePermits.value.coerceAtLeast(0) + + override fun tryAcquire() = error("Not implemented") // Not supported in the ASYNC version. + + override suspend fun acquire() { + // Decrement the number of available permits. + val p = _availablePermits.getAndDecrement() + // Is the permit acquired? + if (p > 0) return + // Suspend otherwise. + suspendCancellableCoroutine { cont -> + check(suspend(cont as Waiter)) { "Should not fail in ASYNC mode" } + } + } + + override fun release() { + // Increment the number of available permits. + val p = _availablePermits.getAndIncrement() + // Is there a waiter that should be resumed? + if (p >= 0) return + // Resume the first waiter. Due to the smart + // cancellation mode it is possible that this + // permit will be refused, so this release + // invocation can take effect with a small lag + // and with an extra suspension, but it is guaranteed + // that the permit will be refused eventually. + resume(Unit) + } + + override fun onCancellation(): Boolean { + // Increment the number of available permits. + val p = _availablePermits.getAndIncrement() + // Return `true` if there is no `release` that + // is going to resume this cancelling `acquire()`, + // or `false` if there is one, and this permit + // should be refused. + return p < 0 + } + + // For prompt cancellation. + override fun returnValue(value: Unit) = release() +} + +/** + * This implementation is similar to the previous one, but with [synchronous][SYNC] + * resumption mode, so it is possible to implement [tryAcquire] correctly. + * The only notable difference happens when a permit to be released is refused, + * and the following [resume] attempt in the cancellation handler fails due to + * the synchronization on resumption, so the permit is going to be returned + * to the semaphore in [returnValue] function. It is worth noting, that it + * is possible to make this implementation correct with prompt cancellation. + */ +internal class SyncSemaphoreSmart(permits: Int) : CancellableQueueSynchronizer(), Semaphore { + override val resumeMode get() = SYNC + override val cancellationMode get() = SMART + + private val _availablePermits = atomic(permits) + override val availablePermits get() = _availablePermits.value.coerceAtLeast(0) + + override suspend fun acquire() { + while (true) { + // Decrement the number of available permits. + val p = _availablePermits.getAndDecrement() + // Is the permit acquired? + if (p > 0) return + // Try to suspend otherwise. + val acquired = suspendCancellableCoroutine { cont -> + if (!suspend(cont as Waiter)) cont.resume(false) + } + if (acquired) return + } + } + + override fun tryAcquire(): Boolean = _availablePermits.loop { cur -> + // Try to decrement the number of available + // permits if it is greater than zero. + if (cur <= 0) return false + if (_availablePermits.compareAndSet(cur, cur -1)) return true + } + + override fun release() { + while (true) { + // Increment the number of available permits. + val p = _availablePermits.getAndIncrement() + // Is there a waiter that should be resumed? + if (p >= 0) return + // Try to resume the first waiter, can fail + // according to the SYNC mode contract. + if (resume(true)) return + } + } + + override fun onCancellation(): Boolean { + // Increment the number of available permits. + val p = _availablePermits.getAndIncrement() + // Return `true` if there is no `release` which + // is going to resume us and cannot skip us and + // resume the next waiter. + return p < 0 + } + + override fun returnValue(value: Boolean) { + // Simply release the permit. + release() + } +} + +class SemaphoreUnboundedSequential1 : SemaphoreSequential(1, false) +class SemaphoreUnboundedSequential2 : SemaphoreSequential(2, false) + +// Comparing to `SemaphoreLincheckTestBase`, it does not support `tryAcquire()`. +abstract class AsyncSemaphoreLincheckTestBase( + semaphore: Semaphore, + private val seqSpec: KClass<*> +) : AbstractLincheckTest() { + private val s = semaphore + + @Operation(allowExtraSuspension = true, promptCancellation = false) + suspend fun acquire() = s.acquire() + + @Operation(handleExceptionsAsResult = [IllegalStateException::class]) + fun release() = s.release() + + override fun > O.customize(isStressTest: Boolean): O = + actorsBefore(0) + .sequentialSpecification(seqSpec.java) + + override fun ModelCheckingOptions.customize(isStressTest: Boolean) = + checkObstructionFreedom() +} + +class AsyncSemaphore1LincheckTest : AsyncSemaphoreLincheckTestBase(AsyncSemaphore(1), SemaphoreUnboundedSequential1::class) +class AsyncSemaphore2LincheckTest : AsyncSemaphoreLincheckTestBase(AsyncSemaphore(2), SemaphoreUnboundedSequential2::class) + +class AsyncSemaphoreSmart1LincheckTest : AsyncSemaphoreLincheckTestBase(AsyncSemaphoreSmart(1), SemaphoreUnboundedSequential1::class) +class AsyncSemaphoreSmart2LincheckTest : AsyncSemaphoreLincheckTestBase(AsyncSemaphoreSmart(2), SemaphoreUnboundedSequential2::class) + +class SyncSemaphoreSmart1LincheckTest : SemaphoreLincheckTestBase(SyncSemaphoreSmart(1), SemaphoreUnboundedSequential1::class) { + override fun ModelCheckingOptions.customize(isStressTest: Boolean) = checkObstructionFreedom(false) +} +class SyncSemaphoreSmart2LincheckTest : SemaphoreLincheckTestBase(SyncSemaphoreSmart(2), SemaphoreUnboundedSequential2::class) { + override fun ModelCheckingOptions.customize(isStressTest: Boolean) = checkObstructionFreedom(false) +} + + +// #################################### +// # COUNT-DOWN-LATCH SYNCHRONIZATION # +// #################################### + +/** + * This primitive allows to wait until several operation are completed. + * It is initialized with a given count, and each [countDown] invocation + * decrements the count of remaining operations to be completed. At the + * same time, [await] suspends until this count reaches zero. + * + * This implementation uses simple cancellation, so the [countDown] invocation + * that reaches the counter zero works in linear of the number of [await] + * invocations, including the ones that are already cancelled. + */ +internal open class CountDownLatch(count: Int) : CancellableQueueSynchronizer() { + override val resumeMode get() = ASYNC + + private val count = atomic(count) + // The number of suspended `await` invocations. + // `DONE_MARK` should be set when the count reaches zero, + // so the following suspension attempts will detect this change by + // checking the mark and complete immediately in this case. + private val waiters = atomic(0) + + protected fun decWaiters() = waiters.decrementAndGet() + + /** + * Decrements the count and resumes waiting + * [await] invocations if it reaches zero. + */ + fun countDown() { + // Decrement the count. + val r = count.decrementAndGet() + // Should the waiters be resumed? + if (r <= 0) resumeWaiters() + } + + private fun resumeWaiters() { + val w = waiters.getAndUpdate { cur -> + // Is the done mark set? + if (cur and DONE_MARK != 0) return + cur or DONE_MARK + } + // This thread has successfully set + // the mark, resume the waiters. + repeat(w) { resume(Unit) } + } + + /** + * Waits until the count reaches zero, + * completes immediately if it is already zero. + */ + suspend fun await() { + // Check whether the count has already reached zero; + // this check can be considered as an optimization. + if (remaining() == 0) return + // Increment the number of waiters and check + // that `DONE_MARK` is not set, finish otherwise. + val w = waiters.incrementAndGet() + if (w and DONE_MARK != 0) return + // The number of waiters is + // successfully incremented, suspend. + suspendCancellableCoroutine { suspend(it as Waiter) } + } + + /** + * Returns the current count. + */ + fun remaining(): Int = count.value.coerceAtLeast(0) + + protected companion object { + const val DONE_MARK = 1 shl 31 + } +} + +/** + * This implementation uses the smart cancellation mode, so the [countDown] + * invocation that reaches the counter zero works in linear of the number + * of non-cancelled [await] invocations. This way, it does not matter + * how many [await] requests has been cancelled - they do not play any role. + */ +internal class CountDownLatchSmart(count: Int) : CountDownLatch(count) { + override val resumeMode: ResumeMode get() = ASYNC + override val cancellationMode get() = SMART + + override fun onCancellation(): Boolean { + // Decrement the number of waiters. + val w = decWaiters() + // Succeed if the `DONE_MARK` is not set yet. + return (w and DONE_MARK) == 0 + } +} + +internal abstract class CountDownLatchLincheckTestBase( + private val cdl: CountDownLatch, + private val seqSpec: KClass<*> +) : AbstractLincheckTest() { + @Operation + fun countDown() = cdl.countDown() + + @Operation + fun remaining() = cdl.remaining() + + @Operation(promptCancellation = false) + suspend fun await() = cdl.await() + + override fun > O.customize(isStressTest: Boolean): O = + actorsBefore(0) + .actorsAfter(0) + .sequentialSpecification(seqSpec.java) + + override fun ModelCheckingOptions.customize(isStressTest: Boolean) = + checkObstructionFreedom() +} + +class CountDownLatchSequential1 : CountDownLatchSequential(1) +class CountDownLatchSequential2 : CountDownLatchSequential(2) + +internal class CountDownLatch1LincheckTest : CountDownLatchLincheckTestBase(CountDownLatch(1), CountDownLatchSequential1::class) +internal class CountDownLatch2LincheckTest : CountDownLatchLincheckTestBase(CountDownLatch(2), CountDownLatchSequential2::class) + +internal class CountDownLatchSmart1LincheckTest : CountDownLatchLincheckTestBase(CountDownLatchSmart(1), CountDownLatchSequential1::class) +internal class CountDownLatchSmart2LincheckTest : CountDownLatchLincheckTestBase(CountDownLatchSmart(2), CountDownLatchSequential2::class) + +open class CountDownLatchSequential(initialCount: Int) : VerifierState() { + private var count = initialCount + private val waiters = ArrayList>() + + fun countDown() { + if (--count == 0) { + waiters.forEach { it.tryResume0(Unit) {} } + waiters.clear() + } + } + + suspend fun await() { + if (count <= 0) return + suspendCancellableCoroutine { cont -> + waiters.add(cont) + } + } + + fun remaining(): Int = count.coerceAtLeast(0) + + override fun extractState() = remaining() +} + + +// ########################### +// # BARRIER SYNCHRONIZATION # +// ########################### + +/** + * This synchronization primitive allows a set of coroutines to + * all wait for each other to reach a common barrier point. + * + * The implementation is straightforward: it maintains a counter + * of arrived coroutines and increments it in the beginning of + * [arrive] operation. The last coroutine should resume all the + * previously arrived ones. + * + * In case of cancellation, the handler decrements the counter if + * not all the parties are arrived. However, it is impossible to + * make cancellation atomic (e.g., Java's implementation simply + * does not work in case of thread interruption) since there is + * no way to resume a set of coroutines atomically. However, + * this implementation is correct with prompt cancellation. + */ +internal class Barrier(private val parties: Int) : CancellableQueueSynchronizer() { + override val resumeMode get() = ASYNC + override val cancellationMode get() = SMART + + // The number of coroutines arrived to this barrier point. + private val arrived = atomic(0L) + + /** + * Waits for other parties and returns `true`. + * Fails if this invocation exceeds the number + * of parties, returns `false` in this case. + * + * It is also possible to make this barrier + * implementation cyclic after introducing + * `resume(count, value)` operation on the + * `SegmentQueueSynchronizer`, which resumes + * the specified number of coroutines and + * with the same value atomically. + */ + suspend fun arrive(): Boolean { + // Are all parties has already arrived? + if (arrived.value > parties) + return false // fail this `arrive()`. + // Increment the number of arrived parties. + val a = arrived.incrementAndGet() + return when { + // Should we suspend? + a < parties -> { + suspendCancellableCoroutineReusable { cont -> suspend(cont as Waiter) } + true + } + // Are we the last party? + a == parties.toLong() -> { + // Resume all waiters. + repeat(parties - 1) { + resume(Unit) + } + true + } + // Should we fail? + else -> false + } + } + + override fun onCancellation(): Boolean { + // Decrement the number of arrived parties if possible. + arrived.loop { cur -> + // Are we going to be resumed? + // The resumption permit should be refused in this case. + if (cur == parties.toLong()) return false + // Successful cancellation, return `true`. + if (arrived.compareAndSet(cur, cur - 1)) return true + } + } +} + +abstract class BarrierLincheckTestBase(parties: Int, val seqSpec: KClass<*>) : AbstractLincheckTest() { + private val b = Barrier(parties) + + @Operation(cancellableOnSuspension = false) + suspend fun arrive() = b.arrive() + + override fun > O.customize(isStressTest: Boolean) = + actorsBefore(0) + .actorsAfter(0) + .threads(3) + .sequentialSpecification(seqSpec.java) + + override fun ModelCheckingOptions.customize(isStressTest: Boolean) = + checkObstructionFreedom() +} + +class BarrierSequential1 : BarrierSequential(1) +class Barrier1LincheckTest : BarrierLincheckTestBase(1, BarrierSequential1::class) +class BarrierSequential2 : BarrierSequential(2) +class Barrier2LincheckTest : BarrierLincheckTestBase(2, BarrierSequential2::class) +class BarrierSequential3 : BarrierSequential(3) +class Barrier3LincheckTest : BarrierLincheckTestBase(3, BarrierSequential3::class) + +open class BarrierSequential(parties: Int) : VerifierState() { + private var remaining = parties + private val waiters = ArrayList>() + + suspend fun arrive(): Boolean { + val r = --remaining + return when { + r > 0 -> { + suspendCancellableCoroutine { cont -> + waiters.add(cont) + cont.invokeOnCancellation { + remaining++ + waiters.remove(cont) + } + } + true + } + r == 0 -> { + waiters.forEach { it.resume(Unit) } + true + } + else -> false + } + } + + override fun extractState() = remaining > 0 +} + + +// ################## +// # BLOCKING POOLS # +// ################## + +/** + * While using resources such as database connections, sockets, etc., + * it is common to reuse them; that requires a fast and handy mechanism. + * This [BlockingPool] abstraction maintains a set of elements that can be put + * into the pool for further reuse or be retrieved to process the current operation. + * When [retrieve] comes to an empty pool, it blocks, and the following [put] operation + * resumes it; all the waiting requests are processed in the first-in-first-out (FIFO) order. + * + * In our tests we consider two pool implementations: the [queue-based][BlockingQueuePool] + * and the [stack-based][BlockingStackPool]. Intuitively, the queue-based implementation is + * faster since it is built on arrays and uses `Fetch-And-Add`-s on the contended path, + * while the stack-based pool retrieves the last inserted, thus, the "hottest", element. + * + * Please note that both these implementations are not linearizable and can retrieve elements + * out-of-order under some races. However, since pools by themselves do not guarantee + * that the stored elements are ordered (the one may consider them as bags), + * these queue- and stack-based versions should be considered as pools with specific heuristics. + */ +interface BlockingPool { + /** + * Either resumes the first waiting [retrieve] operation + * and passes the [element] to it, or simply puts the + * [element] into the pool. + */ + fun put(element: T) + + /** + * Retrieves one of the elements from the pool + * (the order is not specified), or suspends if it is + * empty -- the following [put] operations resume + * waiting [retrieve]-s in the first-in-first-out order. + */ + suspend fun retrieve(): T +} + +/** + * This pool uses queue under the hood and is implemented with simple cancellation. + */ +internal class BlockingQueuePool : CancellableQueueSynchronizer(), BlockingPool { + override val resumeMode get() = ASYNC + + // > 0 -- number of elements; + // = 0 -- empty pool; + // < 0 -- number of waiters. + private val availableElements = atomic(0L) + + // This is an infinite array by design, a plain array is used for simplicity. + private val elements = atomicArrayOfNulls(100) + + // Indices in `elements` for the next `tryInsert()` and `tryRetrieve()` operations. + // Each `tryInsert()`/`tryRetrieve()` pair works on a separate slot. When `tryRetrieve()` + // comes earlier, it marks the slot as `BROKEN` so both this operation and the + // corresponding `tryInsert()` fail. + private val insertIdx = atomic(0) + private val retrieveIdx = atomic(0) + + override fun put(element: T) { + while (true) { + // Increment the number of elements in advance. + val b = availableElements.getAndIncrement() + // Is there a waiting `retrieve`? + if (b < 0) { + // Try to resume the first waiter, + // can fail if it is already cancelled. + if (resume(element)) return + } else { + // Try to insert the element into the + // queue, can fail if the slot is broken. + if (tryInsert(element)) return + } + } + } + + /** + * Tries to insert the [element] into the next + * [elements] array slot. Returns `true` if + * succeeds, or `false` if the slot is [broken][BROKEN]. + */ + private fun tryInsert(element: T): Boolean { + val i = insertIdx.getAndIncrement() + return elements[i].compareAndSet(null, element) + } + + override suspend fun retrieve(): T { + while (true) { + // Decrements the number of elements. + val b = availableElements.getAndDecrement() + // Is there an element in the pool? + if (b > 0) { + // Try to retrieve the first element, + // can fail if the first slot + // is empty due to a race. + val x = tryRetrieve() + if (x != null) return x + } else { + // The pool is empty, suspend. + return suspendCancellableCoroutine { cont -> + suspend(cont as Waiter) + } + } + } + } + + /** + * Tries to retrieve the first element from + * the [elements] array. Returns the element if + * succeeds, or `null` if the first slot is empty + * due to a race -- it marks the slot as [broken][BROKEN] + * in this case, so the corresponding [tryInsert] + * invocation fails. + */ + @Suppress("UNCHECKED_CAST") + private fun tryRetrieve(): T? { + val i = retrieveIdx.getAndIncrement() + return elements[i].getAndSet(BROKEN) as T? + } + + fun stateRepresentation(): String { + val elementsBetweenIndices = mutableListOf() + val first = kotlin.math.min(retrieveIdx.value, insertIdx.value) + val last = kotlin.math.max(retrieveIdx.value, insertIdx.value) + for (i in first until last) { + elementsBetweenIndices.add(elements[i].value) + } + return "availableElements=${availableElements.value}," + + "insertIdx=${insertIdx.value}," + + "retrieveIdx=${retrieveIdx.value}," + + "elements=$elementsBetweenIndices," + + "cqs=<${super.toString()}>" + } + + companion object { + @JvmStatic + val BROKEN = Symbol("BROKEN") + } +} + +/** + * This pool uses stack under the hood and shows how to use + * smart cancellation for data structures that store resources. + */ +internal class BlockingStackPool : CancellableQueueSynchronizer(), BlockingPool { + override val resumeMode get() = ASYNC + override val cancellationMode get() = SMART + + // The stack is implemented via a concurrent linked list, + // this is its head; `null` means that the stack is empty. + private val head = atomic?>(null) + + // > 0 -- number of elements; + // = 0 -- empty pool; + // < 0 -- number of waiters. + private val availableElements = atomic(0) + + override fun put(element: T) { + while (true) { + // Increment the number of elements in advance. + val b = availableElements.getAndIncrement() + // Is there a waiting retrieve? + if (b < 0) { + // Resume the first waiter, never fails + // in the smart cancellation mode. + resume(element) + return + } else { + // Try to insert the element into the + // stack, can fail if a concurrent [tryRetrieve] + // came earlier and marked it with a failure node. + if (tryInsert(element)) return + } + } + } + + /** + * Tries to insert the [element] into the stack. + * Returns `true` on success`, or `false` if the + * stack is marked with a failure node, retrieving + * it in this case. + */ + private fun tryInsert(element: T): Boolean = head.loop { h -> + // Is the stack marked with a failure node? + if (h != null && h.element == null) { + // Try to retrieve the failure node. + if (head.compareAndSet(h, h.next)) return false + } else { + // Try to insert the element. + val newHead = StackNode(element, h) + if (head.compareAndSet(h, newHead)) return true + } + } + + override suspend fun retrieve(): T { + while (true) { + // Decrement the number of elements. + val b = availableElements.getAndDecrement() + // Is there an element in the pool? + if (b > 0) { + // Try to retrieve the top element, + // can fail if the stack is empty + // due to a race. + val x = tryRetrieve() + if (x != null) return x + } else { + // The pool is empty, suspend. + return suspendCancellableCoroutine { cont -> + suspend(cont as Waiter) + } + } + } + } + + /** + * Tries to retrieve the top (last) element and return `true` + * if the stack is not empty; returns `false` and + * inserts a failure node otherwise. + */ + @Suppress("NullChecksToSafeCall") + private fun tryRetrieve(): T? = head.loop { h -> + // Is the queue empty or full of failure nodes? + if (h == null || h.element == null) { + // Try to add one more failure node and fail. + val failNode = StackNode(null, h) + if (head.compareAndSet(h, failNode)) return null + } else { + // Try to retrieve the top element. + if (head.compareAndSet(h, h.next)) return h.element + } + } + + // The logic of cancellation is very similar to the one + // in semaphore, with the only difference that elements + // should be physically returned to the pool. + override fun onCancellation(): Boolean { + val b = availableElements.getAndIncrement() + return b < 0 + } + + // If an element is refused, it should be inserted back to the stack. + override fun tryReturnRefusedValue(value: T) = tryInsert(value) + + // In order to return the value back + // to the pool, [put] is naturally used. + override fun returnValue(value: T) = put(value) + + internal fun stateRepresentation(): String { + val elements = ArrayList() + var curNode = head.value + while (curNode != null) { + elements += curNode.element + curNode = curNode.next + } + return "availableElements=${availableElements.value},elements=$elements,cqs=<${super.toString()}>" + } + + class StackNode(val element: T?, val next: StackNode?) +} + +abstract class BlockingPoolLincheckTestBase(val p: BlockingPool) : AbstractLincheckTest() { + @Operation + fun put() = p.put(Unit) + + @Operation(allowExtraSuspension = true, promptCancellation = false) + suspend fun retrieve() = p.retrieve() + + @StateRepresentation + fun stateRepresentation() = when(p) { + is BlockingStackPool<*> -> p.stateRepresentation() + is BlockingQueuePool<*> -> p.stateRepresentation() + else -> error("Unknown pool type: ${p::class.simpleName}") + } + + override fun > O.customize(isStressTest: Boolean) = + sequentialSpecification(BlockingPoolUnitSequential::class.java) + + override fun ModelCheckingOptions.customize(isStressTest: Boolean) = + checkObstructionFreedom() +} +class BlockingQueuePoolLincheckTest : BlockingPoolLincheckTestBase(BlockingQueuePool()) +class BlockingStackPoolLincheckTest : BlockingPoolLincheckTestBase(BlockingStackPool()) + +class BlockingPoolUnitSequential : VerifierState() { + private var elements = 0 + private val waiters = ArrayList>() + + fun put() { + while (true) { + if (waiters.isNotEmpty()) { + val w = waiters.removeAt(0) + @Suppress("PackageDirectoryMismatch") + if (w.tryResume0(Unit) { put() }) return + } else { + elements ++ + return + } + } + } + + suspend fun retrieve() { + if (elements > 0) { + elements-- + } else { + suspendCancellableCoroutine { cont -> + waiters.add(cont) + } + } + } + + override fun extractState() = elements +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/jvm/test/lincheck/MutexLincheckTest.kt b/kotlinx-coroutines-core/jvm/test/lincheck/MutexLincheckTest.kt index 02964f9793..850f651b0f 100644 --- a/kotlinx-coroutines-core/jvm/test/lincheck/MutexLincheckTest.kt +++ b/kotlinx-coroutines-core/jvm/test/lincheck/MutexLincheckTest.kt @@ -11,19 +11,23 @@ import org.jetbrains.kotlinx.lincheck.* import org.jetbrains.kotlinx.lincheck.annotations.* import org.jetbrains.kotlinx.lincheck.annotations.Operation import org.jetbrains.kotlinx.lincheck.paramgen.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.* @Param(name = "owner", gen = IntGen::class, conf = "0:2") class MutexLincheckTest : AbstractLincheckTest() { private val mutex = Mutex() - @Operation +// @Operation(handleExceptionsAsResult = [IllegalStateException::class]) fun tryLock(@Param(name = "owner") owner: Int) = mutex.tryLock(owner.asOwnerOrNull) - @Operation(promptCancellation = true) + // TODO: `lock()` with non-null owner is non-linearizable + @Operation(promptCancellation = true, handleExceptionsAsResult = [IllegalStateException::class]) suspend fun lock(@Param(name = "owner") owner: Int) = mutex.lock(owner.asOwnerOrNull) + // TODO: `onLock` with non-null owner is non-linearizable // onLock may suspend in case of clause re-registration. - @Operation(allowExtraSuspension = true, promptCancellation = true) + @Suppress("DEPRECATION") + @Operation(allowExtraSuspension = true, promptCancellation = true, handleExceptionsAsResult = [IllegalStateException::class]) suspend fun onLock(@Param(name = "owner") owner: Int) = select { mutex.onLock(owner.asOwnerOrNull) {} } @Operation(handleExceptionsAsResult = [IllegalStateException::class]) @@ -38,6 +42,9 @@ class MutexLincheckTest : AbstractLincheckTest() { override fun > O.customize(isStressTest: Boolean): O = actorsBefore(0) + override fun ModelCheckingOptions.customize(isStressTest: Boolean) = + verboseTrace() + // state[i] == true <=> mutex.holdsLock(i) with the only exception for 0 that specifies `null`. override fun extractState() = (1..2).map { mutex.holdsLock(it) } + mutex.isLocked diff --git a/kotlinx-coroutines-core/jvm/test/lincheck/ReadWriteMutexLincheckTests.kt b/kotlinx-coroutines-core/jvm/test/lincheck/ReadWriteMutexLincheckTests.kt new file mode 100644 index 0000000000..742e1c30ae --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/lincheck/ReadWriteMutexLincheckTests.kt @@ -0,0 +1,209 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ +@file:Suppress("unused") + +package kotlinx.coroutines.lincheck + +import kotlinx.coroutines.* +import kotlinx.coroutines.sync.* +import kotlinx.coroutines.sync.ReadWriteMutexImpl.WriteUnlockPolicy.* +import org.jetbrains.kotlinx.lincheck.* +import org.jetbrains.kotlinx.lincheck.annotations.* +import org.jetbrains.kotlinx.lincheck.annotations.Operation +import org.jetbrains.kotlinx.lincheck.paramgen.* +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.* +import org.jetbrains.kotlinx.lincheck.verifier.* + +class ReadWriteMutexLincheckTest : AbstractLincheckTest() { + private val m = ReadWriteMutexImpl() + private val readLockAcquired = IntArray(6) + private val writeLockAcquired = BooleanArray(6) + + @Operation(allowExtraSuspension = true, promptCancellation = false) + suspend fun readLock(@Param(gen = ThreadIdGen::class) threadId: Int) { + m.readLock() + readLockAcquired[threadId]++ + } + + @Operation + fun readUnlock(@Param(gen = ThreadIdGen::class) threadId: Int): Boolean { + if (readLockAcquired[threadId] == 0) return false + m.readUnlock() + readLockAcquired[threadId]-- + return true + } + + @Operation(allowExtraSuspension = true, promptCancellation = false) + suspend fun writeLock(@Param(gen = ThreadIdGen::class) threadId: Int) { + m.writeLock() + assert(!writeLockAcquired[threadId]) { + "The mutex is not reentrant, this `writeLock()` invocation had to suspend" + } + writeLockAcquired[threadId] = true + } + + @Operation + fun writeUnlock(@Param(gen = ThreadIdGen::class) threadId: Int, prioritizeWriters: Boolean): Boolean { + if (!writeLockAcquired[threadId]) return false + m.writeUnlock(if (prioritizeWriters) PRIORITIZE_WRITERS else PRIORITIZE_READERS) + writeLockAcquired[threadId] = false + return true + } + + @StateRepresentation + fun stateRepresentation() = m.stateRepresentation + + override fun > O.customize(isStressTest: Boolean) = + actorsBefore(0) + .actorsAfter(0) + .sequentialSpecification(ReadWriteMutexLincheckTestSequential::class.java) + + override fun ModelCheckingOptions.customize(isStressTest: Boolean) = + checkObstructionFreedom() +} + +class ReadWriteMutexLincheckTestSequential : VerifierState() { + private val m = ReadWriteMutexSequential() + private val readLockAcquired = IntArray(6) + private val writeLockAcquired = BooleanArray(6) + + fun tryReadLock(threadId: Int): Boolean = + m.tryReadLock().also { success -> + if (success) readLockAcquired[threadId]++ + } + + suspend fun readLock(threadId: Int) { + m.readLock() + readLockAcquired[threadId]++ + } + + fun readUnlock(threadId: Int): Boolean { + if (readLockAcquired[threadId] == 0) return false + m.readUnlock() + readLockAcquired[threadId]-- + return true + } + + fun tryWriteLock(threadId: Int): Boolean = + m.tryWriteLock().also { success -> + if (success) writeLockAcquired[threadId] = true + } + + suspend fun writeLock(threadId: Int) { + m.writeLock() + writeLockAcquired[threadId] = true + } + + fun writeUnlock(threadId: Int, prioritizeWriters: Boolean): Boolean { + if (!writeLockAcquired[threadId]) return false + m.writeUnlock(prioritizeWriters) + writeLockAcquired[threadId] = false + return true + } + + override fun extractState() = + "mutex=${m.state},rlaPerThread=${readLockAcquired.contentToString()},wlaPerThread=${writeLockAcquired.contentToString()}" +} + +internal class ReadWriteMutexSequential { + private var ar = 0 + private var wla = false + private val wr = ArrayList>() + private val ww = ArrayList>() + + fun tryReadLock(): Boolean { + if (wla || ww.isNotEmpty()) return false + ar++ + return true + } + + suspend fun readLock() { + if (wla || ww.isNotEmpty()) { + suspendCancellableCoroutine { cont -> + wr += cont + cont.invokeOnCancellation { wr -= cont } + } + } else { + ar++ + } + } + + fun readUnlock() { + ar-- + if (ar == 0 && ww.isNotEmpty()) { + wla = true + val w = ww.removeAt(0) + w.resume(Unit) { writeUnlock(true) } + } + } + + fun tryWriteLock(): Boolean { + if (wla || ar > 0) return false + wla = true + return true + } + + suspend fun writeLock() { + if (wla || ar > 0) { + suspendCancellableCoroutine { cont -> + ww += cont + cont.invokeOnCancellation { + ww -= cont + if (!wla && ww.isEmpty()) { + ar += wr.size + wr.forEach { it.resume(Unit) { readUnlock() } } + wr.clear() + } + } + } + } else { + wla = true + } + } + + fun writeUnlock(prioritizeWriters: Boolean) { + if (ww.isNotEmpty() && prioritizeWriters) { + val w = ww.removeAt(0) + w.resume(Unit) { writeUnlock(prioritizeWriters) } + } else { + wla = false + ar = wr.size + wr.forEach { it.resume(Unit) { readUnlock() } } + wr.clear() + } + } + + val state get() = "ar=$ar,wla=$wla,wr=${wr.size},ww=${ww.size}" +} + +// This is an additional test to check the [ReadWriteMutex] synchronization contract. +internal class ReadWriteMutexCounterLincheckTest : AbstractLincheckTest() { + private val m = ReadWriteMutexImpl() + private var c = 0 + + @Operation(allowExtraSuspension = true, promptCancellation = false) + suspend fun inc(): Int = m.write { c++ } + + @Operation(allowExtraSuspension = true, promptCancellation = false) + suspend fun get(): Int = m.read { c } + + @StateRepresentation + fun stateRepresentation(): String = "$c + ${m.stateRepresentation}" + + override fun > O.customize(isStressTest: Boolean): O = + actorsBefore(0) + .actorsAfter(0) + .sequentialSpecification(ReadWriteMutexCounterSequential::class.java) +} + +@Suppress("RedundantSuspendModifier") +class ReadWriteMutexCounterSequential : VerifierState() { + private var c = 0 + + fun incViaTryLock() = c++ + suspend fun inc() = c++ + suspend fun get() = c + + override fun extractState() = c +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/jvm/test/lincheck/SemaphoreLincheckTest.kt b/kotlinx-coroutines-core/jvm/test/lincheck/SemaphoreLincheckTest.kt index 09dee56c51..0bc7a7b302 100644 --- a/kotlinx-coroutines-core/jvm/test/lincheck/SemaphoreLincheckTest.kt +++ b/kotlinx-coroutines-core/jvm/test/lincheck/SemaphoreLincheckTest.kt @@ -2,6 +2,7 @@ * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. */ @file:Suppress("unused") + package kotlinx.coroutines.lincheck import kotlinx.coroutines.* @@ -9,27 +10,68 @@ import kotlinx.coroutines.sync.* import org.jetbrains.kotlinx.lincheck.* import org.jetbrains.kotlinx.lincheck.annotations.Operation import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.* +import org.jetbrains.kotlinx.lincheck.verifier.* +import kotlin.reflect.* -abstract class SemaphoreLincheckTestBase(permits: Int) : AbstractLincheckTest() { - private val semaphore = SemaphoreImpl(permits = permits, acquiredPermits = 0) - +abstract class SemaphoreLincheckTestBase( + private val semaphore: Semaphore, + private val seqSpec: KClass<*> +) : AbstractLincheckTest() { @Operation - fun tryAcquire() = semaphore.tryAcquire() + fun tryAcquire() = this.semaphore.tryAcquire() - @Operation(promptCancellation = true, allowExtraSuspension = true) - suspend fun acquire() = semaphore.acquire() + @Operation(promptCancellation = false, allowExtraSuspension = true) + suspend fun acquire() = this.semaphore.acquire() @Operation(handleExceptionsAsResult = [IllegalStateException::class]) - fun release() = semaphore.release() + fun release() = this.semaphore.release() override fun > O.customize(isStressTest: Boolean): O = actorsBefore(0) - - override fun extractState() = semaphore.availablePermits + .sequentialSpecification(seqSpec.java) override fun ModelCheckingOptions.customize(isStressTest: Boolean) = checkObstructionFreedom() } -class Semaphore1LincheckTest : SemaphoreLincheckTestBase(1) -class Semaphore2LincheckTest : SemaphoreLincheckTestBase(2) +open class SemaphoreSequential( + private val permits: Int, + private val boundMaxPermits: Boolean +) : VerifierState() { + private var availablePermits = permits + private val waiters = ArrayList>() + + open fun tryAcquire() = tryAcquireImpl() + + private fun tryAcquireImpl(): Boolean { + if (availablePermits <= 0) return false + availablePermits-- + return true + } + + suspend fun acquire() { + if (tryAcquireImpl()) return + availablePermits-- + suspendCancellableCoroutine { cont -> + waiters.add(cont) + } + } + + fun release() { + while (true) { + if (boundMaxPermits) check(availablePermits < permits) + availablePermits++ + if (availablePermits > 0) return + val w = waiters.removeAt(0) + if (w.tryResume0(Unit, { release() })) return + } + } + + override fun extractState() = availablePermits.coerceAtLeast(0) +} + +class SemaphoreSequential1 : SemaphoreSequential(1, true) +class Semaphore1LincheckTest : SemaphoreLincheckTestBase(Semaphore(1), SemaphoreSequential1::class) + +class SemaphoreSequential2 : SemaphoreSequential(2, true) +class Semaphore2LincheckTest : SemaphoreLincheckTestBase(Semaphore(2), SemaphoreSequential2::class) \ No newline at end of file