Skip to content

Commit d33a32d

Browse files
committed
Detect missing awaitClose calls in callbackFlow and close channel with a proper diagnostic exception
Fixes #1762 Fixes #1770
1 parent f18e0e4 commit d33a32d

File tree

7 files changed

+127
-34
lines changed

7 files changed

+127
-34
lines changed

kotlinx-coroutines-core/common/src/channels/Produce.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public interface ProducerScope<in E> : CoroutineScope, SendChannel<E> {
2727

2828
/**
2929
* Suspends the current coroutine until the channel is either [closed][SendChannel.close] or [cancelled][ReceiveChannel.cancel]
30-
* and invokes the given [block] before resuming the coroutine.
30+
* and invokes the given [block] before resuming the coroutine. This suspending function is cancellable.
3131
*
3232
* Note that when the producer channel is cancelled, this function resumes with a cancellation exception.
3333
* Therefore, in case of cancellation, no code after the call to this function will be executed.

kotlinx-coroutines-core/common/src/flow/Builders.kt

+49-13
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import kotlinx.coroutines.*
1111
import kotlinx.coroutines.channels.*
1212
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
1313
import kotlinx.coroutines.flow.internal.*
14-
import kotlinx.coroutines.flow.internal.unsafeFlow as flow
1514
import kotlin.coroutines.*
1615
import kotlin.jvm.*
16+
import kotlinx.coroutines.flow.internal.unsafeFlow as flow
1717

1818
/**
1919
* Creates a flow from the given suspendable [block].
@@ -259,10 +259,14 @@ public fun <T> channelFlow(@BuilderInference block: suspend ProducerScope<T>.()
259259
*
260260
* This builder ensures thread-safety and context preservation, thus the provided [ProducerScope] can be used
261261
* from any context, e.g. from a callback-based API.
262-
* The resulting flow completes as soon as the code in the [block] and all its children completes.
263-
* Use [awaitClose] as the last statement to keep it running.
264-
* The [awaitClose] argument is called either when a flow consumer cancels the flow collection
265-
* or when a callback-based API invokes [SendChannel.close] manually.
262+
* The resulting flow completes as soon as the code in the [block] completes.
263+
* [awaitClose] should be used to keep the flow running, otherwise the channel will be closed immediately
264+
* when block completes.
265+
* [awaitClose] argument is called either when a flow consumer cancels the flow collection
266+
* or when a callback-based API invokes [SendChannel.close] manually and is typically used
267+
* to cleanup the resources after the completion, e.g. unregister a callback.
268+
* Using [awaitClose] is mandatory in order to prevent memory leaks when the flow collection is cancelled,
269+
* otherwise the callback may keep running even when the flow collector is already completed.
266270
*
267271
* A channel with the [default][Channel.BUFFERED] buffer size is used. Use the [buffer] operator on the
268272
* resulting flow to specify a user-defined value and to control what happens when data is produced faster
@@ -287,21 +291,20 @@ public fun <T> channelFlow(@BuilderInference block: suspend ProducerScope<T>.()
287291
* override fun onCompleted() = channel.close()
288292
* }
289293
* api.register(callback)
290-
* // Suspend until either onCompleted or external cancellation are invoked
294+
* /*
295+
* * Suspends until either 'onCompleted' from the callback is invoked
296+
* * or flow collector is cancelled (e.g. by 'take(1)' or because a collector's activity was destroyed).
297+
* * In both cases, callback will be properly unregistered.
298+
* */
291299
* awaitClose { api.unregister(callback) }
292300
* }
293301
* ```
294-
*
295-
* This function is an alias for [channelFlow], it has a separate name to reflect
296-
* the intent of the usage (integration with a callback-based API) better.
297302
*/
298-
@Suppress("NOTHING_TO_INLINE")
299303
@ExperimentalCoroutinesApi
300-
public inline fun <T> callbackFlow(@BuilderInference noinline block: suspend ProducerScope<T>.() -> Unit): Flow<T> =
301-
channelFlow(block)
304+
public fun <T> callbackFlow(@BuilderInference block: suspend ProducerScope<T>.() -> Unit): Flow<T> = CallbackFlowBuilder(block)
302305

303306
// ChannelFlow implementation that is the first in the chain of flow operations and introduces (builds) a flow
304-
private class ChannelFlowBuilder<T>(
307+
private open class ChannelFlowBuilder<T>(
305308
private val block: suspend ProducerScope<T>.() -> Unit,
306309
context: CoroutineContext = EmptyCoroutineContext,
307310
capacity: Int = BUFFERED
@@ -315,3 +318,36 @@ private class ChannelFlowBuilder<T>(
315318
override fun toString(): String =
316319
"block[$block] -> ${super.toString()}"
317320
}
321+
322+
private class CallbackFlowBuilder<T>(
323+
private val block: suspend ProducerScope<T>.() -> Unit,
324+
context: CoroutineContext = EmptyCoroutineContext,
325+
capacity: Int = BUFFERED
326+
) : ChannelFlowBuilder<T>(block, context, capacity) {
327+
328+
private val collectCallback: suspend (ProducerScope<T>) -> Unit = {
329+
collectTo(it)
330+
/*
331+
* We expect user either call `awaitClose` from within a block (then the channel is closed at this moment)
332+
* or being closed/cancelled externally/manually. Otherwise "user forgot to call
333+
* awaitClose and receives unhelpful ClosedSendChannelException exceptions" situation is detected.
334+
*/
335+
if (it.isActive && !it.isClosedForSend) {
336+
throw IllegalStateException(
337+
"""
338+
'awaitClose { yourCallbackOrListener.cancel() }' should be used in the end of callbackFlow block.
339+
Otherwise, a callback/listener may leak in case of cancellation external cancellation (e.g. by 'take(1)' or destroyed activity).
340+
For a more detailed explanation, please refer to callbackFlow KDoc.
341+
""".trimIndent())
342+
}
343+
}
344+
345+
override fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel<T> =
346+
scope.broadcast(context, produceCapacity, start, block = collectCallback)
347+
348+
override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> =
349+
scope.produce(context, produceCapacity, block = collectCallback)
350+
351+
override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
352+
CallbackFlowBuilder(block, context, capacity)
353+
}

kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt

+11-10
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ public abstract class ChannelFlow<T>(
2727
// buffer capacity between upstream and downstream context
2828
@JvmField val capacity: Int
2929
) : Flow<T> {
30+
31+
// shared code to create a suspend lambda from collectTo function in one place
32+
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
33+
get() = { collectTo(it) }
34+
35+
protected val produceCapacity: Int
36+
get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity
37+
3038
public fun update(
3139
context: CoroutineContext = EmptyCoroutineContext,
3240
capacity: Int = Channel.OPTIONAL_CHANNEL
@@ -57,13 +65,6 @@ public abstract class ChannelFlow<T>(
5765

5866
protected abstract suspend fun collectTo(scope: ProducerScope<T>)
5967

60-
// shared code to create a suspend lambda from collectTo function in one place
61-
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
62-
get() = { collectTo(it) }
63-
64-
private val produceCapacity: Int
65-
get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity
66-
6768
open fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel<T> =
6869
scope.broadcast(context, produceCapacity, start, block = collectToFun)
6970

@@ -75,11 +76,11 @@ public abstract class ChannelFlow<T>(
7576
collector.emitAll(produceImpl(this))
7677
}
7778

79+
open fun additionalToStringProps() = ""
80+
7881
// debug toString
7982
override fun toString(): String =
8083
"$classSimpleName[${additionalToStringProps()}context=$context, capacity=$capacity]"
81-
82-
open fun additionalToStringProps() = ""
8384
}
8485

8586
// ChannelFlow implementation that operates on another flow before it
@@ -161,7 +162,7 @@ private suspend fun <T, V> withContextUndispatched(
161162
countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed
162163
block: suspend (V) -> T, value: V
163164
): T =
164-
suspendCoroutineUninterceptedOrReturn sc@{ uCont ->
165+
suspendCoroutineUninterceptedOrReturn { uCont ->
165166
withCoroutineContext(newContext, countOrElement) {
166167
block.startCoroutineUninterceptedOrReturn(value, Continuation(newContext) {
167168
uCont.resumeWith(it)

kotlinx-coroutines-core/common/test/channels/ProduceTest.kt

+14-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package kotlinx.coroutines.channels
66

77
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.*
89
import kotlin.coroutines.*
910
import kotlin.test.*
1011

@@ -143,9 +144,20 @@ class ProduceTest : TestBase() {
143144

144145
@Test
145146
fun testAwaitIllegalState() = runTest {
146-
val channel = produce<Int> { }
147-
@Suppress("RemoveExplicitTypeArguments") // KT-31525
147+
val channel = produce<Int> { }
148148
assertFailsWith<IllegalStateException> { (channel as ProducerScope<*>).awaitClose() }
149+
callbackFlow<Unit> {
150+
expect(1)
151+
launch {
152+
expect(2)
153+
assertFailsWith<IllegalStateException> {
154+
awaitClose { expectUnreached() }
155+
expectUnreached()
156+
}
157+
}
158+
close()
159+
}.collect()
160+
finish(3)
149161
}
150162

151163
private suspend fun cancelOnCompletion(coroutineContext: CoroutineContext) = CoroutineScope(coroutineContext).apply {

kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt

+34
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,38 @@ class ChannelFlowTest : TestBase() {
160160

161161
finish(6)
162162
}
163+
164+
@Test
165+
fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) {
166+
val outerScope = this
167+
val flow = channelFlow {
168+
// ~ callback-based API, no children
169+
outerScope.launch(Job()) {
170+
expect(2)
171+
send(1)
172+
expectUnreached()
173+
}
174+
expect(1)
175+
}
176+
assertEquals(emptyList(), flow.toList())
177+
finish(3)
178+
}
179+
180+
@Test
181+
fun testNotClosedPrematurely() = runTest {
182+
val outerScope = this
183+
val flow = channelFlow {
184+
// ~ callback-based API
185+
outerScope.launch(Job()) {
186+
expect(2)
187+
send(1)
188+
close()
189+
}
190+
expect(1)
191+
awaitClose()
192+
}
193+
194+
assertEquals(listOf(1), flow.toList())
195+
finish(3)
196+
}
163197
}

kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt

+17-7
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,35 @@ import kotlin.test.*
1212

1313
class FlowCallbackTest : TestBase() {
1414
@Test
15-
fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) {
15+
fun testClosedPrematurely() = runTest {
1616
val outerScope = this
17-
val flow = channelFlow {
17+
val flow = callbackFlow {
1818
// ~ callback-based API
1919
outerScope.launch(Job()) {
2020
expect(2)
21-
send(1)
22-
expectUnreached()
21+
try {
22+
send(1)
23+
expectUnreached()
24+
} catch (e: IllegalStateException) {
25+
expect(3)
26+
assertTrue(e.message!!.contains("awaitClose"))
27+
}
2328
}
2429
expect(1)
2530
}
26-
assertEquals(emptyList(), flow.toList())
27-
finish(3)
31+
try {
32+
flow.collect()
33+
} catch (e: IllegalStateException) {
34+
expect(4)
35+
assertTrue(e.message!!.contains("awaitClose"))
36+
}
37+
finish(5)
2838
}
2939

3040
@Test
3141
fun testNotClosedPrematurely() = runTest {
3242
val outerScope = this
33-
val flow = channelFlow<Int> {
43+
val flow = callbackFlow {
3444
// ~ callback-based API
3545
outerScope.launch(Job()) {
3646
expect(2)

kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class CallbackFlowTest : TestBase() {
3939
runCatching { it.offer(++i) }
4040
}
4141

42-
val flow = channelFlow<Int> {
42+
val flow = callbackFlow<Int> {
4343
api.start(channel)
4444
awaitClose {
4545
api.stop()

0 commit comments

Comments
 (0)