Skip to content

Commit 5f316de

Browse files
committed
Fix context support in Publisher.asFlow.flowOn
* When using asFlow().flowOn(...) context is now properly tracked and taken into account for both execution context of the reactive subscription and for injection into Reactor context. * Publisher.asFlow slow-path implementation is simplified. It does not sure specialized openSubscription anymore, but always uses the same flow request logic. Fixes #1765
1 parent f18e0e4 commit 5f316de

File tree

5 files changed

+162
-33
lines changed

5 files changed

+162
-33
lines changed

reactive/kotlinx-coroutines-reactive/src/ReactiveFlow.kt

+26-27
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import kotlin.coroutines.*
2727
* see its documentation for additional details.
2828
*/
2929
public fun <T : Any> Publisher<T>.asFlow(): Flow<T> =
30-
PublisherAsFlow(this, 1)
30+
PublisherAsFlow(this)
3131

3232
/**
3333
* Transforms the given flow to a reactive specification compliant [Publisher].
@@ -39,30 +39,11 @@ public fun <T : Any> Flow<T>.asPublisher(): Publisher<T> = FlowAsPublisher(this)
3939

4040
private class PublisherAsFlow<T : Any>(
4141
private val publisher: Publisher<T>,
42-
capacity: Int
43-
) : ChannelFlow<T>(EmptyCoroutineContext, capacity) {
42+
context: CoroutineContext = EmptyCoroutineContext,
43+
capacity: Int = 1
44+
) : ChannelFlow<T>(context, capacity) {
4445
override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
45-
PublisherAsFlow(publisher, capacity)
46-
47-
override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
48-
// use another channel for conflation (cannot do openSubscription)
49-
if (capacity < 0) return super.produceImpl(scope)
50-
// Open subscription channel directly
51-
val channel = publisher
52-
.injectCoroutineContext(scope.coroutineContext)
53-
.openSubscription(capacity)
54-
val handle = scope.coroutineContext[Job]?.invokeOnCompletion(onCancelling = true) { cause ->
55-
channel.cancel(cause?.let {
56-
it as? CancellationException ?: CancellationException("Job was cancelled", it)
57-
})
58-
}
59-
if (handle != null && handle !== NonDisposableHandle) {
60-
(channel as SendChannel<*>).invokeOnClose {
61-
handle.dispose()
62-
}
63-
}
64-
return channel
65-
}
46+
PublisherAsFlow(publisher, context, capacity)
6647

6748
private val requestSize: Long
6849
get() = when (capacity) {
@@ -73,8 +54,26 @@ private class PublisherAsFlow<T : Any>(
7354
}
7455

7556
override suspend fun collect(collector: FlowCollector<T>) {
57+
val collectContext = coroutineContext
58+
val newDispatcher = context[ContinuationInterceptor]
59+
if (newDispatcher == null || newDispatcher == collectContext[ContinuationInterceptor]) {
60+
// fast path -- subscribe directly in this dispatcher
61+
return collectImpl(collectContext + context, collector)
62+
}
63+
// slow path -- produce in a separate dispatcher
64+
collectSlowPath(collector)
65+
}
66+
67+
private suspend fun collectSlowPath(collector: FlowCollector<T>) {
68+
coroutineScope {
69+
collector.emitAll(produceImpl(this + context))
70+
}
71+
}
72+
73+
private suspend fun collectImpl(injectContext: CoroutineContext, collector: FlowCollector<T>) {
7674
val subscriber = ReactiveSubscriber<T>(capacity, requestSize)
77-
publisher.injectCoroutineContext(coroutineContext).subscribe(subscriber)
75+
// inject subscribe context into publisher
76+
publisher.injectCoroutineContext(injectContext).subscribe(subscriber)
7877
try {
7978
var consumed = 0L
8079
while (true) {
@@ -90,9 +89,9 @@ private class PublisherAsFlow<T : Any>(
9089
}
9190
}
9291

93-
// The second channel here is used only for broadcast
92+
// The second channel here is used for produceIn/broadcastIn and slow-path (dispatcher change)
9493
override suspend fun collectTo(scope: ProducerScope<T>) =
95-
collect(SendingCollector(scope.channel))
94+
collectImpl(scope.coroutineContext, SendingCollector(scope.channel))
9695
}
9796

9897
@Suppress("SubscriberImplementation")

reactive/kotlinx-coroutines-reactive/test/PublisherAsFlowTest.kt

+2-2
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class PublisherAsFlowTest : TestBase() {
120120
7 -> try {
121121
send(value)
122122
} catch (e: CancellationException) {
123-
finish(6)
123+
expect(5)
124124
throw e
125125
}
126126
else -> expectUnreached()
@@ -143,6 +143,6 @@ class PublisherAsFlowTest : TestBase() {
143143
}
144144
}
145145
}
146-
expect(5)
146+
finish(6)
147147
}
148148
}

reactive/kotlinx-coroutines-reactor/test/FlowAsFluxTest.kt

+48-4
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@ import kotlinx.coroutines.*
44
import kotlinx.coroutines.flow.*
55
import kotlinx.coroutines.reactive.*
66
import org.junit.Test
7-
import reactor.core.publisher.Mono
7+
import reactor.core.publisher.*
88
import reactor.util.context.Context
9-
import kotlin.test.assertEquals
9+
import kotlin.test.*
1010

1111
class FlowAsFluxTest : TestBase() {
1212
@Test
13-
fun testFlowToFluxContextPropagation() {
13+
fun testFlowAsFluxContextPropagation() {
1414
val flux = flow<String> {
1515
(1..4).forEach { i -> emit(createMono(i).awaitFirst()) }
16-
} .asFlux()
16+
}
17+
.asFlux()
1718
.subscriberContext(Context.of(1, "1"))
1819
.subscriberContext(Context.of(2, "2", 3, "3", 4, "4"))
1920
val list = flux.collectList().block()!!
@@ -24,4 +25,47 @@ class FlowAsFluxTest : TestBase() {
2425
val ctx = coroutineContext[ReactorContext]!!.context
2526
ctx.getOrDefault(i, "noValue")
2627
}
28+
29+
@Test
30+
fun testFluxAsFlowContextPropagationWithFlowOn() = runTest {
31+
expect(1)
32+
Flux.create<String> {
33+
it.next("OK")
34+
it.complete()
35+
}
36+
.subscriberContext { ctx ->
37+
expect(2)
38+
assertEquals("CTX", ctx.get(1))
39+
ctx
40+
}
41+
.asFlow()
42+
.flowOn(ReactorContext(Context.of(1, "CTX")))
43+
.collect {
44+
expect(3)
45+
assertEquals("OK", it)
46+
}
47+
finish(4)
48+
}
49+
50+
@Test
51+
fun testFluxAsFlowContextPropagationFromScope() = runTest {
52+
expect(1)
53+
withContext(ReactorContext(Context.of(1, "CTX"))) {
54+
Flux.create<String> {
55+
it.next("OK")
56+
it.complete()
57+
}
58+
.subscriberContext { ctx ->
59+
expect(2)
60+
assertEquals("CTX", ctx.get(1))
61+
ctx
62+
}
63+
.asFlow()
64+
.collect {
65+
expect(3)
66+
assertEquals("OK", it)
67+
}
68+
}
69+
finish(4)
70+
}
2771
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.reactor
6+
7+
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.*
9+
import kotlinx.coroutines.reactive.*
10+
import org.junit.*
11+
import org.junit.Test
12+
import reactor.core.publisher.*
13+
import kotlin.test.*
14+
15+
class FluxContextTest : TestBase() {
16+
private val dispatcher = newSingleThreadContext("FluxContextTest")
17+
18+
@After
19+
fun tearDown() {
20+
dispatcher.close()
21+
}
22+
23+
@Test
24+
fun testFluxCreateAsFlowThread() = runTest {
25+
expect(1)
26+
val mainThread = Thread.currentThread()
27+
val dispatcherThread = withContext(dispatcher) { Thread.currentThread() }
28+
assertTrue(dispatcherThread != mainThread)
29+
Flux.create<String> {
30+
assertEquals(dispatcherThread, Thread.currentThread())
31+
it.next("OK")
32+
it.complete()
33+
}
34+
.asFlow()
35+
.flowOn(dispatcher)
36+
.collect {
37+
expect(2)
38+
assertEquals("OK", it)
39+
assertEquals(mainThread, Thread.currentThread())
40+
}
41+
finish(3)
42+
}
43+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.rx2
6+
7+
import io.reactivex.*
8+
import kotlinx.coroutines.*
9+
import kotlinx.coroutines.flow.*
10+
import kotlinx.coroutines.reactive.*
11+
import org.junit.*
12+
import org.junit.Test
13+
import kotlin.test.*
14+
15+
class FlowableContextTest : TestBase() {
16+
private val dispatcher = newSingleThreadContext("FlowableContextTest")
17+
18+
@After
19+
fun tearDown() {
20+
dispatcher.close()
21+
}
22+
23+
@Test
24+
fun testFlowableCreateAsFlowThread() = runTest {
25+
expect(1)
26+
val mainThread = Thread.currentThread()
27+
val dispatcherThread = withContext(dispatcher) { Thread.currentThread() }
28+
assertTrue(dispatcherThread != mainThread)
29+
Flowable.create<String>({
30+
assertEquals(dispatcherThread, Thread.currentThread())
31+
it.onNext("OK")
32+
it.onComplete()
33+
}, BackpressureStrategy.BUFFER)
34+
.asFlow()
35+
.flowOn(dispatcher)
36+
.collect {
37+
expect(2)
38+
assertEquals("OK", it)
39+
assertEquals(mainThread, Thread.currentThread())
40+
}
41+
finish(3)
42+
}
43+
}

0 commit comments

Comments
 (0)