Skip to content

Commit 7ee2439

Browse files
committed
SharedFlow: Fix scenario with concurrent emitters and cancellation of subscriber
* Added a specific test for a problematic scenario. * Added stress test with concurrent emitters and subscribers that come and go. Fixes #2356
1 parent 4ea4078 commit 7ee2439

File tree

3 files changed

+104
-1
lines changed

3 files changed

+104
-1
lines changed

kotlinx-coroutines-core/common/src/flow/SharedFlow.kt

+7-1
Original file line numberDiff line numberDiff line change
@@ -497,7 +497,13 @@ private class SharedFlowImpl<T>(
497497
}
498498
}
499499
// Compute new buffer size -> how many values we now actually have after resume
500-
val newBufferSize1 = (newBufferEndIndex - head).toInt()
500+
var newBufferSize1 = (newBufferEndIndex - head).toInt()
501+
// Note: When nCollectors == 0 we resume all queued emitters and we might have resumed more than bufferCapacity,
502+
// if which case we need to coerce the resulting buffer size and adjust newMinCollectorIndex
503+
if (nCollectors == 0 && newBufferSize1 > bufferCapacity) {
504+
newMinCollectorIndex += newBufferSize1 - bufferCapacity // adjust minCollectorIndex, too, to skip items
505+
newBufferSize1 = bufferCapacity
506+
}
501507
// Compute new replay size -> limit to replay the number of items we need, take into account that it can only grow
502508
var newReplayIndex = maxOf(replayIndex, newBufferEndIndex - minOf(replay, newBufferSize1))
503509
// adjustment for synchronous case with cancelled emitter (NO_VALUE)

kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowScenarioTest.kt

+20
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,26 @@ class SharedFlowScenarioTest : TestBase() {
201201
emitResumes(e3); expectReplayOf(3)
202202
}
203203

204+
@Test
205+
fun testSuspendedConcurrentEmitAndCancelSubscriber() =
206+
testSharedFlow<Int>(MutableSharedFlow(1)) {
207+
val a = subscribe("a");
208+
emitRightNow(0); expectReplayOf(0)
209+
collect(a, 0)
210+
emitRightNow(1); expectReplayOf(1)
211+
val e2 = emitSuspends(2) // suspends until 1 is collected
212+
val e3 = emitSuspends(3) // suspends until 1 is collected, too
213+
cancel(a) // must resume emitters 2 & 3
214+
emitResumes(e2)
215+
emitResumes(e3)
216+
expectReplayOf(3) // but replay size is 1 so only 3 should be kept
217+
// Note: originally, SharedFlow was in a broken state here with 3 elements in the buffer
218+
val b = subscribe("b")
219+
collect(b, 3)
220+
emitRightNow(4); expectReplayOf(4)
221+
collect(b, 4)
222+
}
223+
204224
private fun <T> testSharedFlow(
205225
sharedFlow: MutableSharedFlow<T>,
206226
scenario: suspend ScenarioDsl<T>.() -> Unit
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.flow
6+
7+
import kotlinx.atomicfu.*
8+
import kotlinx.coroutines.*
9+
import org.junit.*
10+
import org.junit.Test
11+
import kotlin.collections.ArrayList
12+
import kotlin.test.*
13+
import kotlin.time.*
14+
15+
@ExperimentalTime
16+
class SharedFlowStressTest : TestBase() {
17+
private val nProducers = 5
18+
private val nConsumers = 3
19+
private val nSeconds = 3 * stressTestMultiplier
20+
21+
private val sf: MutableSharedFlow<Long> = MutableSharedFlow(1)
22+
private val view: SharedFlow<Long> = sf.asSharedFlow()
23+
24+
@get:Rule
25+
val producerDispatcher = ExecutorRule(nProducers)
26+
@get:Rule
27+
val consumerDispatcher = ExecutorRule(nConsumers)
28+
29+
private val totalProduced = atomic(0L)
30+
private val totalConsumed = atomic(0L)
31+
32+
@Test
33+
fun testStress() = runTest {
34+
val jobs = ArrayList<Job>()
35+
jobs += List(nProducers) { producerIndex ->
36+
launch(producerDispatcher) {
37+
var cur = producerIndex.toLong()
38+
while (isActive) {
39+
sf.emit(cur)
40+
totalProduced.incrementAndGet()
41+
cur += nProducers
42+
}
43+
}
44+
}
45+
jobs += List(nConsumers) { consumerIndex ->
46+
launch(consumerDispatcher) {
47+
while (isActive) {
48+
view
49+
.dropWhile { it % nConsumers != consumerIndex.toLong() }
50+
.take(1)
51+
.collect {
52+
check(it % nConsumers == consumerIndex.toLong())
53+
totalConsumed.incrementAndGet()
54+
}
55+
}
56+
}
57+
}
58+
var lastProduced = 0L
59+
var lastConsumed = 0L
60+
for (sec in 1..nSeconds) {
61+
delay(1.seconds)
62+
val produced = totalProduced.value
63+
val consumed = totalConsumed.value
64+
println("$sec sec: produced = $produced; consumed = $consumed")
65+
assertNotEquals(lastProduced, produced)
66+
assertNotEquals(lastConsumed, consumed)
67+
lastProduced = produced
68+
lastConsumed = consumed
69+
}
70+
jobs.forEach { it.cancel() }
71+
jobs.forEach { it.join() }
72+
println("total: produced = ${totalProduced.value}; consumed = ${totalConsumed.value}")
73+
}
74+
75+
private fun showStats(s: String) {
76+
}
77+
}

0 commit comments

Comments
 (0)