Skip to content

Commit 9cd934d

Browse files
committed
Replace generate implementaion
The new one is more efficient and contains `yieldAll`
1 parent d794c9d commit 9cd934d

File tree

2 files changed

+215
-26
lines changed

2 files changed

+215
-26
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,107 @@
11
package kotlinx.coroutines
22

3-
import kotlin.coroutines.*
4-
import kotlin.coroutines.CoroutineIntrinsics.SUSPENDED
3+
import kotlin.coroutines.Continuation
4+
import kotlin.coroutines.CoroutineIntrinsics
5+
import kotlin.coroutines.RestrictsSuspension
6+
import kotlin.coroutines.createCoroutine
57

68
/**
7-
* Creates a Sequence object based on received coroutine [c].
8-
*
9-
* Each call of 'yield' suspend function within the coroutine lambda generates
10-
* next element of resulting sequence.
9+
* Scope of [generate] block.
1110
*/
12-
interface Generator<in T> {
13-
suspend fun yield(value: T)
14-
}
11+
@RestrictsSuspension
12+
public abstract class Generator<in T> internal constructor() {
13+
/**
14+
* Yields a value in [generate] block.
15+
*/
16+
public abstract suspend fun yield(value: T)
17+
18+
/**
19+
* Yields potentially infinite sequence of iterator values in [generate] block.
20+
*/
21+
public abstract suspend fun yieldAll(iterator: Iterator<T>)
22+
23+
/**
24+
* Yields a collections of values in [generate] block.
25+
*/
26+
public suspend fun yieldAll(elements: Iterable<T>) = yieldAll(elements.iterator())
1527

16-
fun <T> generate(block: suspend Generator<T>.() -> Unit): Sequence<T> = GeneratedSequence(block)
28+
/**
29+
* Yields potentially infinite sequence of values in [generate] block.
30+
*/
31+
public suspend fun yieldAll(sequence: Sequence<T>) = yieldAll(sequence.iterator())
32+
}
1733

18-
private class GeneratedSequence<out T>(private val block: suspend Generator<T>.() -> Unit) : Sequence<T> {
19-
override fun iterator(): Iterator<T> = GeneratedIterator(block)
34+
/**
35+
* Generates lazy sequence.
36+
*/
37+
public fun <T> generate(block: suspend Generator<T>.() -> Unit): Sequence<T> = object : Sequence<T> {
38+
override fun iterator(): Iterator<T> {
39+
val iterator = GeneratorIterator<T>()
40+
iterator.nextStep = block.createCoroutine(receiver = iterator, completion = iterator)
41+
return iterator
42+
}
2043
}
2144

22-
private class GeneratedIterator<T>(block: suspend Generator<T>.() -> Unit) : AbstractIterator<T>(), Generator<T> {
23-
private var nextStep: Continuation<Unit> = block.createCoroutine(this, object : Continuation<Unit> {
24-
override fun resume(data: Unit) {
25-
done()
45+
private class GeneratorIterator<T>: Generator<T>(), Iterator<T>, Continuation<Unit> {
46+
var computedNext = false
47+
var nextStep: Continuation<Unit>? = null
48+
var nextValue: T? = null
49+
50+
override fun hasNext(): Boolean {
51+
if (!computedNext) {
52+
val step = nextStep!!
53+
computedNext = true
54+
nextStep = null
55+
step.resume(Unit) // leaves it in "done" state if crashes
2656
}
57+
return nextStep != null
58+
}
2759

28-
override fun resumeWithException(exception: Throwable) {
29-
throw exception
60+
override fun next(): T {
61+
if (!hasNext()) throw NoSuchElementException()
62+
computedNext = false
63+
return nextValue as T
64+
}
65+
66+
// Completion continuation implementation
67+
override fun resume(value: Unit) {
68+
// nothing to do here -- leave null in nextStep
69+
}
70+
71+
override fun resumeWithException(exception: Throwable) {
72+
throw exception // just rethrow
73+
}
74+
75+
// Generator implementation
76+
override suspend fun yield(value: T) {
77+
nextValue = value
78+
return CoroutineIntrinsics.suspendCoroutineOrReturn { c ->
79+
nextStep = c
80+
CoroutineIntrinsics.SUSPENDED
3081
}
31-
})
82+
}
3283

33-
override fun computeNext() {
34-
nextStep.resume(Unit)
84+
override suspend fun yieldAll(iterator: Iterator<T>) {
85+
if (!iterator.hasNext()) return
86+
nextValue = iterator.next()
87+
return CoroutineIntrinsics.suspendCoroutineOrReturn { c ->
88+
nextStep = IteratorContinuation(c, iterator)
89+
CoroutineIntrinsics.SUSPENDED
90+
}
3591
}
36-
suspend override fun yield(value: T) = CoroutineIntrinsics.suspendCoroutineOrReturn <Unit> { c ->
37-
setNext(value)
38-
nextStep = c
3992

40-
SUSPENDED
93+
inner class IteratorContinuation(val completion: Continuation<Unit>, val iterator: Iterator<T>) : Continuation<Unit> {
94+
override fun resume(value: Unit) {
95+
if (!iterator.hasNext()) {
96+
completion.resume(Unit)
97+
return
98+
}
99+
nextValue = iterator.next()
100+
nextStep = this
101+
}
102+
103+
override fun resumeWithException(exception: Throwable) {
104+
throw exception // just rethrow
105+
}
41106
}
42-
}
107+
}

kotlinx-coroutines-generate/src/test/kotlin/GenerateTest.kt

+124
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,128 @@ class GenerateTest {
147147

148148
assertEquals(listOf(Pair(1, 2), Pair(6, 8), Pair(15, 18)), result.zip(result).toList())
149149
}
150+
151+
@Test
152+
fun testYieldAllIterator() {
153+
val result = generate {
154+
yieldAll(listOf(1, 2, 3).iterator())
155+
}
156+
assertEquals(listOf(1, 2, 3), result.toList())
157+
}
158+
159+
@Test
160+
fun testYieldAllSequence() {
161+
val result = generate {
162+
yieldAll(sequenceOf(1, 2, 3))
163+
}
164+
assertEquals(listOf(1, 2, 3), result.toList())
165+
}
166+
167+
@Test
168+
fun testYieldAllCollection() {
169+
val result = generate {
170+
yieldAll(listOf(1, 2, 3))
171+
}
172+
assertEquals(listOf(1, 2, 3), result.toList())
173+
}
174+
175+
@Test
176+
fun testYieldAllCollectionMixedFirst() {
177+
val result = generate {
178+
yield(0)
179+
yieldAll(listOf(1, 2, 3))
180+
}
181+
assertEquals(listOf(0, 1, 2, 3), result.toList())
182+
}
183+
184+
@Test
185+
fun testYieldAllCollectionMixedLast() {
186+
val result = generate {
187+
yieldAll(listOf(1, 2, 3))
188+
yield(4)
189+
}
190+
assertEquals(listOf(1, 2, 3, 4), result.toList())
191+
}
192+
193+
@Test
194+
fun testYieldAllCollectionMixedBoth() {
195+
val result = generate {
196+
yield(0)
197+
yieldAll(listOf(1, 2, 3))
198+
yield(4)
199+
}
200+
assertEquals(listOf(0, 1, 2, 3, 4), result.toList())
201+
}
202+
203+
@Test
204+
fun testYieldAllCollectionMixedLong() {
205+
val result = generate {
206+
yield(0)
207+
yieldAll(listOf(1, 2, 3))
208+
yield(4)
209+
yield(5)
210+
yieldAll(listOf(6))
211+
yield(7)
212+
yieldAll(listOf())
213+
yield(8)
214+
}
215+
assertEquals(listOf(0, 1, 2, 3, 4, 5, 6, 7, 8), result.toList())
216+
}
217+
218+
@Test
219+
fun testYieldAllCollectionOneEmpty() {
220+
val result = generate<Int> {
221+
yieldAll(listOf())
222+
}
223+
assertEquals(listOf(), result.toList())
224+
}
225+
226+
@Test
227+
fun testYieldAllCollectionManyEmpty() {
228+
val result = generate<Int> {
229+
yieldAll(listOf())
230+
yieldAll(listOf())
231+
yieldAll(listOf())
232+
}
233+
assertEquals(listOf(), result.toList())
234+
}
235+
236+
@Test
237+
fun testYieldAllSideEffects() {
238+
val effects = arrayListOf<Any>()
239+
val result = generate {
240+
effects.add("a")
241+
yieldAll(listOf(1, 2))
242+
effects.add("b")
243+
yieldAll(listOf())
244+
effects.add("c")
245+
yieldAll(listOf(3))
246+
effects.add("d")
247+
yield(4)
248+
effects.add("e")
249+
yieldAll(listOf())
250+
effects.add("f")
251+
yield(5)
252+
}
253+
254+
for (res in result) {
255+
effects.add("(") // marks step start
256+
effects.add(res)
257+
effects.add(")") // marks step end
258+
}
259+
assertEquals(
260+
listOf(
261+
"a",
262+
"(", 1, ")",
263+
"(", 2, ")",
264+
"b", "c",
265+
"(", 3, ")",
266+
"d",
267+
"(", 4, ")",
268+
"e", "f",
269+
"(", 5, ")"
270+
),
271+
effects.toList()
272+
)
273+
}
150274
}

0 commit comments

Comments
 (0)