diff --git a/kotlinx-coroutines-core/common/src/flow/internal/SafeCollector.kt b/kotlinx-coroutines-core/common/src/flow/internal/SafeCollector.kt index 09a63781f0..8761058e71 100644 --- a/kotlinx-coroutines-core/common/src/flow/internal/SafeCollector.kt +++ b/kotlinx-coroutines-core/common/src/flow/internal/SafeCollector.kt @@ -81,7 +81,13 @@ internal class SafeCollector( "FlowCollector is not thread-safe and concurrent emissions are prohibited. To mitigate this restriction please use 'channelFlow' builder instead of 'flow'" ) } - count + 1 + + /* + * If collect job is null (-> EmptyCoroutineContext, probably run from `suspend fun main`), then invariant is maintained + * (common transitive parent is "null"), but count check will fail, so just do not count job context element when + * flow is collected from EmptyCoroutineContext + */ + if (collectJob == null) count else count + 1 } if (result != collectContextSize) { error( diff --git a/kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt b/kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt index 98406869e5..e016b031b2 100644 --- a/kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt @@ -6,6 +6,7 @@ package kotlinx.coroutines.flow import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlinx.coroutines.intrinsics.* import kotlin.coroutines.* import kotlin.reflect.* import kotlin.test.* @@ -214,4 +215,68 @@ class FlowInvariantsTest : TestBase() { } } } + + @Test + fun testEmptyCoroutineContext() = runTest { + emptyContextTest { + map { + expect(it) + it + 1 + } + } + } + + @Test + fun testEmptyCoroutineContextTransform() = runTest { + emptyContextTest { + transform { + expect(it) + emit(it + 1) + } + } + } + + @Test + fun testEmptyCoroutineContextViolation() = runTest { + try { + emptyContextTest { + transform { + expect(it) + kotlinx.coroutines.withContext(Dispatchers.Unconfined) { + emit(it + 1) + } + } + } + expectUnreached() + } catch (e: IllegalStateException) { + assertTrue(e.message!!.contains("Flow invariant is violated")) + finish(2) + } + } + + private suspend fun emptyContextTest(block: Flow.() -> Flow) { + suspend fun collector(): Int { + var result: Int = -1 + channelFlow { + send(1) + }.block() + .collect { + expect(it) + result = it + } + return result + } + + val result = runSuspendFun { collector() } + assertEquals(2, result) + finish(3) + } + + private suspend fun runSuspendFun(block: suspend () -> Int): Int { + val baseline = Result.failure(IllegalStateException("Block was suspended")) + var result: Result = baseline + block.startCoroutineUnintercepted(Continuation(EmptyCoroutineContext) { result = it }) + while (result == baseline) yield() + return result.getOrThrow() + } }