diff --git a/reactive/kotlinx-coroutines-reactive/src/Await.kt b/reactive/kotlinx-coroutines-reactive/src/Await.kt index e9f6955085..9af134cb94 100644 --- a/reactive/kotlinx-coroutines-reactive/src/Await.kt +++ b/reactive/kotlinx-coroutines-reactive/src/Await.kt @@ -4,12 +4,11 @@ package kotlinx.coroutines.reactive -import kotlinx.coroutines.CancellationException -import kotlinx.coroutines.Job -import kotlinx.coroutines.suspendCancellableCoroutine +import kotlinx.coroutines.* import org.reactivestreams.Publisher import org.reactivestreams.Subscriber import org.reactivestreams.Subscription +import java.lang.IllegalStateException import java.util.* import kotlin.coroutines.* @@ -134,31 +133,61 @@ private suspend fun Publisher.awaitOne( mode: Mode, default: T? = null ): T = suspendCancellableCoroutine { cont -> + /* This implementation must obey + https://github.com/reactive-streams/reactive-streams-jvm/blob/v1.0.3/README.md#2-subscriber-code + The numbers of rules are taken from there. */ injectCoroutineContext(cont.context).subscribe(object : Subscriber { - private lateinit var subscription: Subscription + // It is unclear whether 2.13 implies (T: Any), but if so, it seems that we don't break anything by not adhering + private var subscription: Subscription? = null private var value: T? = null private var seenValue = false + private var inTerminalState = false override fun onSubscribe(sub: Subscription) { + /** cancelling the new subscription due to rule 2.5, though the publisher would either have to + * subscribe more than once, which would break 2.12, or leak this [Subscriber]. */ + if (subscription != null) { + sub.cancel() + return + } subscription = sub cont.invokeOnCancellation { sub.cancel() } - sub.request(if (mode == Mode.FIRST) 1 else Long.MAX_VALUE) + sub.request(if (mode == Mode.FIRST || mode == Mode.FIRST_OR_DEFAULT) 1 else Long.MAX_VALUE) } override fun onNext(t: T) { + val sub = subscription.let { + if (it == null) { + /** Enforce rule 1.9: expect [Subscriber.onSubscribe] before any other signals. */ + handleCoroutineException(cont.context, + IllegalStateException("'onNext' was called before 'onSubscribe'")) + return + } else { + it + } + } + if (inTerminalState) { + gotSignalInTerminalStateException(cont.context, "onNext") + return + } when (mode) { Mode.FIRST, Mode.FIRST_OR_DEFAULT -> { - if (!seenValue) { - seenValue = true - subscription.cancel() - cont.resume(t) + if (seenValue) { + moreThanOneValueProvidedException(cont.context, mode) + return } + seenValue = true + sub.cancel() + cont.resume(t) } Mode.LAST, Mode.SINGLE, Mode.SINGLE_OR_DEFAULT -> { if ((mode == Mode.SINGLE || mode == Mode.SINGLE_OR_DEFAULT) && seenValue) { - subscription.cancel() - if (cont.isActive) + sub.cancel() + /* the check for `cont.isActive` is needed in case `sub.cancel() above calls `onComplete` or + `onError` on its own. */ + if (cont.isActive) { cont.resumeWithException(IllegalArgumentException("More than one onNext value for $mode")) + } } else { value = t seenValue = true @@ -169,8 +198,16 @@ private suspend fun Publisher.awaitOne( @Suppress("UNCHECKED_CAST") override fun onComplete() { + if (!tryEnterTerminalState("onComplete")) { + return + } if (seenValue) { - if (cont.isActive) cont.resume(value as T) + /* the check for `cont.isActive` is needed because, otherwise, if the publisher doesn't acknowledge the + call to `cancel` for modes `SINGLE*` when more than one value was seen, it may call `onComplete`, and + here `cont.resume` would fail. */ + if (mode != Mode.FIRST_OR_DEFAULT && mode != Mode.FIRST && cont.isActive) { + cont.resume(value as T) + } return } when { @@ -178,14 +215,43 @@ private suspend fun Publisher.awaitOne( cont.resume(default as T) } cont.isActive -> { + // the check for `cont.isActive` is just a slight optimization and doesn't affect correctness cont.resumeWithException(NoSuchElementException("No value received via onNext for $mode")) } } } override fun onError(e: Throwable) { - cont.resumeWithException(e) + if (tryEnterTerminalState("onError")) { + cont.resumeWithException(e) + } + } + + /** + * Enforce rule 2.4: assume that the [Publisher] is in a terminal state after [onError] or [onComplete]. + */ + private fun tryEnterTerminalState(signalName: String): Boolean { + if (inTerminalState) { + gotSignalInTerminalStateException(cont.context, signalName) + return false + } + inTerminalState = true + return true } }) } +/** + * Enforce rule 2.4 (detect publishers that don't respect rule 1.7): don't process anything after a terminal + * state was reached. + */ +private fun gotSignalInTerminalStateException(context: CoroutineContext, signalName: String) = + handleCoroutineException(context, + IllegalStateException("'$signalName' was called after the publisher already signalled being in a terminal state")) + +/** + * Enforce rule 1.1: it is invalid for a publisher to provide more values than requested. + */ +private fun moreThanOneValueProvidedException(context: CoroutineContext, mode: Mode) = + handleCoroutineException(context, + IllegalStateException("Only a single value was requested in '$mode', but the publisher provided more")) \ No newline at end of file diff --git a/reactive/kotlinx-coroutines-reactive/test/IntegrationTest.kt b/reactive/kotlinx-coroutines-reactive/test/IntegrationTest.kt index 18cd012d16..a1467080ac 100644 --- a/reactive/kotlinx-coroutines-reactive/test/IntegrationTest.kt +++ b/reactive/kotlinx-coroutines-reactive/test/IntegrationTest.kt @@ -9,6 +9,8 @@ import org.junit.Test import org.junit.runner.* import org.junit.runners.* import org.reactivestreams.* +import java.lang.IllegalStateException +import java.lang.RuntimeException import kotlin.coroutines.* import kotlin.test.* @@ -130,6 +132,136 @@ class IntegrationTest( finish(3) } + /** + * Test that the continuation is not being resumed after it has already failed due to there having been too many + * values passed. + */ + @Test + fun testNotCompletingFailedAwait() = runTest { + try { + expect(1) + Publisher { sub -> + sub.onSubscribe(object: Subscription { + override fun request(n: Long) { + expect(2) + sub.onNext(1) + sub.onNext(2) + expect(4) + sub.onComplete() + } + + override fun cancel() { + expect(3) + } + }) + }.awaitSingle() + } catch (e: java.lang.IllegalArgumentException) { + expect(5) + } + finish(6) + } + + /** + * Test the behavior of [awaitOne] on unconforming publishers. + */ + @Test + fun testAwaitOnNonconformingPublishers() = runTest { + fun publisher(block: Subscriber.(n: Long) -> Unit) = + Publisher { subscriber -> + subscriber.onSubscribe(object: Subscription { + override fun request(n: Long) { + subscriber.block(n) + } + + override fun cancel() { + } + }) + } + val dummyMessage = "dummy" + val dummyThrowable = RuntimeException(dummyMessage) + suspend fun assertDetectsBadPublisher( + operation: suspend Publisher.() -> T, + message: String, + block: Subscriber.(n: Long) -> Unit, + ) { + assertCallsExceptionHandlerWith { + try { + publisher(block).operation() + } catch (e: Throwable) { + if (e.message != dummyMessage) + throw e + } + }.let { + assertTrue("Expected the message to contain '$message', got '${it.message}'") { + it.message?.contains(message) ?: false + } + } + } + + // Rule 1.1 broken: the publisher produces more values than requested. + assertDetectsBadPublisher({ awaitFirst() }, "provided more") { + onNext(1) + onNext(2) + onComplete() + } + + // Rule 1.7 broken: the publisher calls a method on a subscriber after reaching the terminal state. + assertDetectsBadPublisher({ awaitSingle() }, "terminal state") { + onNext(1) + onError(dummyThrowable) + onComplete() + } + assertDetectsBadPublisher({ awaitSingleOrDefault(2) }, "terminal state") { + onComplete() + onError(dummyThrowable) + } + assertDetectsBadPublisher({ awaitFirst() }, "terminal state") { + onNext(0) + onComplete() + onComplete() + } + assertDetectsBadPublisher({ awaitFirstOrDefault(1) }, "terminal state") { + onComplete() + onNext(3) + } + assertDetectsBadPublisher({ awaitSingle() }, "terminal state") { + onError(dummyThrowable) + onNext(3) + } + + // Rule 1.9 broken (the first signal to the subscriber was not 'onSubscribe') + assertCallsExceptionHandlerWith { + try { + Publisher { subscriber -> + subscriber.onNext(3) + subscriber.onComplete() + }.awaitFirst() + } catch (e: NoSuchElementException) { + // intentionally blank + } + }.let { assertTrue(it.message?.contains("onSubscribe") ?: false) } + } + + private suspend inline fun assertCallsExceptionHandlerWith( + crossinline operation: suspend () -> Unit): E + { + val caughtExceptions = mutableListOf() + val exceptionHandler = object: AbstractCoroutineContextElement(CoroutineExceptionHandler), + CoroutineExceptionHandler + { + override fun handleException(context: CoroutineContext, exception: Throwable) { + caughtExceptions += exception + } + } + return withContext(exceptionHandler) { + operation() + caughtExceptions.single().let { + assertTrue(it is E) + it + } + } + } + private suspend fun checkNumbers(n: Int, pub: Publisher) { var last = 0 pub.collect {