From a13d35ac178e4a073edb6f8f7449c75281cbe004 Mon Sep 17 00:00:00 2001 From: Roman Elizarov Date: Wed, 21 Dec 2016 13:54:31 +0300 Subject: [PATCH 1/3] Draft async iterator for Rx --- .../src/main/kotlin/asyncRxIterator.kt | 181 ++++++++++++++++++ .../src/test/kotlin/AsyncRxTest.kt | 15 ++ 2 files changed, 196 insertions(+) create mode 100644 kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt diff --git a/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt b/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt new file mode 100644 index 0000000000..3b664658c5 --- /dev/null +++ b/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt @@ -0,0 +1,181 @@ +package kotlinx.coroutines + +import rx.Observable +import rx.Subscriber +import java.util.concurrent.atomic.AtomicReference +import kotlin.coroutines.Continuation +import kotlin.coroutines.suspendCoroutine + +// supports suspending iteration on observables +suspend operator fun Observable.iterator(): ObserverIterator { + val iterator = ObserverIterator() + subscribe(iterator.subscriber) + return iterator +} + +private sealed class Waiter(val c: Continuation) +private class HasNextWaiter(c: Continuation) : Waiter(c) +private class NextWaiter(c: Continuation) : Waiter(c) + +private object Complete +private class Fail(val e: Throwable) + +class ObserverIterator { + internal val subscriber = Sub() + // Contains either null, Complete, Fail(exception), Waiter, or next value + private val rendezvous = AtomicReference() + + @Suppress("UNCHECKED_CAST") + private suspend fun awaitHasNext(): Boolean = suspendCoroutine sc@ { c -> + while (true) { // lock-free loop for rendezvous + val cur = rendezvous.get() + when (cur) { + null -> { + if (rendezvous.compareAndSet(null, HasNextWaiter(c))) return@sc + } + Complete -> { + c.resume(false) + return@sc + } + is Fail -> { + c.resumeWithException(cur.e) + return@sc + } + is Waiter<*> -> { + c.resumeWithException(IllegalStateException("Concurrent iteration")) + return@sc + } + else -> { + c.resume(true) + return@sc + } + } + } + } + + private suspend fun awaitNext(): V = suspendCoroutine sc@ { c -> + while (true) { // lock-free loop for rendezvous + val cur = rendezvous.get() + when (cur) { + null -> { + if (rendezvous.compareAndSet(null, NextWaiter(c))) return@sc + } + Complete -> { + c.resumeWithException(NoSuchElementException()) + return@sc + } + is Fail -> { + c.resumeWithException(cur.e) + return@sc + } + is Waiter<*> -> { + c.resumeWithException(IllegalStateException("Concurrent iteration")) + return@sc + } + else -> { + if (rendezvous.compareAndSet(cur, null)) { + c.resume(consumeValue(cur)) + return@sc + } + } + } + } + } + + suspend operator fun hasNext(): Boolean { + val cur = rendezvous.get() + return when (cur) { + null -> awaitHasNext() + Complete -> false + is Fail -> throw cur.e + is Waiter<*> -> throw IllegalStateException("Concurrent iteration") + else -> true + } + } + + suspend operator fun next(): V { + while (true) { // lock-free loop for rendezvous + val cur = rendezvous.get() + when (cur) { + null -> return awaitNext() + Complete -> throw NoSuchElementException() + is Fail -> throw cur.e + is Waiter<*> -> throw IllegalStateException("Concurrent iteration") + else -> if (rendezvous.compareAndSet(cur, null)) return consumeValue(cur) + } + } + } + + @Suppress("UNCHECKED_CAST") + private fun consumeValue(cur: Any?): V { + subscriber.requestOne() + return cur as V + } + + internal inner class Sub : Subscriber() { + fun requestOne() { + request(1) + } + + override fun onStart() { + requestOne() + } + + override fun onError(e: Throwable) { + while (true) { // lock-free loop for rendezvous + val cur = rendezvous.get() + when (cur) { + null -> if (rendezvous.compareAndSet(null, Fail(e))) return + Complete -> throw IllegalStateException("onError after onComplete") + is Fail -> throw IllegalStateException("onError after onError") + is Waiter<*> -> if (rendezvous.compareAndSet(cur, null)) { + cur.c.resumeWithException(e) + return + } + else -> throw IllegalStateException("onError after onNext before request(1) was called") + } + } + } + + @Suppress("UNCHECKED_CAST") + override fun onNext(v: V) { + while (true) { // lock-free loop for rendezvous + val cur = rendezvous.get() + when (cur) { + null -> if (rendezvous.compareAndSet(null, v)) return + Complete -> throw IllegalStateException("onNext after onComplete") + is Fail -> throw IllegalStateException("onNext after onError") + is HasNextWaiter -> if (rendezvous.compareAndSet(cur, v)) { + cur.c.resume(true) + return + } + is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, null)) { + (cur as NextWaiter).c.resume(v) + return + } + else -> throw IllegalStateException("onNext after onNext before request(1) was called") + } + } + } + + override fun onCompleted() { + while (true) { // lock-free loop for rendezvous + val cur = rendezvous.get() + when (cur) { + null -> if (rendezvous.compareAndSet(null, Complete)) return + Complete -> throw IllegalStateException("onComplete after onComplete") + is Fail -> throw IllegalStateException("onComplete after onError") + is HasNextWaiter -> if (rendezvous.compareAndSet(cur, Complete)) { + cur.c.resume(false) + return + } + is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, Complete)) { + cur.c.resumeWithException(NoSuchElementException()) + return + } + else -> throw IllegalStateException("onComplete after onNext before request(1) was called") + } + } + } + } +} diff --git a/kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt b/kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt index c9677a567b..aa394caa76 100644 --- a/kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt +++ b/kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt @@ -126,6 +126,21 @@ class AsyncRxTest { } } + + @Test + fun testAsyncIterator() { + val observable = asyncRx { + val sb = StringBuilder() + for (s in Observable.just("O", "K")) + sb.append(s) + sb.toString() + } + + checkObservableWithSingleValue(observable) { + assertEquals("OK", it) + } + } + private fun checkErroneousObservable( observable: Observable<*>, checker: (Throwable) -> Unit From 9216705cc8572227e36821591da6dab9db801288 Mon Sep 17 00:00:00 2001 From: Roman Elizarov Date: Wed, 21 Dec 2016 15:11:12 +0300 Subject: [PATCH 2/3] Draft optimized async iterator via CoroutineIntrinsics --- .../src/main/kotlin/asyncRxIterator.kt | 126 ++++++------------ .../src/test/kotlin/AsyncRxTest.kt | 25 +++- 2 files changed, 60 insertions(+), 91 deletions(-) diff --git a/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt b/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt index 3b664658c5..013269f334 100644 --- a/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt +++ b/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt @@ -4,9 +4,14 @@ import rx.Observable import rx.Subscriber import java.util.concurrent.atomic.AtomicReference import kotlin.coroutines.Continuation -import kotlin.coroutines.suspendCoroutine +import kotlin.coroutines.CoroutineIntrinsics.SUSPENDED +import kotlin.coroutines.CoroutineIntrinsics.suspendCoroutineOrReturn -// supports suspending iteration on observables +/** + * Suspending iteration extension. It does not have its own buffer and applies back-pressure to the source. + * If iterating coroutine does not have a dispatcher with its own thread, then the iterating coroutine + * is resumed in the thread that invokes [Subscriber.onNext]. + */ suspend operator fun Observable.iterator(): ObserverIterator { val iterator = ObserverIterator() subscribe(iterator.subscriber) @@ -17,101 +22,45 @@ private sealed class Waiter(val c: Continuation) private class HasNextWaiter(c: Continuation) : Waiter(c) private class NextWaiter(c: Continuation) : Waiter(c) -private object Complete -private class Fail(val e: Throwable) +private object Completed +private class CompletedWith(val v: Any) +private class Error(val e: Throwable) class ObserverIterator { internal val subscriber = Sub() - // Contains either null, Complete, Fail(exception), Waiter, or next value + // Contains either null, Completed, CompletedWith, Error(exception), Waiter, or next value private val rendezvous = AtomicReference() @Suppress("UNCHECKED_CAST") - private suspend fun awaitHasNext(): Boolean = suspendCoroutine sc@ { c -> + suspend operator fun hasNext(): Boolean = suspendCoroutineOrReturn sc@ { c -> while (true) { // lock-free loop for rendezvous val cur = rendezvous.get() when (cur) { - null -> { - if (rendezvous.compareAndSet(null, HasNextWaiter(c))) return@sc - } - Complete -> { - c.resume(false) - return@sc - } - is Fail -> { - c.resumeWithException(cur.e) - return@sc - } - is Waiter<*> -> { - c.resumeWithException(IllegalStateException("Concurrent iteration")) - return@sc - } - else -> { - c.resume(true) - return@sc - } - } - } - } - - private suspend fun awaitNext(): V = suspendCoroutine sc@ { c -> - while (true) { // lock-free loop for rendezvous - val cur = rendezvous.get() - when (cur) { - null -> { - if (rendezvous.compareAndSet(null, NextWaiter(c))) return@sc - } - Complete -> { - c.resumeWithException(NoSuchElementException()) - return@sc - } - is Fail -> { - c.resumeWithException(cur.e) - return@sc - } - is Waiter<*> -> { - c.resumeWithException(IllegalStateException("Concurrent iteration")) - return@sc - } - else -> { - if (rendezvous.compareAndSet(cur, null)) { - c.resume(consumeValue(cur)) - return@sc - } - } + null -> if (rendezvous.compareAndSet(null, HasNextWaiter(c))) return@sc SUSPENDED + Completed -> return@sc false + is CompletedWith -> return@sc true + is Error -> throw cur.e + is Waiter<*> -> throw IllegalStateException("Concurrent iteration") + else -> return@sc true } } } - suspend operator fun hasNext(): Boolean { - val cur = rendezvous.get() - return when (cur) { - null -> awaitHasNext() - Complete -> false - is Fail -> throw cur.e - is Waiter<*> -> throw IllegalStateException("Concurrent iteration") - else -> true - } - } - - suspend operator fun next(): V { + @Suppress("UNCHECKED_CAST") + suspend operator fun next(): V = suspendCoroutineOrReturn sc@ { c -> while (true) { // lock-free loop for rendezvous val cur = rendezvous.get() when (cur) { - null -> return awaitNext() - Complete -> throw NoSuchElementException() - is Fail -> throw cur.e + null -> if (rendezvous.compareAndSet(null, NextWaiter(c))) return@sc SUSPENDED + Completed -> throw NoSuchElementException() + is CompletedWith -> if (rendezvous.compareAndSet(cur, Completed)) return@sc cur.v as V + is Error -> throw cur.e is Waiter<*> -> throw IllegalStateException("Concurrent iteration") - else -> if (rendezvous.compareAndSet(cur, null)) return consumeValue(cur) + else -> if (rendezvous.compareAndSet(cur, null)) return (cur as V).apply { subscriber.requestOne() } } } } - @Suppress("UNCHECKED_CAST") - private fun consumeValue(cur: Any?): V { - subscriber.requestOne() - return cur as V - } - internal inner class Sub : Subscriber() { fun requestOne() { request(1) @@ -125,9 +74,10 @@ class ObserverIterator { while (true) { // lock-free loop for rendezvous val cur = rendezvous.get() when (cur) { - null -> if (rendezvous.compareAndSet(null, Fail(e))) return - Complete -> throw IllegalStateException("onError after onComplete") - is Fail -> throw IllegalStateException("onError after onError") + null -> if (rendezvous.compareAndSet(null, Error(e))) return + Completed -> throw IllegalStateException("onError after onCompleted") + is CompletedWith -> throw IllegalStateException("onError after onCompleted") + is Error -> throw IllegalStateException("onError after onError") is Waiter<*> -> if (rendezvous.compareAndSet(cur, null)) { cur.c.resumeWithException(e) return @@ -143,8 +93,9 @@ class ObserverIterator { val cur = rendezvous.get() when (cur) { null -> if (rendezvous.compareAndSet(null, v)) return - Complete -> throw IllegalStateException("onNext after onComplete") - is Fail -> throw IllegalStateException("onNext after onError") + Completed -> throw IllegalStateException("onNext after onCompleted") + is CompletedWith -> throw IllegalStateException("onNext after onCompleted") + is Error -> throw IllegalStateException("onNext after onError") is HasNextWaiter -> if (rendezvous.compareAndSet(cur, v)) { cur.c.resume(true) return @@ -162,18 +113,19 @@ class ObserverIterator { while (true) { // lock-free loop for rendezvous val cur = rendezvous.get() when (cur) { - null -> if (rendezvous.compareAndSet(null, Complete)) return - Complete -> throw IllegalStateException("onComplete after onComplete") - is Fail -> throw IllegalStateException("onComplete after onError") - is HasNextWaiter -> if (rendezvous.compareAndSet(cur, Complete)) { + null -> if (rendezvous.compareAndSet(null, Completed)) return + Completed -> throw IllegalStateException("onCompleted after onCompleted") + is CompletedWith -> throw IllegalStateException("onCompleted after onCompleted") + is Error -> throw IllegalStateException("onCompleted after onError") + is HasNextWaiter -> if (rendezvous.compareAndSet(cur, Completed)) { cur.c.resume(false) return } - is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, Complete)) { + is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, Completed)) { cur.c.resumeWithException(NoSuchElementException()) return } - else -> throw IllegalStateException("onComplete after onNext before request(1) was called") + else -> if (rendezvous.compareAndSet(cur, CompletedWith(cur))) return } } } diff --git a/kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt b/kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt index aa394caa76..05d5b4d49e 100644 --- a/kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt +++ b/kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt @@ -130,10 +130,11 @@ class AsyncRxTest { @Test fun testAsyncIterator() { val observable = asyncRx { - val sb = StringBuilder() - for (s in Observable.just("O", "K")) - sb.append(s) - sb.toString() + var result = "" + for (s in Observable.just("O", "K")) { + result += s + } + result } checkObservableWithSingleValue(observable) { @@ -141,6 +142,22 @@ class AsyncRxTest { } } + @Test + fun testAsyncIteratorException() { + val observable = asyncRx { + var result = "" + for (s in Observable.error(RuntimeException("OK"))) { + result += s + } + result + } + + checkErroneousObservable(observable) { + assert(it is RuntimeException) + assertEquals("OK", it.message) + } + } + private fun checkErroneousObservable( observable: Observable<*>, checker: (Throwable) -> Unit From af6239789f8574a21fa1aefb9798cce4510426e8 Mon Sep 17 00:00:00 2001 From: Roman Elizarov Date: Wed, 21 Dec 2016 15:40:09 +0300 Subject: [PATCH 3/3] Optimized version of asyncRx iterator --- .../src/main/kotlin/asyncRxIterator.kt | 153 +++++++++--------- 1 file changed, 77 insertions(+), 76 deletions(-) diff --git a/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt b/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt index 013269f334..ce4020f1c1 100644 --- a/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt +++ b/kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt @@ -2,19 +2,19 @@ package kotlinx.coroutines import rx.Observable import rx.Subscriber -import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater import kotlin.coroutines.Continuation -import kotlin.coroutines.CoroutineIntrinsics.SUSPENDED -import kotlin.coroutines.CoroutineIntrinsics.suspendCoroutineOrReturn +import kotlin.coroutines.CoroutineIntrinsics /** - * Suspending iteration extension. It does not have its own buffer and applies back-pressure to the source. - * If iterating coroutine does not have a dispatcher with its own thread, then the iterating coroutine - * is resumed in the thread that invokes [Subscriber.onNext]. + * Suspending iteration extension. It does not have its own buffer and works by arranging rendezvous between + * producer and consumer. It applies back-pressure to the producer as needed. If iterating coroutine does not have a + * dispatcher with its own thread, then the iterating coroutine is resumed and works in the thread that governs + * producer observable. */ -suspend operator fun Observable.iterator(): ObserverIterator { - val iterator = ObserverIterator() - subscribe(iterator.subscriber) +public suspend operator fun Observable.iterator(): ObservableIterator { + val iterator = ObservableIterator() + subscribe(iterator) return iterator } @@ -26,17 +26,24 @@ private object Completed private class CompletedWith(val v: Any) private class Error(val e: Throwable) -class ObserverIterator { - internal val subscriber = Sub() +public class ObservableIterator : Subscriber() { // Contains either null, Completed, CompletedWith, Error(exception), Waiter, or next value - private val rendezvous = AtomicReference() + @Volatile + private var rendezvous: Any? = null + + companion object { + private val RENDEZVOUS_UPDATER = AtomicReferenceFieldUpdater + .newUpdater(ObservableIterator::class.java, Any::class.java, "rendezvous") + } + + private fun cas(expect: Any?, update: Any?) = RENDEZVOUS_UPDATER.compareAndSet(this, expect, update) @Suppress("UNCHECKED_CAST") - suspend operator fun hasNext(): Boolean = suspendCoroutineOrReturn sc@ { c -> + public suspend operator fun hasNext(): Boolean = CoroutineIntrinsics.suspendCoroutineOrReturn sc@ { c -> while (true) { // lock-free loop for rendezvous - val cur = rendezvous.get() + val cur = rendezvous when (cur) { - null -> if (rendezvous.compareAndSet(null, HasNextWaiter(c))) return@sc SUSPENDED + null -> if (cas(null, HasNextWaiter(c))) return@sc CoroutineIntrinsics.SUSPENDED Completed -> return@sc false is CompletedWith -> return@sc true is Error -> throw cur.e @@ -47,86 +54,80 @@ class ObserverIterator { } @Suppress("UNCHECKED_CAST") - suspend operator fun next(): V = suspendCoroutineOrReturn sc@ { c -> + public suspend operator fun next(): V = CoroutineIntrinsics.suspendCoroutineOrReturn sc@ { c -> while (true) { // lock-free loop for rendezvous - val cur = rendezvous.get() + val cur = rendezvous when (cur) { - null -> if (rendezvous.compareAndSet(null, NextWaiter(c))) return@sc SUSPENDED + null -> if (cas(null, NextWaiter(c))) return@sc CoroutineIntrinsics.SUSPENDED Completed -> throw NoSuchElementException() - is CompletedWith -> if (rendezvous.compareAndSet(cur, Completed)) return@sc cur.v as V + is CompletedWith -> if (cas(cur, Completed)) return@sc cur.v as V is Error -> throw cur.e is Waiter<*> -> throw IllegalStateException("Concurrent iteration") - else -> if (rendezvous.compareAndSet(cur, null)) return (cur as V).apply { subscriber.requestOne() } + else -> if (cas(cur, null)) return (cur as V).apply { request(1) } } } } - internal inner class Sub : Subscriber() { - fun requestOne() { - request(1) - } - - override fun onStart() { - requestOne() - } + public override fun onStart() { + request(1) + } - override fun onError(e: Throwable) { - while (true) { // lock-free loop for rendezvous - val cur = rendezvous.get() - when (cur) { - null -> if (rendezvous.compareAndSet(null, Error(e))) return - Completed -> throw IllegalStateException("onError after onCompleted") - is CompletedWith -> throw IllegalStateException("onError after onCompleted") - is Error -> throw IllegalStateException("onError after onError") - is Waiter<*> -> if (rendezvous.compareAndSet(cur, null)) { - cur.c.resumeWithException(e) - return - } - else -> throw IllegalStateException("onError after onNext before request(1) was called") + public override fun onError(e: Throwable) { + while (true) { // lock-free loop for rendezvous + val cur = rendezvous + when (cur) { + null -> if (cas(null, Error(e))) return + Completed -> throw IllegalStateException("onError after onCompleted") + is CompletedWith -> throw IllegalStateException("onError after onCompleted") + is Error -> throw IllegalStateException("onError after onError") + is Waiter<*> -> if (cas(cur, null)) { + cur.c.resumeWithException(e) + return } + else -> throw IllegalStateException("onError after onNext before request(1) was called") } } + } - @Suppress("UNCHECKED_CAST") - override fun onNext(v: V) { - while (true) { // lock-free loop for rendezvous - val cur = rendezvous.get() - when (cur) { - null -> if (rendezvous.compareAndSet(null, v)) return - Completed -> throw IllegalStateException("onNext after onCompleted") - is CompletedWith -> throw IllegalStateException("onNext after onCompleted") - is Error -> throw IllegalStateException("onNext after onError") - is HasNextWaiter -> if (rendezvous.compareAndSet(cur, v)) { - cur.c.resume(true) - return - } - is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, null)) { - (cur as NextWaiter).c.resume(v) - return - } - else -> throw IllegalStateException("onNext after onNext before request(1) was called") + @Suppress("UNCHECKED_CAST") + public override fun onNext(v: V) { + while (true) { // lock-free loop for rendezvous + val cur = rendezvous + when (cur) { + null -> if (cas(null, v)) return + Completed -> throw IllegalStateException("onNext after onCompleted") + is CompletedWith -> throw IllegalStateException("onNext after onCompleted") + is Error -> throw IllegalStateException("onNext after onError") + is HasNextWaiter -> if (cas(cur, v)) { + cur.c.resume(true) + return + } + is NextWaiter<*> -> if (cas(cur, null)) { + (cur as NextWaiter).c.resume(v) + return } + else -> throw IllegalStateException("onNext after onNext before request(1) was called") } } + } - override fun onCompleted() { - while (true) { // lock-free loop for rendezvous - val cur = rendezvous.get() - when (cur) { - null -> if (rendezvous.compareAndSet(null, Completed)) return - Completed -> throw IllegalStateException("onCompleted after onCompleted") - is CompletedWith -> throw IllegalStateException("onCompleted after onCompleted") - is Error -> throw IllegalStateException("onCompleted after onError") - is HasNextWaiter -> if (rendezvous.compareAndSet(cur, Completed)) { - cur.c.resume(false) - return - } - is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, Completed)) { - cur.c.resumeWithException(NoSuchElementException()) - return - } - else -> if (rendezvous.compareAndSet(cur, CompletedWith(cur))) return + public override fun onCompleted() { + while (true) { // lock-free loop for rendezvous + val cur = rendezvous + when (cur) { + null -> if (cas(null, Completed)) return + Completed -> throw IllegalStateException("onCompleted after onCompleted") + is CompletedWith -> throw IllegalStateException("onCompleted after onCompleted") + is Error -> throw IllegalStateException("onCompleted after onError") + is HasNextWaiter -> if (cas(cur, Completed)) { + cur.c.resume(false) + return + } + is NextWaiter<*> -> if (cas(cur, Completed)) { + cur.c.resumeWithException(NoSuchElementException()) + return } + else -> if (cas(cur, CompletedWith(cur))) return } } }