Skip to content

Commit bae6325

Browse files
committed
Draft async iterator for Rx
1 parent ff14f7b commit bae6325

File tree

3 files changed

+196
-1
lines changed

3 files changed

+196
-1
lines changed

kotlinx-coroutines-rx/src/main/kotlin/asyncRx.kt

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ fun <T> asyncRx(
3939
return result
4040
}
4141

42-
4342
suspend fun <V> Observable<V>.awaitFirst(): V = first().awaitOne()
4443

4544
suspend fun <V> Observable<V>.awaitLast(): V = last().awaitOne()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
package kotlinx.coroutines
2+
3+
import rx.Observable
4+
import rx.Subscriber
5+
import java.util.concurrent.atomic.AtomicReference
6+
import kotlin.coroutines.Continuation
7+
import kotlin.coroutines.suspendCoroutine
8+
9+
// supports suspending iteration on observables
10+
suspend operator fun <V : Any> Observable<V>.iterator(): ObserverIterator<V> {
11+
val iterator = ObserverIterator<V>()
12+
subscribe(iterator.subscriber)
13+
return iterator
14+
}
15+
16+
private sealed class Waiter<in T>(val c: Continuation<T>)
17+
private class HasNextWaiter(c: Continuation<Boolean>) : Waiter<Boolean>(c)
18+
private class NextWaiter<V>(c: Continuation<V>) : Waiter<V>(c)
19+
20+
private object Complete
21+
private class Fail(val e: Throwable)
22+
23+
class ObserverIterator<V : Any> {
24+
internal val subscriber = Sub()
25+
// Contains either null, Complete, Fail(exception), Waiter, or next value
26+
private val rendezvous = AtomicReference<Any?>()
27+
28+
@Suppress("UNCHECKED_CAST")
29+
private suspend fun awaitHasNext(): Boolean = suspendCoroutine sc@ { c ->
30+
while (true) { // lock-free loop for rendezvous
31+
val cur = rendezvous.get()
32+
when (cur) {
33+
null -> {
34+
if (rendezvous.compareAndSet(null, HasNextWaiter(c))) return@sc
35+
}
36+
Complete -> {
37+
c.resume(false)
38+
return@sc
39+
}
40+
is Fail -> {
41+
c.resumeWithException(cur.e)
42+
return@sc
43+
}
44+
is Waiter<*> -> {
45+
c.resumeWithException(IllegalStateException("Concurrent iteration"))
46+
return@sc
47+
}
48+
else -> {
49+
c.resume(true)
50+
return@sc
51+
}
52+
}
53+
}
54+
}
55+
56+
private suspend fun awaitNext(): V = suspendCoroutine sc@ { c ->
57+
while (true) { // lock-free loop for rendezvous
58+
val cur = rendezvous.get()
59+
when (cur) {
60+
null -> {
61+
if (rendezvous.compareAndSet(null, NextWaiter(c))) return@sc
62+
}
63+
Complete -> {
64+
c.resumeWithException(NoSuchElementException())
65+
return@sc
66+
}
67+
is Fail -> {
68+
c.resumeWithException(cur.e)
69+
return@sc
70+
}
71+
is Waiter<*> -> {
72+
c.resumeWithException(IllegalStateException("Concurrent iteration"))
73+
return@sc
74+
}
75+
else -> {
76+
if (rendezvous.compareAndSet(cur, null)) {
77+
c.resume(consumeValue(cur))
78+
return@sc
79+
}
80+
}
81+
}
82+
}
83+
}
84+
85+
suspend operator fun hasNext(): Boolean {
86+
val cur = rendezvous.get()
87+
return when (cur) {
88+
null -> awaitHasNext()
89+
Complete -> false
90+
is Fail -> throw cur.e
91+
is Waiter<*> -> throw IllegalStateException("Concurrent iteration")
92+
else -> true
93+
}
94+
}
95+
96+
suspend operator fun next(): V {
97+
while (true) { // lock-free loop for rendezvous
98+
val cur = rendezvous.get()
99+
when (cur) {
100+
null -> return awaitNext()
101+
Complete -> throw NoSuchElementException()
102+
is Fail -> throw cur.e
103+
is Waiter<*> -> throw IllegalStateException("Concurrent iteration")
104+
else -> if (rendezvous.compareAndSet(cur, null)) return consumeValue(cur)
105+
}
106+
}
107+
}
108+
109+
@Suppress("UNCHECKED_CAST")
110+
private fun consumeValue(cur: Any?): V {
111+
subscriber.requestOne()
112+
return cur as V
113+
}
114+
115+
internal inner class Sub : Subscriber<V>() {
116+
fun requestOne() {
117+
request(1)
118+
}
119+
120+
override fun onStart() {
121+
requestOne()
122+
}
123+
124+
override fun onError(e: Throwable) {
125+
while (true) { // lock-free loop for rendezvous
126+
val cur = rendezvous.get()
127+
when (cur) {
128+
null -> if (rendezvous.compareAndSet(null, Fail(e))) return
129+
Complete -> throw IllegalStateException("onError after onComplete")
130+
is Fail -> throw IllegalStateException("onError after onError")
131+
is Waiter<*> -> if (rendezvous.compareAndSet(cur, null)) {
132+
cur.c.resumeWithException(e)
133+
return
134+
}
135+
else -> throw IllegalStateException("onError after onNext before request(1) was called")
136+
}
137+
}
138+
}
139+
140+
@Suppress("UNCHECKED_CAST")
141+
override fun onNext(v: V) {
142+
while (true) { // lock-free loop for rendezvous
143+
val cur = rendezvous.get()
144+
when (cur) {
145+
null -> if (rendezvous.compareAndSet(null, v)) return
146+
Complete -> throw IllegalStateException("onNext after onComplete")
147+
is Fail -> throw IllegalStateException("onNext after onError")
148+
is HasNextWaiter -> if (rendezvous.compareAndSet(cur, v)) {
149+
cur.c.resume(true)
150+
return
151+
}
152+
is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, null)) {
153+
(cur as NextWaiter<V>).c.resume(v)
154+
return
155+
}
156+
else -> throw IllegalStateException("onNext after onNext before request(1) was called")
157+
}
158+
}
159+
}
160+
161+
override fun onCompleted() {
162+
while (true) { // lock-free loop for rendezvous
163+
val cur = rendezvous.get()
164+
when (cur) {
165+
null -> if (rendezvous.compareAndSet(null, Complete)) return
166+
Complete -> throw IllegalStateException("onComplete after onComplete")
167+
is Fail -> throw IllegalStateException("onComplete after onError")
168+
is HasNextWaiter -> if (rendezvous.compareAndSet(cur, Complete)) {
169+
cur.c.resume(false)
170+
return
171+
}
172+
is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, Complete)) {
173+
cur.c.resumeWithException(NoSuchElementException())
174+
return
175+
}
176+
else -> throw IllegalStateException("onComplete after onNext before request(1) was called")
177+
}
178+
}
179+
}
180+
}
181+
}

kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt

+15
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,21 @@ class AsyncRxTest {
126126
}
127127
}
128128

129+
130+
@Test
131+
fun testAsyncIterator() {
132+
val observable = asyncRx {
133+
val sb = StringBuilder()
134+
for (s in Observable.just("O", "K"))
135+
sb.append(s)
136+
sb.toString()
137+
}
138+
139+
checkObservableWithSingleValue(observable) {
140+
assertEquals("OK", it)
141+
}
142+
}
143+
129144
private fun checkErroneousObservable(
130145
observable: Observable<*>,
131146
checker: (Throwable) -> Unit

0 commit comments

Comments
 (0)