Skip to content

Commit 905b0cb

Browse files
qwwdfsadpablobaxter
authored andcommitted
Non-conflating subscription count in SharedFlow and StateFlow (Kotlin#2872)
* Non-conflating subscription count in SharedFlow and StateFlow Sharing strategies are too sensitive to conflation around extrema and may miss the necessity to start or not to stop the sharing. For more particular examples see Kotlin#2863 and Kotlin#2488 Fixes Kotlin#2488 Fixes Kotlin#2863 Fixes Kotlin#2871
1 parent 6f00cdf commit 905b0cb

File tree

9 files changed

+151
-21
lines changed

9 files changed

+151
-21
lines changed

kotlinx-coroutines-core/common/src/flow/SharedFlow.kt

+11-2
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,8 @@ public interface MutableSharedFlow<T> : SharedFlow<T>, FlowCollector<T> {
198198
* }
199199
* .launchIn(scope) // launch it
200200
* ```
201+
*
202+
* Implementation note: the resulting flow **does not** conflate subscription count.
201203
*/
202204
public val subscriptionCount: StateFlow<Int>
203205

@@ -253,7 +255,7 @@ public fun <T> MutableSharedFlow(
253255

254256
// ------------------------------------ Implementation ------------------------------------
255257

256-
private class SharedFlowSlot : AbstractSharedFlowSlot<SharedFlowImpl<*>>() {
258+
internal class SharedFlowSlot : AbstractSharedFlowSlot<SharedFlowImpl<*>>() {
257259
@JvmField
258260
var index = -1L // current "to-be-emitted" index, -1 means the slot is free now
259261

@@ -275,7 +277,7 @@ private class SharedFlowSlot : AbstractSharedFlowSlot<SharedFlowImpl<*>>() {
275277
}
276278
}
277279

278-
private class SharedFlowImpl<T>(
280+
internal open class SharedFlowImpl<T>(
279281
private val replay: Int,
280282
private val bufferCapacity: Int,
281283
private val onBufferOverflow: BufferOverflow
@@ -334,6 +336,13 @@ private class SharedFlowImpl<T>(
334336
result
335337
}
336338

339+
/*
340+
* A tweak for SubscriptionCountStateFlow to get the latest value.
341+
*/
342+
@Suppress("UNCHECKED_CAST")
343+
protected val lastReplayedLocked: T
344+
get() = buffer!!.getBufferAt(replayIndex + replaySize - 1) as T
345+
337346
@Suppress("UNCHECKED_CAST")
338347
override suspend fun collect(collector: FlowCollector<T>) {
339348
val slot = allocateSlot()

kotlinx-coroutines-core/common/src/flow/StateFlow.kt

-4
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,6 @@ private class StateFlowImpl<T>(
415415
fuseStateFlow(context, capacity, onBufferOverflow)
416416
}
417417

418-
internal fun MutableStateFlow<Int>.increment(delta: Int) {
419-
update { it + delta }
420-
}
421-
422418
internal fun <T> StateFlow<T>.fuseStateFlow(
423419
context: CoroutineContext,
424420
capacity: Int,

kotlinx-coroutines-core/common/src/flow/internal/AbstractSharedFlow.kt

+41-8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package kotlinx.coroutines.flow.internal
66

7+
import kotlinx.coroutines.channels.*
78
import kotlinx.coroutines.flow.*
89
import kotlinx.coroutines.internal.*
910
import kotlin.coroutines.*
@@ -26,12 +27,12 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
2627
protected var nCollectors = 0 // number of allocated (!free) slots
2728
private set
2829
private var nextIndex = 0 // oracle for the next free slot index
29-
private var _subscriptionCount: MutableStateFlow<Int>? = null // init on first need
30+
private var _subscriptionCount: SubscriptionCountStateFlow? = null // init on first need
3031

3132
val subscriptionCount: StateFlow<Int>
3233
get() = synchronized(this) {
3334
// allocate under lock in sync with nCollectors variable
34-
_subscriptionCount ?: MutableStateFlow(nCollectors).also {
35+
_subscriptionCount ?: SubscriptionCountStateFlow(nCollectors).also {
3536
_subscriptionCount = it
3637
}
3738
}
@@ -43,7 +44,7 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
4344
@Suppress("UNCHECKED_CAST")
4445
protected fun allocateSlot(): S {
4546
// Actually create slot under lock
46-
var subscriptionCount: MutableStateFlow<Int>? = null
47+
var subscriptionCount: SubscriptionCountStateFlow? = null
4748
val slot = synchronized(this) {
4849
val slots = when (val curSlots = slots) {
4950
null -> createSlotArray(2).also { slots = it }
@@ -74,7 +75,7 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
7475
@Suppress("UNCHECKED_CAST")
7576
protected fun freeSlot(slot: S) {
7677
// Release slot under lock
77-
var subscriptionCount: MutableStateFlow<Int>? = null
78+
var subscriptionCount: SubscriptionCountStateFlow? = null
7879
val resumes = synchronized(this) {
7980
nCollectors--
8081
subscriptionCount = _subscriptionCount // retrieve under lock if initialized
@@ -83,10 +84,10 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
8384
(slot as AbstractSharedFlowSlot<Any>).freeLocked(this)
8485
}
8586
/*
86-
Resume suspended coroutines.
87-
This can happens when the subscriber that was freed was a slow one and was holding up buffer.
88-
When this subscriber was freed, previously queued emitted can now wake up and are resumed here.
89-
*/
87+
* Resume suspended coroutines.
88+
* This can happen when the subscriber that was freed was a slow one and was holding up buffer.
89+
* When this subscriber was freed, previously queued emitted can now wake up and are resumed here.
90+
*/
9091
for (cont in resumes) cont?.resume(Unit)
9192
// decrement subscription count
9293
subscriptionCount?.increment(-1)
@@ -99,3 +100,35 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
99100
}
100101
}
101102
}
103+
104+
/**
105+
* [StateFlow] that represents the number of subscriptions.
106+
*
107+
* It is exposed as a regular [StateFlow] in our public API, but it is implemented as [SharedFlow] undercover to
108+
* avoid conflations of consecutive updates because the subscription count is very sensitive to it.
109+
*
110+
* The importance of non-conflating can be demonstrated with the following example:
111+
* ```
112+
* val shared = flowOf(239).stateIn(this, SharingStarted.Lazily, 42) // stateIn for the sake of the initial value
113+
* println(shared.first())
114+
* yield()
115+
* println(shared.first())
116+
* ```
117+
* If the flow is shared within the same dispatcher (e.g. Main) or with a slow/throttled one,
118+
* the `SharingStarted.Lazily` will never be able to start the source: `first` sees the initial value and immediately
119+
* unsubscribes, leaving the asynchronous `SharingStarted` with conflated zero.
120+
*
121+
* To avoid that (especially in a more complex scenarios), we do not conflate subscription updates.
122+
*/
123+
private class SubscriptionCountStateFlow(initialValue: Int) : StateFlow<Int>,
124+
SharedFlowImpl<Int>(1, Int.MAX_VALUE, BufferOverflow.DROP_OLDEST)
125+
{
126+
init { tryEmit(initialValue) }
127+
128+
override val value: Int
129+
get() = synchronized(this) { lastReplayedLocked }
130+
131+
fun increment(delta: Int) = synchronized(this) {
132+
tryEmit(lastReplayedLocked + delta)
133+
}
134+
}

kotlinx-coroutines-core/common/src/flow/operators/Share.kt

+11-2
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,16 @@ private fun <T> CoroutineScope.launchSharing(
197197
shared: MutableSharedFlow<T>,
198198
started: SharingStarted,
199199
initialValue: T
200-
): Job =
201-
launch(context) { // the single coroutine to rule the sharing
200+
): Job {
201+
/*
202+
* Conditional start: in the case when sharing and subscribing happens in the same dispatcher, we want to
203+
* have the following invariants preserved:
204+
* * Delayed sharing strategies have a chance to immediately observe consecutive subscriptions.
205+
* E.g. in the cases like `flow.shareIn(...); flow.take(1)` we want sharing strategy to see the initial subscription
206+
* * Eager sharing does not start immediately, so the subscribers have actual chance to subscribe _prior_ to sharing.
207+
*/
208+
val start = if (started == SharingStarted.Eagerly) CoroutineStart.DEFAULT else CoroutineStart.UNDISPATCHED
209+
return launch(context, start = start) { // the single coroutine to rule the sharing
202210
// Optimize common built-in started strategies
203211
when {
204212
started === SharingStarted.Eagerly -> {
@@ -230,6 +238,7 @@ private fun <T> CoroutineScope.launchSharing(
230238
}
231239
}
232240
}
241+
}
233242

234243
// -------------------------------- stateIn --------------------------------
235244

kotlinx-coroutines-core/common/test/flow/sharing/ShareInConflationTest.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class ShareInConflationTest : TestBase() {
2121
op: suspend Flow<Int>.(CoroutineScope) -> Flow<Int>
2222
) = runTest {
2323
expect(1)
24-
// emit all and conflate, then should collect bufferCapacity latest ones
24+
// emit all and conflate, then should collect bufferCapacity the latest ones
2525
val done = Job()
2626
flow {
2727
repeat(n) { i ->
@@ -159,4 +159,4 @@ class ShareInConflationTest : TestBase() {
159159
checkConflation(1, BufferOverflow.DROP_LATEST) {
160160
buffer(23).buffer(onBufferOverflow = BufferOverflow.DROP_LATEST).shareIn(it, SharingStarted.Eagerly, 0)
161161
}
162-
}
162+
}

kotlinx-coroutines-core/common/test/flow/sharing/ShareInTest.kt

+26
Original file line numberDiff line numberDiff line change
@@ -210,4 +210,30 @@ class ShareInTest : TestBase() {
210210
stop()
211211
}
212212
}
213+
214+
@Test
215+
fun testShouldStart() = runTest {
216+
val flow = flow {
217+
expect(2)
218+
emit(1)
219+
expect(3)
220+
}.shareIn(this, SharingStarted.Lazily)
221+
222+
expect(1)
223+
flow.onSubscription { throw CancellationException("") }
224+
.catch { e -> assertTrue { e is CancellationException } }
225+
.collect()
226+
yield()
227+
finish(4)
228+
}
229+
230+
@Test
231+
fun testShouldStartScalar() = runTest {
232+
val j = Job()
233+
val shared = flowOf(239).stateIn(this + j, SharingStarted.Lazily, 42)
234+
assertEquals(42, shared.first())
235+
yield()
236+
assertEquals(239, shared.first())
237+
j.cancel()
238+
}
213239
}

kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowTest.kt

+20
Original file line numberDiff line numberDiff line change
@@ -798,4 +798,24 @@ class SharedFlowTest : TestBase() {
798798
job.join()
799799
finish(5)
800800
}
801+
802+
@Test
803+
fun testSubscriptionCount() = runTest {
804+
val flow = MutableSharedFlow<Int>()
805+
fun startSubscriber() = launch(start = CoroutineStart.UNDISPATCHED) { flow.collect() }
806+
807+
assertEquals(0, flow.subscriptionCount.first())
808+
809+
val j1 = startSubscriber()
810+
assertEquals(1, flow.subscriptionCount.first())
811+
812+
val j2 = startSubscriber()
813+
assertEquals(2, flow.subscriptionCount.first())
814+
815+
j1.cancelAndJoin()
816+
assertEquals(1, flow.subscriptionCount.first())
817+
818+
j2.cancelAndJoin()
819+
assertEquals(0, flow.subscriptionCount.first())
820+
}
801821
}

kotlinx-coroutines-core/common/test/flow/sharing/SharingStartedWhileSubscribedTest.kt

+34-1
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,38 @@ class SharingStartedWhileSubscribedTest : TestBase() {
4040
assertEquals(SharingStarted.WhileSubscribed(replayExpirationMillis = 7000), SharingStarted.WhileSubscribed(replayExpiration = 7.seconds))
4141
assertEquals(SharingStarted.WhileSubscribed(replayExpirationMillis = Long.MAX_VALUE), SharingStarted.WhileSubscribed(replayExpiration = Duration.INFINITE))
4242
}
43-
}
4443

44+
@Test
45+
fun testShouldRestart() = runTest {
46+
var started = 0
47+
val flow = flow {
48+
expect(1 + ++started)
49+
emit(1)
50+
hang { }
51+
}.shareIn(this, SharingStarted.WhileSubscribed(100 /* ms */))
52+
53+
expect(1)
54+
flow.first()
55+
delay(200)
56+
flow.first()
57+
finish(4)
58+
coroutineContext.job.cancelChildren()
59+
}
60+
61+
@Test
62+
fun testImmediateUnsubscribe() = runTest {
63+
val flow = flow {
64+
expect(2)
65+
emit(1)
66+
hang { finish(4) }
67+
}.shareIn(this, SharingStarted.WhileSubscribed(400, 0 /* ms */), 1)
68+
69+
expect(1)
70+
repeat(5) {
71+
flow.first()
72+
delay(100)
73+
}
74+
expect(3)
75+
coroutineContext.job.cancelChildren()
76+
}
77+
}

kotlinx-coroutines-core/jvm/test/flow/SharingStressTest.kt

+6-2
Original file line numberDiff line numberDiff line change
@@ -189,5 +189,9 @@ class SharingStressTest : TestBase() {
189189
var count = 0L
190190
}
191191

192-
private fun log(msg: String) = println("${testStarted.elapsedNow().toLongMilliseconds()} ms: $msg")
193-
}
192+
private fun log(msg: String) = println("${testStarted.elapsedNow().inWholeMilliseconds} ms: $msg")
193+
194+
private fun MutableStateFlow<Int>.increment(delta: Int) {
195+
update { it + delta }
196+
}
197+
}

0 commit comments

Comments
 (0)