Skip to content

Commit f116346

Browse files
committed
Introduce ReceiveChannel.consumeAsFlow with efficient implementation
* This is a consuming conversion -- the resulting flow can be collected just once and the channel is closed after the first collect. * The implementation is made efficient (without iterators) using a new internal ReceiveChannel.consumeEachTo function which also ensures that the reference to the last emitted value is not retained (does not leak). * AbstractChannel implementation is optimized to avoid code duplication in different receive methods (receive and receiveOrNull) and also shares code with new receiveInternal that is used for an efficient consumeEachTo implementation. Fixes #1340 Fixes #1333
1 parent db0ef0c commit f116346

File tree

7 files changed

+161
-50
lines changed

7 files changed

+161
-50
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+2
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,7 @@ public abstract interface class kotlinx/coroutines/channels/ReceiveChannel {
742742
public abstract synthetic fun cancel ()V
743743
public abstract synthetic fun cancel (Ljava/lang/Throwable;)Z
744744
public abstract fun cancel (Ljava/util/concurrent/CancellationException;)V
745+
public abstract fun consumeEachTo (Lkotlinx/coroutines/flow/FlowCollector;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
745746
public abstract fun getOnReceive ()Lkotlinx/coroutines/selects/SelectClause1;
746747
public abstract fun getOnReceiveOrNull ()Lkotlinx/coroutines/selects/SelectClause1;
747748
public abstract fun isClosedForReceive ()Z
@@ -827,6 +828,7 @@ public final class kotlinx/coroutines/flow/FlowKt {
827828
public static final fun combineLatest (Lkotlinx/coroutines/flow/Flow;[Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function0;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
828829
public static final synthetic fun combineLatest (Lkotlinx/coroutines/flow/Flow;[Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
829830
public static final fun conflate (Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow;
831+
public static final fun consumeAsFlow (Lkotlinx/coroutines/channels/ReceiveChannel;)Lkotlinx/coroutines/flow/Flow;
830832
public static final fun count (Lkotlinx/coroutines/flow/Flow;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
831833
public static final fun count (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
832834
public static final fun debounce (Lkotlinx/coroutines/flow/Flow;J)Lkotlinx/coroutines/flow/Flow;

kotlinx-coroutines-core/common/src/channels/AbstractChannel.kt

+69-36
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package kotlinx.coroutines.channels
66

77
import kotlinx.atomicfu.*
88
import kotlinx.coroutines.*
9+
import kotlinx.coroutines.flow.*
910
import kotlinx.coroutines.internal.*
1011
import kotlinx.coroutines.intrinsics.*
1112
import kotlinx.coroutines.selects.*
@@ -548,7 +549,7 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
548549
val result = pollInternal()
549550
if (result !== POLL_FAILED) return receiveResult(result)
550551
// slow-path does suspend
551-
return receiveSuspend()
552+
return receiveSuspend(RECEIVE_THROWS_ON_CLOSE)
552553
}
553554

554555
@Suppress("UNCHECKED_CAST")
@@ -558,8 +559,8 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
558559
}
559560

560561
@Suppress("UNCHECKED_CAST")
561-
private suspend fun receiveSuspend(): E = suspendAtomicCancellableCoroutine sc@ { cont ->
562-
val receive = ReceiveElement(cont as CancellableContinuation<E?>, nullOnClose = false)
562+
private suspend fun receiveSuspend(onClose: Int): E = suspendAtomicCancellableCoroutine sc@ { cont ->
563+
val receive = ReceiveElement<E>(cont as CancellableContinuation<Any?>, onClose)
563564
while (true) {
564565
if (enqueueReceive(receive)) {
565566
removeReceiveOnCancel(cont, receive)
@@ -568,7 +569,7 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
568569
// hm... something is not right. try to poll
569570
val result = pollInternal()
570571
if (result is Closed<*>) {
571-
cont.resumeWithException(result.receiveException)
572+
receive.resumeReceiveClosed(result)
572573
return@sc
573574
}
574575
if (result !== POLL_FAILED) {
@@ -592,7 +593,7 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
592593
val result = pollInternal()
593594
if (result !== POLL_FAILED) return receiveOrNullResult(result)
594595
// slow-path does suspend
595-
return receiveOrNullSuspend()
596+
return receiveSuspend(RECEIVE_NULL_ON_CLOSE)
596597
}
597598

598599
@Suppress("UNCHECKED_CAST")
@@ -604,30 +605,6 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
604605
return result as E
605606
}
606607

607-
@Suppress("UNCHECKED_CAST")
608-
private suspend fun receiveOrNullSuspend(): E? = suspendAtomicCancellableCoroutine sc@ { cont ->
609-
val receive = ReceiveElement(cont, nullOnClose = true)
610-
while (true) {
611-
if (enqueueReceive(receive)) {
612-
removeReceiveOnCancel(cont, receive)
613-
return@sc
614-
}
615-
// hm... something is not right. try to poll
616-
val result = pollInternal()
617-
if (result is Closed<*>) {
618-
if (result.closeCause == null)
619-
cont.resume(null)
620-
else
621-
cont.resumeWithException(result.closeCause)
622-
return@sc
623-
}
624-
if (result !== POLL_FAILED) {
625-
cont.resume(result as E)
626-
return@sc
627-
}
628-
}
629-
}
630-
631608
@Suppress("UNCHECKED_CAST")
632609
public final override fun poll(): E? {
633610
val result = pollInternal()
@@ -663,6 +640,57 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
663640

664641
public final override fun iterator(): ChannelIterator<E> = Itr(this)
665642

643+
// ------ efficient consumeEach implementation ------
644+
645+
@Suppress("UNCHECKED_CAST")
646+
public override suspend fun consumeEachTo(collector: FlowCollector<E>) {
647+
// Manually inlined "consumeEach" implementation that does not use iterator but works via "receiveInternal".
648+
// It has smaller and more efficient spilled state which also allows to implement a manual kludge to
649+
// fix retention of the last emitted value.
650+
// See https://youtrack.jetbrains.com/issue/KT-16222
651+
// See https://github.com/Kotlin/kotlinx.coroutines/issues/1333
652+
var cause: Throwable? = null
653+
try {
654+
while (true) {
655+
// :KLUDGE: This "run" call is resolved to an extension function "run" and forces the size of
656+
// spilled state to increase by an additional slot, so there are 4 object local variables spilled here
657+
// which makes the size of spill state equal to the 4 slots that are spilled around subsequent "emit"
658+
// call, ensuring that the previously emitted value is not retained in the state while receiving
659+
// the next one.
660+
// L$0 <- this
661+
// L$1 <- collector
662+
// L$2 <- cause
663+
// L$3 <- this$run (actually equal to this)
664+
val result = run { receiveInternal() }
665+
if (result is Closed<*>) {
666+
result.closeCause?.let { throw it }
667+
break // returns normally when result.closeCause == null
668+
}
669+
// result is spilled here to the coroutine state and retained after the call, even though
670+
// it is not actually needed in the next loop iteration.
671+
// L$0 <- this
672+
// L$1 <- collector
673+
// L$2 <- cause
674+
// L$3 <- result
675+
collector.emit(result as E)
676+
}
677+
} catch (e: Throwable) {
678+
cause = e
679+
throw e
680+
} finally {
681+
cancelConsumed(cause)
682+
}
683+
}
684+
685+
// Return type is `E | Closed`
686+
private suspend fun receiveInternal(): Any? {
687+
// fast path -- try poll non-blocking
688+
val result = pollInternal()
689+
if (result !== POLL_FAILED) return result
690+
// slow-path does suspend
691+
return receiveSuspend(RECEIVE_TOKEN_ON_CLOSE)
692+
}
693+
666694
// ------ registerSelectReceive ------
667695

668696
/**
@@ -884,18 +912,19 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
884912
}
885913

886914
private class ReceiveElement<in E>(
887-
@JvmField val cont: CancellableContinuation<E?>,
888-
@JvmField val nullOnClose: Boolean
915+
@JvmField val cont: CancellableContinuation<Any?>,
916+
@JvmField val onClose: Int
889917
) : Receive<E>() {
890918
override fun tryResumeReceive(value: E, idempotent: Any?): Any? = cont.tryResume(value, idempotent)
891919
override fun completeResumeReceive(token: Any) = cont.completeResume(token)
892920
override fun resumeReceiveClosed(closed: Closed<*>) {
893-
if (closed.closeCause == null && nullOnClose)
894-
cont.resume(null)
895-
else
896-
cont.resumeWithException(closed.receiveException)
921+
when {
922+
onClose == RECEIVE_NULL_ON_CLOSE && closed.closeCause == null -> cont.resume(null)
923+
onClose == RECEIVE_TOKEN_ON_CLOSE -> cont.resume(closed)
924+
else -> cont.resumeWithException(closed.receiveException)
925+
}
897926
}
898-
override fun toString(): String = "ReceiveElement[$cont,nullOnClose=$nullOnClose]"
927+
override fun toString(): String = "ReceiveElement[$cont,onClose=$onClose]"
899928
}
900929

901930
private class ReceiveHasNext<E>(
@@ -982,6 +1011,10 @@ internal abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E
9821011
)
9831012
}
9841013

1014+
internal const val RECEIVE_THROWS_ON_CLOSE = 0
1015+
internal const val RECEIVE_NULL_ON_CLOSE = 1
1016+
internal const val RECEIVE_TOKEN_ON_CLOSE = 2
1017+
9851018
@JvmField
9861019
@SharedImmutable
9871020
internal val OFFER_SUCCESS: Any = Symbol("OFFER_SUCCESS")

kotlinx-coroutines-core/common/src/channels/Channel.kt

+17-3
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,13 @@
77
package kotlinx.coroutines.channels
88

99
import kotlinx.coroutines.*
10+
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
11+
import kotlinx.coroutines.channels.Channel.Factory.CHANNEL_DEFAULT_CAPACITY
1012
import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
1113
import kotlinx.coroutines.channels.Channel.Factory.RENDEZVOUS
1214
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
13-
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
14-
import kotlinx.coroutines.channels.Channel.Factory.CHANNEL_DEFAULT_CAPACITY
15-
import kotlinx.coroutines.internal.systemProp
15+
import kotlinx.coroutines.flow.*
16+
import kotlinx.coroutines.internal.*
1617
import kotlinx.coroutines.selects.*
1718
import kotlin.jvm.*
1819

@@ -242,6 +243,19 @@ public interface ReceiveChannel<out E> {
242243
*/
243244
public operator fun iterator(): ChannelIterator<E>
244245

246+
/**
247+
* Consumes all elements from this channels by [emitting][FlowCollector.emit] them to the given [collector]
248+
* and [cancels][cancel] the channel after the execution of the block.
249+
* If you need to iterate over the channel without consuming it, a regular `for` loop should be used instead.
250+
*
251+
* This function provides a more efficient shorthand for `consumeEach { value -> collector.emit(value) }`.
252+
* See [consumeEach].
253+
*
254+
* @suppress **This an internal API and should not be used from general code.**
255+
*/
256+
@InternalCoroutinesApi
257+
public suspend fun consumeEachTo(collector: FlowCollector<E>)
258+
245259
/**
246260
* Cancels reception of remaining elements from this channel with an optional [cause].
247261
* This function closes the channel and removes all buffered sent elements from it.

kotlinx-coroutines-core/common/src/flow/Channels.kt

+24-8
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,33 @@
77

88
package kotlinx.coroutines.flow
99

10+
import kotlinx.atomicfu.*
1011
import kotlinx.coroutines.*
1112
import kotlinx.coroutines.channels.*
12-
import kotlinx.coroutines.channels.Channel.Factory.CONFLATED
13-
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
14-
import kotlinx.coroutines.channels.Channel.Factory.OPTIONAL_CHANNEL
1513
import kotlinx.coroutines.flow.internal.*
16-
import kotlin.coroutines.*
1714
import kotlin.jvm.*
15+
import kotlinx.coroutines.flow.unsafeFlow as flow
16+
17+
/**
18+
* Represents the given receive channel as a hot flow and [consumes][ReceiveChannel.consume] the channel
19+
* on the first collection from this flow. The resulting flow can be collected just once and throws
20+
* [IllegalStateException] when trying to collect it more than once.
21+
*
22+
* ### Cancellation semantics
23+
* 1) Flow consumer is cancelled when the original channel is cancelled.
24+
* 2) Flow consumer completes normally when the original channel completes (~is closed) normally.
25+
* 3) If the flow consumer fails with an exception, channel is cancelled.
26+
*
27+
*/
28+
@FlowPreview
29+
public fun <T> ReceiveChannel<T>.consumeAsFlow(): Flow<T> = object : Flow<T> {
30+
val collected = atomic(false)
31+
32+
override suspend fun collect(collector: FlowCollector<T>) {
33+
check(!collected.getAndSet(true)) { "ReceiveChannel.consumeAsFlow can be collected just once" }
34+
this@consumeAsFlow.consumeEachTo(collector)
35+
}
36+
}
1837

1938
/**
2039
* Represents the given broadcast channel as a hot flow.
@@ -27,10 +46,7 @@ import kotlin.jvm.*
2746
*/
2847
@FlowPreview
2948
public fun <T> BroadcastChannel<T>.asFlow(): Flow<T> = flow {
30-
val subscription = openSubscription()
31-
subscription.consumeEach { value ->
32-
emit(value)
33-
}
49+
openSubscription().consumeEachTo(this)
3450
}
3551

3652
/**

kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,7 @@ public abstract class ChannelFlow<T>(
7272

7373
override suspend fun collect(collector: FlowCollector<T>) =
7474
coroutineScope {
75-
val channel = produceImpl(this)
76-
channel.consumeEach { collector.emit(it) }
75+
produceImpl(this).consumeEachTo(collector)
7776
}
7877

7978
// debug toString

kotlinx-coroutines-core/common/test/channels/TestChannelKind.kt

+3-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package kotlinx.coroutines.channels
66

77
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.*
89
import kotlinx.coroutines.selects.*
910

1011
enum class TestChannelKind {
@@ -60,7 +61,8 @@ private class ChannelViaBroadcast<E>(
6061
override suspend fun receiveOrNull(): E? = sub.receiveOrNull()
6162
override fun poll(): E? = sub.poll()
6263
override fun iterator(): ChannelIterator<E> = sub.iterator()
63-
64+
override suspend fun consumeEachTo(collector: FlowCollector<E>) = sub.consumeEachTo(collector)
65+
6466
override fun cancel(cause: CancellationException?) = sub.cancel(cause)
6567

6668
// implementing hidden method anyway, so can cast to an internal class

kotlinx-coroutines-core/common/test/flow/channels/ChannelBuildersFlowTest.kt

+45
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,51 @@ import kotlinx.coroutines.channels.*
99
import kotlin.test.*
1010

1111
class ChannelBuildersFlowTest : TestBase() {
12+
@Test
13+
fun testChannelConsumeAsFlow() = runTest {
14+
val channel = produce {
15+
repeat(10) {
16+
send(it + 1)
17+
}
18+
}
19+
val flow = channel.consumeAsFlow()
20+
assertEquals(55, flow.sum())
21+
assertFailsWith<IllegalStateException> { flow.collect() }
22+
}
23+
24+
@Test
25+
fun testConsumeAsFlowCancellation() = runTest {
26+
expect(1)
27+
val channel = produce(NonCancellable) { // otherwise failure will cancel scope as well
28+
repeat(10) {
29+
send(it + 1)
30+
}
31+
throw TestException()
32+
}
33+
val flow = channel.consumeAsFlow()
34+
assertEquals(15, flow.take(5).sum())
35+
// the channel should have been canceled, even though took only 5 elements
36+
assertTrue(channel.isClosedForReceive)
37+
assertFailsWith<IllegalStateException> { flow.collect() }
38+
finish(2)
39+
}
40+
41+
@Test
42+
fun testConsumeAsFlowException() = runTest {
43+
expect(1)
44+
val channel = produce(NonCancellable) { // otherwise failure will cancel scope as well
45+
repeat(10) {
46+
send(it + 1)
47+
}
48+
throw TestException()
49+
}
50+
val flow = channel.consumeAsFlow()
51+
assertFailsWith<TestException> { flow.sum() }
52+
assertFailsWith<IllegalStateException> { flow.collect() }
53+
finish(2)
54+
}
55+
56+
1257
@Test
1358
fun testBroadcastChannelAsFlow() = runTest {
1459
val channel = broadcast {

0 commit comments

Comments
 (0)