Skip to content

Commit 3e53dc5

Browse files
committed
Strengthen flow context preservation invariant
* Add additional check in SafeCollector with an error message pointing to channelFlow * Improve performance of the CoroutineId check in SafeCollector Fixes #1210
1 parent ea99784 commit 3e53dc5

File tree

5 files changed

+96
-14
lines changed

5 files changed

+96
-14
lines changed

kotlinx-coroutines-core/common/src/flow/Flow.kt

+2-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import kotlinx.coroutines.*
2727
* trigger their evaluation every time [collect] is executed) or hot ones, but, conventionally, they represent cold streams.
2828
* Transitions between hot and cold streams are supported via channels and the corresponding API: [flowViaChannel], [broadcastIn], [produceIn].
2929
*
30-
* The flow has a context preserving property: it encapsulates its own execution context and never propagates or leaks it downstream, thus making
30+
* The flow has a context preservation property: it encapsulates its own execution context and never propagates or leaks it downstream, thus making
3131
* reasoning about the execution context of particular transformations or terminal operations trivial.
3232
*
3333
* There are two ways to change the context of a flow: [flowOn][Flow.flowOn] and [flowWith][Flow.flowWith].
@@ -104,6 +104,7 @@ public interface Flow<out T> {
104104
* is a proper [Flow] implementation, but using `launch(Dispatchers.IO)` is not.
105105
*
106106
* 2) It should serialize calls to [emit][FlowCollector.emit] as [FlowCollector] implementations are not thread safe by default.
107+
* To automatically serialize emissions [channelFlow] builder can be used instead of [flow]
107108
*/
108109
public suspend fun collect(collector: FlowCollector<T>)
109110
}

kotlinx-coroutines-core/common/src/flow/internal/SafeCollector.kt

+26-12
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,50 @@ package kotlinx.coroutines.flow.internal
66

77
import kotlinx.coroutines.*
88
import kotlinx.coroutines.flow.*
9-
import kotlinx.coroutines.internal.*
9+
import kotlinx.coroutines.internal.ScopeCoroutine
1010
import kotlin.coroutines.*
1111

1212
@PublishedApi
1313
internal class SafeCollector<T>(
1414
private val collector: FlowCollector<T>,
1515
collectContext: CoroutineContext
16-
) : FlowCollector<T>, SynchronizedObject() {
16+
) : FlowCollector<T> {
1717

1818
private val collectContext = collectContext.minusKey(Job).minusId()
19-
private var lastObservedContext: CoroutineContext? = null
19+
private var lastEmissionContext: CoroutineContext? = null
2020

2121
override suspend fun emit(value: T) {
2222
/*
2323
* Benign data-race here:
2424
* We read potentially racy published coroutineContext, but we only use it for
25-
* referential comparison (=> thus safe) and are not using it for actual comparisons.
25+
* referential comparison (=> thus safe) and are not using it for structural comparisons.
2626
*/
2727
val currentContext = coroutineContext
28-
if (lastObservedContext !== currentContext) {
28+
val observedContext = lastEmissionContext
29+
if (observedContext !== currentContext) {
30+
if (observedContext !== null) checkJobs(observedContext, currentContext)
2931
val emitContext = currentContext.minusKey(Job).minusId()
3032
if (emitContext != collectContext) {
31-
error(
32-
"Flow invariant is violated: flow was collected in $collectContext, but emission happened in $emitContext. " +
33-
"Please refer to 'flow' documentation or use 'flowOn' instead"
34-
)
35-
}
36-
// Racy publication
37-
lastObservedContext = currentContext
33+
error(
34+
"Flow invariant is violated: flow was collected in $collectContext, but emission happened in $emitContext. " +
35+
"Please refer to 'flow' documentation or use 'flowOn' instead"
36+
)
37+
}
38+
lastEmissionContext = currentContext
3839
}
3940
collector.emit(value) // TCE
4041
}
42+
43+
private fun checkJobs(observedContext: CoroutineContext, currentContext: CoroutineContext) {
44+
val previousJob = observedContext[Job].transitiveCoroutineParent()
45+
val currentJob = currentContext[Job].transitiveCoroutineParent()
46+
check(previousJob === currentJob) { "Flow invariant is violated: emissions from different coroutines are detected ($currentContext and $lastEmissionContext). " +
47+
"FlowCollector is not thread-safe and concurrent emissions are prohibited. To mitigate this restriction please use 'flowChannel' builder instead of 'flow'" }
48+
}
49+
50+
private fun Job?.transitiveCoroutineParent(): Job? {
51+
if (this === null) return null
52+
if (this !is ScopeCoroutine<*>) return this
53+
return parent.transitiveCoroutineParent()
54+
}
4155
}

kotlinx-coroutines-core/common/src/internal/Scopes.kt

+2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ internal open class ScopeCoroutine<in T>(
1919
final override fun getStackTraceElement(): StackTraceElement? = null
2020
override val defaultResumeMode: Int get() = MODE_DIRECT
2121

22+
internal val parent: Job? get() = parentContext[Job]
23+
2224
override val cancelsParent: Boolean
2325
get() = false // it throws exception to parent instead of cancelling it
2426

kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt

+65
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,71 @@ class FlowInvariantsTest : TestBase() {
104104
finish(2)
105105
}
106106

107+
@Test
108+
fun testMergeViolation() = runTest {
109+
fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = flow {
110+
coroutineScope {
111+
launch {
112+
collect { value -> emit(value) }
113+
}
114+
other.collect { value -> emit(value) }
115+
}
116+
}
117+
118+
fun Flow<Int>.trickyMerge(other: Flow<Int>): Flow<Int> = flow {
119+
coroutineScope {
120+
launch {
121+
collect { value ->
122+
coroutineScope { emit(value) }
123+
}
124+
}
125+
other.collect { value -> emit(value) }
126+
}
127+
}
128+
129+
val flow = flowOf(1)
130+
assertFailsWith<IllegalStateException> { flow.merge(flow).toList() }
131+
assertFailsWith<IllegalStateException> { flow.trickyMerge(flow).toList() }
132+
}
133+
134+
135+
// TODO merge artifact
136+
fun <T> channelFlow(bufferSize: Int = 16, @BuilderInference block: suspend ProducerScope<T>.() -> Unit): Flow<T> =
137+
flow {
138+
coroutineScope {
139+
val channel = produce(capacity = bufferSize, block = block)
140+
channel.consumeEach { value ->
141+
emit(value)
142+
}
143+
}
144+
}
145+
146+
@Test
147+
fun testNoMergeViolation() = runTest {
148+
fun Flow<Int>.merge(other: Flow<Int>): Flow<Int> = channelFlow {
149+
launch {
150+
collect { value -> send(value) }
151+
}
152+
other.collect { value -> send(value) }
153+
}
154+
155+
fun Flow<Int>.trickyMerge(other: Flow<Int>): Flow<Int> = channelFlow {
156+
coroutineScope {
157+
launch {
158+
collect { value ->
159+
coroutineScope { send(value) }
160+
}
161+
}
162+
other.collect { value -> send(value) }
163+
}
164+
}
165+
166+
val flow = flowOf(1)
167+
assertEquals(listOf(1, 1), flow.merge(flow).toList())
168+
assertEquals(listOf(1, 1), flow.trickyMerge(flow).toList())
169+
}
170+
171+
107172
private fun Flow<Int>.buffer(coroutineContext: CoroutineContext): Flow<Int> = flow {
108173
coroutineScope {
109174
val channel = Channel<Int>()

kotlinx-coroutines-core/jvm/src/CoroutineContext.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ internal actual val CoroutineContext.coroutineName: String? get() {
8080
}
8181

8282
@Suppress("NOTHING_TO_INLINE")
83-
internal actual inline fun CoroutineContext.minusId(): CoroutineContext = minusKey(CoroutineId)
83+
internal actual inline fun CoroutineContext.minusId(): CoroutineContext = if (DEBUG) minusKey(CoroutineId) else this
8484

8585
private const val DEBUG_THREAD_NAME_SEPARATOR = " @"
8686

0 commit comments

Comments
 (0)