Skip to content

Non-conflating subscription count in SharedFlow and StateFlow #2872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions kotlinx-coroutines-core/common/src/flow/SharedFlow.kt
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ public interface MutableSharedFlow<T> : SharedFlow<T>, FlowCollector<T> {
* }
* .launchIn(scope) // launch it
* ```
*
* Implementation note: the resulting flow **does not** conflate subscription count.
*/
public val subscriptionCount: StateFlow<Int>

Expand Down Expand Up @@ -253,7 +255,7 @@ public fun <T> MutableSharedFlow(

// ------------------------------------ Implementation ------------------------------------

private class SharedFlowSlot : AbstractSharedFlowSlot<SharedFlowImpl<*>>() {
internal class SharedFlowSlot : AbstractSharedFlowSlot<SharedFlowImpl<*>>() {
@JvmField
var index = -1L // current "to-be-emitted" index, -1 means the slot is free now

Expand All @@ -275,7 +277,7 @@ private class SharedFlowSlot : AbstractSharedFlowSlot<SharedFlowImpl<*>>() {
}
}

private class SharedFlowImpl<T>(
internal class SharedFlowImpl<T>(
private val replay: Int,
private val bufferCapacity: Int,
private val onBufferOverflow: BufferOverflow
Expand Down Expand Up @@ -334,6 +336,13 @@ private class SharedFlowImpl<T>(
result
}

/*
* A tweak for SubscriptionCountStateFlow to get the latest value.
*/
@Suppress("UNCHECKED_CAST")
val lastReplayedLocked: T
get() = buffer!!.getBufferAt(replayIndex + replaySize - 1) as T

@Suppress("UNCHECKED_CAST")
override suspend fun collect(collector: FlowCollector<T>) {
val slot = allocateSlot()
Expand Down
4 changes: 0 additions & 4 deletions kotlinx-coroutines-core/common/src/flow/StateFlow.kt
Original file line number Diff line number Diff line change
Expand Up @@ -415,10 +415,6 @@ private class StateFlowImpl<T>(
fuseStateFlow(context, capacity, onBufferOverflow)
}

internal fun MutableStateFlow<Int>.increment(delta: Int) {
update { it + delta }
}

internal fun <T> StateFlow<T>.fuseStateFlow(
context: CoroutineContext,
capacity: Int,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package kotlinx.coroutines.flow.internal

import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.internal.*
import kotlin.coroutines.*
Expand All @@ -26,12 +27,12 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
protected var nCollectors = 0 // number of allocated (!free) slots
private set
private var nextIndex = 0 // oracle for the next free slot index
private var _subscriptionCount: MutableStateFlow<Int>? = null // init on first need
private var _subscriptionCount: SubscriptionCountStateFlow? = null // init on first need

val subscriptionCount: StateFlow<Int>
get() = synchronized(this) {
// allocate under lock in sync with nCollectors variable
_subscriptionCount ?: MutableStateFlow(nCollectors).also {
_subscriptionCount ?: SubscriptionCountStateFlow(nCollectors).also {
_subscriptionCount = it
}
}
Expand All @@ -43,7 +44,7 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
@Suppress("UNCHECKED_CAST")
protected fun allocateSlot(): S {
// Actually create slot under lock
var subscriptionCount: MutableStateFlow<Int>? = null
var subscriptionCount: SubscriptionCountStateFlow? = null
val slot = synchronized(this) {
val slots = when (val curSlots = slots) {
null -> createSlotArray(2).also { slots = it }
Expand Down Expand Up @@ -74,7 +75,7 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
@Suppress("UNCHECKED_CAST")
protected fun freeSlot(slot: S) {
// Release slot under lock
var subscriptionCount: MutableStateFlow<Int>? = null
var subscriptionCount: SubscriptionCountStateFlow? = null
val resumes = synchronized(this) {
nCollectors--
subscriptionCount = _subscriptionCount // retrieve under lock if initialized
Expand All @@ -83,10 +84,10 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
(slot as AbstractSharedFlowSlot<Any>).freeLocked(this)
}
/*
Resume suspended coroutines.
This can happens when the subscriber that was freed was a slow one and was holding up buffer.
When this subscriber was freed, previously queued emitted can now wake up and are resumed here.
*/
* Resume suspended coroutines.
* This can happen when the subscriber that was freed was a slow one and was holding up buffer.
* When this subscriber was freed, previously queued emitted can now wake up and are resumed here.
*/
for (cont in resumes) cont?.resume(Unit)
// decrement subscription count
subscriptionCount?.increment(-1)
Expand All @@ -99,3 +100,43 @@ internal abstract class AbstractSharedFlow<S : AbstractSharedFlowSlot<*>> : Sync
}
}
}

/**
* [StateFlow] that represents the number of subscriptions.
*
* It is exposed as a regular [StateFlow] in our public API, but it is implemented as [SharedFlow] undercover to
* avoid conflations of consecutive updates because the subscription count is very sensitive to it.
*
* The importance of non-conflating can be demonstrated with the following example:
* ```
* val shared = flowOf(239).stateIn(this, SharingStarted.Lazily, 42) // stateIn for the sake of the initial value
* println(shared.first())
* yield()
* println(shared.first())
* ```
* If the flow is shared within the same dispatcher (e.g. Main) or with a slow/throttled one,
* the `SharingStarted.Lazily` will never be able to start the source: `first` sees the initial value and immediately
* unsubscribes, leaving the asynchronous `SharingStarted` with conflated zero.
*
* To avoid that (especially in a more complex scenarios), we do not conflate subscription updates.
*/
private class SubscriptionCountStateFlow(initialValue: Int) : StateFlow<Int> {
private val sharedFlow = SharedFlowImpl<Int>(1, Int.MAX_VALUE, BufferOverflow.DROP_OLDEST)
.also { it.tryEmit(initialValue) }

override val replayCache: List<Int>
get() = sharedFlow.replayCache

override val value: Int
get() = synchronized(sharedFlow) {
sharedFlow.lastReplayedLocked
}

fun increment(delta: Int) = synchronized(sharedFlow) {
sharedFlow.tryEmit(sharedFlow.lastReplayedLocked + delta)
}

override suspend fun collect(collector: FlowCollector<Int>) {
sharedFlow.collect(collector)
}
}
13 changes: 11 additions & 2 deletions kotlinx-coroutines-core/common/src/flow/operators/Share.kt
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,16 @@ private fun <T> CoroutineScope.launchSharing(
shared: MutableSharedFlow<T>,
started: SharingStarted,
initialValue: T
): Job =
launch(context) { // the single coroutine to rule the sharing
): Job {
/*
* Conditional start: in the case when sharing and subscribing happens in the same dispatcher, we want to
* have the following invariants preserved:
* * Delayed sharing strategies have a chance to immediately observe consecutive subscriptions.
* E.g. in the cases like `flow.shareIn(...); flow.take(1)` we want sharing strategy to see the initial subscription
* * Eager sharing does not start immediately, so the subscribers have actual chance to subscribe _prior_ to sharing.
*/
val start = if (started == SharingStarted.Eagerly) CoroutineStart.DEFAULT else CoroutineStart.UNDISPATCHED
return launch(context, start = start) { // the single coroutine to rule the sharing
// Optimize common built-in started strategies
when {
started === SharingStarted.Eagerly -> {
Expand Down Expand Up @@ -230,6 +238,7 @@ private fun <T> CoroutineScope.launchSharing(
}
}
}
}

// -------------------------------- stateIn --------------------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class ShareInConflationTest : TestBase() {
op: suspend Flow<Int>.(CoroutineScope) -> Flow<Int>
) = runTest {
expect(1)
// emit all and conflate, then should collect bufferCapacity latest ones
// emit all and conflate, then should collect bufferCapacity the latest ones
val done = Job()
flow {
repeat(n) { i ->
Expand Down Expand Up @@ -159,4 +159,4 @@ class ShareInConflationTest : TestBase() {
checkConflation(1, BufferOverflow.DROP_LATEST) {
buffer(23).buffer(onBufferOverflow = BufferOverflow.DROP_LATEST).shareIn(it, SharingStarted.Eagerly, 0)
}
}
}
26 changes: 26 additions & 0 deletions kotlinx-coroutines-core/common/test/flow/sharing/ShareInTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,30 @@ class ShareInTest : TestBase() {
stop()
}
}

@Test
fun testShouldStart() = runTest {
val flow = flow {
expect(2)
emit(1)
expect(3)
}.shareIn(this, SharingStarted.Lazily)

expect(1)
flow.onSubscription { throw CancellationException("") }
.catch { e -> assertTrue { e is CancellationException } }
.collect()
yield()
finish(4)
}

@Test
fun testShouldStartScalar() = runTest {
val j = Job()
val shared = flowOf(239).stateIn(this + j, SharingStarted.Lazily, 42)
assertEquals(42, shared.first())
yield()
assertEquals(239, shared.first())
j.cancel()
}
}
20 changes: 20 additions & 0 deletions kotlinx-coroutines-core/common/test/flow/sharing/SharedFlowTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -798,4 +798,24 @@ class SharedFlowTest : TestBase() {
job.join()
finish(5)
}

@Test
fun testSubscriptionCount() = runTest {
val flow = MutableSharedFlow<Int>()
fun startSubscriber() = launch(start = CoroutineStart.UNDISPATCHED) { flow.collect() }

assertEquals(0, flow.subscriptionCount.first())

val j1 = startSubscriber()
assertEquals(1, flow.subscriptionCount.first())

val j2 = startSubscriber()
assertEquals(2, flow.subscriptionCount.first())

j1.cancelAndJoin()
assertEquals(1, flow.subscriptionCount.first())

j2.cancelAndJoin()
assertEquals(0, flow.subscriptionCount.first())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,38 @@ class SharingStartedWhileSubscribedTest : TestBase() {
assertEquals(SharingStarted.WhileSubscribed(replayExpirationMillis = 7000), SharingStarted.WhileSubscribed(replayExpiration = 7.seconds))
assertEquals(SharingStarted.WhileSubscribed(replayExpirationMillis = Long.MAX_VALUE), SharingStarted.WhileSubscribed(replayExpiration = Duration.INFINITE))
}
}

@Test
fun testShouldRestart() = runTest {
var started = 0
val flow = flow {
expect(1 + ++started)
emit(1)
hang { }
}.shareIn(this, SharingStarted.WhileSubscribed(100 /* ms */))

expect(1)
flow.first()
delay(200)
flow.first()
finish(4)
coroutineContext.job.cancelChildren()
}

@Test
fun testImmediateUnsubscribe() = runTest {
val flow = flow {
expect(2)
emit(1)
hang { finish(4) }
}.shareIn(this, SharingStarted.WhileSubscribed(400, 0 /* ms */), 1)

expect(1)
repeat(5) {
flow.first()
delay(100)
}
expect(3)
coroutineContext.job.cancelChildren()
}
}
8 changes: 6 additions & 2 deletions kotlinx-coroutines-core/jvm/test/flow/SharingStressTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -189,5 +189,9 @@ class SharingStressTest : TestBase() {
var count = 0L
}

private fun log(msg: String) = println("${testStarted.elapsedNow().toLongMilliseconds()} ms: $msg")
}
private fun log(msg: String) = println("${testStarted.elapsedNow().inWholeMilliseconds} ms: $msg")

private fun MutableStateFlow<Int>.increment(delta: Int) {
update { it + delta }
}
}