Skip to content

Remove reference counters in the concurrent doubly-linked list used in BufferedChannel and Semaphore #4302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: develop
Choose a base branch
from
7 changes: 5 additions & 2 deletions benchmarks/src/jmh/kotlin/benchmarks/ChannelSinkBenchmark.kt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import kotlin.coroutines.*
@State(Scope.Benchmark)
@Fork(1)
open class ChannelSinkBenchmark {
@Param("${Channel.RENDEZVOUS}", "${Channel.BUFFERED}")
var capacity: Int = 0

private val tl = ThreadLocal.withInitial({ 42 })
private val tl2 = ThreadLocal.withInitial({ 239 })

Expand Down Expand Up @@ -42,15 +45,15 @@ open class ChannelSinkBenchmark {
.fold(0) { a, b -> a + b }
}

private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) = GlobalScope.produce(context) {
private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) = GlobalScope.produce(context, capacity) {
for (i in start until (start + count))
send(i)
}

// Migrated from deprecated operators, are good only for stressing channels

private fun <E> ReceiveChannel<E>.filter(context: CoroutineContext = Dispatchers.Unconfined, predicate: suspend (E) -> Boolean): ReceiveChannel<E> =
GlobalScope.produce(context, onCompletion = { cancel() }) {
GlobalScope.produce(context, capacity, onCompletion = { cancel() }) {
for (e in this@filter) {
if (predicate(e)) send(e)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import kotlin.coroutines.*
@State(Scope.Benchmark)
@Fork(2)
open class ChannelSinkDepthBenchmark {
@Param("${Channel.RENDEZVOUS}", "${Channel.BUFFERED}")
var capacity: Int = 0

private val tl = ThreadLocal.withInitial({ 42 })

private val unconfinedOneElement = Dispatchers.Unconfined + tl.asContextElement()
Expand Down Expand Up @@ -45,7 +48,7 @@ open class ChannelSinkDepthBenchmark {
}

private fun Channel.Factory.range(start: Int, count: Int, context: CoroutineContext) =
GlobalScope.produce(context) {
GlobalScope.produce(context, capacity) {
for (i in start until (start + count))
send(i)
}
Expand All @@ -57,7 +60,7 @@ open class ChannelSinkDepthBenchmark {
context: CoroutineContext = Dispatchers.Unconfined,
predicate: suspend (Int) -> Boolean
): ReceiveChannel<Int> =
GlobalScope.produce(context, onCompletion = { cancel() }) {
GlobalScope.produce(context, capacity, onCompletion = { cancel() }) {
deeplyNestedFilter(this, callTraceDepth, predicate)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import kotlin.coroutines.*
@State(Scope.Benchmark)
@Fork(1)
open class ChannelSinkNoAllocationsBenchmark {
@Param("${Channel.RENDEZVOUS}", "${Channel.BUFFERED}")
var capacity: Int = 0

private val unconfined = Dispatchers.Unconfined

@Benchmark
Expand All @@ -26,7 +29,7 @@ open class ChannelSinkNoAllocationsBenchmark {
return size
}

private fun Channel.Factory.range(context: CoroutineContext) = GlobalScope.produce(context) {
private fun Channel.Factory.range(context: CoroutineContext) = GlobalScope.produce(context, capacity) {
for (i in 0 until 100_000)
send(Unit) // no allocations
}
Expand Down
146 changes: 50 additions & 96 deletions kotlinx-coroutines-core/common/src/channels/BufferedChannel.kt

Large diffs are not rendered by default.

103 changes: 74 additions & 29 deletions kotlinx-coroutines-core/common/src/internal/ConcurrentLinkedList.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package kotlinx.coroutines.internal

import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.SEGMENT_SIZE
import kotlin.coroutines.*
import kotlin.jvm.*

Expand Down Expand Up @@ -34,26 +35,57 @@ internal fun <S : Segment<S>> S.findSegmentInternal(
return SegmentOrClosed(cur)
}

/**
* Returns the segment with the specified [id] or the last one if the required one does not exist (if it was removed
* or was not created yet).

* Unlike [findSegmentInternal], [findSpecifiedOrLast] does not add new segments to the list.
*/
internal fun <S : Segment<S>> S.findSpecifiedOrLast(id: Long): S {
// Start searching the required segment from the specified one.
var cur: S = this
while (cur.id < id) {
cur = cur.next ?: break
}
return cur
}

/**
* Returns `false` if the segment `to` is logically removed, `true` on a successful update.
*/
@Suppress("NOTHING_TO_INLINE", "RedundantNullableReturnType") // Must be inline because it is an AtomicRef extension
internal inline fun <S : Segment<S>> AtomicRef<S>.moveForward(to: S): Boolean = loop { cur ->
if (cur.id >= to.id) return true
if (!to.tryIncPointers()) return false
if (compareAndSet(cur, to)) { // the segment is moved
if (cur.decPointers()) cur.remove()
if (cur.id >= to.id) return true // No need to update the pointer
if (to.isRemoved) return false // Trying to move pointer to the logically removed segment
if (compareAndSet(cur, to)) { // The segment is moved
if (to.isRemoved) return false // The segment was removed in parallel during the `CAS` operation
cleanLeftmostPrev(cur, to)
return true
}
if (to.decPointers()) to.remove() // undo tryIncPointers
}

/**
* Cleans the `prev` reference of the leftmost segment in the list. The method works with the sublist which
* boundaries are specified by the given nodes [from] and [to]. It looks for the leftmost segment going from
* the tail to the head of the sublist.
*
* The method is called when [moveForward] successfully updates the value stored in the `AtomicRef` reference.
*/
private inline fun <S : Segment<S>> cleanLeftmostPrev(from: S, to: S) {
var cur = to
// Find the leftmost segment on the sublist between `from` and `to` segments.
while (!cur.isLeftmostOrProcessed && cur.id > from.id) {
cur = cur.prev ?:
// The `prev` reference was cleaned in parallel.
return
}
if (cur.isLeftmostOrProcessed) cur.cleanPrev() // The leftmost segment is found
}

/**
* Tries to find a segment with the specified [id] following by next references from the
* [startFrom] segment and creating new ones if needed. The typical use-case is reading this `AtomicRef` values,
* doing some synchronization, and invoking this function to find the required segment and update the pointer.
* At the same time, [Segment.cleanPrev] should also be invoked if the previous segments are no longer needed
* (e.g., queues should use it in dequeue operations).
*
* Since segments can be removed from the list, or it can be closed for further segment additions.
* Returns the segment `s` with `s.id >= id` or `CLOSED` if all the segments in this linked list have lower `id`,
Expand All @@ -71,6 +103,27 @@ internal inline fun <S : Segment<S>> AtomicRef<S>.findSegmentAndMoveForward(
}
}

/**
* Updates the `AtomicRef` reference by moving it to the existing segment.
*
* Unlike [findSegmentAndMoveForward], [moveToSpecifiedOrLast] does not add new segments into the list.
*/
@Suppress("NOTHING_TO_INLINE")
internal inline fun <S : Segment<S>> AtomicRef<S>.moveToSpecifiedOrLast(id: Long, startFrom: S) {
// Start searching the required segment from the specified one.
var s = startFrom.findSpecifiedOrLast(id)
// Skip all removed segments and try to update the channel pointer to the first non-removed one.
// This part should succeed eventually, as the tail segment is never removed.
while (true) {
while (s.isRemoved) {
s = s.next ?: break
}
// Try to update the value of `AtomicRef`.
// On failure, the found segment is already removed, so it should be skipped.
if (moveForward(s)) return
}
}

/**
* Closes this linked list of nodes by forbidding adding new ones,
* returns the last node in the list.
Expand Down Expand Up @@ -144,8 +197,10 @@ internal abstract class ConcurrentLinkedListNode<N : ConcurrentLinkedListNode<N>
/**
* Removes this node physically from this linked list. The node should be
* logically removed (so [isRemoved] returns `true`) at the point of invocation.
*
* Returns `true`, if the node was physically removed, and `false` otherwise.
*/
fun remove() {
open fun remove() {
assert { isRemoved || isTail } // The node should be logically removed at first.
// The physical tail cannot be removed. Instead, we remove it when
// a new segment is added and this segment is not the tail one anymore.
Expand All @@ -168,7 +223,7 @@ internal abstract class ConcurrentLinkedListNode<N : ConcurrentLinkedListNode<N>
private val aliveSegmentLeft: N? get() {
var cur = prev
while (cur !== null && cur.isRemoved)
cur = cur._prev.value
cur = cur.prev
return cur
}

Expand All @@ -190,7 +245,7 @@ internal abstract class ConcurrentLinkedListNode<N : ConcurrentLinkedListNode<N>
* instance-check it and uses a separate code-path for that.
*/
internal abstract class Segment<S : Segment<S>>(
@JvmField val id: Long, prev: S?, pointers: Int
@JvmField val id: Long, prev: S?
) : ConcurrentLinkedListNode<S>(prev),
// Segments typically store waiting continuations. Thus, on cancellation, the corresponding
// slot should be cleaned and the segment should be removed if it becomes full of cancelled cells.
Expand All @@ -207,21 +262,20 @@ internal abstract class Segment<S : Segment<S>>(
abstract val numberOfSlots: Int

/**
* Numbers of cleaned slots (the lowest bits) and AtomicRef pointers to this segment (the highest bits)
* Number of cleaned slots.
*/
private val cleanedAndPointers = atomic(pointers shl POINTERS_SHIFT)
private val cleanedSlots = atomic(0)

/**
* The segment is considered as removed if all the slots are cleaned
* and there are no pointers to this segment from outside.
*/
override val isRemoved get() = cleanedAndPointers.value == numberOfSlots && !isTail
override val isRemoved get() = cleanedSlots.value == numberOfSlots && !isTail

// increments the number of pointers if this segment is not logically removed.
internal fun tryIncPointers() = cleanedAndPointers.addConditionally(1 shl POINTERS_SHIFT) { it != numberOfSlots || isTail }

// returns `true` if this segment is logically removed after the decrement.
internal fun decPointers() = cleanedAndPointers.addAndGet(-(1 shl POINTERS_SHIFT)) == numberOfSlots && !isTail
/**
* Shows if all nodes going before this node have been processed.
*/
abstract val isLeftmostOrProcessed: Boolean

/**
* This function is invoked on continuation cancellation when this segment
Expand All @@ -240,15 +294,8 @@ internal abstract class Segment<S : Segment<S>>(
* Invoked on each slot clean-up; should not be invoked twice for the same slot.
*/
fun onSlotCleaned() {
if (cleanedAndPointers.incrementAndGet() == numberOfSlots) remove()
}
}

private inline fun AtomicInt.addConditionally(delta: Int, condition: (cur: Int) -> Boolean): Boolean {
while (true) {
val cur = this.value
if (!condition(cur)) return false
if (this.compareAndSet(cur, cur + delta)) return true
check(cleanedSlots.incrementAndGet() <= SEGMENT_SIZE) { "Some cell was interrupted twice." }
if (isRemoved) remove()
}
}

Expand All @@ -259,6 +306,4 @@ internal value class SegmentOrClosed<S : Segment<S>>(private val value: Any?) {
val segment: S get() = if (value === CLOSED) error("Does not contain segment") else value as S
}

private const val POINTERS_SHIFT = 16

private val CLOSED = Symbol("CLOSED")
21 changes: 17 additions & 4 deletions kotlinx-coroutines-core/common/src/sync/Semaphore.kt
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,16 @@ internal open class SemaphoreAndMutexImpl(private val permits: Int, acquiredPerm
private val tail: AtomicRef<SemaphoreSegment>
private val enqIdx = atomic(0L)

/**
This value is used in [SemaphoreSegment.isLeftmostOrProcessed].
It helps to detect when the `prev` reference of the segment should be cleaned.
*/
internal val headId: Long get() = head.value.id

init {
require(permits > 0) { "Semaphore should have at least 1 permit, but had $permits" }
require(acquiredPermits in 0..permits) { "The number of acquired permits should be in 0..$permits" }
val s = SemaphoreSegment(0, null, 2)
val s = SemaphoreSegment(0, null, this)
head = atomic(s)
tail = atomic(s)
}
Expand Down Expand Up @@ -317,7 +323,6 @@ internal open class SemaphoreAndMutexImpl(private val permits: Int, acquiredPerm
val createNewSegment = ::createSegment
val segment = this.head.findSegmentAndMoveForward(id, startFrom = curHead,
createNewSegment = createNewSegment).segment // cannot be closed
segment.cleanPrev()
if (segment.id > id) return false
val i = (deqIdx % SEGMENT_SIZE).toInt()
val cellState = segment.getAndSet(i, PERMIT) // set PERMIT and retrieve the prev cell state
Expand Down Expand Up @@ -356,11 +361,19 @@ private class SemaphoreImpl(
permits: Int, acquiredPermits: Int
): SemaphoreAndMutexImpl(permits, acquiredPermits), Semaphore

private fun createSegment(id: Long, prev: SemaphoreSegment?) = SemaphoreSegment(id, prev, 0)
private fun createSegment(id: Long, prev: SemaphoreSegment) = SemaphoreSegment(
id = id,
prev = prev,
semaphore = prev.semaphore
)

private class SemaphoreSegment(id: Long, prev: SemaphoreSegment?, pointers: Int) : Segment<SemaphoreSegment>(id, prev, pointers) {
private class SemaphoreSegment(
id: Long, prev: SemaphoreSegment?,
val semaphore: SemaphoreAndMutexImpl
) : Segment<SemaphoreSegment>(id, prev) {
val acquirers = atomicArrayOfNulls<Any?>(SEGMENT_SIZE)
override val numberOfSlots: Int get() = SEGMENT_SIZE
override val isLeftmostOrProcessed: Boolean get() = id <= semaphore.headId

@Suppress("NOTHING_TO_INLINE")
inline fun get(index: Int): Any? = acquirers[index].value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,9 @@ class CancellableContinuationHandlersTest : TestBase() {

@Test
fun testSegmentAsHandler() = runTest {
class MySegment : Segment<MySegment>(0, null, 0) {
class MySegment : Segment<MySegment>(0, null) {
override val numberOfSlots: Int get() = 0
override val isLeftmostOrProcessed: Boolean get() = false

var invokeOnCancellationCalled = false
override fun onCancellation(index: Int, cause: Throwable?, context: CoroutineContext) {
Expand Down