Skip to content

Commit 18e3a4a

Browse files
committed
Mark Flow.collect as internal to prevent its direct implementation and provide AbstractFlow instead that enforces context preservation guarantees
1 parent 3216825 commit 18e3a4a

File tree

3 files changed

+89
-26
lines changed

3 files changed

+89
-26
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+6
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,12 @@ public final class kotlinx/coroutines/channels/TickerMode : java/lang/Enum {
780780
public static fun values ()[Lkotlinx/coroutines/channels/TickerMode;
781781
}
782782

783+
public abstract class kotlinx/coroutines/flow/AbstractFlow : kotlinx/coroutines/flow/Flow {
784+
public fun <init> ()V
785+
public final fun collect (Lkotlinx/coroutines/flow/FlowCollector;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
786+
public abstract fun collectSafely (Lkotlinx/coroutines/flow/FlowCollector;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
787+
}
788+
783789
public abstract interface class kotlinx/coroutines/flow/Flow {
784790
public abstract fun collect (Lkotlinx/coroutines/flow/FlowCollector;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
785791
}

kotlinx-coroutines-core/common/src/flow/Flow.kt

+47-3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
package kotlinx.coroutines.flow
66

77
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.internal.SafeCollector
9+
import kotlin.coroutines.*
810

911
/**
1012
* A cold asynchronous data stream that sequentially emits values
@@ -112,17 +114,59 @@ import kotlinx.coroutines.*
112114
@ExperimentalCoroutinesApi
113115
public interface Flow<out T> {
114116

117+
/**
118+
* Accepts the given [collector] and [emits][FlowCollector.emit] values into it.
119+
* This method should never be implemented or used directly.
120+
*
121+
* The only way to implement flow interface directly is to extend [AbstractFlow].
122+
* To collect it into the specific collector, either `collector.emitAll(flow)` or `collect { }` extension should be used.
123+
* Such limitation ensures that context preservation property is not violated and prevents most of the developer mistakes
124+
* related to concurrency, inconsistent flow dispatchers and cancellation.
125+
*/
126+
@InternalCoroutinesApi
127+
public suspend fun collect(collector: FlowCollector<T>)
128+
}
129+
130+
/**
131+
* Base class to extend to have a stateful implementation of the flow.
132+
* It tracks all the properties required for context preservation and throws [IllegalStateException] if any of the properties are violated.
133+
* Example of the implementation:
134+
* ```
135+
* // list.asFlow() + collect counter
136+
* class CountingListFlow(private val values: List<Int>) : AbstractFlow<Int>() {
137+
* private val collectedCounter = AtomicInteger(0)
138+
*
139+
* override suspend fun collectSafely(collector: FlowCollector<Int>) {
140+
* collectedCounter.incrementAndGet() // Increment collected counter
141+
* values.forEach { // Emit all the values
142+
* collector.emit(it)
143+
* }
144+
* }
145+
*
146+
* fun toDiagnosticString(): String = "Flow with values $values was collected ${collectedCounter.value} times"
147+
* }
148+
* ```
149+
*/
150+
@FlowPreview
151+
public abstract class AbstractFlow<T> : Flow<T> {
152+
153+
@InternalCoroutinesApi
154+
public final override suspend fun collect(collector: FlowCollector<T>) {
155+
collectSafely(SafeCollector(collector, collectContext = coroutineContext))
156+
}
157+
115158
/**
116159
* Accepts the given [collector] and [emits][FlowCollector.emit] values into it.
117160
*
118161
* A valid implementation of this method has the following constraints:
119162
* 1) It should not change the coroutine context (e.g. with `withContext(Dispatchers.IO)`) when emitting values.
120163
* The emission should happen in the context of the [collect] call.
121164
* Please refer to the top-level [Flow] documentation for more details.
122-
*
123165
* 2) It should serialize calls to [emit][FlowCollector.emit] as [FlowCollector] implementations are not
124-
* thread safe by default.
166+
* thread-safe by default.
125167
* To automatically serialize emissions [channelFlow] builder can be used instead of [flow]
168+
*
169+
* @throws IllegalStateException if any of the invariants are violated.
126170
*/
127-
public suspend fun collect(collector: FlowCollector<T>)
171+
public abstract suspend fun collectSafely(collector: FlowCollector<T>)
128172
}

kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt

+36-23
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,37 @@ package kotlinx.coroutines.flow
77
import kotlinx.coroutines.*
88
import kotlinx.coroutines.channels.*
99
import kotlin.coroutines.*
10+
import kotlin.reflect.*
1011
import kotlin.test.*
1112

1213
class FlowInvariantsTest : TestBase() {
1314

15+
private fun <T> runParametrizedTest(
16+
expectedException: KClass<out Throwable>? = null,
17+
testBody: suspend (flowFactory: (suspend FlowCollector<T>.() -> Unit) -> Flow<T>) -> Unit
18+
) = runTest {
19+
val r1 = runCatching { testBody { flow(it) } }.exceptionOrNull()
20+
check(r1, expectedException)
21+
reset()
22+
23+
val r2 = runCatching { testBody { abstractFlow(it) } }.exceptionOrNull()
24+
check(r2, expectedException)
25+
}
26+
27+
private fun <T> abstractFlow(block: suspend FlowCollector<T>.() -> Unit): Flow<T> = object : AbstractFlow<T>() {
28+
override suspend fun collectSafely(collector: FlowCollector<T>) {
29+
collector.block()
30+
}
31+
}
32+
33+
private fun check(exception: Throwable?, expectedException: KClass<out Throwable>?) {
34+
if (expectedException != null && exception == null) fail("Expected $expectedException, but test completed successfully")
35+
if (expectedException != null && exception != null) assertTrue(expectedException.isInstance(exception))
36+
if (expectedException == null && exception != null) throw exception
37+
}
38+
1439
@Test
15-
fun testWithContextContract() = runTest({ it is IllegalStateException }) {
40+
fun testWithContextContract() = runParametrizedTest<Int>(IllegalStateException::class) { flow ->
1641
flow {
1742
kotlinx.coroutines.withContext(NonCancellable) {
1843
emit(1)
@@ -23,7 +48,7 @@ class FlowInvariantsTest : TestBase() {
2348
}
2449

2550
@Test
26-
fun testWithDispatcherContractViolated() = runTest({ it is IllegalStateException }) {
51+
fun testWithDispatcherContractViolated() = runParametrizedTest<Int>(IllegalStateException::class) { flow ->
2752
flow {
2853
kotlinx.coroutines.withContext(NamedDispatchers("foo")) {
2954
emit(1)
@@ -34,7 +59,7 @@ class FlowInvariantsTest : TestBase() {
3459
}
3560

3661
@Test
37-
fun testCachedInvariantCheckResult() = runTest {
62+
fun testCachedInvariantCheckResult() = runParametrizedTest<Int> { flow ->
3863
flow {
3964
emit(1)
4065

@@ -55,7 +80,7 @@ class FlowInvariantsTest : TestBase() {
5580
}
5681

5782
@Test
58-
fun testWithNameContractViolated() = runTest({ it is IllegalStateException }) {
83+
fun testWithNameContractViolated() = runParametrizedTest<Int>(IllegalStateException::class) { flow ->
5984
flow {
6085
kotlinx.coroutines.withContext(CoroutineName("foo")) {
6186
emit(1)
@@ -86,25 +111,25 @@ class FlowInvariantsTest : TestBase() {
86111
}
87112

88113
@Test
89-
fun testScopedJob() = runTest({ it is IllegalStateException }) {
90-
flow { emit(1) }.buffer(EmptyCoroutineContext).collect {
114+
fun testScopedJob() = runParametrizedTest<Int>(IllegalStateException::class) { flow ->
115+
flow { emit(1) }.buffer(EmptyCoroutineContext, flow).collect {
91116
expect(1)
92117
}
93118

94119
finish(2)
95120
}
96121

97122
@Test
98-
fun testScopedJobWithViolation() = runTest({ it is IllegalStateException }) {
99-
flow { emit(1) }.buffer(Dispatchers.Unconfined).collect {
123+
fun testScopedJobWithViolation() = runParametrizedTest<Int>(IllegalStateException::class) { flow ->
124+
flow { emit(1) }.buffer(Dispatchers.Unconfined, flow).collect {
100125
expect(1)
101126
}
102127

103128
finish(2)
104129
}
105130

106131
@Test
107-
fun testMergeViolation() = runTest {
132+
fun testMergeViolation() = runParametrizedTest<Int> { flow ->
108133
fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = flow {
109134
coroutineScope {
110135
launch {
@@ -130,17 +155,6 @@ class FlowInvariantsTest : TestBase() {
130155
assertFailsWith<IllegalStateException> { flow.trickyMerge(flow).toList() }
131156
}
132157

133-
// TODO merge artifact
134-
private fun <T> channelFlow(bufferSize: Int = 16, @BuilderInference block: suspend ProducerScope<T>.() -> Unit): Flow<T> =
135-
flow {
136-
coroutineScope {
137-
val channel = produce(capacity = bufferSize, block = block)
138-
channel.consumeEach { value ->
139-
emit(value)
140-
}
141-
}
142-
}
143-
144158
@Test
145159
fun testNoMergeViolation() = runTest {
146160
fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = channelFlow {
@@ -167,7 +181,7 @@ class FlowInvariantsTest : TestBase() {
167181
}
168182

169183
@Test
170-
fun testScopedCoroutineNoViolation() = runTest {
184+
fun testScopedCoroutineNoViolation() = runParametrizedTest<Int> { flow ->
171185
fun Flow<Int>.buffer(): Flow<Int> = flow {
172186
coroutineScope {
173187
val channel = produce {
@@ -180,11 +194,10 @@ class FlowInvariantsTest : TestBase() {
180194
}
181195
}
182196
}
183-
184197
assertEquals(listOf(1, 1), flowOf(1, 1).buffer().toList())
185198
}
186199

187-
private fun Flow<Int>.buffer(coroutineContext: CoroutineContext): Flow<Int> = flow {
200+
private fun Flow<Int>.buffer(coroutineContext: CoroutineContext, flow: (suspend FlowCollector<Int>.() -> Unit) -> Flow<Int>): Flow<Int> = flow {
188201
coroutineScope {
189202
val channel = Channel<Int>()
190203
launch {

0 commit comments

Comments
 (0)