Skip to content

Commit 0ed0fd4

Browse files
committed
Add chunked and windowed operators
1 parent 835ed4d commit 0ed0fd4

File tree

4 files changed

+339
-0
lines changed

4 files changed

+339
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package kotlinx.coroutines.flow.operators
2+
3+
import kotlinx.coroutines.flow.Flow
4+
import kotlinx.coroutines.flow.collect
5+
import kotlinx.coroutines.flow.flow
6+
import kotlinx.coroutines.internal.RingBuffer
7+
import kotlin.math.max
8+
import kotlin.math.min
9+
10+
/**
11+
* Returns a flow of lists each not exceeding the given [size].
12+
*The last list in the resulting flow may have less elements than the given [size].
13+
*
14+
* @param size the number of elements to take in each list, must be positive and can be greater than the number of elements in this flow.
15+
*/
16+
fun <T> Flow<T>.chunked(size: Int): Flow<List<T>> = chunked(size) { it.toList() }
17+
18+
/**
19+
* Chunks a flow of elements into flow of lists, each not exceeding the given [size]
20+
* and applies the given [transform] function to an each.
21+
*
22+
* Note that the list passed to the [transform] function is ephemeral and is valid only inside that function.
23+
* You should not store it or allow it to escape in some way, unless you made a snapshot of it.
24+
* The last list may have less elements than the given [size].
25+
*
26+
* This is slightly faster, than using flow.chunked(n).map { ... }
27+
*
28+
* @param size the number of elements to take in each list, must be positive and can be greater than the number of elements in this flow.
29+
*/
30+
fun <T, R> Flow<T>.chunked(size: Int, transform: suspend (List<T>) -> R): Flow<R> {
31+
require(size > 0) { "Size should be greater than 0, but was $size" }
32+
return windowed(size, size, true, transform)
33+
}
34+
35+
/**
36+
* Returns a flow of snapshots of the window of the given [size]
37+
* sliding along this flow with the given [step], where each
38+
* snapshot is a list.
39+
*
40+
* Several last lists may have less elements than the given [size].
41+
*
42+
* Both [size] and [step] must be positive and can be greater than the number of elements in this flow.
43+
* @param size the number of elements to take in each window
44+
* @param step the number of elements to move the window forward by on an each step
45+
* @param partialWindows controls whether or not to keep partial windows in the end if any.
46+
*/
47+
fun <T> Flow<T>.windowed(size: Int, step: Int, partialWindows: Boolean): Flow<List<T>> =
48+
windowed(size, step, partialWindows) { it.toList() }
49+
50+
/**
51+
* Returns a flow of results of applying the given [transform] function to
52+
* an each list representing a view over the window of the given [size]
53+
* sliding along this collection with the given [step].
54+
*
55+
* Note that the list passed to the [transform] function is ephemeral and is valid only inside that function.
56+
* You should not store it or allow it to escape in some way, unless you made a snapshot of it.
57+
* Several last lists may have less elements than the given [size].
58+
*
59+
* This is slightly faster, than using flow.windowed(...).map { ... }
60+
*
61+
* Both [size] and [step] must be positive and can be greater than the number of elements in this collection.
62+
* @param size the number of elements to take in each window
63+
* @param step the number of elements to move the window forward by on an each step.
64+
* @param partialWindows controls whether or not to keep partial windows in the end if any.
65+
*/
66+
fun <T, R> Flow<T>.windowed(size: Int, step: Int, partialWindows: Boolean, transform: suspend (List<T>) -> R): Flow<R> {
67+
require(size > 0 && step > 0) { "Size and step should be greater than 0, but was size: $size, step: $step" }
68+
69+
return flow {
70+
val buffer = RingBuffer<T>(size)
71+
val toDrop = min(step, size)
72+
val toSkip = max(step - size, 0)
73+
var skipped = toSkip
74+
75+
collect { value ->
76+
if(toSkip == skipped) buffer.add(value)
77+
else skipped++
78+
79+
if (buffer.isFull()) {
80+
emit(transform(buffer))
81+
buffer.removeFirst(toDrop)
82+
skipped = 0
83+
}
84+
}
85+
86+
while (partialWindows && buffer.isNotEmpty()) {
87+
emit(transform(buffer))
88+
buffer.removeFirst(min(toDrop, buffer.size))
89+
}
90+
}
91+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package kotlinx.coroutines.internal
2+
3+
internal class RingBuffer<T>(val capacity: Int) : AbstractList<T>(), RandomAccess {
4+
init {
5+
require(capacity >= 0) { "ring buffer capacity should not be negative but it is $capacity" }
6+
}
7+
8+
private val buffer = arrayOfNulls<Any?>(capacity)
9+
private var startIndex: Int = 0
10+
11+
override var size: Int = 0
12+
private set
13+
14+
override fun get(index: Int): T {
15+
require(index in 0 until size)
16+
@Suppress("UNCHECKED_CAST")
17+
return buffer[startIndex.forward(index)] as T
18+
}
19+
20+
fun isFull() = size == capacity
21+
22+
override fun iterator(): Iterator<T> = object : AbstractIterator<T>() {
23+
private var count = size
24+
private var index = startIndex
25+
26+
override fun computeNext() {
27+
if (count == 0) {
28+
done()
29+
} else {
30+
@Suppress("UNCHECKED_CAST")
31+
setNext(buffer[index] as T)
32+
index = index.forward(1)
33+
count--
34+
}
35+
}
36+
}
37+
38+
@Suppress("UNCHECKED_CAST")
39+
override fun <T> toArray(array: Array<T>): Array<T> {
40+
val result: Array<T?> =
41+
if (array.size < this.size) array.copyOf(this.size) else array as Array<T?>
42+
43+
val size = this.size
44+
45+
var widx = 0
46+
var idx = startIndex
47+
48+
while (widx < size && idx < capacity) {
49+
result[widx] = buffer[idx] as T
50+
widx++
51+
idx++
52+
}
53+
54+
idx = 0
55+
while (widx < size) {
56+
result[widx] = buffer[idx] as T
57+
widx++
58+
idx++
59+
}
60+
if (result.size > this.size) result[this.size] = null
61+
62+
return result as Array<T>
63+
}
64+
65+
override fun toArray(): Array<Any?> {
66+
return toArray(arrayOfNulls(size))
67+
}
68+
69+
/**
70+
* Add [element] to the buffer or fail with [IllegalStateException] if no free space available in the buffer
71+
*/
72+
fun add(element: T) {
73+
if (isFull()) {
74+
throw IllegalStateException("ring buffer is full")
75+
}
76+
77+
buffer[startIndex.forward(size)] = element
78+
size++
79+
}
80+
81+
/**
82+
* Removes [n] first elements from the buffer or fails with [IllegalArgumentException] if not enough elements in the buffer to remove
83+
*/
84+
fun removeFirst(n: Int) {
85+
require(n >= 0) { "n shouldn't be negative but it is $n" }
86+
require(n <= size) { "n shouldn't be greater than the buffer size: n = $n, size = $size" }
87+
88+
if (n > 0) {
89+
val start = startIndex
90+
val end = start.forward(n)
91+
92+
if (start > end) {
93+
buffer.fill(null, start, capacity)
94+
buffer.fill(null, 0, end)
95+
} else {
96+
buffer.fill(null, start, end)
97+
}
98+
99+
startIndex = end
100+
size -= n
101+
}
102+
}
103+
104+
105+
@Suppress("NOTHING_TO_INLINE")
106+
private inline fun Int.forward(n: Int): Int = (this + n) % capacity
107+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package kotlinx.coroutines.flow.operators
2+
3+
import kotlinx.coroutines.*
4+
import kotlinx.coroutines.channels.Channel
5+
import kotlinx.coroutines.flow.*
6+
import kotlin.test.Test
7+
import kotlin.test.assertEquals
8+
9+
class ChunkedTest : TestBase() {
10+
11+
private val flow = flow {
12+
emit(1)
13+
emit(2)
14+
emit(3)
15+
emit(4)
16+
}
17+
18+
@Test
19+
fun `Chunks correct number of emissions with possible partial window at the end`() = runTest {
20+
assertEquals(2, flow.chunked(2).count())
21+
assertEquals(2, flow.chunked(3).count())
22+
assertEquals(1, flow.chunked(5).count())
23+
}
24+
25+
@Test
26+
fun `Throws IllegalArgumentException for chunk of size less than 1`() {
27+
assertFailsWith<IllegalArgumentException> { flow.chunked(0) }
28+
assertFailsWith<IllegalArgumentException> { flow.chunked(-1) }
29+
}
30+
31+
@Test
32+
fun `No emissions with empty flow`() = runTest {
33+
assertEquals(0, flowOf<Int>().chunked(2).count())
34+
}
35+
36+
@Test
37+
fun testErrorCancelsUpstream() = runTest {
38+
val latch = Channel<Unit>()
39+
val flow = flow {
40+
coroutineScope {
41+
launch(start = CoroutineStart.ATOMIC) {
42+
latch.send(Unit)
43+
hang { expect(3) }
44+
}
45+
emit(1)
46+
expect(1)
47+
emit(2)
48+
expectUnreached()
49+
}
50+
}.chunked<Int, Int>(2) { chunk ->
51+
expect(2) // 2
52+
latch.receive()
53+
throw TestException()
54+
}.catch { emit(42) }
55+
56+
assertEquals(42, flow.single())
57+
finish(4)
58+
}
59+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package kotlinx.coroutines.flow.operators
2+
3+
import kotlinx.coroutines.*
4+
import kotlinx.coroutines.channels.Channel
5+
import kotlinx.coroutines.flow.*
6+
import kotlin.test.Test
7+
import kotlin.test.assertEquals
8+
9+
class WindowedTest : TestBase() {
10+
11+
private val flow = flow {
12+
emit(1)
13+
emit(2)
14+
emit(3)
15+
emit(4)
16+
}
17+
18+
@Test
19+
fun `Throws IllegalArgumentException for window of size or step less than 1`() {
20+
assertFailsWith<IllegalArgumentException> { flow.windowed(0, 1, false) }
21+
assertFailsWith<IllegalArgumentException> { flow.windowed(-1, 2, false) }
22+
assertFailsWith<IllegalArgumentException> { flow.windowed(2, 0, false) }
23+
assertFailsWith<IllegalArgumentException> { flow.windowed(5, -2, false) }
24+
}
25+
26+
@Test
27+
fun `No emissions with empty flow`() = runTest {
28+
assertEquals(0, flowOf<Int>().windowed(2, 2, false).count())
29+
}
30+
31+
@Test
32+
fun `Emits correct sum with overlapping non partial windows`() = runTest {
33+
assertEquals(15, flow.windowed(3, 1, false) { window ->
34+
window.sum()
35+
}.sum())
36+
}
37+
38+
@Test
39+
fun `Emits correct sum with overlapping partial windows`() = runTest {
40+
assertEquals(13, flow.windowed(3, 2, true) { window ->
41+
window.sum()
42+
}.sum())
43+
}
44+
45+
@Test
46+
fun `Emits correct number of overlapping windows for long sequence of overlapping partial windows`() = runTest {
47+
val elements = generateSequence(1) { it + 1 }.take(100)
48+
val flow = elements.asFlow().windowed(100, 1, true) { }
49+
assertEquals(100, flow.count())
50+
}
51+
52+
@Test
53+
fun `Emits correct sum with partial windows set apart`() = runTest {
54+
assertEquals(7, flow.windowed(2, 3, true) { window ->
55+
window.sum()
56+
}.sum())
57+
}
58+
59+
@Test
60+
fun testErrorCancelsUpstream() = runTest {
61+
val latch = Channel<Unit>()
62+
val flow = flow {
63+
coroutineScope {
64+
launch(start = CoroutineStart.ATOMIC) {
65+
latch.send(Unit)
66+
hang { expect(3) }
67+
}
68+
emit(1)
69+
expect(1)
70+
emit(2)
71+
expectUnreached()
72+
}
73+
}.windowed<Int, Int>(2, 3, false) { window ->
74+
expect(2) // 2
75+
latch.receive()
76+
throw TestException()
77+
}.catch { emit(42) }
78+
79+
assertEquals(42, flow.single())
80+
finish(4)
81+
}
82+
}

0 commit comments

Comments
 (0)