Skip to content

Make the subscriber in awaitOne less permissive #2586

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 79 additions & 13 deletions reactive/kotlinx-coroutines-reactive/src/Await.kt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@

package kotlinx.coroutines.reactive

import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.Job
import kotlinx.coroutines.suspendCancellableCoroutine
import kotlinx.coroutines.*
import org.reactivestreams.Publisher
import org.reactivestreams.Subscriber
import org.reactivestreams.Subscription
import java.lang.IllegalStateException
import java.util.*
import kotlin.coroutines.*

Expand Down Expand Up @@ -134,31 +133,61 @@ private suspend fun <T> Publisher<T>.awaitOne(
mode: Mode,
default: T? = null
): T = suspendCancellableCoroutine { cont ->
/* This implementation must obey
https://github.com/reactive-streams/reactive-streams-jvm/blob/v1.0.3/README.md#2-subscriber-code
The numbers of rules are taken from there. */
injectCoroutineContext(cont.context).subscribe(object : Subscriber<T> {
private lateinit var subscription: Subscription
// It is unclear whether 2.13 implies (T: Any), but if so, it seems that we don't break anything by not adhering
private var subscription: Subscription? = null
private var value: T? = null
private var seenValue = false
private var inTerminalState = false

override fun onSubscribe(sub: Subscription) {
/** cancelling the new subscription due to rule 2.5, though the publisher would either have to
* subscribe more than once, which would break 2.12, or leak this [Subscriber]. */
if (subscription != null) {
sub.cancel()
return
}
subscription = sub
cont.invokeOnCancellation { sub.cancel() }
sub.request(if (mode == Mode.FIRST) 1 else Long.MAX_VALUE)
sub.request(if (mode == Mode.FIRST || mode == Mode.FIRST_OR_DEFAULT) 1 else Long.MAX_VALUE)
}

override fun onNext(t: T) {
val sub = subscription.let {
if (it == null) {
/** Enforce rule 1.9: expect [Subscriber.onSubscribe] before any other signals. */
handleCoroutineException(cont.context,
IllegalStateException("'onNext' was called before 'onSubscribe'"))
return
} else {
it
}
}
if (inTerminalState) {
gotSignalInTerminalStateException(cont.context, "onNext")
return
}
when (mode) {
Mode.FIRST, Mode.FIRST_OR_DEFAULT -> {
if (!seenValue) {
seenValue = true
subscription.cancel()
cont.resume(t)
if (seenValue) {
moreThanOneValueProvidedException(cont.context, mode)
return
}
seenValue = true
sub.cancel()
cont.resume(t)
}
Mode.LAST, Mode.SINGLE, Mode.SINGLE_OR_DEFAULT -> {
if ((mode == Mode.SINGLE || mode == Mode.SINGLE_OR_DEFAULT) && seenValue) {
subscription.cancel()
if (cont.isActive)
sub.cancel()
/* the check for `cont.isActive` is needed in case `sub.cancel() above calls `onComplete` or
`onError` on its own. */
if (cont.isActive) {
cont.resumeWithException(IllegalArgumentException("More than one onNext value for $mode"))
}
} else {
value = t
seenValue = true
Expand All @@ -169,23 +198,60 @@ private suspend fun <T> Publisher<T>.awaitOne(

@Suppress("UNCHECKED_CAST")
override fun onComplete() {
if (!tryEnterTerminalState("onComplete")) {
return
}
if (seenValue) {
if (cont.isActive) cont.resume(value as T)
/* the check for `cont.isActive` is needed because, otherwise, if the publisher doesn't acknowledge the
call to `cancel` for modes `SINGLE*` when more than one value was seen, it may call `onComplete`, and
here `cont.resume` would fail. */
if (mode != Mode.FIRST_OR_DEFAULT && mode != Mode.FIRST && cont.isActive) {
cont.resume(value as T)
}
return
}
when {
(mode == Mode.FIRST_OR_DEFAULT || mode == Mode.SINGLE_OR_DEFAULT) -> {
cont.resume(default as T)
}
cont.isActive -> {
// the check for `cont.isActive` is just a slight optimization and doesn't affect correctness
cont.resumeWithException(NoSuchElementException("No value received via onNext for $mode"))
}
}
}

override fun onError(e: Throwable) {
cont.resumeWithException(e)
if (tryEnterTerminalState("onError")) {
cont.resumeWithException(e)
}
}

/**
* Enforce rule 2.4: assume that the [Publisher] is in a terminal state after [onError] or [onComplete].
*/
private fun tryEnterTerminalState(signalName: String): Boolean {
if (inTerminalState) {
gotSignalInTerminalStateException(cont.context, signalName)
return false
}
inTerminalState = true
return true
}
})
}

/**
* Enforce rule 2.4 (detect publishers that don't respect rule 1.7): don't process anything after a terminal
* state was reached.
*/
private fun gotSignalInTerminalStateException(context: CoroutineContext, signalName: String) =
handleCoroutineException(context,
IllegalStateException("'$signalName' was called after the publisher already signalled being in a terminal state"))

/**
* Enforce rule 1.1: it is invalid for a publisher to provide more values than requested.
*/
private fun moreThanOneValueProvidedException(context: CoroutineContext, mode: Mode) =
handleCoroutineException(context,
IllegalStateException("Only a single value was requested in '$mode', but the publisher provided more"))
132 changes: 132 additions & 0 deletions reactive/kotlinx-coroutines-reactive/test/IntegrationTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import org.junit.Test
import org.junit.runner.*
import org.junit.runners.*
import org.reactivestreams.*
import java.lang.IllegalStateException
import java.lang.RuntimeException
import kotlin.coroutines.*
import kotlin.test.*

Expand Down Expand Up @@ -130,6 +132,136 @@ class IntegrationTest(
finish(3)
}

/**
* Test that the continuation is not being resumed after it has already failed due to there having been too many
* values passed.
*/
@Test
fun testNotCompletingFailedAwait() = runTest {
try {
expect(1)
Publisher<Int> { sub ->
sub.onSubscribe(object: Subscription {
override fun request(n: Long) {
expect(2)
sub.onNext(1)
sub.onNext(2)
expect(4)
sub.onComplete()
}

override fun cancel() {
expect(3)
}
})
}.awaitSingle()
} catch (e: java.lang.IllegalArgumentException) {
expect(5)
}
finish(6)
}

/**
* Test the behavior of [awaitOne] on unconforming publishers.
*/
@Test
fun testAwaitOnNonconformingPublishers() = runTest {
fun <T> publisher(block: Subscriber<in T>.(n: Long) -> Unit) =
Publisher<T> { subscriber ->
subscriber.onSubscribe(object: Subscription {
override fun request(n: Long) {
subscriber.block(n)
}

override fun cancel() {
}
})
}
val dummyMessage = "dummy"
val dummyThrowable = RuntimeException(dummyMessage)
suspend fun <T> assertDetectsBadPublisher(
operation: suspend Publisher<T>.() -> T,
message: String,
block: Subscriber<in T>.(n: Long) -> Unit,
) {
assertCallsExceptionHandlerWith<IllegalStateException> {
try {
publisher(block).operation()
} catch (e: Throwable) {
if (e.message != dummyMessage)
throw e
}
}.let {
assertTrue("Expected the message to contain '$message', got '${it.message}'") {
it.message?.contains(message) ?: false
}
}
}

// Rule 1.1 broken: the publisher produces more values than requested.
assertDetectsBadPublisher<Int>({ awaitFirst() }, "provided more") {
onNext(1)
onNext(2)
onComplete()
}

// Rule 1.7 broken: the publisher calls a method on a subscriber after reaching the terminal state.
assertDetectsBadPublisher<Int>({ awaitSingle() }, "terminal state") {
onNext(1)
onError(dummyThrowable)
onComplete()
}
assertDetectsBadPublisher<Int>({ awaitSingleOrDefault(2) }, "terminal state") {
onComplete()
onError(dummyThrowable)
}
assertDetectsBadPublisher<Int>({ awaitFirst() }, "terminal state") {
onNext(0)
onComplete()
onComplete()
}
assertDetectsBadPublisher<Int>({ awaitFirstOrDefault(1) }, "terminal state") {
onComplete()
onNext(3)
}
assertDetectsBadPublisher<Int>({ awaitSingle() }, "terminal state") {
onError(dummyThrowable)
onNext(3)
}

// Rule 1.9 broken (the first signal to the subscriber was not 'onSubscribe')
assertCallsExceptionHandlerWith<IllegalStateException> {
try {
Publisher<Int> { subscriber ->
subscriber.onNext(3)
subscriber.onComplete()
}.awaitFirst()
} catch (e: NoSuchElementException) {
// intentionally blank
}
}.let { assertTrue(it.message?.contains("onSubscribe") ?: false) }
}

private suspend inline fun <reified E: Throwable> assertCallsExceptionHandlerWith(
crossinline operation: suspend () -> Unit): E
{
val caughtExceptions = mutableListOf<Throwable>()
val exceptionHandler = object: AbstractCoroutineContextElement(CoroutineExceptionHandler),
CoroutineExceptionHandler
{
override fun handleException(context: CoroutineContext, exception: Throwable) {
caughtExceptions += exception
}
}
return withContext(exceptionHandler) {
operation()
caughtExceptions.single().let {
assertTrue(it is E)
it
}
}
}

private suspend fun checkNumbers(n: Int, pub: Publisher<Int>) {
var last = 0
pub.collect {
Expand Down