Skip to content

Commit ddc516d

Browse files
dkhalanskyjbpablobaxter
authored andcommitted
Make the subscriber in awaitOne less permissive (Kotlin#2586)
The implementation of Reactive Streams' Subscriber used for `await*` operations was assuming that the publisher is correct. Now, the implementation detects some instances of problematic behavior for publishers and reports them. Fixes Kotlin#2079
1 parent cb682f5 commit ddc516d

File tree

2 files changed

+211
-13
lines changed

2 files changed

+211
-13
lines changed

reactive/kotlinx-coroutines-reactive/src/Await.kt

+79-13
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,11 @@
44

55
package kotlinx.coroutines.reactive
66

7-
import kotlinx.coroutines.CancellationException
8-
import kotlinx.coroutines.Job
9-
import kotlinx.coroutines.suspendCancellableCoroutine
7+
import kotlinx.coroutines.*
108
import org.reactivestreams.Publisher
119
import org.reactivestreams.Subscriber
1210
import org.reactivestreams.Subscription
11+
import java.lang.IllegalStateException
1312
import java.util.*
1413
import kotlin.coroutines.*
1514

@@ -134,31 +133,61 @@ private suspend fun <T> Publisher<T>.awaitOne(
134133
mode: Mode,
135134
default: T? = null
136135
): T = suspendCancellableCoroutine { cont ->
136+
/* This implementation must obey
137+
https://github.com/reactive-streams/reactive-streams-jvm/blob/v1.0.3/README.md#2-subscriber-code
138+
The numbers of rules are taken from there. */
137139
injectCoroutineContext(cont.context).subscribe(object : Subscriber<T> {
138-
private lateinit var subscription: Subscription
140+
// It is unclear whether 2.13 implies (T: Any), but if so, it seems that we don't break anything by not adhering
141+
private var subscription: Subscription? = null
139142
private var value: T? = null
140143
private var seenValue = false
144+
private var inTerminalState = false
141145

142146
override fun onSubscribe(sub: Subscription) {
147+
/** cancelling the new subscription due to rule 2.5, though the publisher would either have to
148+
* subscribe more than once, which would break 2.12, or leak this [Subscriber]. */
149+
if (subscription != null) {
150+
sub.cancel()
151+
return
152+
}
143153
subscription = sub
144154
cont.invokeOnCancellation { sub.cancel() }
145-
sub.request(if (mode == Mode.FIRST) 1 else Long.MAX_VALUE)
155+
sub.request(if (mode == Mode.FIRST || mode == Mode.FIRST_OR_DEFAULT) 1 else Long.MAX_VALUE)
146156
}
147157

148158
override fun onNext(t: T) {
159+
val sub = subscription.let {
160+
if (it == null) {
161+
/** Enforce rule 1.9: expect [Subscriber.onSubscribe] before any other signals. */
162+
handleCoroutineException(cont.context,
163+
IllegalStateException("'onNext' was called before 'onSubscribe'"))
164+
return
165+
} else {
166+
it
167+
}
168+
}
169+
if (inTerminalState) {
170+
gotSignalInTerminalStateException(cont.context, "onNext")
171+
return
172+
}
149173
when (mode) {
150174
Mode.FIRST, Mode.FIRST_OR_DEFAULT -> {
151-
if (!seenValue) {
152-
seenValue = true
153-
subscription.cancel()
154-
cont.resume(t)
175+
if (seenValue) {
176+
moreThanOneValueProvidedException(cont.context, mode)
177+
return
155178
}
179+
seenValue = true
180+
sub.cancel()
181+
cont.resume(t)
156182
}
157183
Mode.LAST, Mode.SINGLE, Mode.SINGLE_OR_DEFAULT -> {
158184
if ((mode == Mode.SINGLE || mode == Mode.SINGLE_OR_DEFAULT) && seenValue) {
159-
subscription.cancel()
160-
if (cont.isActive)
185+
sub.cancel()
186+
/* the check for `cont.isActive` is needed in case `sub.cancel() above calls `onComplete` or
187+
`onError` on its own. */
188+
if (cont.isActive) {
161189
cont.resumeWithException(IllegalArgumentException("More than one onNext value for $mode"))
190+
}
162191
} else {
163192
value = t
164193
seenValue = true
@@ -169,23 +198,60 @@ private suspend fun <T> Publisher<T>.awaitOne(
169198

170199
@Suppress("UNCHECKED_CAST")
171200
override fun onComplete() {
201+
if (!tryEnterTerminalState("onComplete")) {
202+
return
203+
}
172204
if (seenValue) {
173-
if (cont.isActive) cont.resume(value as T)
205+
/* the check for `cont.isActive` is needed because, otherwise, if the publisher doesn't acknowledge the
206+
call to `cancel` for modes `SINGLE*` when more than one value was seen, it may call `onComplete`, and
207+
here `cont.resume` would fail. */
208+
if (mode != Mode.FIRST_OR_DEFAULT && mode != Mode.FIRST && cont.isActive) {
209+
cont.resume(value as T)
210+
}
174211
return
175212
}
176213
when {
177214
(mode == Mode.FIRST_OR_DEFAULT || mode == Mode.SINGLE_OR_DEFAULT) -> {
178215
cont.resume(default as T)
179216
}
180217
cont.isActive -> {
218+
// the check for `cont.isActive` is just a slight optimization and doesn't affect correctness
181219
cont.resumeWithException(NoSuchElementException("No value received via onNext for $mode"))
182220
}
183221
}
184222
}
185223

186224
override fun onError(e: Throwable) {
187-
cont.resumeWithException(e)
225+
if (tryEnterTerminalState("onError")) {
226+
cont.resumeWithException(e)
227+
}
228+
}
229+
230+
/**
231+
* Enforce rule 2.4: assume that the [Publisher] is in a terminal state after [onError] or [onComplete].
232+
*/
233+
private fun tryEnterTerminalState(signalName: String): Boolean {
234+
if (inTerminalState) {
235+
gotSignalInTerminalStateException(cont.context, signalName)
236+
return false
237+
}
238+
inTerminalState = true
239+
return true
188240
}
189241
})
190242
}
191243

244+
/**
245+
* Enforce rule 2.4 (detect publishers that don't respect rule 1.7): don't process anything after a terminal
246+
* state was reached.
247+
*/
248+
private fun gotSignalInTerminalStateException(context: CoroutineContext, signalName: String) =
249+
handleCoroutineException(context,
250+
IllegalStateException("'$signalName' was called after the publisher already signalled being in a terminal state"))
251+
252+
/**
253+
* Enforce rule 1.1: it is invalid for a publisher to provide more values than requested.
254+
*/
255+
private fun moreThanOneValueProvidedException(context: CoroutineContext, mode: Mode) =
256+
handleCoroutineException(context,
257+
IllegalStateException("Only a single value was requested in '$mode', but the publisher provided more"))

reactive/kotlinx-coroutines-reactive/test/IntegrationTest.kt

+132
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ import org.junit.Test
99
import org.junit.runner.*
1010
import org.junit.runners.*
1111
import org.reactivestreams.*
12+
import java.lang.IllegalStateException
13+
import java.lang.RuntimeException
1214
import kotlin.coroutines.*
1315
import kotlin.test.*
1416

@@ -130,6 +132,136 @@ class IntegrationTest(
130132
finish(3)
131133
}
132134

135+
/**
136+
* Test that the continuation is not being resumed after it has already failed due to there having been too many
137+
* values passed.
138+
*/
139+
@Test
140+
fun testNotCompletingFailedAwait() = runTest {
141+
try {
142+
expect(1)
143+
Publisher<Int> { sub ->
144+
sub.onSubscribe(object: Subscription {
145+
override fun request(n: Long) {
146+
expect(2)
147+
sub.onNext(1)
148+
sub.onNext(2)
149+
expect(4)
150+
sub.onComplete()
151+
}
152+
153+
override fun cancel() {
154+
expect(3)
155+
}
156+
})
157+
}.awaitSingle()
158+
} catch (e: java.lang.IllegalArgumentException) {
159+
expect(5)
160+
}
161+
finish(6)
162+
}
163+
164+
/**
165+
* Test the behavior of [awaitOne] on unconforming publishers.
166+
*/
167+
@Test
168+
fun testAwaitOnNonconformingPublishers() = runTest {
169+
fun <T> publisher(block: Subscriber<in T>.(n: Long) -> Unit) =
170+
Publisher<T> { subscriber ->
171+
subscriber.onSubscribe(object: Subscription {
172+
override fun request(n: Long) {
173+
subscriber.block(n)
174+
}
175+
176+
override fun cancel() {
177+
}
178+
})
179+
}
180+
val dummyMessage = "dummy"
181+
val dummyThrowable = RuntimeException(dummyMessage)
182+
suspend fun <T> assertDetectsBadPublisher(
183+
operation: suspend Publisher<T>.() -> T,
184+
message: String,
185+
block: Subscriber<in T>.(n: Long) -> Unit,
186+
) {
187+
assertCallsExceptionHandlerWith<IllegalStateException> {
188+
try {
189+
publisher(block).operation()
190+
} catch (e: Throwable) {
191+
if (e.message != dummyMessage)
192+
throw e
193+
}
194+
}.let {
195+
assertTrue("Expected the message to contain '$message', got '${it.message}'") {
196+
it.message?.contains(message) ?: false
197+
}
198+
}
199+
}
200+
201+
// Rule 1.1 broken: the publisher produces more values than requested.
202+
assertDetectsBadPublisher<Int>({ awaitFirst() }, "provided more") {
203+
onNext(1)
204+
onNext(2)
205+
onComplete()
206+
}
207+
208+
// Rule 1.7 broken: the publisher calls a method on a subscriber after reaching the terminal state.
209+
assertDetectsBadPublisher<Int>({ awaitSingle() }, "terminal state") {
210+
onNext(1)
211+
onError(dummyThrowable)
212+
onComplete()
213+
}
214+
assertDetectsBadPublisher<Int>({ awaitSingleOrDefault(2) }, "terminal state") {
215+
onComplete()
216+
onError(dummyThrowable)
217+
}
218+
assertDetectsBadPublisher<Int>({ awaitFirst() }, "terminal state") {
219+
onNext(0)
220+
onComplete()
221+
onComplete()
222+
}
223+
assertDetectsBadPublisher<Int>({ awaitFirstOrDefault(1) }, "terminal state") {
224+
onComplete()
225+
onNext(3)
226+
}
227+
assertDetectsBadPublisher<Int>({ awaitSingle() }, "terminal state") {
228+
onError(dummyThrowable)
229+
onNext(3)
230+
}
231+
232+
// Rule 1.9 broken (the first signal to the subscriber was not 'onSubscribe')
233+
assertCallsExceptionHandlerWith<IllegalStateException> {
234+
try {
235+
Publisher<Int> { subscriber ->
236+
subscriber.onNext(3)
237+
subscriber.onComplete()
238+
}.awaitFirst()
239+
} catch (e: NoSuchElementException) {
240+
// intentionally blank
241+
}
242+
}.let { assertTrue(it.message?.contains("onSubscribe") ?: false) }
243+
}
244+
245+
private suspend inline fun <reified E: Throwable> assertCallsExceptionHandlerWith(
246+
crossinline operation: suspend () -> Unit): E
247+
{
248+
val caughtExceptions = mutableListOf<Throwable>()
249+
val exceptionHandler = object: AbstractCoroutineContextElement(CoroutineExceptionHandler),
250+
CoroutineExceptionHandler
251+
{
252+
override fun handleException(context: CoroutineContext, exception: Throwable) {
253+
caughtExceptions += exception
254+
}
255+
}
256+
return withContext(exceptionHandler) {
257+
operation()
258+
caughtExceptions.single().let {
259+
assertTrue(it is E)
260+
it
261+
}
262+
}
263+
}
264+
133265
private suspend fun checkNumbers(n: Int, pub: Publisher<Int>) {
134266
var last = 0
135267
pub.collect {

0 commit comments

Comments
 (0)