Skip to content

Commit 4aa3880

Browse files
authored
Detect missing awaitClose calls in callbackFlow and close channel wit… (#1771)
Fixes #1762 Fixes #1770
1 parent 0126dba commit 4aa3880

File tree

10 files changed

+139
-71
lines changed

10 files changed

+139
-71
lines changed

kotlinx-coroutines-core/common/src/channels/Channel.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -586,4 +586,4 @@ public class ClosedSendChannelException(message: String?) : IllegalStateExceptio
586586
*
587587
* This exception is a subclass of [NoSuchElementException] to be consistent with plain collections.
588588
*/
589-
public class ClosedReceiveChannelException(message: String?) : NoSuchElementException(message)
589+
public class ClosedReceiveChannelException(message: String?) : NoSuchElementException(message)

kotlinx-coroutines-core/common/src/channels/Produce.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public interface ProducerScope<in E> : CoroutineScope, SendChannel<E> {
2727

2828
/**
2929
* Suspends the current coroutine until the channel is either [closed][SendChannel.close] or [cancelled][ReceiveChannel.cancel]
30-
* and invokes the given [block] before resuming the coroutine.
30+
* and invokes the given [block] before resuming the coroutine. This suspending function is cancellable.
3131
*
3232
* Note that when the producer channel is cancelled, this function resumes with a cancellation exception.
3333
* Therefore, in case of cancellation, no code after the call to this function will be executed.

kotlinx-coroutines-core/common/src/flow/Builders.kt

+53-16
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ import kotlinx.coroutines.*
1111
import kotlinx.coroutines.channels.*
1212
import kotlinx.coroutines.channels.Channel.Factory.BUFFERED
1313
import kotlinx.coroutines.flow.internal.*
14-
import kotlinx.coroutines.flow.internal.unsafeFlow as flow
1514
import kotlin.coroutines.*
1615
import kotlin.jvm.*
16+
import kotlinx.coroutines.flow.internal.unsafeFlow as flow
1717

1818
/**
1919
* Creates a flow from the given suspendable [block].
@@ -259,10 +259,16 @@ public fun <T> channelFlow(@BuilderInference block: suspend ProducerScope<T>.()
259259
*
260260
* This builder ensures thread-safety and context preservation, thus the provided [ProducerScope] can be used
261261
* from any context, e.g. from a callback-based API.
262-
* The resulting flow completes as soon as the code in the [block] and all its children completes.
263-
* Use [awaitClose] as the last statement to keep it running.
264-
* The [awaitClose] argument is called either when a flow consumer cancels the flow collection
265-
* or when a callback-based API invokes [SendChannel.close] manually.
262+
* The resulting flow completes as soon as the code in the [block] completes.
263+
* [awaitClose] should be used to keep the flow running, otherwise the channel will be closed immediately
264+
* when block completes.
265+
* [awaitClose] argument is called either when a flow consumer cancels the flow collection
266+
* or when a callback-based API invokes [SendChannel.close] manually and is typically used
267+
* to cleanup the resources after the completion, e.g. unregister a callback.
268+
* Using [awaitClose] is mandatory in order to prevent memory leaks when the flow collection is cancelled,
269+
* otherwise the callback may keep running even when the flow collector is already completed.
270+
* To avoid such leaks, this method throws [IllegalStateException] if block returns, but the channel
271+
* is not closed yet.
266272
*
267273
* A channel with the [default][Channel.BUFFERED] buffer size is used. Use the [buffer] operator on the
268274
* resulting flow to specify a user-defined value and to control what happens when data is produced faster
@@ -277,31 +283,34 @@ public fun <T> channelFlow(@BuilderInference block: suspend ProducerScope<T>.()
277283
* fun flowFrom(api: CallbackBasedApi): Flow<T> = callbackFlow {
278284
* val callback = object : Callback { // implementation of some callback interface
279285
* override fun onNextValue(value: T) {
280-
* // Note: offer drops value when buffer is full
281-
* // Use either buffer(Channel.CONFLATED) or buffer(Channel.UNLIMITED) to avoid overfill
282-
* offer(value)
286+
* // To avoid blocking you can configure channel capacity using
287+
* // either buffer(Channel.CONFLATED) or buffer(Channel.UNLIMITED) to avoid overfill
288+
* try {
289+
* sendBlocking(value)
290+
* } catch (e: Exception) {
291+
* // Handle exception from the channel: failure in flow or premature closing
292+
* }
283293
* }
284294
* override fun onApiError(cause: Throwable) {
285295
* cancel(CancellationException("API Error", cause))
286296
* }
287297
* override fun onCompleted() = channel.close()
288298
* }
289299
* api.register(callback)
290-
* // Suspend until either onCompleted or external cancellation are invoked
300+
* /*
301+
* * Suspends until either 'onCompleted'/'onApiError' from the callback is invoked
302+
* * or flow collector is cancelled (e.g. by 'take(1)' or because a collector's coroutine was cancelled).
303+
* * In both cases, callback will be properly unregistered.
304+
* */
291305
* awaitClose { api.unregister(callback) }
292306
* }
293307
* ```
294-
*
295-
* This function is an alias for [channelFlow], it has a separate name to reflect
296-
* the intent of the usage (integration with a callback-based API) better.
297308
*/
298-
@Suppress("NOTHING_TO_INLINE")
299309
@ExperimentalCoroutinesApi
300-
public inline fun <T> callbackFlow(@BuilderInference noinline block: suspend ProducerScope<T>.() -> Unit): Flow<T> =
301-
channelFlow(block)
310+
public fun <T> callbackFlow(@BuilderInference block: suspend ProducerScope<T>.() -> Unit): Flow<T> = CallbackFlowBuilder(block)
302311

303312
// ChannelFlow implementation that is the first in the chain of flow operations and introduces (builds) a flow
304-
private class ChannelFlowBuilder<T>(
313+
private open class ChannelFlowBuilder<T>(
305314
private val block: suspend ProducerScope<T>.() -> Unit,
306315
context: CoroutineContext = EmptyCoroutineContext,
307316
capacity: Int = BUFFERED
@@ -315,3 +324,31 @@ private class ChannelFlowBuilder<T>(
315324
override fun toString(): String =
316325
"block[$block] -> ${super.toString()}"
317326
}
327+
328+
private class CallbackFlowBuilder<T>(
329+
private val block: suspend ProducerScope<T>.() -> Unit,
330+
context: CoroutineContext = EmptyCoroutineContext,
331+
capacity: Int = BUFFERED
332+
) : ChannelFlowBuilder<T>(block, context, capacity) {
333+
334+
override suspend fun collectTo(scope: ProducerScope<T>) {
335+
super.collectTo(scope)
336+
/*
337+
* We expect user either call `awaitClose` from within a block (then the channel is closed at this moment)
338+
* or being closed/cancelled externally/manually. Otherwise "user forgot to call
339+
* awaitClose and receives unhelpful ClosedSendChannelException exceptions" situation is detected.
340+
*/
341+
if (!scope.isClosedForSend) {
342+
throw IllegalStateException(
343+
"""
344+
'awaitClose { yourCallbackOrListener.cancel() }' should be used in the end of callbackFlow block.
345+
Otherwise, a callback/listener may leak in case of external cancellation.
346+
See callbackFlow API documentation for the details.
347+
""".trimIndent()
348+
)
349+
}
350+
}
351+
352+
override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
353+
CallbackFlowBuilder(block, context, capacity)
354+
}

kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt

+11-10
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ public abstract class ChannelFlow<T>(
2727
// buffer capacity between upstream and downstream context
2828
@JvmField val capacity: Int
2929
) : Flow<T> {
30+
31+
// shared code to create a suspend lambda from collectTo function in one place
32+
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
33+
get() = { collectTo(it) }
34+
35+
private val produceCapacity: Int
36+
get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity
37+
3038
public fun update(
3139
context: CoroutineContext = EmptyCoroutineContext,
3240
capacity: Int = Channel.OPTIONAL_CHANNEL
@@ -57,13 +65,6 @@ public abstract class ChannelFlow<T>(
5765

5866
protected abstract suspend fun collectTo(scope: ProducerScope<T>)
5967

60-
// shared code to create a suspend lambda from collectTo function in one place
61-
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
62-
get() = { collectTo(it) }
63-
64-
private val produceCapacity: Int
65-
get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity
66-
6768
open fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel<T> =
6869
scope.broadcast(context, produceCapacity, start, block = collectToFun)
6970

@@ -75,11 +76,11 @@ public abstract class ChannelFlow<T>(
7576
collector.emitAll(produceImpl(this))
7677
}
7778

79+
open fun additionalToStringProps() = ""
80+
7881
// debug toString
7982
override fun toString(): String =
8083
"$classSimpleName[${additionalToStringProps()}context=$context, capacity=$capacity]"
81-
82-
open fun additionalToStringProps() = ""
8384
}
8485

8586
// ChannelFlow implementation that operates on another flow before it
@@ -161,7 +162,7 @@ private suspend fun <T, V> withContextUndispatched(
161162
countOrElement: Any = threadContextElements(newContext), // can be precomputed for speed
162163
block: suspend (V) -> T, value: V
163164
): T =
164-
suspendCoroutineUninterceptedOrReturn sc@{ uCont ->
165+
suspendCoroutineUninterceptedOrReturn { uCont ->
165166
withCoroutineContext(newContext, countOrElement) {
166167
block.startCoroutineUninterceptedOrReturn(value, Continuation(newContext) {
167168
uCont.resumeWith(it)

kotlinx-coroutines-core/common/test/channels/ProduceTest.kt

+14-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
package kotlinx.coroutines.channels
66

77
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.*
89
import kotlin.coroutines.*
910
import kotlin.test.*
1011

@@ -143,9 +144,20 @@ class ProduceTest : TestBase() {
143144

144145
@Test
145146
fun testAwaitIllegalState() = runTest {
146-
val channel = produce<Int> { }
147-
@Suppress("RemoveExplicitTypeArguments") // KT-31525
147+
val channel = produce<Int> { }
148148
assertFailsWith<IllegalStateException> { (channel as ProducerScope<*>).awaitClose() }
149+
callbackFlow<Unit> {
150+
expect(1)
151+
launch {
152+
expect(2)
153+
assertFailsWith<IllegalStateException> {
154+
awaitClose { expectUnreached() }
155+
expectUnreached()
156+
}
157+
}
158+
close()
159+
}.collect()
160+
finish(3)
149161
}
150162

151163
private suspend fun cancelOnCompletion(coroutineContext: CoroutineContext) = CoroutineScope(coroutineContext).apply {

kotlinx-coroutines-core/common/test/flow/channels/ChannelFlowTest.kt

+34
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,38 @@ class ChannelFlowTest : TestBase() {
160160

161161
finish(6)
162162
}
163+
164+
@Test
165+
fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) {
166+
val outerScope = this
167+
val flow = channelFlow {
168+
// ~ callback-based API, no children
169+
outerScope.launch(Job()) {
170+
expect(2)
171+
send(1)
172+
expectUnreached()
173+
}
174+
expect(1)
175+
}
176+
assertEquals(emptyList(), flow.toList())
177+
finish(3)
178+
}
179+
180+
@Test
181+
fun testNotClosedPrematurely() = runTest {
182+
val outerScope = this
183+
val flow = channelFlow {
184+
// ~ callback-based API
185+
outerScope.launch(Job()) {
186+
expect(2)
187+
send(1)
188+
close()
189+
}
190+
expect(1)
191+
awaitClose()
192+
}
193+
194+
assertEquals(listOf(1), flow.toList())
195+
finish(3)
196+
}
163197
}

kotlinx-coroutines-core/common/test/flow/channels/FlowCallbackTest.kt

+17-7
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,35 @@ import kotlin.test.*
1212

1313
class FlowCallbackTest : TestBase() {
1414
@Test
15-
fun testClosedPrematurely() = runTest(unhandled = listOf({ e -> e is ClosedSendChannelException })) {
15+
fun testClosedPrematurely() = runTest {
1616
val outerScope = this
17-
val flow = channelFlow {
17+
val flow = callbackFlow {
1818
// ~ callback-based API
1919
outerScope.launch(Job()) {
2020
expect(2)
21-
send(1)
22-
expectUnreached()
21+
try {
22+
send(1)
23+
expectUnreached()
24+
} catch (e: IllegalStateException) {
25+
expect(3)
26+
assertTrue(e.message!!.contains("awaitClose"))
27+
}
2328
}
2429
expect(1)
2530
}
26-
assertEquals(emptyList(), flow.toList())
27-
finish(3)
31+
try {
32+
flow.collect()
33+
} catch (e: IllegalStateException) {
34+
expect(4)
35+
assertTrue(e.message!!.contains("awaitClose"))
36+
}
37+
finish(5)
2838
}
2939

3040
@Test
3141
fun testNotClosedPrematurely() = runTest {
3242
val outerScope = this
33-
val flow = channelFlow<Int> {
43+
val flow = callbackFlow {
3444
// ~ callback-based API
3545
outerScope.launch(Job()) {
3646
expect(2)

kotlinx-coroutines-core/jvm/src/channels/Channels.kt

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
2+
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

55
@file:JvmMultifileClass
@@ -9,8 +9,6 @@ package kotlinx.coroutines.channels
99

1010
import kotlinx.coroutines.*
1111

12-
// -------- Operations on SendChannel --------
13-
1412
/**
1513
* Adds [element] into to this channel, **blocking** the caller while this channel [Channel.isFull],
1614
* or throws exception if the channel [Channel.isClosedForSend] (see [Channel.close] for details).

kotlinx-coroutines-core/jvm/test/AsyncJvmTest.kt

+5-29
Original file line numberDiff line numberDiff line change
@@ -10,36 +10,12 @@ class AsyncJvmTest : TestBase() {
1010
// This must be a common test but it fails on JS because of KT-21961
1111
@Test
1212
fun testAsyncWithFinally() = runTest {
13-
expect(1)
13+
launch(Dispatchers.Default) {
14+
15+
}
16+
17+
launch(Dispatchers.IO) {
1418

15-
@Suppress("UNREACHABLE_CODE")
16-
val d = async {
17-
expect(3)
18-
try {
19-
yield() // to main, will cancel
20-
} finally {
21-
expect(6) // will go there on await
22-
return@async "Fail" // result will not override cancellation
23-
}
24-
expectUnreached()
25-
"Fail2"
26-
}
27-
expect(2)
28-
yield() // to async
29-
expect(4)
30-
check(d.isActive && !d.isCompleted && !d.isCancelled)
31-
d.cancel()
32-
check(!d.isActive && !d.isCompleted && d.isCancelled)
33-
check(!d.isActive && !d.isCompleted && d.isCancelled)
34-
expect(5)
35-
try {
36-
d.await() // awaits
37-
expectUnreached() // does not complete normally
38-
} catch (e: Throwable) {
39-
expect(7)
40-
check(e is CancellationException)
4119
}
42-
check(!d.isActive && d.isCompleted && d.isCancelled)
43-
finish(8)
4420
}
4521
}

kotlinx-coroutines-core/jvm/test/flow/CallbackFlowTest.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class CallbackFlowTest : TestBase() {
3939
runCatching { it.offer(++i) }
4040
}
4141

42-
val flow = channelFlow<Int> {
42+
val flow = callbackFlow<Int> {
4343
api.start(channel)
4444
awaitClose {
4545
api.stop()
@@ -118,7 +118,7 @@ class CallbackFlowTest : TestBase() {
118118
}
119119
}
120120

121-
private fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = callbackFlow {
121+
private fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = channelFlow {
122122
launch {
123123
collect { send(it) }
124124
}

0 commit comments

Comments
 (0)