Skip to content

Commit c0d91e6

Browse files
committed
Draft optimized async iterator via CoroutineIntrinsics
1 parent bae6325 commit c0d91e6

File tree

2 files changed

+60
-91
lines changed

2 files changed

+60
-91
lines changed

kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt

+39-87
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,14 @@ import rx.Observable
44
import rx.Subscriber
55
import java.util.concurrent.atomic.AtomicReference
66
import kotlin.coroutines.Continuation
7-
import kotlin.coroutines.suspendCoroutine
7+
import kotlin.coroutines.CoroutineIntrinsics.SUSPENDED
8+
import kotlin.coroutines.CoroutineIntrinsics.suspendCoroutineOrReturn
89

9-
// supports suspending iteration on observables
10+
/**
11+
* Suspending iteration extension. It does not have its own buffer and applies back-pressure to the source.
12+
* If iterating coroutine does not have a dispatcher with its own thread, then the iterating coroutine
13+
* is resumed in the thread that invokes [Subscriber.onNext].
14+
*/
1015
suspend operator fun <V : Any> Observable<V>.iterator(): ObserverIterator<V> {
1116
val iterator = ObserverIterator<V>()
1217
subscribe(iterator.subscriber)
@@ -17,101 +22,45 @@ private sealed class Waiter<in T>(val c: Continuation<T>)
1722
private class HasNextWaiter(c: Continuation<Boolean>) : Waiter<Boolean>(c)
1823
private class NextWaiter<V>(c: Continuation<V>) : Waiter<V>(c)
1924

20-
private object Complete
21-
private class Fail(val e: Throwable)
25+
private object Completed
26+
private class CompletedWith(val v: Any)
27+
private class Error(val e: Throwable)
2228

2329
class ObserverIterator<V : Any> {
2430
internal val subscriber = Sub()
25-
// Contains either null, Complete, Fail(exception), Waiter, or next value
31+
// Contains either null, Completed, CompletedWith, Error(exception), Waiter, or next value
2632
private val rendezvous = AtomicReference<Any?>()
2733

2834
@Suppress("UNCHECKED_CAST")
29-
private suspend fun awaitHasNext(): Boolean = suspendCoroutine sc@ { c ->
35+
suspend operator fun hasNext(): Boolean = suspendCoroutineOrReturn sc@ { c ->
3036
while (true) { // lock-free loop for rendezvous
3137
val cur = rendezvous.get()
3238
when (cur) {
33-
null -> {
34-
if (rendezvous.compareAndSet(null, HasNextWaiter(c))) return@sc
35-
}
36-
Complete -> {
37-
c.resume(false)
38-
return@sc
39-
}
40-
is Fail -> {
41-
c.resumeWithException(cur.e)
42-
return@sc
43-
}
44-
is Waiter<*> -> {
45-
c.resumeWithException(IllegalStateException("Concurrent iteration"))
46-
return@sc
47-
}
48-
else -> {
49-
c.resume(true)
50-
return@sc
51-
}
52-
}
53-
}
54-
}
55-
56-
private suspend fun awaitNext(): V = suspendCoroutine sc@ { c ->
57-
while (true) { // lock-free loop for rendezvous
58-
val cur = rendezvous.get()
59-
when (cur) {
60-
null -> {
61-
if (rendezvous.compareAndSet(null, NextWaiter(c))) return@sc
62-
}
63-
Complete -> {
64-
c.resumeWithException(NoSuchElementException())
65-
return@sc
66-
}
67-
is Fail -> {
68-
c.resumeWithException(cur.e)
69-
return@sc
70-
}
71-
is Waiter<*> -> {
72-
c.resumeWithException(IllegalStateException("Concurrent iteration"))
73-
return@sc
74-
}
75-
else -> {
76-
if (rendezvous.compareAndSet(cur, null)) {
77-
c.resume(consumeValue(cur))
78-
return@sc
79-
}
80-
}
39+
null -> if (rendezvous.compareAndSet(null, HasNextWaiter(c))) return@sc SUSPENDED
40+
Completed -> return@sc false
41+
is CompletedWith -> return@sc true
42+
is Error -> throw cur.e
43+
is Waiter<*> -> throw IllegalStateException("Concurrent iteration")
44+
else -> return@sc true
8145
}
8246
}
8347
}
8448

85-
suspend operator fun hasNext(): Boolean {
86-
val cur = rendezvous.get()
87-
return when (cur) {
88-
null -> awaitHasNext()
89-
Complete -> false
90-
is Fail -> throw cur.e
91-
is Waiter<*> -> throw IllegalStateException("Concurrent iteration")
92-
else -> true
93-
}
94-
}
95-
96-
suspend operator fun next(): V {
49+
@Suppress("UNCHECKED_CAST")
50+
suspend operator fun next(): V = suspendCoroutineOrReturn sc@ { c ->
9751
while (true) { // lock-free loop for rendezvous
9852
val cur = rendezvous.get()
9953
when (cur) {
100-
null -> return awaitNext()
101-
Complete -> throw NoSuchElementException()
102-
is Fail -> throw cur.e
54+
null -> if (rendezvous.compareAndSet(null, NextWaiter(c))) return@sc SUSPENDED
55+
Completed -> throw NoSuchElementException()
56+
is CompletedWith -> if (rendezvous.compareAndSet(cur, Completed)) return@sc cur.v as V
57+
is Error -> throw cur.e
10358
is Waiter<*> -> throw IllegalStateException("Concurrent iteration")
104-
else -> if (rendezvous.compareAndSet(cur, null)) return consumeValue(cur)
59+
else -> if (rendezvous.compareAndSet(cur, null)) return (cur as V).apply { subscriber.requestOne() }
10560
}
10661
}
10762
}
10863

109-
@Suppress("UNCHECKED_CAST")
110-
private fun consumeValue(cur: Any?): V {
111-
subscriber.requestOne()
112-
return cur as V
113-
}
114-
11564
internal inner class Sub : Subscriber<V>() {
11665
fun requestOne() {
11766
request(1)
@@ -125,9 +74,10 @@ class ObserverIterator<V : Any> {
12574
while (true) { // lock-free loop for rendezvous
12675
val cur = rendezvous.get()
12776
when (cur) {
128-
null -> if (rendezvous.compareAndSet(null, Fail(e))) return
129-
Complete -> throw IllegalStateException("onError after onComplete")
130-
is Fail -> throw IllegalStateException("onError after onError")
77+
null -> if (rendezvous.compareAndSet(null, Error(e))) return
78+
Completed -> throw IllegalStateException("onError after onCompleted")
79+
is CompletedWith -> throw IllegalStateException("onError after onCompleted")
80+
is Error -> throw IllegalStateException("onError after onError")
13181
is Waiter<*> -> if (rendezvous.compareAndSet(cur, null)) {
13282
cur.c.resumeWithException(e)
13383
return
@@ -143,8 +93,9 @@ class ObserverIterator<V : Any> {
14393
val cur = rendezvous.get()
14494
when (cur) {
14595
null -> if (rendezvous.compareAndSet(null, v)) return
146-
Complete -> throw IllegalStateException("onNext after onComplete")
147-
is Fail -> throw IllegalStateException("onNext after onError")
96+
Completed -> throw IllegalStateException("onNext after onCompleted")
97+
is CompletedWith -> throw IllegalStateException("onNext after onCompleted")
98+
is Error -> throw IllegalStateException("onNext after onError")
14899
is HasNextWaiter -> if (rendezvous.compareAndSet(cur, v)) {
149100
cur.c.resume(true)
150101
return
@@ -162,18 +113,19 @@ class ObserverIterator<V : Any> {
162113
while (true) { // lock-free loop for rendezvous
163114
val cur = rendezvous.get()
164115
when (cur) {
165-
null -> if (rendezvous.compareAndSet(null, Complete)) return
166-
Complete -> throw IllegalStateException("onComplete after onComplete")
167-
is Fail -> throw IllegalStateException("onComplete after onError")
168-
is HasNextWaiter -> if (rendezvous.compareAndSet(cur, Complete)) {
116+
null -> if (rendezvous.compareAndSet(null, Completed)) return
117+
Completed -> throw IllegalStateException("onCompleted after onCompleted")
118+
is CompletedWith -> throw IllegalStateException("onCompleted after onCompleted")
119+
is Error -> throw IllegalStateException("onCompleted after onError")
120+
is HasNextWaiter -> if (rendezvous.compareAndSet(cur, Completed)) {
169121
cur.c.resume(false)
170122
return
171123
}
172-
is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, Complete)) {
124+
is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, Completed)) {
173125
cur.c.resumeWithException(NoSuchElementException())
174126
return
175127
}
176-
else -> throw IllegalStateException("onComplete after onNext before request(1) was called")
128+
else -> if (rendezvous.compareAndSet(cur, CompletedWith(cur))) return
177129
}
178130
}
179131
}

kotlinx-coroutines-rx/src/test/kotlin/AsyncRxTest.kt

+21-4
Original file line numberDiff line numberDiff line change
@@ -130,17 +130,34 @@ class AsyncRxTest {
130130
@Test
131131
fun testAsyncIterator() {
132132
val observable = asyncRx {
133-
val sb = StringBuilder()
134-
for (s in Observable.just("O", "K"))
135-
sb.append(s)
136-
sb.toString()
133+
var result = ""
134+
for (s in Observable.just("O", "K")) {
135+
result += s
136+
}
137+
result
137138
}
138139

139140
checkObservableWithSingleValue(observable) {
140141
assertEquals("OK", it)
141142
}
142143
}
143144

145+
@Test
146+
fun testAsyncIteratorException() {
147+
val observable = asyncRx {
148+
var result = ""
149+
for (s in Observable.error<String>(RuntimeException("OK"))) {
150+
result += s
151+
}
152+
result
153+
}
154+
155+
checkErroneousObservable(observable) {
156+
assert(it is RuntimeException)
157+
assertEquals("OK", it.message)
158+
}
159+
}
160+
144161
private fun checkErroneousObservable(
145162
observable: Observable<*>,
146163
checker: (Throwable) -> Unit

0 commit comments

Comments
 (0)