Skip to content

yieldAll through nested iterators #8

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 110 additions & 10 deletions kotlinx-coroutines-generate/src/main/kotlin/generate.kt
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
package kotlinx.coroutines

import java.util.*

/**
* Creates a Sequence object based on received coroutine [c].
*
* Each call of 'yield' suspend function within the coroutine lambda generates
* next element of resulting sequence.
* Each call of yield suspend function within the coroutine lambda generates
* next element of resulting sequence. Calling yieldAll suspend function can be used to
* yield sequence of values efficiently.
*/
fun <T> generate(
coroutine c: GeneratorController<T>.() -> Continuation<Unit>
): Sequence<T> =
object : Sequence<T> {
override fun iterator(): Iterator<T> {
fun <T> generate(coroutine c: GeneratorController<T>.() -> Continuation<Unit>): Sequence<T> =
object : NestedIterable<T> {
override fun nestedIterator(): NestedIterator<T> {
val iterator = GeneratorController<T>()
iterator.setNextStep(c(iterator))
iterator.setNextStep(iterator.c())
return iterator
}
}

class GeneratorController<T> internal constructor() : AbstractIterator<T>() {
class GeneratorController<T> internal constructor() : NestedIterator<T>() {
private lateinit var nextStep: Continuation<Unit>

override fun computeNext() {
override fun computeNextItemOrIterator() {
nextStep.resume(Unit)
}

Expand All @@ -33,7 +34,106 @@ class GeneratorController<T> internal constructor() : AbstractIterator<T>() {
setNextStep(c)
}

private fun yieldFromIterator(iterator: Iterator<T>, c: Continuation<Unit>) {
setNextIterator(iterator)
setNextStep(c)
}

suspend fun yieldAll(values: Sequence<T>, c: Continuation<Unit>) {
yieldFromIterator(if (values is NestedIterable<T>)
values.nestedIterator() else
values.iterator(), c)
}

suspend fun yieldAll(values: Iterable<T>, c: Continuation<Unit>) {
yieldFromIterator(values.iterator(), c)
}

suspend fun yieldAll(vararg values: T, c: Continuation<Unit>) {
yieldFromIterator(values.iterator(), c)
}

operator fun handleResult(result: Unit, c: Continuation<Nothing>) {
done()
}
}

/**
* Extends [AbstractIterator] adding it the ability to produce a next nested iterator instead of a next item.
* If [hasNext] is true, then either a next item or a next nested iterator is produced.
*
* If [nextNestedIterator] != null, the consumer should use the items from [nextNestedIterator]
* and then continue iterating over this iterator. Also if a next nested iterator is produced,
* the value returned by [next] should not be used as it is invalid, but [next] should still be called.
* Therefore [NestedIterator] cannot be used as a normal iterator.
*/
abstract class NestedIterator<T> internal constructor() : AbstractIterator<T>() {
/**
* Either `null` or the iterator that should be used before continuing iteration over this iterator.
* If [hasNext] is true and nextNestedIterator is null, then there's no next iterator.
*/
internal var nextNestedIterator: Iterator<T>? = null
private set

final override fun computeNext(): Unit {
nextNestedIterator = null
computeNextItemOrIterator()
}

/**
* Computes the next item or a next nested iterator of this iterator.
*
* This callback method should call one of these three methods:
*
* * [setNext] with the next value of the iteration
* * [setNextIterator] with the next nested iterator of the iteration
* * [done] to indicate there are no more elements
*
* Failure to call either method will result in the iteration terminating with a failed state
*/
internal abstract fun computeNextItemOrIterator()

protected fun setNextIterator(iterator: Iterator<T>) {
nextNestedIterator = iterator
setNext(null as T) //state transfer to Ready
}
}

/**
* Manages [NestedIterator]s, arranging them in a stack and switching between them as they produce next
* nested iterators or finish.
* The concept is taken from
* [this article](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/specsharp-iterators.pdf).
*/
private class RootIterator<T> constructor(iterator: Iterator<T>) : AbstractIterator<T>() {
private val stack = Stack<Iterator<T>>().apply { push(iterator) }

override fun computeNext() {
while (true) {
if (stack.isEmpty()) {
done()
return
}
val i = stack.peek()
if (!i.hasNext()) {
stack.pop()
} else {
if (i is NestedIterator<T> && i.nextNestedIterator != null) {
stack.push(i.nextNestedIterator)
i.next() //state transfer to NotReady
} else {
setNext(i.next())
return
}
}
}
}
}

/**
* Default implementation of [Sequence] for [NestedIterator] providers.
*/
private interface NestedIterable<T> : Sequence<T> {
fun nestedIterator(): NestedIterator<T>
override fun iterator(): Iterator<T> = RootIterator(nestedIterator())
}
77 changes: 76 additions & 1 deletion kotlinx-coroutines-generate/src/test/kotlin/GenerateTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,79 @@ class GenerateTest {

assertEquals(listOf(Pair(1, 2), Pair(6, 8), Pair(15, 18)), result.zip(result).toList())
}
}

@Test
fun testYieldAll() {
val sequence = generate<Int> {
yield(1)
yieldAll(1, 2)
yieldAll(listOf(1, 2))
yieldAll(sequenceOf(1, 2))
yield(3)
}

assertEquals(listOf(1, 1, 2, 1, 2, 1, 2, 3), sequence.toList())
}

@Test
fun testYieldsAllContinuation() {
var continuationCalled = false
val sequence = generate<Int> {
yieldAll(sequenceOf(1, 2, 3))
continuationCalled = true
yield(4)
}
val iterator = sequence.iterator()
assertEquals(1, iterator.next())
assertEquals(2, iterator.next())
assertEquals(3, iterator.next())
assertFalse(continuationCalled)
assertEquals(4, iterator.next())
assertTrue(continuationCalled)
assertFalse(iterator.hasNext())
}

@Test
fun testYieldAllEmpty() {
var continuationCalled = false
val sequence = generate<Int> {
yieldAll(emptyList())
continuationCalled = true
yield(1)
}
val iterator = sequence.iterator()
assertEquals(1, iterator.next())
assertTrue(continuationCalled)
}

@Test
fun testYieldAllHasNext() {
val sequence = generate<Int> {
yieldAll(1, 2, 3)
}
val iterator = sequence.iterator()
assertEquals(1, iterator.next())
assertEquals(2, iterator.next())
assertEquals(3, iterator.next())
assertFalse(iterator.hasNext())
}

@Test
fun testFromToAndBack() {
fun fromToAndBack(from: Int, to: Int): Sequence<Int> = generate {
if (from > to)
return@generate
yield(from)
yieldAll(fromToAndBack(from + 1, to))
yield(from)
}

val sequence = fromToAndBack(1, 10000)
val list1 = sequence.toList()
val list2 = sequence.toList() //test repeated call

val expected = (1..10000).toList() + (1..10000).reversed().toList()
assertEquals(expected, list1)
assertEquals(expected, list2)
}
}