Skip to content

Commit bea1f5b

Browse files
committed
Support of context passing added for Flux
Fixes #284
1 parent 931c36e commit bea1f5b

File tree

2 files changed

+251
-3
lines changed

2 files changed

+251
-3
lines changed

reactive/kotlinx-coroutines-reactor/src/Flux.kt

+241-2
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,20 @@
44

55
package kotlinx.coroutines.reactor
66

7+
import kotlinx.atomicfu.atomic
78
import kotlinx.coroutines.*
89
import kotlinx.coroutines.channels.*
910
import kotlinx.coroutines.reactive.*
11+
import kotlinx.coroutines.selects.SelectClause2
12+
import kotlinx.coroutines.selects.SelectInstance
13+
import kotlinx.coroutines.sync.Mutex
14+
import org.reactivestreams.Publisher
15+
import org.reactivestreams.Subscriber
16+
import org.reactivestreams.Subscription
17+
import reactor.core.CoreSubscriber
18+
import reactor.core.Disposable
1019
import reactor.core.publisher.*
20+
import reactor.util.context.Context
1121
import kotlin.coroutines.*
1222

1323
/**
@@ -32,9 +42,238 @@ import kotlin.coroutines.*
3242
* **Note: This is an experimental api.** Behaviour of publishers that work as children in a parent scope with respect
3343
* to cancellation and error handling may change in the future.
3444
*/
45+
3546
@ExperimentalCoroutinesApi
3647
fun <T> CoroutineScope.flux(
3748
context: CoroutineContext = EmptyCoroutineContext,
3849
@BuilderInference block: suspend ProducerScope<T>.() -> Unit
39-
): Flux<T> =
40-
Flux.from(publish(newCoroutineContext(context), block = block))
50+
): Flux<T> = Flux.from(reactorPublish(newCoroutineContext(context), block = block))
51+
52+
@ExperimentalCoroutinesApi
53+
public fun <T> CoroutineScope.reactorPublish(
54+
context: CoroutineContext = EmptyCoroutineContext,
55+
@BuilderInference block: suspend ProducerScope<T>.() -> Unit
56+
): Publisher<T> = Publisher { subscriber ->
57+
// specification requires NPE on null subscriber
58+
if (subscriber == null) throw NullPointerException("Subscriber cannot be null")
59+
val currentContext = if (subscriber is CoreSubscriber) subscriber.currentContext() else Context.empty()
60+
val reactorContext = (coroutineContext[ReactorContext]?.context?.putAll(currentContext) ?: currentContext).asCoroutineContext()
61+
val newContext = newCoroutineContext(context + reactorContext)
62+
val coroutine = PublisherCoroutine(newContext, subscriber)
63+
subscriber.onSubscribe(coroutine) // do it first (before starting coroutine), to avoid unnecessary suspensions
64+
coroutine.start(CoroutineStart.DEFAULT, coroutine, block)
65+
}
66+
67+
private const val CLOSED = -1L // closed, but have not signalled onCompleted/onError yet
68+
private const val SIGNALLED = -2L // already signalled subscriber onCompleted/onError
69+
70+
@Suppress("CONFLICTING_JVM_DECLARATIONS", "RETURN_TYPE_MISMATCH_ON_INHERITANCE")
71+
private class PublisherCoroutine<in T>(
72+
parentContext: CoroutineContext,
73+
private val subscriber: Subscriber<T>
74+
) : AbstractCoroutine<Unit>(parentContext, true), ProducerScope<T>, Subscription, SelectClause2<T, SendChannel<T>> {
75+
override val channel: SendChannel<T> get() = this
76+
77+
// Mutex is locked when either nRequested == 0 or while subscriber.onXXX is being invoked
78+
private val mutex = Mutex(locked = true)
79+
80+
private val _nRequested = atomic(0L) // < 0 when closed (CLOSED or SIGNALLED)
81+
82+
@Volatile
83+
private var cancelled = false // true when Subscription.cancel() is invoked
84+
85+
private var shouldHandleException = false // when handleJobException is invoked
86+
87+
override val isClosedForSend: Boolean get() = isCompleted
88+
override val isFull: Boolean = mutex.isLocked
89+
override fun close(cause: Throwable?): Boolean = cancelCoroutine(cause)
90+
override fun invokeOnClose(handler: (Throwable?) -> Unit) =
91+
throw UnsupportedOperationException("PublisherCoroutine doesn't support invokeOnClose")
92+
93+
override fun offer(element: T): Boolean {
94+
if (!mutex.tryLock()) return false
95+
doLockedNext(element)
96+
return true
97+
}
98+
99+
public override suspend fun send(element: T) {
100+
// fast-path -- try send without suspension
101+
if (offer(element)) return
102+
// slow-path does suspend
103+
return sendSuspend(element)
104+
}
105+
106+
private suspend fun sendSuspend(element: T) {
107+
mutex.lock()
108+
doLockedNext(element)
109+
}
110+
111+
override val onSend: SelectClause2<T, SendChannel<T>>
112+
get() = this
113+
114+
// registerSelectSend
115+
@Suppress("PARAMETER_NAME_CHANGED_ON_OVERRIDE")
116+
override fun <R> registerSelectClause2(select: SelectInstance<R>, element: T, block: suspend (SendChannel<T>) -> R) {
117+
mutex.onLock.registerSelectClause2(select, null) {
118+
doLockedNext(element)
119+
block(this)
120+
}
121+
}
122+
123+
/*
124+
* This code is not trivial because of the two properties:
125+
* 1. It ensures conformance to the reactive specification that mandates that onXXX invocations should not
126+
* be concurrent. It uses Mutex to protect all onXXX invocation and ensure conformance even when multiple
127+
* coroutines are invoking `send` function.
128+
* 2. Normally, `onComplete/onError` notification is sent only when coroutine and all its children are complete.
129+
* However, nothing prevents `publish` coroutine from leaking reference to it send channel to some
130+
* globally-scoped coroutine that is invoking `send` outside of this context. Without extra precaution this may
131+
* lead to `onNext` that is concurrent with `onComplete/onError`, so that is why signalling for
132+
* `onComplete/onError` is also done under the same mutex.
133+
*/
134+
135+
// assert: mutex.isLocked()
136+
private fun doLockedNext(elem: T) {
137+
// check if already closed for send, note, that isActive become false as soon as cancel() is invoked,
138+
// because the job is cancelled, so this check also ensure conformance to the reactive specification's
139+
// requirement that after cancellation requested we don't call onXXX
140+
if (!isActive) {
141+
unlockAndCheckCompleted()
142+
throw getCancellationException()
143+
}
144+
// notify subscriber
145+
try {
146+
subscriber.onNext(elem)
147+
} catch (e: Throwable) {
148+
// If onNext fails with exception, then we cancel coroutine (with this exception) and then rethrow it
149+
// to abort the corresponding send/offer invocation. From the standpoint of coroutines machinery,
150+
// this failure is essentially equivalent to a failure of a child coroutine.
151+
cancelCoroutine(e)
152+
unlockAndCheckCompleted()
153+
throw e
154+
}
155+
// now update nRequested
156+
while (true) { // lock-free loop on nRequested
157+
val cur = _nRequested.value
158+
if (cur < 0) break // closed from inside onNext => unlock
159+
if (cur == Long.MAX_VALUE) break // no back-pressure => unlock
160+
val upd = cur - 1
161+
if (_nRequested.compareAndSet(cur, upd)) {
162+
if (upd == 0L) {
163+
// return to keep locked due to back-pressure
164+
return
165+
}
166+
break // unlock if upd > 0
167+
}
168+
}
169+
unlockAndCheckCompleted()
170+
}
171+
172+
private fun unlockAndCheckCompleted() {
173+
/*
174+
* There is no sense to check completion before doing `unlock`, because completion might
175+
* happen after this check and before `unlock` (see `signalCompleted` that does not do anything
176+
* if it fails to acquire the lock that we are still holding).
177+
* We have to recheck `isCompleted` after `unlock` anyway.
178+
*/
179+
mutex.unlock()
180+
// check isCompleted and and try to regain lock to signal completion
181+
if (isCompleted && mutex.tryLock()) doLockedSignalCompleted()
182+
}
183+
184+
// assert: mutex.isLocked() & isCompleted
185+
private fun doLockedSignalCompleted() {
186+
try {
187+
if (_nRequested.value >= CLOSED) {
188+
_nRequested.value = SIGNALLED // we'll signal onError/onCompleted (that the final state -- no CAS needed)
189+
val cause = getCompletionCause()
190+
// Specification requires that after cancellation requested we don't call onXXX
191+
if (cancelled) {
192+
// If the parent had failed to handle our exception (handleJobException was invoked), then
193+
// we must not loose this exception
194+
if (shouldHandleException && cause != null) {
195+
handleCoroutineException(context, cause)
196+
}
197+
} else {
198+
try {
199+
if (cause != null && cause !is CancellationException) {
200+
subscriber.onError(cause)
201+
}
202+
else {
203+
subscriber.onComplete()
204+
}
205+
} catch (e: Throwable) {
206+
handleCoroutineException(context, e)
207+
}
208+
}
209+
}
210+
} finally {
211+
mutex.unlock()
212+
}
213+
}
214+
215+
override fun request(n: Long) {
216+
if (n <= 0) {
217+
// Specification requires IAE for n <= 0
218+
cancelCoroutine(IllegalArgumentException("non-positive subscription request $n"))
219+
return
220+
}
221+
while (true) { // lock-free loop for nRequested
222+
val cur = _nRequested.value
223+
if (cur < 0) return // already closed for send, ignore requests
224+
var upd = cur + n
225+
if (upd < 0 || n == Long.MAX_VALUE)
226+
upd = Long.MAX_VALUE
227+
if (cur == upd) return // nothing to do
228+
if (_nRequested.compareAndSet(cur, upd)) {
229+
// unlock the mutex when we don't have back-pressure anymore
230+
if (cur == 0L) {
231+
unlockAndCheckCompleted()
232+
}
233+
return
234+
}
235+
}
236+
}
237+
238+
// assert: isCompleted
239+
private fun signalCompleted() {
240+
while (true) { // lock-free loop for nRequested
241+
val cur = _nRequested.value
242+
if (cur == SIGNALLED) return // some other thread holding lock already signalled cancellation/completion
243+
check(cur >= 0) // no other thread could have marked it as CLOSED, because onCompleted[Exceptionally] is invoked once
244+
if (!_nRequested.compareAndSet(cur, CLOSED)) continue // retry on failed CAS
245+
// Ok -- marked as CLOSED, now can unlock the mutex if it was locked due to backpressure
246+
if (cur == 0L) {
247+
doLockedSignalCompleted()
248+
} else {
249+
// otherwise mutex was either not locked or locked in concurrent onNext... try lock it to signal completion
250+
if (mutex.tryLock()) doLockedSignalCompleted()
251+
// Note: if failed `tryLock`, then `doLockedNext` will signal after performing `unlock`
252+
}
253+
return // done anyway
254+
}
255+
}
256+
257+
// Note: It is invoked when parent fails to handle an exception and strictly before onCompleted[Exception]
258+
// so here we just raise a flag (and it need NOT be volatile!) to handle this exception.
259+
// This way we defer decision to handle this exception based on our ability to send this exception
260+
// to the subscriber (see doLockedSignalCompleted)
261+
override fun handleJobException(exception: Throwable, handled: Boolean) {
262+
if (!handled) shouldHandleException = true
263+
}
264+
265+
override fun onCompletedExceptionally(exception: Throwable) {
266+
signalCompleted()
267+
}
268+
269+
override fun onCompleted(value: Unit) {
270+
signalCompleted()
271+
}
272+
273+
override fun cancel() {
274+
// Specification requires that after cancellation publisher stops signalling
275+
// This flag distinguishes subscription cancellation request from the job crash
276+
cancelled = true
277+
super.cancel(null)
278+
}
279+
}

reactive/kotlinx-coroutines-reactor/test/ReactorContextTest.kt

+10-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ import reactor.util.context.Context
77
import kotlin.test.assertEquals
88

99
class ReactorContextTest {
10-
1110
@Test
1211
fun monoHookedContext() = runBlocking(Context.of(1, "1", 7, "7").asCoroutineContext()) {
1312
val mono = mono {
@@ -20,4 +19,14 @@ class ReactorContextTest {
2019
assertEquals(mono.awaitFirst(), "1234567")
2120
}
2221

22+
@Test
23+
fun fluxTest() = runBlocking<Unit>(Context.of(1, "1", 7, "7").asCoroutineContext()) {
24+
val flux = flux<String?> {
25+
val ctx = coroutineContext[ReactorContext]!!.context
26+
(1 .. 7).forEach { send(ctx.getOrDefault(it, "noValue")) }
27+
} .subscriberContext(Context.of(2, "2", 3, "3", 4, "4", 5, "5"))
28+
.subscriberContext { ctx -> ctx.put(6, "6") }
29+
var i = 0
30+
flux.subscribe { str -> i++; assertEquals(str, i.toString()) }
31+
}
2332
}

0 commit comments

Comments
 (0)