Skip to content

Flow.transformWhile operator #2066

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions benchmarks/src/jmh/kotlin/benchmarks/flow/TakeWhileBenchmark.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

@file:Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER")

package benchmarks.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import kotlinx.coroutines.flow.internal.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.TimeUnit

@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 1)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
open class TakeWhileBenchmark {
@Param("1", "10", "100", "1000")
private var size: Int = 0

private suspend inline fun Flow<Long>.consume() =
filter { it % 2L != 0L }
.map { it * it }.count()

@Benchmark
fun baseline() = runBlocking<Int> {
(0L until size).asFlow().consume()
}

@Benchmark
fun takeWhileDirect() = runBlocking<Int> {
(0L..Long.MAX_VALUE).asFlow().takeWhileDirect { it < size }.consume()
}

@Benchmark
fun takeWhileViaCollectWhile() = runBlocking<Int> {
(0L..Long.MAX_VALUE).asFlow().takeWhileViaCollectWhile { it < size }.consume()
}

// Direct implemenatation by checking predicate and throwing AbortFlowException
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo

private fun <T> Flow<T>.takeWhileDirect(predicate: suspend (T) -> Boolean): Flow<T> = unsafeFlow {
try {
collect { value ->
if (predicate(value)) emit(value)
else throw AbortFlowException(this)
}
} catch (e: AbortFlowException) {
e.checkOwnership(owner = this)
}
}

// Essentially the same code, but reusing the logic via collectWhile function
private fun <T> Flow<T>.takeWhileViaCollectWhile(predicate: suspend (T) -> Boolean): Flow<T> = unsafeFlow {
// This return is needed to work around a bug in JS BE: KT-39227
return@unsafeFlow collectWhile { value ->
if (predicate(value)) {
emit(value)
true
} else {
false
}
}
}
}
1 change: 1 addition & 0 deletions kotlinx-coroutines-core/api/kotlinx-coroutines-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -995,6 +995,7 @@ public final class kotlinx/coroutines/flow/FlowKt {
public static synthetic fun toSet$default (Lkotlinx/coroutines/flow/Flow;Ljava/util/Set;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object;
public static final fun transform (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
public static final fun transformLatest (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
public static final fun transformWhile (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
public static final fun unsafeTransform (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
public static final fun withIndex (Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow;
public static final fun zip (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow;
Expand Down
7 changes: 4 additions & 3 deletions kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@ import kotlin.jvm.*
/**
* Applies [transform] function to each value of the given flow.
*
* The receiver of the [transform] is [FlowCollector] and thus `transform` is a
* generic function that may transform emitted element, skip it or emit it multiple times.
* The receiver of the `transform` is [FlowCollector] and thus `transform` is a
* flexible function that may transform emitted element, skip it or emit it multiple times.
*
* This operator can be used as a building block for other operators, for example:
* This operator generalizes [filter] and [map] operators and
* can be used as a building block for other operators, for example:
*
* ```
* fun Flow<Int>.skipOddAndDuplicateEven(): Flow<Int> = transform { value ->
Expand Down
69 changes: 64 additions & 5 deletions kotlinx-coroutines-core/common/src/flow/operators/Limit.kt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.internal.*
import kotlin.jvm.*
import kotlinx.coroutines.flow.flow as safeFlow
import kotlinx.coroutines.flow.internal.unsafeFlow as flow

/**
Expand Down Expand Up @@ -51,6 +53,10 @@ public fun <T> Flow<T>.take(count: Int): Flow<T> {
var consumed = 0
try {
collect { value ->
// Note: this for take is not written via collectWhile on purpose.
// It checks condition first and then makes a tail-call to either emit or emitAbort.
// This way normal execution does not require a state machine, only a termination (emitAbort).
// See "TakeBenchmark" for comparision of different approaches.
if (++consumed < count) {
return@collect emit(value)
} else {
Expand All @@ -70,14 +76,67 @@ private suspend fun <T> FlowCollector<T>.emitAbort(value: T) {

/**
* Returns a flow that contains first elements satisfying the given [predicate].
*
* Note, that the resulting flow does not contain the element on which the [predicate] returned `false`.
* See [transformWhile] for a more flexible operator.
*/
public fun <T> Flow<T>.takeWhile(predicate: suspend (T) -> Boolean): Flow<T> = flow {
try {
collect { value ->
if (predicate(value)) emit(value)
else throw AbortFlowException(this)
// This return is needed to work around a bug in JS BE: KT-39227
return@flow collectWhile { value ->
if (predicate(value)) {
emit(value)
true
} else {
false
}
}
}

/**
* Applies [transform] function to each value of the given flow while this
* function returns `true`.
*
* The receiver of the `transformWhile` is [FlowCollector] and thus `transformWhile` is a
* flexible function that may transform emitted element, skip it or emit it multiple times.
*
* This operator generalizes [takeWhile] and can be used as a building block for other operators.
* For example, a flow of download progress messages can be completed when the
* download is done but emit this last message (unlike `takeWhile`):
*
* ```
* fun Flow<DownloadProgress>.completeWhenDone(): Flow<DownloadProgress> =
* transformWhile { progress ->
* emit(progress) // always emit progress
* !progress.isDone() // continue while download is not done
* }
* }
* ```
*/
@ExperimentalCoroutinesApi
public fun <T, R> Flow<T>.transformWhile(
@BuilderInference transform: suspend FlowCollector<R>.(value: T) -> Boolean
): Flow<R> =
safeFlow { // Note: safe flow is used here, because collector is exposed to transform on each operation
// This return is needed to work around a bug in JS BE: KT-39227
return@safeFlow collectWhile { value ->
transform(value)
}
}

// Internal building block for non-tailcalling flow-truncating operators
internal suspend inline fun <T> Flow<T>.collectWhile(crossinline predicate: suspend (value: T) -> Boolean) {
val collector = object : FlowCollector<T> {
override suspend fun emit(value: T) {
// Note: we are checking predicate first, then throw. If the predicate does suspend (calls emit, for example)
// the the resulting code is never tail-suspending and produces a state-machine
if (!predicate(value)) {
throw AbortFlowException(this)
}
}
}
try {
collect(collector)
} catch (e: AbortFlowException) {
e.checkOwnership(owner = this)
e.checkOwnership(collector)
}
}
35 changes: 10 additions & 25 deletions kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ public suspend fun <T: Any> Flow<T>.singleOrNull(): T? {
*/
public suspend fun <T> Flow<T>.first(): T {
var result: Any? = NULL
collectUntil {
collectWhile {
result = it
true
false
}
if (result === NULL) throw NoSuchElementException("Expected at least one element")
return result as T
Expand All @@ -96,12 +96,12 @@ public suspend fun <T> Flow<T>.first(): T {
*/
public suspend fun <T> Flow<T>.first(predicate: suspend (T) -> Boolean): T {
var result: Any? = NULL
collectUntil {
collectWhile {
if (predicate(it)) {
result = it
true
} else {
false
} else {
true
}
}
if (result === NULL) throw NoSuchElementException("Expected at least one element matching the predicate $predicate")
Expand All @@ -114,9 +114,9 @@ public suspend fun <T> Flow<T>.first(predicate: suspend (T) -> Boolean): T {
*/
public suspend fun <T : Any> Flow<T>.firstOrNull(): T? {
var result: T? = null
collectUntil {
collectWhile {
result = it
true
false
}
return result
}
Expand All @@ -127,28 +127,13 @@ public suspend fun <T : Any> Flow<T>.firstOrNull(): T? {
*/
public suspend fun <T : Any> Flow<T>.firstOrNull(predicate: suspend (T) -> Boolean): T? {
var result: T? = null
collectUntil {
collectWhile {
if (predicate(it)) {
result = it
true
} else {
false
} else {
true
}
}
return result
}

internal suspend inline fun <T> Flow<T>.collectUntil(crossinline block: suspend (value: T) -> Boolean) {
val collector = object : FlowCollector<T> {
override suspend fun emit(value: T) {
if (block(value)) {
throw AbortFlowException(this)
}
}
}
try {
collect(collector)
} catch (e: AbortFlowException) {
e.checkOwnership(collector)
}
}
34 changes: 32 additions & 2 deletions kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class FlowInvariantsTest : TestBase() {
}

@Test
fun testEmptyCoroutineContext() = runTest {
fun testEmptyCoroutineContextMap() = runTest {
emptyContextTest {
map {
expect(it)
Expand All @@ -213,7 +213,18 @@ class FlowInvariantsTest : TestBase() {
}

@Test
fun testEmptyCoroutineContextViolation() = runTest {
fun testEmptyCoroutineContextTransformWhile() = runTest {
emptyContextTest {
transformWhile {
expect(it)
emit(it + 1)
true
}
}
}

@Test
fun testEmptyCoroutineContextViolationTransform() = runTest {
try {
emptyContextTest {
transform {
Expand All @@ -230,6 +241,25 @@ class FlowInvariantsTest : TestBase() {
}
}

@Test
fun testEmptyCoroutineContextViolationTransformWhile() = runTest {
try {
emptyContextTest {
transformWhile {
expect(it)
withContext(Dispatchers.Unconfined) {
emit(it + 1)
}
true
}
}
expectUnreached()
} catch (e: IllegalStateException) {
assertTrue(e.message!!.contains("Flow invariant is violated"))
finish(2)
}
}

private suspend fun emptyContextTest(block: Flow<Int>.() -> Flow<Int>) {
suspend fun collector(): Int {
var result: Int = -1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow

import kotlinx.coroutines.*
import kotlin.test.*

class TransformWhileTest : TestBase() {
@Test
fun testSimple() = runTest {
val flow = (0..10).asFlow()
val expected = listOf("A", "B", "C", "D")
val actual = flow.transformWhile { value ->
when(value) {
0 -> { emit("A"); true }
1 -> true
2 -> { emit("B"); emit("C"); true }
3 -> { emit("D"); false }
else -> { expectUnreached(); false }
}
}.toList()
assertEquals(expected, actual)
}

@Test
fun testCancelUpstream() = runTest {
var cancelled = false
val flow = flow {
coroutineScope {
launch(start = CoroutineStart.ATOMIC) {
hang { cancelled = true }
}
emit(1)
emit(2)
emit(3)
}
}
val transformed = flow.transformWhile {
emit(it)
it < 2
}
assertEquals(listOf(1, 2), transformed.toList())
assertTrue(cancelled)
}

@Test
fun testExample() = runTest {
val source = listOf(
DownloadProgress(0),
DownloadProgress(50),
DownloadProgress(100),
DownloadProgress(147)
)
val expected = source.subList(0, 3)
val actual = source.asFlow().completeWhenDone().toList()
assertEquals(expected, actual)
}

private fun Flow<DownloadProgress>.completeWhenDone(): Flow<DownloadProgress> =
transformWhile { progress ->
emit(progress) // always emit progress
!progress.isDone() // continue while download is not done
}

private data class DownloadProgress(val percent: Int) {
fun isDone() = percent >= 100
}
}