From 8ab710ef4bb2d09d483f905f2621e7bc6481a758 Mon Sep 17 00:00:00 2001 From: sakex Date: Fri, 17 Jun 2022 18:35:46 +0200 Subject: [PATCH] Add onUndeliveredElement callback to ArrayChannels --- .../common/src/channels/ArrayChannel.kt | 23 +++++++++++++++---- .../common/src/channels/Channel.kt | 9 ++++---- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/kotlinx-coroutines-core/common/src/channels/ArrayChannel.kt b/kotlinx-coroutines-core/common/src/channels/ArrayChannel.kt index 7e6c0e68c5..75f6905ef4 100644 --- a/kotlinx-coroutines-core/common/src/channels/ArrayChannel.kt +++ b/kotlinx-coroutines-core/common/src/channels/ArrayChannel.kt @@ -25,7 +25,8 @@ internal open class ArrayChannel( */ private val capacity: Int, private val onBufferOverflow: BufferOverflow, - onUndeliveredElement: OnUndeliveredElement? + onUndeliveredElement: OnUndeliveredElement?, + private val onDroppedElement: ((E) -> Unit)? = null ) : AbstractChannel(onUndeliveredElement) { init { // This check is actually used by the Channel(...) constructor function which checks only for known @@ -63,7 +64,8 @@ internal open class ArrayChannel( // check for receivers that were waiting on empty queue if (size == 0) { loop@ while (true) { - receive = takeFirstReceiveOrPeekClosed() ?: break@loop // break when no receivers queued + receive = takeFirstReceiveOrPeekClosed() + ?: break@loop // break when no receivers queued if (receive is Closed) { this.size.value = size // restore size return receive!! @@ -153,6 +155,11 @@ internal open class ArrayChannel( } else { // buffer is full assert { onBufferOverflow == BufferOverflow.DROP_OLDEST } // the only way we can get here + val dropped = buffer[head % buffer.size] + if (dropped != null) { + @Suppress("UNCHECKED_CAST") + onDroppedElement?.let { it(dropped as E) } + } buffer[head % buffer.size] = null // drop oldest element buffer[(head + currentSize) % buffer.size] = element // actually queue element head = (head + 1) % buffer.size @@ -180,7 +187,8 @@ internal open class ArrayChannel( var result: Any? = null lock.withLock { val size = this.size.value - if (size == 0) return closedForSend ?: POLL_FAILED // when nothing can be read from buffer + if (size == 0) return closedForSend + ?: POLL_FAILED // when nothing can be read from buffer // size > 0: not empty -- retrieve element result = buffer[head] buffer[head] = null @@ -282,13 +290,18 @@ internal open class ArrayChannel( override fun onCancelIdempotent(wasClosed: Boolean) { // clear buffer first, but do not wait for it in helpers val onUndeliveredElement = onUndeliveredElement - var undeliveredElementException: UndeliveredElementException? = null // first cancel exception, others suppressed + var undeliveredElementException: UndeliveredElementException? = + null // first cancel exception, others suppressed lock.withLock { repeat(size.value) { val value = buffer[head] if (onUndeliveredElement != null && value !== EMPTY) { @Suppress("UNCHECKED_CAST") - undeliveredElementException = onUndeliveredElement.callUndeliveredElementCatchingException(value as E, undeliveredElementException) + undeliveredElementException = + onUndeliveredElement.callUndeliveredElementCatchingException( + value as E, + undeliveredElementException + ) } buffer[head] = EMPTY head = (head + 1) % buffer.size diff --git a/kotlinx-coroutines-core/common/src/channels/Channel.kt b/kotlinx-coroutines-core/common/src/channels/Channel.kt index 5ad79fdcff..1c748d2b1a 100644 --- a/kotlinx-coroutines-core/common/src/channels/Channel.kt +++ b/kotlinx-coroutines-core/common/src/channels/Channel.kt @@ -768,14 +768,15 @@ public interface Channel : SendChannel, ReceiveChannel { public fun Channel( capacity: Int = RENDEZVOUS, onBufferOverflow: BufferOverflow = BufferOverflow.SUSPEND, - onUndeliveredElement: ((E) -> Unit)? = null + onUndeliveredElement: ((E) -> Unit)? = null, + onDroppedElement: ((E) -> Unit)? = null ): Channel = when (capacity) { RENDEZVOUS -> { if (onBufferOverflow == BufferOverflow.SUSPEND) RendezvousChannel(onUndeliveredElement) // an efficient implementation of rendezvous channel else - ArrayChannel(1, onBufferOverflow, onUndeliveredElement) // support buffer overflow with buffered channel + ArrayChannel(1, onBufferOverflow, onUndeliveredElement, onDroppedElement) // support buffer overflow with buffered channel } CONFLATED -> { require(onBufferOverflow == BufferOverflow.SUSPEND) { @@ -786,13 +787,13 @@ public fun Channel( UNLIMITED -> LinkedListChannel(onUndeliveredElement) // ignores onBufferOverflow: it has buffer, but it never overflows BUFFERED -> ArrayChannel( // uses default capacity with SUSPEND if (onBufferOverflow == BufferOverflow.SUSPEND) CHANNEL_DEFAULT_CAPACITY else 1, - onBufferOverflow, onUndeliveredElement + onBufferOverflow, onUndeliveredElement, onDroppedElement ) else -> { if (capacity == 1 && onBufferOverflow == BufferOverflow.DROP_OLDEST) ConflatedChannel(onUndeliveredElement) // conflated implementation is more efficient but appears to work in the same way else - ArrayChannel(capacity, onBufferOverflow, onUndeliveredElement) + ArrayChannel(capacity, onBufferOverflow, onUndeliveredElement, onDroppedElement) } }