diff --git a/kotlinx-coroutines-generate/src/main/kotlin/generate.kt b/kotlinx-coroutines-generate/src/main/kotlin/generate.kt index ca8366227d..944defc386 100644 --- a/kotlinx-coroutines-generate/src/main/kotlin/generate.kt +++ b/kotlinx-coroutines-generate/src/main/kotlin/generate.kt @@ -1,26 +1,27 @@ package kotlinx.coroutines +import java.util.* + /** * Creates a Sequence object based on received coroutine [c]. * - * Each call of 'yield' suspend function within the coroutine lambda generates - * next element of resulting sequence. + * Each call of yield suspend function within the coroutine lambda generates + * next element of resulting sequence. Calling yieldAll suspend function can be used to + * yield sequence of values efficiently. */ -fun generate( - coroutine c: GeneratorController.() -> Continuation -): Sequence = - object : Sequence { - override fun iterator(): Iterator { +fun generate(coroutine c: GeneratorController.() -> Continuation): Sequence = + object : NestedIterable { + override fun nestedIterator(): NestedIterator { val iterator = GeneratorController() - iterator.setNextStep(c(iterator)) + iterator.setNextStep(iterator.c()) return iterator } } -class GeneratorController internal constructor() : AbstractIterator() { +class GeneratorController internal constructor() : NestedIterator() { private lateinit var nextStep: Continuation - override fun computeNext() { + override fun computeNextItemOrIterator() { nextStep.resume(Unit) } @@ -33,7 +34,106 @@ class GeneratorController internal constructor() : AbstractIterator() { setNextStep(c) } + private fun yieldFromIterator(iterator: Iterator, c: Continuation) { + setNextIterator(iterator) + setNextStep(c) + } + + suspend fun yieldAll(values: Sequence, c: Continuation) { + yieldFromIterator(if (values is NestedIterable) + values.nestedIterator() else + values.iterator(), c) + } + + suspend fun yieldAll(values: Iterable, c: Continuation) { + yieldFromIterator(values.iterator(), c) + } + + suspend fun yieldAll(vararg values: T, c: Continuation) { + yieldFromIterator(values.iterator(), c) + } + operator fun handleResult(result: Unit, c: Continuation) { done() } } + +/** + * Extends [AbstractIterator] adding it the ability to produce a next nested iterator instead of a next item. + * If [hasNext] is true, then either a next item or a next nested iterator is produced. + * + * If [nextNestedIterator] != null, the consumer should use the items from [nextNestedIterator] + * and then continue iterating over this iterator. Also if a next nested iterator is produced, + * the value returned by [next] should not be used as it is invalid, but [next] should still be called. + * Therefore [NestedIterator] cannot be used as a normal iterator. + */ +abstract class NestedIterator internal constructor() : AbstractIterator() { + /** + * Either `null` or the iterator that should be used before continuing iteration over this iterator. + * If [hasNext] is true and nextNestedIterator is null, then there's no next iterator. + */ + internal var nextNestedIterator: Iterator? = null + private set + + final override fun computeNext(): Unit { + nextNestedIterator = null + computeNextItemOrIterator() + } + + /** + * Computes the next item or a next nested iterator of this iterator. + * + * This callback method should call one of these three methods: + * + * * [setNext] with the next value of the iteration + * * [setNextIterator] with the next nested iterator of the iteration + * * [done] to indicate there are no more elements + * + * Failure to call either method will result in the iteration terminating with a failed state + */ + internal abstract fun computeNextItemOrIterator() + + protected fun setNextIterator(iterator: Iterator) { + nextNestedIterator = iterator + setNext(null as T) //state transfer to Ready + } +} + +/** + * Manages [NestedIterator]s, arranging them in a stack and switching between them as they produce next + * nested iterators or finish. + * The concept is taken from + * [this article](https://www.microsoft.com/en-us/research/wp-content/uploads/2016/02/specsharp-iterators.pdf). + */ +private class RootIterator constructor(iterator: Iterator) : AbstractIterator() { + private val stack = Stack>().apply { push(iterator) } + + override fun computeNext() { + while (true) { + if (stack.isEmpty()) { + done() + return + } + val i = stack.peek() + if (!i.hasNext()) { + stack.pop() + } else { + if (i is NestedIterator && i.nextNestedIterator != null) { + stack.push(i.nextNestedIterator) + i.next() //state transfer to NotReady + } else { + setNext(i.next()) + return + } + } + } + } +} + +/** + * Default implementation of [Sequence] for [NestedIterator] providers. + */ +private interface NestedIterable : Sequence { + fun nestedIterator(): NestedIterator + override fun iterator(): Iterator = RootIterator(nestedIterator()) +} \ No newline at end of file diff --git a/kotlinx-coroutines-generate/src/test/kotlin/GenerateTest.kt b/kotlinx-coroutines-generate/src/test/kotlin/GenerateTest.kt index 539edfc4c0..10d98cd7f6 100644 --- a/kotlinx-coroutines-generate/src/test/kotlin/GenerateTest.kt +++ b/kotlinx-coroutines-generate/src/test/kotlin/GenerateTest.kt @@ -147,4 +147,79 @@ class GenerateTest { assertEquals(listOf(Pair(1, 2), Pair(6, 8), Pair(15, 18)), result.zip(result).toList()) } -} + + @Test + fun testYieldAll() { + val sequence = generate { + yield(1) + yieldAll(1, 2) + yieldAll(listOf(1, 2)) + yieldAll(sequenceOf(1, 2)) + yield(3) + } + + assertEquals(listOf(1, 1, 2, 1, 2, 1, 2, 3), sequence.toList()) + } + + @Test + fun testYieldsAllContinuation() { + var continuationCalled = false + val sequence = generate { + yieldAll(sequenceOf(1, 2, 3)) + continuationCalled = true + yield(4) + } + val iterator = sequence.iterator() + assertEquals(1, iterator.next()) + assertEquals(2, iterator.next()) + assertEquals(3, iterator.next()) + assertFalse(continuationCalled) + assertEquals(4, iterator.next()) + assertTrue(continuationCalled) + assertFalse(iterator.hasNext()) + } + + @Test + fun testYieldAllEmpty() { + var continuationCalled = false + val sequence = generate { + yieldAll(emptyList()) + continuationCalled = true + yield(1) + } + val iterator = sequence.iterator() + assertEquals(1, iterator.next()) + assertTrue(continuationCalled) + } + + @Test + fun testYieldAllHasNext() { + val sequence = generate { + yieldAll(1, 2, 3) + } + val iterator = sequence.iterator() + assertEquals(1, iterator.next()) + assertEquals(2, iterator.next()) + assertEquals(3, iterator.next()) + assertFalse(iterator.hasNext()) + } + + @Test + fun testFromToAndBack() { + fun fromToAndBack(from: Int, to: Int): Sequence = generate { + if (from > to) + return@generate + yield(from) + yieldAll(fromToAndBack(from + 1, to)) + yield(from) + } + + val sequence = fromToAndBack(1, 10000) + val list1 = sequence.toList() + val list2 = sequence.toList() //test repeated call + + val expected = (1..10000).toList() + (1..10000).reversed().toList() + assertEquals(expected, list1) + assertEquals(expected, list2) + } +} \ No newline at end of file