Skip to content

Commit 0905c62

Browse files
authored
Properly enforce flow invariant when flow is used from "suspend fun m… (#1426)
* Properly enforce flow invariant when flow is used from "suspend fun main" or artificially started coroutine (e.g. by block.startCoroutine(...)) Fixes #1421
1 parent 1681cad commit 0905c62

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

kotlinx-coroutines-core/common/src/flow/internal/SafeCollector.kt

+7-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,13 @@ internal class SafeCollector<T>(
8181
"FlowCollector is not thread-safe and concurrent emissions are prohibited. To mitigate this restriction please use 'channelFlow' builder instead of 'flow'"
8282
)
8383
}
84-
count + 1
84+
85+
/*
86+
* If collect job is null (-> EmptyCoroutineContext, probably run from `suspend fun main`), then invariant is maintained
87+
* (common transitive parent is "null"), but count check will fail, so just do not count job context element when
88+
* flow is collected from EmptyCoroutineContext
89+
*/
90+
if (collectJob == null) count else count + 1
8591
}
8692
if (result != collectContextSize) {
8793
error(

kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt

+65
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ package kotlinx.coroutines.flow
66

77
import kotlinx.coroutines.*
88
import kotlinx.coroutines.channels.*
9+
import kotlinx.coroutines.intrinsics.*
910
import kotlin.coroutines.*
1011
import kotlin.reflect.*
1112
import kotlin.test.*
@@ -214,4 +215,68 @@ class FlowInvariantsTest : TestBase() {
214215
}
215216
}
216217
}
218+
219+
@Test
220+
fun testEmptyCoroutineContext() = runTest {
221+
emptyContextTest {
222+
map {
223+
expect(it)
224+
it + 1
225+
}
226+
}
227+
}
228+
229+
@Test
230+
fun testEmptyCoroutineContextTransform() = runTest {
231+
emptyContextTest {
232+
transform {
233+
expect(it)
234+
emit(it + 1)
235+
}
236+
}
237+
}
238+
239+
@Test
240+
fun testEmptyCoroutineContextViolation() = runTest {
241+
try {
242+
emptyContextTest {
243+
transform {
244+
expect(it)
245+
kotlinx.coroutines.withContext(Dispatchers.Unconfined) {
246+
emit(it + 1)
247+
}
248+
}
249+
}
250+
expectUnreached()
251+
} catch (e: IllegalStateException) {
252+
assertTrue(e.message!!.contains("Flow invariant is violated"))
253+
finish(2)
254+
}
255+
}
256+
257+
private suspend fun emptyContextTest(block: Flow<Int>.() -> Flow<Int>) {
258+
suspend fun collector(): Int {
259+
var result: Int = -1
260+
channelFlow {
261+
send(1)
262+
}.block()
263+
.collect {
264+
expect(it)
265+
result = it
266+
}
267+
return result
268+
}
269+
270+
val result = runSuspendFun { collector() }
271+
assertEquals(2, result)
272+
finish(3)
273+
}
274+
275+
private suspend fun runSuspendFun(block: suspend () -> Int): Int {
276+
val baseline = Result.failure<Int>(IllegalStateException("Block was suspended"))
277+
var result: Result<Int> = baseline
278+
block.startCoroutineUnintercepted(Continuation(EmptyCoroutineContext) { result = it })
279+
while (result == baseline) yield()
280+
return result.getOrThrow()
281+
}
217282
}

0 commit comments

Comments
 (0)