Skip to content

Commit af62397

Browse files
committed
Optimized version of asyncRx iterator
1 parent 9216705 commit af62397

File tree

1 file changed

+77
-76
lines changed

1 file changed

+77
-76
lines changed

kotlinx-coroutines-rx/src/main/kotlin/asyncRxIterator.kt

+77-76
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@ package kotlinx.coroutines
22

33
import rx.Observable
44
import rx.Subscriber
5-
import java.util.concurrent.atomic.AtomicReference
5+
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater
66
import kotlin.coroutines.Continuation
7-
import kotlin.coroutines.CoroutineIntrinsics.SUSPENDED
8-
import kotlin.coroutines.CoroutineIntrinsics.suspendCoroutineOrReturn
7+
import kotlin.coroutines.CoroutineIntrinsics
98

109
/**
11-
* Suspending iteration extension. It does not have its own buffer and applies back-pressure to the source.
12-
* If iterating coroutine does not have a dispatcher with its own thread, then the iterating coroutine
13-
* is resumed in the thread that invokes [Subscriber.onNext].
10+
* Suspending iteration extension. It does not have its own buffer and works by arranging rendezvous between
11+
* producer and consumer. It applies back-pressure to the producer as needed. If iterating coroutine does not have a
12+
* dispatcher with its own thread, then the iterating coroutine is resumed and works in the thread that governs
13+
* producer observable.
1414
*/
15-
suspend operator fun <V : Any> Observable<V>.iterator(): ObserverIterator<V> {
16-
val iterator = ObserverIterator<V>()
17-
subscribe(iterator.subscriber)
15+
public suspend operator fun <V : Any> Observable<V>.iterator(): ObservableIterator<V> {
16+
val iterator = ObservableIterator<V>()
17+
subscribe(iterator)
1818
return iterator
1919
}
2020

@@ -26,17 +26,24 @@ private object Completed
2626
private class CompletedWith(val v: Any)
2727
private class Error(val e: Throwable)
2828

29-
class ObserverIterator<V : Any> {
30-
internal val subscriber = Sub()
29+
public class ObservableIterator<V : Any> : Subscriber<V>() {
3130
// Contains either null, Completed, CompletedWith, Error(exception), Waiter, or next value
32-
private val rendezvous = AtomicReference<Any?>()
31+
@Volatile
32+
private var rendezvous: Any? = null
33+
34+
companion object {
35+
private val RENDEZVOUS_UPDATER = AtomicReferenceFieldUpdater
36+
.newUpdater(ObservableIterator::class.java, Any::class.java, "rendezvous")
37+
}
38+
39+
private fun cas(expect: Any?, update: Any?) = RENDEZVOUS_UPDATER.compareAndSet(this, expect, update)
3340

3441
@Suppress("UNCHECKED_CAST")
35-
suspend operator fun hasNext(): Boolean = suspendCoroutineOrReturn sc@ { c ->
42+
public suspend operator fun hasNext(): Boolean = CoroutineIntrinsics.suspendCoroutineOrReturn sc@ { c ->
3643
while (true) { // lock-free loop for rendezvous
37-
val cur = rendezvous.get()
44+
val cur = rendezvous
3845
when (cur) {
39-
null -> if (rendezvous.compareAndSet(null, HasNextWaiter(c))) return@sc SUSPENDED
46+
null -> if (cas(null, HasNextWaiter(c))) return@sc CoroutineIntrinsics.SUSPENDED
4047
Completed -> return@sc false
4148
is CompletedWith -> return@sc true
4249
is Error -> throw cur.e
@@ -47,86 +54,80 @@ class ObserverIterator<V : Any> {
4754
}
4855

4956
@Suppress("UNCHECKED_CAST")
50-
suspend operator fun next(): V = suspendCoroutineOrReturn sc@ { c ->
57+
public suspend operator fun next(): V = CoroutineIntrinsics.suspendCoroutineOrReturn sc@ { c ->
5158
while (true) { // lock-free loop for rendezvous
52-
val cur = rendezvous.get()
59+
val cur = rendezvous
5360
when (cur) {
54-
null -> if (rendezvous.compareAndSet(null, NextWaiter(c))) return@sc SUSPENDED
61+
null -> if (cas(null, NextWaiter(c))) return@sc CoroutineIntrinsics.SUSPENDED
5562
Completed -> throw NoSuchElementException()
56-
is CompletedWith -> if (rendezvous.compareAndSet(cur, Completed)) return@sc cur.v as V
63+
is CompletedWith -> if (cas(cur, Completed)) return@sc cur.v as V
5764
is Error -> throw cur.e
5865
is Waiter<*> -> throw IllegalStateException("Concurrent iteration")
59-
else -> if (rendezvous.compareAndSet(cur, null)) return (cur as V).apply { subscriber.requestOne() }
66+
else -> if (cas(cur, null)) return (cur as V).apply { request(1) }
6067
}
6168
}
6269
}
6370

64-
internal inner class Sub : Subscriber<V>() {
65-
fun requestOne() {
66-
request(1)
67-
}
68-
69-
override fun onStart() {
70-
requestOne()
71-
}
71+
public override fun onStart() {
72+
request(1)
73+
}
7274

73-
override fun onError(e: Throwable) {
74-
while (true) { // lock-free loop for rendezvous
75-
val cur = rendezvous.get()
76-
when (cur) {
77-
null -> if (rendezvous.compareAndSet(null, Error(e))) return
78-
Completed -> throw IllegalStateException("onError after onCompleted")
79-
is CompletedWith -> throw IllegalStateException("onError after onCompleted")
80-
is Error -> throw IllegalStateException("onError after onError")
81-
is Waiter<*> -> if (rendezvous.compareAndSet(cur, null)) {
82-
cur.c.resumeWithException(e)
83-
return
84-
}
85-
else -> throw IllegalStateException("onError after onNext before request(1) was called")
75+
public override fun onError(e: Throwable) {
76+
while (true) { // lock-free loop for rendezvous
77+
val cur = rendezvous
78+
when (cur) {
79+
null -> if (cas(null, Error(e))) return
80+
Completed -> throw IllegalStateException("onError after onCompleted")
81+
is CompletedWith -> throw IllegalStateException("onError after onCompleted")
82+
is Error -> throw IllegalStateException("onError after onError")
83+
is Waiter<*> -> if (cas(cur, null)) {
84+
cur.c.resumeWithException(e)
85+
return
8686
}
87+
else -> throw IllegalStateException("onError after onNext before request(1) was called")
8788
}
8889
}
90+
}
8991

90-
@Suppress("UNCHECKED_CAST")
91-
override fun onNext(v: V) {
92-
while (true) { // lock-free loop for rendezvous
93-
val cur = rendezvous.get()
94-
when (cur) {
95-
null -> if (rendezvous.compareAndSet(null, v)) return
96-
Completed -> throw IllegalStateException("onNext after onCompleted")
97-
is CompletedWith -> throw IllegalStateException("onNext after onCompleted")
98-
is Error -> throw IllegalStateException("onNext after onError")
99-
is HasNextWaiter -> if (rendezvous.compareAndSet(cur, v)) {
100-
cur.c.resume(true)
101-
return
102-
}
103-
is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, null)) {
104-
(cur as NextWaiter<V>).c.resume(v)
105-
return
106-
}
107-
else -> throw IllegalStateException("onNext after onNext before request(1) was called")
92+
@Suppress("UNCHECKED_CAST")
93+
public override fun onNext(v: V) {
94+
while (true) { // lock-free loop for rendezvous
95+
val cur = rendezvous
96+
when (cur) {
97+
null -> if (cas(null, v)) return
98+
Completed -> throw IllegalStateException("onNext after onCompleted")
99+
is CompletedWith -> throw IllegalStateException("onNext after onCompleted")
100+
is Error -> throw IllegalStateException("onNext after onError")
101+
is HasNextWaiter -> if (cas(cur, v)) {
102+
cur.c.resume(true)
103+
return
104+
}
105+
is NextWaiter<*> -> if (cas(cur, null)) {
106+
(cur as NextWaiter<V>).c.resume(v)
107+
return
108108
}
109+
else -> throw IllegalStateException("onNext after onNext before request(1) was called")
109110
}
110111
}
112+
}
111113

112-
override fun onCompleted() {
113-
while (true) { // lock-free loop for rendezvous
114-
val cur = rendezvous.get()
115-
when (cur) {
116-
null -> if (rendezvous.compareAndSet(null, Completed)) return
117-
Completed -> throw IllegalStateException("onCompleted after onCompleted")
118-
is CompletedWith -> throw IllegalStateException("onCompleted after onCompleted")
119-
is Error -> throw IllegalStateException("onCompleted after onError")
120-
is HasNextWaiter -> if (rendezvous.compareAndSet(cur, Completed)) {
121-
cur.c.resume(false)
122-
return
123-
}
124-
is NextWaiter<*> -> if (rendezvous.compareAndSet(cur, Completed)) {
125-
cur.c.resumeWithException(NoSuchElementException())
126-
return
127-
}
128-
else -> if (rendezvous.compareAndSet(cur, CompletedWith(cur))) return
114+
public override fun onCompleted() {
115+
while (true) { // lock-free loop for rendezvous
116+
val cur = rendezvous
117+
when (cur) {
118+
null -> if (cas(null, Completed)) return
119+
Completed -> throw IllegalStateException("onCompleted after onCompleted")
120+
is CompletedWith -> throw IllegalStateException("onCompleted after onCompleted")
121+
is Error -> throw IllegalStateException("onCompleted after onError")
122+
is HasNextWaiter -> if (cas(cur, Completed)) {
123+
cur.c.resume(false)
124+
return
125+
}
126+
is NextWaiter<*> -> if (cas(cur, Completed)) {
127+
cur.c.resumeWithException(NoSuchElementException())
128+
return
129129
}
130+
else -> if (cas(cur, CompletedWith(cur))) return
130131
}
131132
}
132133
}

0 commit comments

Comments
 (0)