diff --git a/benchmarks/src/jmh/kotlin/benchmarks/SemaphoreBenchmark.kt b/benchmarks/src/jmh/kotlin/benchmarks/SemaphoreBenchmark.kt new file mode 100644 index 0000000000..0fc563a89e --- /dev/null +++ b/benchmarks/src/jmh/kotlin/benchmarks/SemaphoreBenchmark.kt @@ -0,0 +1,97 @@ +package benchmarks + +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.Channel +import kotlinx.coroutines.scheduling.ExperimentalCoroutineDispatcher +import kotlinx.coroutines.sync.Semaphore +import kotlinx.coroutines.sync.withPermit +import org.openjdk.jmh.annotations.* +import java.util.concurrent.ForkJoinPool +import java.util.concurrent.ThreadLocalRandom +import java.util.concurrent.TimeUnit + +@Warmup(iterations = 3, time = 500, timeUnit = TimeUnit.MICROSECONDS) +@Measurement(iterations = 10, time = 500, timeUnit = TimeUnit.MICROSECONDS) +@Fork(value = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +open class SemaphoreBenchmark { + @Param + private var _1_dispatcher: SemaphoreBenchDispatcherCreator = SemaphoreBenchDispatcherCreator.FORK_JOIN + + @Param("0", "1000") + private var _2_coroutines: Int = 0 + + @Param("1", "2", "4", "8", "32", "128", "100000") + private var _3_maxPermits: Int = 0 + + @Param("1", "2", "4") // local machine +// @Param("1", "2", "4", "8", "16", "32", "64", "128", "144") // dasquad +// @Param("1", "2", "4", "8", "16", "32", "64", "96") // Google Cloud + private var _4_parallelism: Int = 0 + + private lateinit var dispatcher: CoroutineDispatcher + private var coroutines = 0 + + @InternalCoroutinesApi + @Setup + fun setup() { + dispatcher = _1_dispatcher.create(_4_parallelism) + coroutines = if (_2_coroutines == 0) _4_parallelism else _2_coroutines + } + + @Benchmark + fun semaphore() = runBlocking { + val n = BATCH_SIZE / coroutines + val semaphore = Semaphore(_3_maxPermits) + val jobs = ArrayList(coroutines) + repeat(coroutines) { + jobs += GlobalScope.launch { + repeat(n) { + semaphore.withPermit { + doWork(WORK_INSIDE) + } + doWork(WORK_OUTSIDE) + } + } + } + jobs.forEach { it.join() } + } + + @Benchmark + fun channelAsSemaphore() = runBlocking { + val n = BATCH_SIZE / coroutines + val semaphore = Channel(_3_maxPermits) + val jobs = ArrayList(coroutines) + repeat(coroutines) { + jobs += GlobalScope.launch { + repeat(n) { + semaphore.send(Unit) // acquire + doWork(WORK_INSIDE) + semaphore.receive() // release + doWork(WORK_OUTSIDE) + } + } + } + jobs.forEach { it.join() } + } +} + +enum class SemaphoreBenchDispatcherCreator(val create: (parallelism: Int) -> CoroutineDispatcher) { + FORK_JOIN({ parallelism -> ForkJoinPool(parallelism).asCoroutineDispatcher() }), + EXPERIMENTAL({ parallelism -> ExperimentalCoroutineDispatcher(corePoolSize = parallelism, maxPoolSize = parallelism) }) +} + +private fun doWork(work: Int) { + // We use geometric distribution here + val p = 1.0 / work + val r = ThreadLocalRandom.current() + while (true) { + if (r.nextDouble() < p) break + } +} + +private const val WORK_INSIDE = 80 +private const val WORK_OUTSIDE = 40 +private const val BATCH_SIZE = 1000000 \ No newline at end of file diff --git a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt index 1dcad707b1..86a2203aba 100644 --- a/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt +++ b/binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt @@ -1008,6 +1008,19 @@ public final class kotlinx/coroutines/sync/MutexKt { public static synthetic fun withLock$default (Lkotlinx/coroutines/sync/Mutex;Ljava/lang/Object;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; } +public abstract interface class kotlinx/coroutines/sync/Semaphore { + public abstract fun acquire (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public abstract fun getAvailablePermits ()I + public abstract fun release ()V + public abstract fun tryAcquire ()Z +} + +public final class kotlinx/coroutines/sync/SemaphoreKt { + public static final fun Semaphore (II)Lkotlinx/coroutines/sync/Semaphore; + public static synthetic fun Semaphore$default (IIILjava/lang/Object;)Lkotlinx/coroutines/sync/Semaphore; + public static final fun withPermit (Lkotlinx/coroutines/sync/Semaphore;Lkotlin/jvm/functions/Function0;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + public final class kotlinx/coroutines/test/TestCoroutineContext : kotlin/coroutines/CoroutineContext { public fun ()V public fun (Ljava/lang/String;)V diff --git a/gradle.properties b/gradle.properties index 13b510dcef..387d56e166 100644 --- a/gradle.properties +++ b/gradle.properties @@ -5,7 +5,7 @@ kotlin_version=1.3.31 # Dependencies junit_version=4.12 -atomicfu_version=0.12.7 +atomicfu_version=0.12.8 html_version=0.6.8 lincheck_version=2.0 dokka_version=0.9.16-rdev-2-mpp-hacks diff --git a/kotlinx-coroutines-core/common/src/flow/operators/Merge.kt b/kotlinx-coroutines-core/common/src/flow/operators/Merge.kt index 1ca6c4b6c6..3ed2c0bf63 100644 --- a/kotlinx-coroutines-core/common/src/flow/operators/Merge.kt +++ b/kotlinx-coroutines-core/common/src/flow/operators/Merge.kt @@ -13,6 +13,7 @@ import kotlinx.coroutines.channels.* import kotlinx.coroutines.channels.Channel.Factory.OPTIONAL_CHANNEL import kotlinx.coroutines.flow.internal.* import kotlinx.coroutines.internal.* +import kotlinx.coroutines.sync.* import kotlin.coroutines.* import kotlin.jvm.* import kotlinx.coroutines.flow.unsafeFlow as flow @@ -149,16 +150,15 @@ private class ChannelFlowMerge( // The actual merge implementation with concurrency limit private suspend fun mergeImpl(scope: CoroutineScope, collector: ConcurrentFlowCollector) { - val semaphore = Channel(concurrency) + val semaphore = Semaphore(concurrency) @Suppress("UNCHECKED_CAST") flow.collect { inner -> - // TODO real semaphore (#94) - semaphore.send(Unit) // Acquire concurrency permit + semaphore.acquire() // Acquire concurrency permit scope.launch { try { inner.collect(collector) } finally { - semaphore.receive() // Release concurrency permit + semaphore.release() // Release concurrency permit } } } diff --git a/kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt b/kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt new file mode 100644 index 0000000000..4ad554fd38 --- /dev/null +++ b/kotlinx-coroutines-core/common/src/internal/SegmentQueue.kt @@ -0,0 +1,176 @@ +package kotlinx.coroutines.internal + +import kotlinx.atomicfu.AtomicRef +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.loop + +/** + * Essentially, this segment queue is an infinite array of segments, which is represented as + * a Michael-Scott queue of them. All segments are instances of [Segment] class and + * follow in natural order (see [Segment.id]) in the queue. + */ +internal abstract class SegmentQueue>() { + private val _head: AtomicRef + /** + * Returns the first segment in the queue. + */ + protected val head: S get() = _head.value + + private val _tail: AtomicRef + /** + * Returns the last segment in the queue. + */ + protected val tail: S get() = _tail.value + + init { + val initialSegment = newSegment(0) + _head = atomic(initialSegment) + _tail = atomic(initialSegment) + } + + /** + * The implementation should create an instance of segment [S] with the specified id + * and initial reference to the previous one. + */ + abstract fun newSegment(id: Long, prev: S? = null): S + + /** + * Finds a segment with the specified [id] following by next references from the + * [startFrom] segment. The typical use-case is reading [tail] or [head], doing some + * synchronization, and invoking [getSegment] or [getSegmentAndMoveHead] correspondingly + * to find the required segment. + */ + protected fun getSegment(startFrom: S, id: Long): S? { + // Go through `next` references and add new segments if needed, + // similarly to the `push` in the Michael-Scott queue algorithm. + // The only difference is that `CAS failure` means that the + // required segment has already been added, so the algorithm just + // uses it. This way, only one segment with each id can be in the queue. + var cur: S = startFrom + while (cur.id < id) { + var curNext = cur.next + if (curNext == null) { + // Add a new segment. + val newTail = newSegment(cur.id + 1, cur) + curNext = if (cur.casNext(null, newTail)) { + if (cur.removed) { + cur.remove() + } + moveTailForward(newTail) + newTail + } else { + cur.next!! + } + } + cur = curNext + } + if (cur.id != id) return null + return cur + } + + /** + * Invokes [getSegment] and replaces [head] with the result if its [id] is greater. + */ + protected fun getSegmentAndMoveHead(startFrom: S, id: Long): S? { + @Suppress("LeakingThis") + if (startFrom.id == id) return startFrom + val s = getSegment(startFrom, id) ?: return null + moveHeadForward(s) + return s + } + + /** + * Updates [head] to the specified segment + * if its `id` is greater. + */ + private fun moveHeadForward(new: S) { + _head.loop { curHead -> + if (curHead.id > new.id) return + if (_head.compareAndSet(curHead, new)) { + new.prev.value = null + return + } + } + } + + /** + * Updates [tail] to the specified segment + * if its `id` is greater. + */ + private fun moveTailForward(new: S) { + _tail.loop { curTail -> + if (curTail.id > new.id) return + if (_tail.compareAndSet(curTail, new)) return + } + } +} + +/** + * Each segment in [SegmentQueue] has a unique id and is created by [SegmentQueue.newSegment]. + * Essentially, this is a node in the Michael-Scott queue algorithm, but with + * maintaining [prev] pointer for efficient [remove] implementation. + */ +internal abstract class Segment>(val id: Long, prev: S?) { + // Pointer to the next segment, updates similarly to the Michael-Scott queue algorithm. + private val _next = atomic(null) + val next: S? get() = _next.value + fun casNext(expected: S?, value: S?): Boolean = _next.compareAndSet(expected, value) + // Pointer to the previous segment, updates in [remove] function. + val prev = atomic(null) + + /** + * Returns `true` if this segment is logically removed from the queue. + * The [remove] function should be called right after it becomes logically removed. + */ + abstract val removed: Boolean + + init { + this.prev.value = prev + } + + /** + * Removes this segment physically from the segment queue. The segment should be + * logically removed (so [removed] returns `true`) at the point of invocation. + */ + fun remove() { + check(removed) { " The segment should be logically removed at first "} + // Read `next` and `prev` pointers. + var next = this._next.value ?: return // tail cannot be removed + var prev = prev.value ?: return // head cannot be removed + // Link `next` and `prev`. + prev.moveNextToRight(next) + while (prev.removed) { + prev = prev.prev.value ?: break + prev.moveNextToRight(next) + } + next.movePrevToLeft(prev) + while (next.removed) { + next = next.next ?: break + next.movePrevToLeft(prev) + } + } + + /** + * Updates [next] pointer to the specified segment if + * the [id] of the specified segment is greater. + */ + private fun moveNextToRight(next: S) { + while (true) { + val curNext = this._next.value as S + if (next.id <= curNext.id) return + if (this._next.compareAndSet(curNext, next)) return + } + } + + /** + * Updates [prev] pointer to the specified segment if + * the [id] of the specified segment is lower. + */ + private fun movePrevToLeft(prev: S) { + while (true) { + val curPrev = this.prev.value ?: return + if (curPrev.id <= prev.id) return + if (this.prev.compareAndSet(curPrev, prev)) return + } + } +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/common/src/sync/Semaphore.kt b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt new file mode 100644 index 0000000000..0ffb99006b --- /dev/null +++ b/kotlinx-coroutines-core/common/src/sync/Semaphore.kt @@ -0,0 +1,213 @@ +package kotlinx.coroutines.sync + +import kotlinx.atomicfu.atomic +import kotlinx.atomicfu.atomicArrayOfNulls +import kotlinx.atomicfu.getAndUpdate +import kotlinx.atomicfu.loop +import kotlinx.coroutines.* +import kotlinx.coroutines.internal.* +import kotlin.coroutines.resume +import kotlin.math.max + +/** + * A counting semaphore for coroutines. It maintains a number of available permits. + * Each [acquire] suspends if necessary until a permit is available, and then takes it. + * Each [release] adds a permit, potentially releasing a suspended acquirer. + * + * Semaphore with `permits = 1` is essentially a [Mutex]. + **/ +public interface Semaphore { + /** + * Returns the current number of permits available in this semaphore. + */ + public val availablePermits: Int + + /** + * Acquires a permit from this semaphore, suspending until one is available. + * All suspending acquirers are processed in first-in-first-out (FIFO) order. + * + * This suspending function is cancellable. If the [Job] of the current coroutine is cancelled or completed while this + * function is suspended, this function immediately resumes with [CancellationException]. + * + * *Cancellation of suspended semaphore acquisition` is atomic* -- when this function + * throws [CancellationException] it means that the semaphore was not acquired. + * + * Note, that this function does not check for cancellation when it is not suspended. + * Use [yield] or [CoroutineScope.isActive] to periodically check for cancellation in tight loops if needed. + * + * Use [tryAcquire] to try acquire a permit of this semaphore without suspension. + */ + public suspend fun acquire() + + /** + * Tries to acquire a permit from this semaphore without suspension. + * + * @return `true` if a permit was acquired, `false` otherwise. + */ + public fun tryAcquire(): Boolean + + /** + * Releases a permit, returning it into this semaphore. Resumes the first + * suspending acquirer if there is one at the point of invocation. + * Throws [IllegalStateException] if there is no acquired permit + * at the point of invocation. + */ + public fun release() +} + +/** + * Creates new [Semaphore] instance. + * @param permits the number of permits available in this semaphore. + * @param acquiredPermits the number of already acquired permits, + * should be between `0` and `permits` (inclusively). + */ +@Suppress("FunctionName") +public fun Semaphore(permits: Int, acquiredPermits: Int = 0): Semaphore = SemaphoreImpl(permits, acquiredPermits) + +/** + * Executes the given [action], acquiring a permit from this semaphore at the beginning + * and releasing it after the [action] is completed. + * + * @return the return value of the [action]. + */ +public suspend inline fun Semaphore.withPermit(action: () -> T): T { + acquire() + try { + return action() + } finally { + release() + } +} + +private class SemaphoreImpl( + private val permits: Int, acquiredPermits: Int +) : Semaphore, SegmentQueue() { + init { + require(permits > 0) { "Semaphore should have at least 1 permit" } + require(acquiredPermits in 0..permits) { "The number of acquired permits should be in 0..permits" } + } + + override fun newSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev) + + /** + * This counter indicates a number of available permits if it is non-negative, + * or the size with minus sign otherwise. Note, that 32-bit counter is enough here + * since the maximal number of available permits is [permits] which is [Int], + * and the maximum number of waiting acquirers cannot be greater than 2^31 in any + * real application. + */ + private val _availablePermits = atomic(permits) + override val availablePermits: Int get() = max(_availablePermits.value, 0) + + // The queue of waiting acquirers is essentially an infinite array based on `SegmentQueue`; + // each segment contains a fixed number of slots. To determine a slot for each enqueue + // and dequeue operation, we increment the corresponding counter at the beginning of the operation + // and use the value before the increment as a slot number. This way, each enqueue-dequeue pair + // works with an individual cell. + private val enqIdx = atomic(0L) + private val deqIdx = atomic(0L) + + override fun tryAcquire(): Boolean { + _availablePermits.loop { p -> + if (p <= 0) return false + if (_availablePermits.compareAndSet(p, p - 1)) return true + } + } + + override suspend fun acquire() { + val p = _availablePermits.getAndDecrement() + if (p > 0) return // permit acquired + addToQueueAndSuspend() + } + + override fun release() { + val p = _availablePermits.getAndUpdate { cur -> + check(cur < permits) { "The number of acquired permits cannot be greater than `permits`" } + cur + 1 + } + if (p >= 0) return // no waiters + resumeNextFromQueue() + } + + private suspend fun addToQueueAndSuspend() = suspendAtomicCancellableCoroutine sc@ { cont -> + val last = this.tail + val enqIdx = enqIdx.getAndIncrement() + val segment = getSegment(last, enqIdx / SEGMENT_SIZE) + val i = (enqIdx % SEGMENT_SIZE).toInt() + if (segment === null || segment.get(i) === RESUMED || !segment.cas(i, null, cont)) { + // already resumed + cont.resume(Unit) + return@sc + } + cont.invokeOnCancellation(CancelSemaphoreAcquisitionHandler(this, segment, i).asHandler) + } + + @Suppress("UNCHECKED_CAST") + private fun resumeNextFromQueue() { + val first = this.head + val deqIdx = deqIdx.getAndIncrement() + val segment = getSegmentAndMoveHead(first, deqIdx / SEGMENT_SIZE) ?: return + val i = (deqIdx % SEGMENT_SIZE).toInt() + val cont = segment.getAndUpdate(i) { + // Cancelled continuation invokes `release` + // and resumes next suspended acquirer if needed. + if (it === CANCELLED) return + RESUMED + } + if (cont === null) return // just resumed + (cont as CancellableContinuation).resume(Unit) + } +} + +private class CancelSemaphoreAcquisitionHandler( + private val semaphore: Semaphore, + private val segment: SemaphoreSegment, + private val index: Int +) : CancelHandler() { + override fun invoke(cause: Throwable?) { + segment.cancel(index) + semaphore.release() + } + + override fun toString() = "CancelSemaphoreAcquisitionHandler[$semaphore, $segment, $index]" +} + +private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?): Segment(id, prev) { + private val acquirers = atomicArrayOfNulls(SEGMENT_SIZE) + + @Suppress("NOTHING_TO_INLINE") + inline fun get(index: Int): Any? = acquirers[index].value + + @Suppress("NOTHING_TO_INLINE") + inline fun cas(index: Int, expected: Any?, value: Any?): Boolean = acquirers[index].compareAndSet(expected, value) + + inline fun getAndUpdate(index: Int, function: (Any?) -> Any?): Any? { + while (true) { + val cur = acquirers[index].value + val upd = function(cur) + if (cas(index, cur, upd)) return cur + } + } + + private val cancelledSlots = atomic(0) + override val removed get() = cancelledSlots.value == SEGMENT_SIZE + + // Cleans the acquirer slot located by the specified index + // and removes this segment physically if all slots are cleaned. + fun cancel(index: Int) { + // Clean the specified waiter + acquirers[index].value = CANCELLED + // Remove this segment if needed + if (cancelledSlots.incrementAndGet() == SEGMENT_SIZE) + remove() + } + + override fun toString() = "SemaphoreSegment[id=$id, hashCode=${hashCode()}]" +} + +@SharedImmutable +private val RESUMED = Symbol("RESUMED") +@SharedImmutable +private val CANCELLED = Symbol("CANCELLED") +@SharedImmutable +private val SEGMENT_SIZE = systemProp("kotlinx.coroutines.semaphore.segmentSize", 16) \ No newline at end of file diff --git a/kotlinx-coroutines-core/common/test/sync/SemaphoreTest.kt b/kotlinx-coroutines-core/common/test/sync/SemaphoreTest.kt new file mode 100644 index 0000000000..a6aaf24cb3 --- /dev/null +++ b/kotlinx-coroutines-core/common/test/sync/SemaphoreTest.kt @@ -0,0 +1,119 @@ +package kotlinx.coroutines.sync + +import kotlinx.coroutines.TestBase +import kotlinx.coroutines.cancelAndJoin +import kotlinx.coroutines.launch +import kotlinx.coroutines.yield +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertFalse +import kotlin.test.assertTrue + +class SemaphoreTest : TestBase() { + + @Test + fun testSimple() = runTest { + val semaphore = Semaphore(2) + launch { + expect(3) + semaphore.release() + expect(4) + } + expect(1) + semaphore.acquire() + semaphore.acquire() + expect(2) + semaphore.acquire() + finish(5) + } + + @Test + fun testSimpleAsMutex() = runTest { + val semaphore = Semaphore(1) + expect(1) + launch { + expect(4) + semaphore.acquire() // suspends + expect(7) // now got lock + semaphore.release() + expect(8) + } + expect(2) + semaphore.acquire() // locked + expect(3) + yield() // yield to child + expect(5) + semaphore.release() + expect(6) + yield() // now child has lock + finish(9) + } + + @Test + fun tryAcquireTest() = runTest { + val semaphore = Semaphore(2) + assertTrue(semaphore.tryAcquire()) + assertTrue(semaphore.tryAcquire()) + assertFalse(semaphore.tryAcquire()) + assertEquals(0, semaphore.availablePermits) + semaphore.release() + assertEquals(1, semaphore.availablePermits) + assertTrue(semaphore.tryAcquire()) + assertEquals(0, semaphore.availablePermits) + } + + @Test + fun withSemaphoreTest() = runTest { + val semaphore = Semaphore(1) + assertEquals(1, semaphore.availablePermits) + semaphore.withPermit { + assertEquals(0, semaphore.availablePermits) + } + assertEquals(1, semaphore.availablePermits) + } + + @Test + fun fairnessTest() = runTest { + val semaphore = Semaphore(1) + semaphore.acquire() + launch(coroutineContext) { + // first to acquire + expect(2) + semaphore.acquire() // suspend + expect(6) + } + launch(coroutineContext) { + // second to acquire + expect(3) + semaphore.acquire() // suspend + expect(9) + } + expect(1) + yield() + expect(4) + semaphore.release() + expect(5) + yield() + expect(7) + semaphore.release() + expect(8) + yield() + finish(10) + } + + @Test + fun testCancellationReleasesSemaphore() = runTest { + val semaphore = Semaphore(1) + semaphore.acquire() + assertEquals(0, semaphore.availablePermits) + val job = launch { + assertFalse(semaphore.tryAcquire()) + semaphore.acquire() + } + yield() + job.cancelAndJoin() + assertEquals(0, semaphore.availablePermits) + semaphore.release() + assertEquals(1, semaphore.availablePermits) + } +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt b/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt new file mode 100644 index 0000000000..293be7a59e --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/internal/SegmentBasedQueue.kt @@ -0,0 +1,72 @@ +package kotlinx.coroutines.internal + +import kotlinx.atomicfu.atomic + +/** + * This queue implementation is based on [SegmentQueue] for testing purposes and is organized as follows. Essentially, + * the [SegmentBasedQueue] is represented as an infinite array of segments, each stores one element (see [OneElementSegment]). + * Both [enqueue] and [dequeue] operations increment the corresponding global index ([enqIdx] for [enqueue] and + * [deqIdx] for [dequeue]) and work with the indexed by this counter cell. Since both operations increment the indices + * at first, there could be a race: [enqueue] increments [enqIdx], then [dequeue] checks that the queue is not empty + * (that's true) and increments [deqIdx], looking into the corresponding cell after that; however, the cell is empty + * because the [enqIdx] operation has not been put its element yet. To make the queue non-blocking, [dequeue] can mark + * the cell with [BROKEN] token and retry the operation, [enqueue] at the same time should restart as well; this way, + * the queue is obstruction-free. + */ +internal class SegmentBasedQueue : SegmentQueue>() { + override fun newSegment(id: Long, prev: OneElementSegment?): OneElementSegment = OneElementSegment(id, prev) + + private val enqIdx = atomic(0L) + private val deqIdx = atomic(0L) + + // Returns the segments associated with the enqueued element. + fun enqueue(element: T): OneElementSegment { + while (true) { + var tail = this.tail + val enqIdx = this.enqIdx.getAndIncrement() + tail = getSegment(tail, enqIdx) ?: continue + if (tail.element.value === BROKEN) continue + if (tail.element.compareAndSet(null, element)) return tail + } + } + + fun dequeue(): T? { + while (true) { + if (this.deqIdx.value >= this.enqIdx.value) return null + var firstSegment = this.head + val deqIdx = this.deqIdx.getAndIncrement() + firstSegment = getSegmentAndMoveHead(firstSegment, deqIdx) ?: continue + var el = firstSegment.element.value + if (el === null) { + if (firstSegment.element.compareAndSet(null, BROKEN)) continue + else el = firstSegment.element.value + } + if (el === REMOVED) continue + return el as T + } + } + + val numberOfSegments: Int get() { + var s: OneElementSegment? = head + var i = 0 + while (s != null) { + s = s.next + i++ + } + return i + } +} + +internal class OneElementSegment(id: Long, prev: OneElementSegment?) : Segment>(id, prev) { + val element = atomic(null) + + override val removed get() = element.value === REMOVED + + fun removeSegment() { + element.value = REMOVED + remove() + } +} + +private val BROKEN = Symbol("BROKEN") +private val REMOVED = Symbol("REMOVED") \ No newline at end of file diff --git a/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueLFTest.kt b/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueLFTest.kt new file mode 100644 index 0000000000..b6faf683fb --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueLFTest.kt @@ -0,0 +1,26 @@ +package kotlinx.coroutines.internal + +import com.devexperts.dxlab.lincheck.LinChecker +import com.devexperts.dxlab.lincheck.annotations.Operation +import com.devexperts.dxlab.lincheck.annotations.Param +import com.devexperts.dxlab.lincheck.paramgen.IntGen +import com.devexperts.dxlab.lincheck.strategy.stress.StressCTest +import org.junit.Test + +@StressCTest +class SegmentQueueLFTest { + private val q = SegmentBasedQueue() + + @Operation + fun add(@Param(gen = IntGen::class) x: Int) { + q.enqueue(x) + } + + @Operation + fun poll(): Int? = q.dequeue() + + @Test + fun test() { + LinChecker.check(SegmentQueueLFTest::class.java) + } +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt b/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt new file mode 100644 index 0000000000..9a6ee42aa0 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/internal/SegmentQueueTest.kt @@ -0,0 +1,99 @@ +package kotlinx.coroutines.internal + +import kotlinx.coroutines.TestBase +import org.junit.Test +import java.util.* +import java.util.concurrent.CyclicBarrier +import java.util.concurrent.atomic.AtomicInteger +import kotlin.concurrent.thread +import kotlin.random.Random +import kotlin.test.assertEquals + +class SegmentQueueTest : TestBase() { + + @Test + fun simpleTest() { + val q = SegmentBasedQueue() + assertEquals( 1, q.numberOfSegments) + assertEquals(null, q.dequeue()) + q.enqueue(1) + assertEquals(1, q.numberOfSegments) + q.enqueue(2) + assertEquals(2, q.numberOfSegments) + assertEquals(1, q.dequeue()) + assertEquals(2, q.numberOfSegments) + assertEquals(2, q.dequeue()) + assertEquals(1, q.numberOfSegments) + assertEquals(null, q.dequeue()) + } + + @Test + fun testSegmentRemoving() { + val q = SegmentBasedQueue() + q.enqueue(1) + val s = q.enqueue(2) + q.enqueue(3) + assertEquals(3, q.numberOfSegments) + s.removeSegment() + assertEquals(2, q.numberOfSegments) + assertEquals(1, q.dequeue()) + assertEquals(3, q.dequeue()) + assertEquals(null, q.dequeue()) + } + + @Test + fun testRemoveHeadSegment() { + val q = SegmentBasedQueue() + q.enqueue(1) + val s = q.enqueue(2) + assertEquals(1, q.dequeue()) + q.enqueue(3) + s.removeSegment() + assertEquals(3, q.dequeue()) + assertEquals(null, q.dequeue()) + } + + @Test + fun stressTest() { + val q = SegmentBasedQueue() + val expectedQueue = ArrayDeque() + val r = Random(0) + repeat(1_000_000 * stressTestMultiplier) { + if (r.nextBoolean()) { // add + val el = r.nextInt() + q.enqueue(el) + expectedQueue.add(el) + } else { // remove + assertEquals(expectedQueue.poll(), q.dequeue()) + } + } + } + + @Test + fun stressTestRemoveSegmentsSerial() = stressTestRemoveSegments(false) + + @Test + fun stressTestRemoveSegmentsRandom() = stressTestRemoveSegments(true) + + private fun stressTestRemoveSegments(random: Boolean) { + val N = 100_000 * stressTestMultiplier + val T = 10 + val q = SegmentBasedQueue() + val segments = (1..N).map { q.enqueue(it) }.toMutableList() + if (random) segments.shuffle() + assertEquals(N, q.numberOfSegments) + val nextSegmentIndex = AtomicInteger() + val barrier = CyclicBarrier(T) + (1..T).map { + thread { + barrier.await() + while (true) { + val i = nextSegmentIndex.getAndIncrement() + if (i >= N) break + segments[i].removeSegment() + } + } + }.forEach { it.join() } + assertEquals(2, q.numberOfSegments) + } +} \ No newline at end of file diff --git a/kotlinx-coroutines-core/jvm/test/sync/SemaphoreStressTest.kt b/kotlinx-coroutines-core/jvm/test/sync/SemaphoreStressTest.kt new file mode 100644 index 0000000000..cdfcc6bded --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/sync/SemaphoreStressTest.kt @@ -0,0 +1,60 @@ +package kotlinx.coroutines.sync + +import kotlinx.coroutines.* +import org.junit.Test +import kotlin.test.assertEquals + +class SemaphoreStressTest : TestBase() { + + @Test + fun stressTestAsMutex() = runTest { + val n = 10_000 * stressTestMultiplier + val k = 100 + var shared = 0 + val semaphore = Semaphore(1) + val jobs = List(n) { + launch { + repeat(k) { + semaphore.acquire() + shared++ + semaphore.release() + } + } + } + jobs.forEach { it.join() } + assertEquals(n * k, shared) + } + + @Test + fun stressTest() = runTest { + val n = 10_000 * stressTestMultiplier + val k = 100 + val semaphore = Semaphore(10) + val jobs = List(n) { + launch { + repeat(k) { + semaphore.acquire() + semaphore.release() + } + } + } + jobs.forEach { it.join() } + } + + @Test + fun stressCancellation() = runTest { + val n = 10_000 * stressTestMultiplier + val semaphore = Semaphore(1) + semaphore.acquire() + repeat(n) { + val job = launch { + semaphore.acquire() + } + yield() + job.cancelAndJoin() + } + assertEquals(0, semaphore.availablePermits) + semaphore.release() + assertEquals(1, semaphore.availablePermits) + } +} \ No newline at end of file