Skip to content

Commit b42f986

Browse files
elizarovEdwarDDayLouisCAD
authored
Flow.transformWhile operator (#2066)
Also, most flow-truncating operators are refactored via a common internal collectWhile operator that properly uses AbortFlowException and checks for its ownership, so that we don't have to look for bugs in interactions between all those operators (and zip, too, which is also flow-truncating). But `take` operator still users a custom highly-tuned implementation. Fixes #2065 Co-authored-by: EdwarDDay <[email protected]> Co-authored-by: Louis CAD <[email protected]>
1 parent 5183b62 commit b42f986

File tree

7 files changed

+249
-35
lines changed

7 files changed

+249
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
/*
2+
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
@file:Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER")
6+
7+
package benchmarks.flow
8+
9+
import kotlinx.coroutines.*
10+
import kotlinx.coroutines.flow.*
11+
import kotlinx.coroutines.flow.internal.*
12+
import org.openjdk.jmh.annotations.*
13+
import java.util.concurrent.TimeUnit
14+
15+
@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
16+
@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
17+
@Fork(value = 1)
18+
@BenchmarkMode(Mode.AverageTime)
19+
@OutputTimeUnit(TimeUnit.MICROSECONDS)
20+
@State(Scope.Benchmark)
21+
open class TakeWhileBenchmark {
22+
@Param("1", "10", "100", "1000")
23+
private var size: Int = 0
24+
25+
private suspend inline fun Flow<Long>.consume() =
26+
filter { it % 2L != 0L }
27+
.map { it * it }.count()
28+
29+
@Benchmark
30+
fun baseline() = runBlocking<Int> {
31+
(0L until size).asFlow().consume()
32+
}
33+
34+
@Benchmark
35+
fun takeWhileDirect() = runBlocking<Int> {
36+
(0L..Long.MAX_VALUE).asFlow().takeWhileDirect { it < size }.consume()
37+
}
38+
39+
@Benchmark
40+
fun takeWhileViaCollectWhile() = runBlocking<Int> {
41+
(0L..Long.MAX_VALUE).asFlow().takeWhileViaCollectWhile { it < size }.consume()
42+
}
43+
44+
// Direct implementation by checking predicate and throwing AbortFlowException
45+
private fun <T> Flow<T>.takeWhileDirect(predicate: suspend (T) -> Boolean): Flow<T> = unsafeFlow {
46+
try {
47+
collect { value ->
48+
if (predicate(value)) emit(value)
49+
else throw AbortFlowException(this)
50+
}
51+
} catch (e: AbortFlowException) {
52+
e.checkOwnership(owner = this)
53+
}
54+
}
55+
56+
// Essentially the same code, but reusing the logic via collectWhile function
57+
private fun <T> Flow<T>.takeWhileViaCollectWhile(predicate: suspend (T) -> Boolean): Flow<T> = unsafeFlow {
58+
// This return is needed to work around a bug in JS BE: KT-39227
59+
return@unsafeFlow collectWhile { value ->
60+
if (predicate(value)) {
61+
emit(value)
62+
true
63+
} else {
64+
false
65+
}
66+
}
67+
}
68+
}

kotlinx-coroutines-core/api/kotlinx-coroutines-core.api

+1
Original file line numberDiff line numberDiff line change
@@ -995,6 +995,7 @@ public final class kotlinx/coroutines/flow/FlowKt {
995995
public static synthetic fun toSet$default (Lkotlinx/coroutines/flow/Flow;Ljava/util/Set;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
996996
public static final fun transform (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
997997
public static final fun transformLatest (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
998+
public static final fun transformWhile (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
998999
public static final fun unsafeTransform (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
9991000
public static final fun withIndex (Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow;
10001001
public static final fun zip (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;

kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt

+4-3
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@ import kotlin.jvm.*
1919
/**
2020
* Applies [transform] function to each value of the given flow.
2121
*
22-
* The receiver of the [transform] is [FlowCollector] and thus `transform` is a
23-
* generic function that may transform emitted element, skip it or emit it multiple times.
22+
* The receiver of the `transform` is [FlowCollector] and thus `transform` is a
23+
* flexible function that may transform emitted element, skip it or emit it multiple times.
2424
*
25-
* This operator can be used as a building block for other operators, for example:
25+
* This operator generalizes [filter] and [map] operators and
26+
* can be used as a building block for other operators, for example:
2627
*
2728
* ```
2829
* fun Flow<Int>.skipOddAndDuplicateEven(): Flow<Int> = transform { value ->

kotlinx-coroutines-core/common/src/flow/operators/Limit.kt

+64-5
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77

88
package kotlinx.coroutines.flow
99

10+
import kotlinx.coroutines.*
1011
import kotlinx.coroutines.flow.internal.*
1112
import kotlin.jvm.*
13+
import kotlinx.coroutines.flow.flow as safeFlow
1214
import kotlinx.coroutines.flow.internal.unsafeFlow as flow
1315

1416
/**
@@ -51,6 +53,10 @@ public fun <T> Flow<T>.take(count: Int): Flow<T> {
5153
var consumed = 0
5254
try {
5355
collect { value ->
56+
// Note: this for take is not written via collectWhile on purpose.
57+
// It checks condition first and then makes a tail-call to either emit or emitAbort.
58+
// This way normal execution does not require a state machine, only a termination (emitAbort).
59+
// See "TakeBenchmark" for comparision of different approaches.
5460
if (++consumed < count) {
5561
return@collect emit(value)
5662
} else {
@@ -70,14 +76,67 @@ private suspend fun <T> FlowCollector<T>.emitAbort(value: T) {
7076

7177
/**
7278
* Returns a flow that contains first elements satisfying the given [predicate].
79+
*
80+
* Note, that the resulting flow does not contain the element on which the [predicate] returned `false`.
81+
* See [transformWhile] for a more flexible operator.
7382
*/
7483
public fun <T> Flow<T>.takeWhile(predicate: suspend (T) -> Boolean): Flow<T> = flow {
75-
try {
76-
collect { value ->
77-
if (predicate(value)) emit(value)
78-
else throw AbortFlowException(this)
84+
// This return is needed to work around a bug in JS BE: KT-39227
85+
return@flow collectWhile { value ->
86+
if (predicate(value)) {
87+
emit(value)
88+
true
89+
} else {
90+
false
7991
}
92+
}
93+
}
94+
95+
/**
96+
* Applies [transform] function to each value of the given flow while this
97+
* function returns `true`.
98+
*
99+
* The receiver of the `transformWhile` is [FlowCollector] and thus `transformWhile` is a
100+
* flexible function that may transform emitted element, skip it or emit it multiple times.
101+
*
102+
* This operator generalizes [takeWhile] and can be used as a building block for other operators.
103+
* For example, a flow of download progress messages can be completed when the
104+
* download is done but emit this last message (unlike `takeWhile`):
105+
*
106+
* ```
107+
* fun Flow<DownloadProgress>.completeWhenDone(): Flow<DownloadProgress> =
108+
* transformWhile { progress ->
109+
* emit(progress) // always emit progress
110+
* !progress.isDone() // continue while download is not done
111+
* }
112+
* }
113+
* ```
114+
*/
115+
@ExperimentalCoroutinesApi
116+
public fun <T, R> Flow<T>.transformWhile(
117+
@BuilderInference transform: suspend FlowCollector<R>.(value: T) -> Boolean
118+
): Flow<R> =
119+
safeFlow { // Note: safe flow is used here, because collector is exposed to transform on each operation
120+
// This return is needed to work around a bug in JS BE: KT-39227
121+
return@safeFlow collectWhile { value ->
122+
transform(value)
123+
}
124+
}
125+
126+
// Internal building block for non-tailcalling flow-truncating operators
127+
internal suspend inline fun <T> Flow<T>.collectWhile(crossinline predicate: suspend (value: T) -> Boolean) {
128+
val collector = object : FlowCollector<T> {
129+
override suspend fun emit(value: T) {
130+
// Note: we are checking predicate first, then throw. If the predicate does suspend (calls emit, for example)
131+
// the the resulting code is never tail-suspending and produces a state-machine
132+
if (!predicate(value)) {
133+
throw AbortFlowException(this)
134+
}
135+
}
136+
}
137+
try {
138+
collect(collector)
80139
} catch (e: AbortFlowException) {
81-
e.checkOwnership(owner = this)
140+
e.checkOwnership(collector)
82141
}
83142
}

kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt

+10-25
Original file line numberDiff line numberDiff line change
@@ -82,9 +82,9 @@ public suspend fun <T: Any> Flow<T>.singleOrNull(): T? {
8282
*/
8383
public suspend fun <T> Flow<T>.first(): T {
8484
var result: Any? = NULL
85-
collectUntil {
85+
collectWhile {
8686
result = it
87-
true
87+
false
8888
}
8989
if (result === NULL) throw NoSuchElementException("Expected at least one element")
9090
return result as T
@@ -96,12 +96,12 @@ public suspend fun <T> Flow<T>.first(): T {
9696
*/
9797
public suspend fun <T> Flow<T>.first(predicate: suspend (T) -> Boolean): T {
9898
var result: Any? = NULL
99-
collectUntil {
99+
collectWhile {
100100
if (predicate(it)) {
101101
result = it
102-
true
103-
} else {
104102
false
103+
} else {
104+
true
105105
}
106106
}
107107
if (result === NULL) throw NoSuchElementException("Expected at least one element matching the predicate $predicate")
@@ -114,9 +114,9 @@ public suspend fun <T> Flow<T>.first(predicate: suspend (T) -> Boolean): T {
114114
*/
115115
public suspend fun <T : Any> Flow<T>.firstOrNull(): T? {
116116
var result: T? = null
117-
collectUntil {
117+
collectWhile {
118118
result = it
119-
true
119+
false
120120
}
121121
return result
122122
}
@@ -127,28 +127,13 @@ public suspend fun <T : Any> Flow<T>.firstOrNull(): T? {
127127
*/
128128
public suspend fun <T : Any> Flow<T>.firstOrNull(predicate: suspend (T) -> Boolean): T? {
129129
var result: T? = null
130-
collectUntil {
130+
collectWhile {
131131
if (predicate(it)) {
132132
result = it
133-
true
134-
} else {
135133
false
134+
} else {
135+
true
136136
}
137137
}
138138
return result
139139
}
140-
141-
internal suspend inline fun <T> Flow<T>.collectUntil(crossinline block: suspend (value: T) -> Boolean) {
142-
val collector = object : FlowCollector<T> {
143-
override suspend fun emit(value: T) {
144-
if (block(value)) {
145-
throw AbortFlowException(this)
146-
}
147-
}
148-
}
149-
try {
150-
collect(collector)
151-
} catch (e: AbortFlowException) {
152-
e.checkOwnership(collector)
153-
}
154-
}

kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt

+32-2
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ class FlowInvariantsTest : TestBase() {
192192
}
193193

194194
@Test
195-
fun testEmptyCoroutineContext() = runTest {
195+
fun testEmptyCoroutineContextMap() = runTest {
196196
emptyContextTest {
197197
map {
198198
expect(it)
@@ -212,7 +212,18 @@ class FlowInvariantsTest : TestBase() {
212212
}
213213

214214
@Test
215-
fun testEmptyCoroutineContextViolation() = runTest {
215+
fun testEmptyCoroutineContextTransformWhile() = runTest {
216+
emptyContextTest {
217+
transformWhile {
218+
expect(it)
219+
emit(it + 1)
220+
true
221+
}
222+
}
223+
}
224+
225+
@Test
226+
fun testEmptyCoroutineContextViolationTransform() = runTest {
216227
try {
217228
emptyContextTest {
218229
transform {
@@ -229,6 +240,25 @@ class FlowInvariantsTest : TestBase() {
229240
}
230241
}
231242

243+
@Test
244+
fun testEmptyCoroutineContextViolationTransformWhile() = runTest {
245+
try {
246+
emptyContextTest {
247+
transformWhile {
248+
expect(it)
249+
withContext(Dispatchers.Unconfined) {
250+
emit(it + 1)
251+
}
252+
true
253+
}
254+
}
255+
expectUnreached()
256+
} catch (e: IllegalStateException) {
257+
assertTrue(e.message!!.contains("Flow invariant is violated"))
258+
finish(2)
259+
}
260+
}
261+
232262
private suspend fun emptyContextTest(block: Flow<Int>.() -> Flow<Int>) {
233263
suspend fun collector(): Int {
234264
var result: Int = -1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.coroutines.flow
6+
7+
import kotlinx.coroutines.*
8+
import kotlin.test.*
9+
10+
class TransformWhileTest : TestBase() {
11+
@Test
12+
fun testSimple() = runTest {
13+
val flow = (0..10).asFlow()
14+
val expected = listOf("A", "B", "C", "D")
15+
val actual = flow.transformWhile { value ->
16+
when(value) {
17+
0 -> { emit("A"); true }
18+
1 -> true
19+
2 -> { emit("B"); emit("C"); true }
20+
3 -> { emit("D"); false }
21+
else -> { expectUnreached(); false }
22+
}
23+
}.toList()
24+
assertEquals(expected, actual)
25+
}
26+
27+
@Test
28+
fun testCancelUpstream() = runTest {
29+
var cancelled = false
30+
val flow = flow {
31+
coroutineScope {
32+
launch(start = CoroutineStart.ATOMIC) {
33+
hang { cancelled = true }
34+
}
35+
emit(1)
36+
emit(2)
37+
emit(3)
38+
}
39+
}
40+
val transformed = flow.transformWhile {
41+
emit(it)
42+
it < 2
43+
}
44+
assertEquals(listOf(1, 2), transformed.toList())
45+
assertTrue(cancelled)
46+
}
47+
48+
@Test
49+
fun testExample() = runTest {
50+
val source = listOf(
51+
DownloadProgress(0),
52+
DownloadProgress(50),
53+
DownloadProgress(100),
54+
DownloadProgress(147)
55+
)
56+
val expected = source.subList(0, 3)
57+
val actual = source.asFlow().completeWhenDone().toList()
58+
assertEquals(expected, actual)
59+
}
60+
61+
private fun Flow<DownloadProgress>.completeWhenDone(): Flow<DownloadProgress> =
62+
transformWhile { progress ->
63+
emit(progress) // always emit progress
64+
!progress.isDone() // continue while download is not done
65+
}
66+
67+
private data class DownloadProgress(val percent: Int) {
68+
fun isDone() = percent >= 100
69+
}
70+
}

0 commit comments

Comments
 (0)