Skip to content

Commit 21fd774

Browse files
committed
Re-implement flatMapMerge via the channel to have context preservation property
Fixes #1440
1 parent 62a51e9 commit 21fd774

File tree

10 files changed

+86
-123
lines changed

10 files changed

+86
-123
lines changed

benchmarks/src/jmh/kotlin/benchmarks/YieldRelativeCostBenchmark.kt

-35
This file was deleted.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
/*
2+
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package benchmarks.flow
6+
7+
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.*
9+
import org.openjdk.jmh.annotations.*
10+
import java.util.concurrent.*
11+
12+
@Warmup(iterations = 7, time = 1, timeUnit = TimeUnit.SECONDS)
13+
@Measurement(iterations = 7, time = 1, timeUnit = TimeUnit.SECONDS)
14+
@Fork(value = 1)
15+
@BenchmarkMode(Mode.AverageTime)
16+
@OutputTimeUnit(TimeUnit.MICROSECONDS)
17+
@State(Scope.Benchmark)
18+
open class FlatMapMergeBenchmark {
19+
20+
// Note: tests only absence of contention on downstream
21+
22+
@Param("10", "100", "1000")
23+
private var iterations = 100
24+
25+
@Benchmark
26+
fun flatMapUnsafe() = runBlocking {
27+
benchmarks.flow.scrabble.flow {
28+
repeat(iterations) { emit(it) }
29+
}.flatMapMerge { value ->
30+
flowOf(value)
31+
}.collect {
32+
if (it == -1) error("")
33+
}
34+
}
35+
36+
@Benchmark
37+
fun flatMapSafe() = runBlocking {
38+
kotlinx.coroutines.flow.flow {
39+
repeat(iterations) { emit(it) }
40+
}.flatMapMerge { value ->
41+
flowOf(value)
42+
}.collect {
43+
if (it == -1) error("")
44+
}
45+
}
46+
47+
}

benchmarks/src/jmh/kotlin/benchmarks/flow/misc/Numbers.kt renamed to benchmarks/src/jmh/kotlin/benchmarks/flow/NumbersBenchmark.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
*/
44

55

6-
package benchmarks.flow.misc
6+
package benchmarks.flow
77

88
import benchmarks.flow.scrabble.flow
99
import io.reactivex.*
@@ -35,7 +35,7 @@ import java.util.concurrent.*
3535
@BenchmarkMode(Mode.AverageTime)
3636
@OutputTimeUnit(TimeUnit.MICROSECONDS)
3737
@State(Scope.Benchmark)
38-
open class Numbers {
38+
open class NumbersBenchmark {
3939

4040
companion object {
4141
private const val primes = 100

benchmarks/src/jmh/kotlin/benchmarks/flow/misc/SafeFlowBenchmark.kt renamed to benchmarks/src/jmh/kotlin/benchmarks/flow/SafeFlowBenchmark.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

5-
package benchmarks.flow.misc
5+
package benchmarks.flow
66

77
import kotlinx.coroutines.*
88
import kotlinx.coroutines.flow.*

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -992,7 +992,7 @@ public final class kotlinx/coroutines/flow/internal/SafeCollectorKt {
992992
public static final fun unsafeFlow (Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
993993
}
994994

995-
public final class kotlinx/coroutines/flow/internal/SendingCollector : kotlinx/coroutines/flow/internal/ConcurrentFlowCollector {
995+
public final class kotlinx/coroutines/flow/internal/SendingCollector : kotlinx/coroutines/flow/FlowCollector {
996996
public fun <init> (Lkotlinx/coroutines/channels/SendChannel;)V
997997
public fun emit (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
998998
}

kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt

+2-10
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ public abstract class ChannelFlow<T>(
6161
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
6262
get() = { collectTo(it) }
6363

64-
internal val produceCapacity: Int
64+
private val produceCapacity: Int
6565
get() = if (capacity == Channel.OPTIONAL_CHANNEL) Channel.BUFFERED else capacity
6666

6767
open fun broadcastImpl(scope: CoroutineScope, start: CoroutineStart): BroadcastChannel<T> =
@@ -140,13 +140,11 @@ internal class ChannelFlowOperatorImpl<T>(
140140
private fun <T> FlowCollector<T>.withUndispatchedContextCollector(emitContext: CoroutineContext): FlowCollector<T> = when (this) {
141141
// SendingCollector & NopCollector do not care about the context at all and can be used as is
142142
is SendingCollector, is NopCollector -> this
143-
// Original collector is concurrent, so wrap into ConcurrentUndispatchedContextCollector (also concurrent)
144-
is ConcurrentFlowCollector -> ConcurrentUndispatchedContextCollector(this, emitContext)
145143
// Otherwise just wrap into UndispatchedContextCollector interface implementation
146144
else -> UndispatchedContextCollector(this, emitContext)
147145
}
148146

149-
private open class UndispatchedContextCollector<T>(
147+
private class UndispatchedContextCollector<T>(
150148
downstream: FlowCollector<T>,
151149
private val emitContext: CoroutineContext
152150
) : FlowCollector<T> {
@@ -157,12 +155,6 @@ private open class UndispatchedContextCollector<T>(
157155
withContextUndispatched(emitContext, countOrElement, emitRef, value)
158156
}
159157

160-
// named class for a combination of UndispatchedContextCollector & ConcurrentFlowCollector interface
161-
private class ConcurrentUndispatchedContextCollector<T>(
162-
downstream: ConcurrentFlowCollector<T>,
163-
emitContext: CoroutineContext
164-
) : UndispatchedContextCollector<T>(downstream, emitContext), ConcurrentFlowCollector<T>
165-
166158
// Efficiently computes block(value) in the newContext
167159
private suspend fun <T, V> withContextUndispatched(
168160
newContext: CoroutineContext,

kotlinx-coroutines-core/common/src/flow/internal/Concurrent.kt

+2-63
Original file line numberDiff line numberDiff line change
@@ -4,78 +4,17 @@
44

55
package kotlinx.coroutines.flow.internal
66

7-
import kotlinx.atomicfu.*
87
import kotlinx.coroutines.*
98
import kotlinx.coroutines.channels.*
10-
import kotlinx.coroutines.channels.ArrayChannel
119
import kotlinx.coroutines.flow.*
1210

13-
internal fun <T> FlowCollector<T>.asConcurrentFlowCollector(): ConcurrentFlowCollector<T> =
14-
this as? ConcurrentFlowCollector<T> ?: SerializingCollector(this)
15-
16-
// Flow collector that supports concurrent emit calls.
17-
// It is internal for now but may be public in the future.
18-
// Two basic implementations are here: SendingCollector and ConcurrentFlowCollector
19-
internal interface ConcurrentFlowCollector<T> : FlowCollector<T>
20-
2111
/**
22-
* Collection that sends to channel. It is marked as [ConcurrentFlowCollector] because it can be used concurrently.
23-
*
12+
* Collection that sends to channel
2413
* @suppress **This an internal API and should not be used from general code.**
2514
*/
2615
@InternalCoroutinesApi
2716
public class SendingCollector<T>(
2817
private val channel: SendChannel<T>
29-
) : ConcurrentFlowCollector<T> {
18+
) : FlowCollector<T> {
3019
override suspend fun emit(value: T) = channel.send(value)
3120
}
32-
33-
// Effectively serializes access to downstream collector for merging
34-
// This is basically a converted from FlowCollector interface to ConcurrentFlowCollector
35-
private class SerializingCollector<T>(
36-
private val downstream: FlowCollector<T>
37-
) : ConcurrentFlowCollector<T> {
38-
// Let's try to leverage the fact that merge is never contended
39-
// Should be Any, but KT-30796
40-
private val _channel = atomic<ArrayChannel<Any?>?>(null)
41-
private val inProgressLock = atomic(false)
42-
43-
private val channel: ArrayChannel<Any?>
44-
get() = _channel.updateAndGet { value ->
45-
if (value != null) return value
46-
ArrayChannel(Channel.CHANNEL_DEFAULT_CAPACITY)
47-
}!!
48-
49-
public override suspend fun emit(value: T) {
50-
if (!inProgressLock.tryAcquire()) {
51-
channel.send(value ?: NULL)
52-
if (inProgressLock.tryAcquire()) {
53-
helpEmit()
54-
}
55-
return
56-
}
57-
downstream.emit(value)
58-
helpEmit()
59-
}
60-
61-
@Suppress("UNCHECKED_CAST")
62-
private suspend fun helpEmit() {
63-
while (true) {
64-
while (true) {
65-
val element = _channel.value?.poll() ?: break // todo: pollOrClosed
66-
downstream.emit(NULL.unbox(element))
67-
}
68-
inProgressLock.release()
69-
// Enforce liveness
70-
if (_channel.value?.isEmpty != false || !inProgressLock.tryAcquire()) break
71-
}
72-
}
73-
}
74-
75-
@Suppress("NOTHING_TO_INLINE")
76-
private inline fun AtomicBoolean.tryAcquire(): Boolean = compareAndSet(false, true)
77-
78-
@Suppress("NOTHING_TO_INLINE")
79-
private inline fun AtomicBoolean.release() {
80-
value = false
81-
}

kotlinx-coroutines-core/common/src/flow/internal/Merge.kt

+12-9
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,21 @@ internal class ChannelFlowTransformLatest<T, R>(
3737
}
3838
}
3939

40-
internal class ChannelFlowMerge<T>(
40+
internal class ChannelFlowMerge<T> (
4141
flow: Flow<Flow<T>>,
4242
private val concurrency: Int,
4343
context: CoroutineContext = EmptyCoroutineContext,
44-
capacity: Int = Channel.OPTIONAL_CHANNEL
44+
capacity: Int = Channel.BUFFERED
4545
) : ChannelFlowOperator<Flow<T>, T>(flow, context, capacity) {
4646
override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
4747
ChannelFlowMerge(flow, concurrency, context, capacity)
4848

4949
override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
50-
return scope.flowProduce(context, produceCapacity, block = collectToFun)
50+
return scope.flowProduce(context, capacity, block = collectToFun)
5151
}
5252

5353
// The actual merge implementation with concurrency limit
54-
private suspend fun mergeImpl(scope: CoroutineScope, collector: ConcurrentFlowCollector<T>) {
54+
private suspend fun mergeImpl(scope: CoroutineScope, collector: SendingCollector<T>) {
5555
val semaphore = Semaphore(concurrency)
5656
val job: Job? = coroutineContext[Job]
5757
flow.collect { inner ->
@@ -72,12 +72,15 @@ internal class ChannelFlowMerge<T>(
7272
}
7373
}
7474

75-
// Fast path in ChannelFlowOperator calls this function (channel was not created yet)
7675
override suspend fun flowCollect(collector: FlowCollector<T>) {
77-
// this function should not have been invoked when channel was explicitly requested
78-
assert { capacity == Channel.OPTIONAL_CHANNEL }
79-
flowScope {
80-
mergeImpl(this, collector.asConcurrentFlowCollector())
76+
assert { collector !is SendingCollector<*> }
77+
coroutineScope {
78+
val output = produce<T>(capacity = capacity) {
79+
mergeImpl(this, SendingCollector(this))
80+
}
81+
output.consumeEach {
82+
collector.emit(it)
83+
}
8184
}
8285
}
8386

kotlinx-coroutines-core/common/src/flow/internal/NopCollector.kt

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
package kotlinx.coroutines.flow.internal
66

7-
internal object NopCollector : ConcurrentFlowCollector<Any?> {
7+
import kotlinx.coroutines.flow.*
8+
9+
internal object NopCollector : FlowCollector<Any?> {
810
override suspend fun emit(value: Any?) {
911
// does nothing
1012
}

kotlinx-coroutines-core/common/test/flow/operators/FlattenMergeTest.kt

+16-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class FlattenMergeTest : FlatMapMergeBaseTest() {
1414
@Test
1515
override fun testFlatMapConcurrency() = runTest {
1616
var concurrentRequests = 0
17-
val flow = (1..100).asFlow().map() { value ->
17+
val flow = (1..100).asFlow().map { value ->
1818
flow {
1919
++concurrentRequests
2020
emit(value)
@@ -36,4 +36,19 @@ class FlattenMergeTest : FlatMapMergeBaseTest() {
3636
consumer.cancelAndJoin()
3737
finish(3)
3838
}
39+
40+
@Test
41+
fun testContextPreservationAcrossFlows() = runTest {
42+
val result = flow {
43+
flowOf(1, 2).flatMapMerge {
44+
flow {
45+
yield()
46+
emit(it)
47+
}
48+
}.collect {
49+
emit(it)
50+
}
51+
}.toList()
52+
assertEquals(listOf(1, 2), result)
53+
}
3954
}

0 commit comments

Comments
 (0)