Skip to content

Commit bc34a2e

Browse files
committed
Introduce ReceiveChannel.consumeAsFlow with efficient implementation
* This is a consuming conversion -- the resulting flow can be collected just once and the channel is closed after the first collect. * The implementation is made efficient (without iterators) using a new internal ReceiveChannel.consumeEachTo function which also ensure that the reference to the last emitted value is not retained. * AbstractChannel implementation is optimized to avoid code duplication in different receive methods. Fixes #1340 Fixes #1333
1 parent db0ef0c commit bc34a2e

File tree

7 files changed

+158
-50
lines changed

7 files changed

+158
-50
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+2
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,7 @@ public abstract interface class kotlinx/coroutines/channels/ReceiveChannel {
742742
public abstract synthetic fun cancel ()V
743743
public abstract synthetic fun cancel (Ljava/lang/Throwable;)Z
744744
public abstract fun cancel (Ljava/util/concurrent/CancellationException;)V
745+
public abstract fun consumeEachTo (Lkotlinx/coroutines/flow/FlowCollector;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
745746
public abstract fun getOnReceive ()Lkotlinx/coroutines/selects/SelectClause1;
746747
public abstract fun getOnReceiveOrNull ()Lkotlinx/coroutines/selects/SelectClause1;
747748
public abstract fun isClosedForReceive ()Z
@@ -827,6 +828,7 @@ public final class kotlinx/coroutines/flow/FlowKt {
827828
public static final fun combineLatest (Lkotlinx/coroutines/flow/Flow;[Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
828829
public static final synthetic fun combineLatest (Lkotlinx/coroutines/flow/Flow;[Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
829830
public static final fun conflate (Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow;
831+
public static final fun consumeAsFlow (Lkotlinx/coroutines/channels/ReceiveChannel;)Lkotlinx/coroutines/flow/Flow;
830832
public static final fun count (Lkotlinx/coroutines/flow/Flow;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
831833
public static final fun count (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
832834
public static final fun debounce (Lkotlinx/coroutines/flow/Flow;J)Lkotlinx/coroutines/flow/Flow;

kotlinx-coroutines-core/common/src/channels/AbstractChannel.kt

+69-36
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package kotlinx.coroutines.channels
66

77
import kotlinx.atomicfu.*
88
import kotlinx.coroutines.*
9+
import kotlinx.coroutines.flow.*
910
import kotlinx.coroutines.internal.*
1011
import kotlinx.coroutines.intrinsics.*
1112
import kotlinx.coroutines.selects.*
@@ -548,7 +549,7 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
548549
val result = pollInternal()
549550
if (result !== POLL_FAILED) return receiveResult(result)
550551
// slow-path does suspend
551-
return receiveSuspend()
552+
return receiveSuspend(RECEIVE_THROWS_ON_CLOSE)
552553
}
553554

554555
@Suppress("UNCHECKED_CAST")
@@ -558,8 +559,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
558559
}
559560

560561
@Suppress("UNCHECKED_CAST")
561-
private suspend fun receiveSuspend(): E = suspendAtomicCancellableCoroutine sc@ { cont ->
562-
val receive = ReceiveElement(cont as CancellableContinuation<E?>, nullOnClose = false)
562+
private suspend fun receiveSuspend(onClose: Int): E = suspendAtomicCancellableCoroutine sc@ { cont ->
563+
val receive = ReceiveElement<E>(cont as CancellableContinuation<Any?>, onClose)
563564
while (true) {
564565
if (enqueueReceive(receive)) {
565566
removeReceiveOnCancel(cont, receive)
@@ -568,7 +569,7 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
568569
// hm... something is not right. try to poll
569570
val result = pollInternal()
570571
if (result is Closed<*>) {
571-
cont.resumeWithException(result.receiveException)
572+
receive.resumeReceiveClosed(result)
572573
return@sc
573574
}
574575
if (result !== POLL_FAILED) {
@@ -592,7 +593,7 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
592593
val result = pollInternal()
593594
if (result !== POLL_FAILED) return receiveOrNullResult(result)
594595
// slow-path does suspend
595-
return receiveOrNullSuspend()
596+
return receiveSuspend(RECEIVE_NULL_ON_CLOSE)
596597
}
597598

598599
@Suppress("UNCHECKED_CAST")
@@ -604,30 +605,6 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
604605
return result as E
605606
}
606607

607-
@Suppress("UNCHECKED_CAST")
608-
private suspend fun receiveOrNullSuspend(): E? = suspendAtomicCancellableCoroutine sc@ { cont ->
609-
val receive = ReceiveElement(cont, nullOnClose = true)
610-
while (true) {
611-
if (enqueueReceive(receive)) {
612-
removeReceiveOnCancel(cont, receive)
613-
return@sc
614-
}
615-
// hm... something is not right. try to poll
616-
val result = pollInternal()
617-
if (result is Closed<*>) {
618-
if (result.closeCause == null)
619-
cont.resume(null)
620-
else
621-
cont.resumeWithException(result.closeCause)
622-
return@sc
623-
}
624-
if (result !== POLL_FAILED) {
625-
cont.resume(result as E)
626-
return@sc
627-
}
628-
}
629-
}
630-
631608
@Suppress("UNCHECKED_CAST")
632609
public final override fun poll(): E? {
633610
val result = pollInternal()
@@ -663,6 +640,57 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
663640

664641
public final override fun iterator(): ChannelIterator<E> = Itr(this)
665642

643+
// ------ efficient consumeEach implementation ------
644+
645+
@Suppress("UNCHECKED_CAST")
646+
public override suspend fun consumeEachTo(collector: FlowCollector<E>) {
647+
// Manually inlined "consumeEach" implementation that does not use iterator but works via "receiveInternal".
648+
// It has smaller and more efficient spilled state which also allows to implement a manual kludge to
649+
// fix retention of the last emitted value.
650+
// See https://youtrack.jetbrains.com/issue/KT-16222
651+
// See https://github.com/Kotlin/kotlinx.coroutines/issues/1333
652+
var cause: Throwable? = null
653+
try {
654+
while (true) {
655+
// :KLUDGE: This "run" call is resolved to an extension function "run" and forces the size of
656+
// spilled state to increase by an additional slot, so there are 4 object local variables spilled here
657+
// which makes the size of spill state equal to the 4 slots that are spilled around subsequent "emit"
658+
// call, ensuring that the previously emitted value is not retained in the state while receiving
659+
// the next one.
660+
// L$0 <- this
661+
// L$1 <- collector
662+
// L$2 <- cause
663+
// L$3 <- this$run (actually equal to this)
664+
val result = run { receiveInternal() }
665+
if (result is Closed<*>) {
666+
result.closeCause?.let { throw it }
667+
break // returns normally when result.closeCause == null
668+
}
669+
// result is spilled here to the coroutine state and retained after the call, even though
670+
// it is not actually needed in the next loop iteration.
671+
// L$0 <- this
672+
// L$1 <- collector
673+
// L$2 <- cause
674+
// L$3 <- result
675+
collector.emit(result as E)
676+
}
677+
} catch (e: Throwable) {
678+
cause = e
679+
throw e
680+
} finally {
681+
cancelConsumed(cause)
682+
}
683+
}
684+
685+
// Return type is `E | Closed`
686+
private suspend fun receiveInternal(): Any? {
687+
// fast path -- try poll non-blocking
688+
val result = pollInternal()
689+
if (result !== POLL_FAILED) return result
690+
// slow-path does suspend
691+
return receiveSuspend(RECEIVE_TOKEN_ON_CLOSE)
692+
}
693+
666694
// ------ registerSelectReceive ------
667695

668696
/**
@@ -884,18 +912,19 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
884912
}
885913

886914
private class ReceiveElement<in E>(
887-
@JvmField val cont: CancellableContinuation<E?>,
888-
@JvmField val nullOnClose: Boolean
915+
@JvmField val cont: CancellableContinuation<Any?>,
916+
@JvmField val onClose: Int
889917
) : Receive<E>() {
890918
override fun tryResumeReceive(value: E, idempotent: Any?): Any? = cont.tryResume(value, idempotent)
891919
override fun completeResumeReceive(token: Any) = cont.completeResume(token)
892920
override fun resumeReceiveClosed(closed: Closed<*>) {
893-
if (closed.closeCause == null && nullOnClose)
894-
cont.resume(null)
895-
else
896-
cont.resumeWithException(closed.receiveException)
921+
when {
922+
onClose == RECEIVE_NULL_ON_CLOSE && closed.closeCause == null -> cont.resume(null)
923+
onClose == RECEIVE_TOKEN_ON_CLOSE -> cont.resume(closed)
924+
else -> cont.resumeWithException(closed.receiveException)
925+
}
897926
}
898-
override fun toString(): String = "ReceiveElement[$cont,nullOnClose=$nullOnClose]"
927+
override fun toString(): String = "ReceiveElement[$cont,onClose=$onClose]"
899928
}
900929

901930
private class ReceiveHasNext<E>(
@@ -982,6 +1011,10 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
9821011
)
9831012
}
9841013

1014+
internal const val RECEIVE_THROWS_ON_CLOSE = 0
1015+
internal const val RECEIVE_NULL_ON_CLOSE = 1
1016+
internal const val RECEIVE_TOKEN_ON_CLOSE = 2
1017+
9851018
@JvmField
9861019
@SharedImmutable
9871020
internal val OFFER_SUCCESS: Any = Symbol("OFFER_SUCCESS")

kotlinx-coroutines-core/common/src/channels/Channel.kt

+17-3
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
package kotlinx.coroutines.channels
88

99
import kotlinx.coroutines.*
10+
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
11+
import kotlinx.coroutines.channels.Channel.Factory.CHANNEL_DEFAULT_CAPACITY
1012
import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
1113
import kotlinx.coroutines.channels.Channel.Factory.RENDEZVOUS
1214
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
13-
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
14-
import kotlinx.coroutines.channels.Channel.Factory.CHANNEL_DEFAULT_CAPACITY
15-
import kotlinx.coroutines.internal.systemProp
15+
import kotlinx.coroutines.flow.*
16+
import kotlinx.coroutines.internal.*
1617
import kotlinx.coroutines.selects.*
1718
import kotlin.jvm.*
1819

@@ -242,6 +243,19 @@ public interface ReceiveChannel<out E> {
242243
*/
243244
public operator fun iterator(): ChannelIterator<E>
244245

246+
/**
247+
* Consumes all elements from this channels by [emitting][FlowCollector.emit] them to the given [collector]
248+
* and [cancels][cancel] the channel after the execution of the block.
249+
* If you need to iterate over the channel without consuming it, a regular `for` loop should be used instead.
250+
*
251+
* This function provides a more efficient shorthand for `consumeEach { value -> collector.emit(value) }`.
252+
* See [consumeEach].
253+
*
254+
* @suppress **This an internal API and should not be used from general code.**
255+
*/
256+
@InternalCoroutinesApi
257+
public suspend fun consumeEachTo(collector: FlowCollector<E>)
258+
245259
/**
246260
* Cancels reception of remaining elements from this channel with an optional [cause].
247261
* This function closes the channel and removes all buffered sent elements from it.

kotlinx-coroutines-core/common/src/flow/Channels.kt

+24-8
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,33 @@
77

88
package kotlinx.coroutines.flow
99

10+
import kotlinx.atomicfu.*
1011
import kotlinx.coroutines.*
1112
import kotlinx.coroutines.channels.*
12-
import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
13-
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
14-
import kotlinx.coroutines.channels.Channel.Factory.OPTIONAL_CHANNEL
1513
import kotlinx.coroutines.flow.internal.*
16-
import kotlin.coroutines.*
1714
import kotlin.jvm.*
15+
import kotlinx.coroutines.flow.unsafeFlow as flow
16+
17+
/**
18+
* Represents the given receive channel as a hot flow and [consumes][ReceiveChannel.consume] the channel
19+
* on the first collection from this flow. The resulting flow can be collected just once and throws
20+
* [IllegalStateException] when trying to collect it more than once.
21+
*
22+
* ### Cancellation semantics
23+
* 1) Flow consumer is cancelled when the original channel is cancelled.
24+
* 2) Flow consumer completes normally when the original channel completes (~is closed) normally.
25+
* 3) If the flow consumer fails with an exception, channel is cancelled.
26+
*
27+
*/
28+
@FlowPreview
29+
public fun <T> ReceiveChannel<T>.consumeAsFlow(): Flow<T> = object : Flow<T> {
30+
val collected = atomic(false)
31+
32+
override suspend fun collect(collector: FlowCollector<T>) {
33+
check(!collected.getAndSet(true)) { "ReceiveChannel.consumeAsFlow can be collected just once" }
34+
this@consumeAsFlow.consumeEachTo(collector)
35+
}
36+
}
1837

1938
/**
2039
* Represents the given broadcast channel as a hot flow.
@@ -27,10 +46,7 @@ import kotlin.jvm.*
2746
*/
2847
@FlowPreview
2948
public fun <T> BroadcastChannel<T>.asFlow(): Flow<T> = flow {
30-
val subscription = openSubscription()
31-
subscription.consumeEach { value ->
32-
emit(value)
33-
}
49+
openSubscription().consumeEachTo(this)
3450
}
3551

3652
/**

kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ public abstract class ChannelFlow<T>(
7272

7373
override suspend fun collect(collector: FlowCollector<T>) =
7474
coroutineScope {
75-
val channel = produceImpl(this)
76-
channel.consumeEach { collector.emit(it) }
75+
produceImpl(this).consumeEachTo(collector)
7776
}
7877

7978
// debug toString

kotlinx-coroutines-core/common/test/channels/TestChannelKind.kt

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package kotlinx.coroutines.channels
66

77
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.*
89
import kotlinx.coroutines.selects.*
910

1011
enum class TestChannelKind {
@@ -60,7 +61,8 @@ private class ChannelViaBroadcast<E>(
6061
override suspend fun receiveOrNull(): E? = sub.receiveOrNull()
6162
override fun poll(): E? = sub.poll()
6263
override fun iterator(): ChannelIterator<E> = sub.iterator()
63-
64+
override suspend fun consumeEachTo(collector: FlowCollector<E>) = sub.consumeEachTo(collector)
65+
6466
override fun cancel(cause: CancellationException?) = sub.cancel(cause)
6567

6668
// implementing hidden method anyway, so can cast to an internal class

kotlinx-coroutines-core/common/test/flow/channels/ChannelBuildersFlowTest.kt

+42
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,48 @@ import kotlinx.coroutines.channels.*
99
import kotlin.test.*
1010

1111
class ChannelBuildersFlowTest : TestBase() {
12+
@Test
13+
fun testChannelConsumeAsFlow() = runTest {
14+
val channel = produce {
15+
repeat(10) {
16+
send(it + 1)
17+
}
18+
}
19+
val sum = channel.consumeAsFlow().sum()
20+
assertEquals(55, sum)
21+
}
22+
23+
@Test
24+
fun testConsumeAsFlowCancellation() = runTest {
25+
expect(1)
26+
val channel = produce(NonCancellable) { // otherwise failure will cancel scope as well
27+
repeat(10) {
28+
send(it + 1)
29+
}
30+
throw TestException()
31+
}
32+
val flow = channel.consumeAsFlow()
33+
assertEquals(15, flow.take(5).sum())
34+
// the channel should have been canceled, even though took only 5 elements
35+
assertTrue(channel.isClosedForReceive)
36+
assertFailsWith<IllegalStateException> { flow.collect() }
37+
finish(2)
38+
}
39+
40+
@Test
41+
fun testConsumeAsFlowException() = runTest {
42+
expect(1)
43+
val channel = produce(NonCancellable) { // otherwise failure will cancel scope as well
44+
repeat(10) {
45+
send(it + 1)
46+
}
47+
throw TestException()
48+
}
49+
assertFailsWith<TestException> { channel.consumeAsFlow().sum() }
50+
finish(2)
51+
}
52+
53+
1254
@Test
1355
fun testBroadcastChannelAsFlow() = runTest {
1456
val channel = broadcast {

0 commit comments

Comments
 (0)