Skip to content

asyncRx suspending iterator #13

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
package kotlinx.coroutines

import rx.Observable
import rx.Subscriber
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater
import kotlin.coroutines.Continuation
import kotlin.coroutines.CoroutineIntrinsics

/**
* Suspending iteration extension. It does not have its own buffer and works by arranging rendezvous between
* producer and consumer. It applies back-pressure to the producer as needed. If iterating coroutine does not have a
* dispatcher with its own thread, then the iterating coroutine is resumed and works in the thread that governs
* producer observable.
*/
public suspend operator fun <V : Any> Observable<V>.iterator(): ObservableIterator<V> {
val iterator = ObservableIterator<V>()
subscribe(iterator)
return iterator
}

private sealed class Waiter<in T>(val c: Continuation<T>)
private class HasNextWaiter(c: Continuation<Boolean>) : Waiter<Boolean>(c)
private class NextWaiter<V>(c: Continuation<V>) : Waiter<V>(c)

private object Completed
private class CompletedWith(val v: Any)
private class Error(val e: Throwable)

public class ObservableIterator<V : Any> : Subscriber<V>() {
// Contains either null, Completed, CompletedWith, Error(exception), Waiter, or next value
@Volatile
private var rendezvous: Any? = null

companion object {
private val RENDEZVOUS_UPDATER = AtomicReferenceFieldUpdater
.newUpdater(ObservableIterator::class.java, Any::class.java, "rendezvous")
}

private fun cas(expect: Any?, update: Any?) = RENDEZVOUS_UPDATER.compareAndSet(this, expect, update)

@Suppress("UNCHECKED_CAST")
public suspend operator fun hasNext(): Boolean = CoroutineIntrinsics.suspendCoroutineOrReturn sc@ { c ->
while (true) { // lock-free loop for rendezvous
val cur = rendezvous
when (cur) {
null -> if (cas(null, HasNextWaiter(c))) return@sc CoroutineIntrinsics.SUSPENDED
Completed -> return@sc false
is CompletedWith -> return@sc true
is Error -> throw cur.e
is Waiter<*> -> throw IllegalStateException("Concurrent iteration")
else -> return@sc true
}
}
}

@Suppress("UNCHECKED_CAST")
public suspend operator fun next(): V = CoroutineIntrinsics.suspendCoroutineOrReturn sc@ { c ->
while (true) { // lock-free loop for rendezvous
val cur = rendezvous
when (cur) {
null -> if (cas(null, NextWaiter(c))) return@sc CoroutineIntrinsics.SUSPENDED
Completed -> throw NoSuchElementException()
is CompletedWith -> if (cas(cur, Completed)) return@sc cur.v as V
is Error -> throw cur.e
is Waiter<*> -> throw IllegalStateException("Concurrent iteration")
else -> if (cas(cur, null)) return (cur as V).apply { request(1) }
}
}
}

public override fun onStart() {
request(1)
}

public override fun onError(e: Throwable) {
while (true) { // lock-free loop for rendezvous
val cur = rendezvous
when (cur) {
null -> if (cas(null, Error(e))) return
Completed -> throw IllegalStateException("onError after onCompleted")
is CompletedWith -> throw IllegalStateException("onError after onCompleted")
is Error -> throw IllegalStateException("onError after onError")
is Waiter<*> -> if (cas(cur, null)) {
cur.c.resumeWithException(e)
return
}
else -> throw IllegalStateException("onError after onNext before request(1) was called")
}
}
}

@Suppress("UNCHECKED_CAST")
public override fun onNext(v: V) {
while (true) { // lock-free loop for rendezvous
val cur = rendezvous
when (cur) {
null -> if (cas(null, v)) return
Completed -> throw IllegalStateException("onNext after onCompleted")
is CompletedWith -> throw IllegalStateException("onNext after onCompleted")
is Error -> throw IllegalStateException("onNext after onError")
is HasNextWaiter -> if (cas(cur, v)) {
cur.c.resume(true)
return
}
is NextWaiter<*> -> if (cas(cur, null)) {
(cur as NextWaiter<V>).c.resume(v)
return
}
else -> throw IllegalStateException("onNext after onNext before request(1) was called")
}
}
}

public override fun onCompleted() {
while (true) { // lock-free loop for rendezvous
val cur = rendezvous
when (cur) {
null -> if (cas(null, Completed)) return
Completed -> throw IllegalStateException("onCompleted after onCompleted")
is CompletedWith -> throw IllegalStateException("onCompleted after onCompleted")
is Error -> throw IllegalStateException("onCompleted after onError")
is HasNextWaiter -> if (cas(cur, Completed)) {
cur.c.resume(false)
return
}
is NextWaiter<*> -> if (cas(cur, Completed)) {
cur.c.resumeWithException(NoSuchElementException())
return
}
else -> if (cas(cur, CompletedWith(cur))) return
}
}
}
}
32 changes: 32 additions & 0 deletions kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,38 @@ class AsyncRxTest {
}
}


@Test
fun testAsyncIterator() {
val observable = asyncRx {
var result = ""
for (s in Observable.just("O", "K")) {
result += s
}
result
}

checkObservableWithSingleValue(observable) {
assertEquals("OK", it)
}
}

@Test
fun testAsyncIteratorException() {
val observable = asyncRx {
var result = ""
for (s in Observable.error<String>(RuntimeException("OK"))) {
result += s
}
result
}

checkErroneousObservable(observable) {
assert(it is RuntimeException)
assertEquals("OK", it.message)
}
}

private fun checkErroneousObservable(
observable: Observable<*>,
checker: (Throwable) -> Unit
Expand Down