diff --git a/benchmarks/src/jmh/kotlin/benchmarks/flow/TakeWhileBenchmark.kt b/benchmarks/src/jmh/kotlin/benchmarks/flow/TakeWhileBenchmark.kt new file mode 100644 index 0000000000..fd3d3cdb96 --- /dev/null +++ b/benchmarks/src/jmh/kotlin/benchmarks/flow/TakeWhileBenchmark.kt @@ -0,0 +1,68 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +@file:Suppress("INVISIBLE_REFERENCE", "INVISIBLE_MEMBER") + +package benchmarks.flow + +import kotlinx.coroutines.* +import kotlinx.coroutines.flow.* +import kotlinx.coroutines.flow.internal.* +import org.openjdk.jmh.annotations.* +import java.util.concurrent.TimeUnit + +@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Measurement(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS) +@Fork(value = 1) +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MICROSECONDS) +@State(Scope.Benchmark) +open class TakeWhileBenchmark { + @Param("1", "10", "100", "1000") + private var size: Int = 0 + + private suspend inline fun Flow.consume() = + filter { it % 2L != 0L } + .map { it * it }.count() + + @Benchmark + fun baseline() = runBlocking { + (0L until size).asFlow().consume() + } + + @Benchmark + fun takeWhileDirect() = runBlocking { + (0L..Long.MAX_VALUE).asFlow().takeWhileDirect { it < size }.consume() + } + + @Benchmark + fun takeWhileViaCollectWhile() = runBlocking { + (0L..Long.MAX_VALUE).asFlow().takeWhileViaCollectWhile { it < size }.consume() + } + + // Direct implementation by checking predicate and throwing AbortFlowException + private fun Flow.takeWhileDirect(predicate: suspend (T) -> Boolean): Flow = unsafeFlow { + try { + collect { value -> + if (predicate(value)) emit(value) + else throw AbortFlowException(this) + } + } catch (e: AbortFlowException) { + e.checkOwnership(owner = this) + } + } + + // Essentially the same code, but reusing the logic via collectWhile function + private fun Flow.takeWhileViaCollectWhile(predicate: suspend (T) -> Boolean): Flow = unsafeFlow { + // This return is needed to work around a bug in JS BE: KT-39227 + return@unsafeFlow collectWhile { value -> + if (predicate(value)) { + emit(value) + true + } else { + false + } + } + } +} diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index 6a24b6a23a..9fc49ca798 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -995,6 +995,7 @@ public final class kotlinx/coroutines/flow/FlowKt { public static synthetic fun toSet$default (Lkotlinx/coroutines/flow/Flow;Ljava/util/Set;Lkotlin/coroutines/Continuation;ILjava/lang/Object;)Ljava/lang/Object; public static final fun transform (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow; public static final fun transformLatest (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow; + public static final fun transformWhile (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow; public static final fun unsafeTransform (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow; public static final fun withIndex (Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow; public static final fun zip (Lkotlinx/coroutines/flow/Flow;Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function3;)Lkotlinx/coroutines/flow/Flow; diff --git a/kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt b/kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt index 1ffbf94a98..fb37da3a83 100644 --- a/kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt +++ b/kotlinx-coroutines-core/common/src/flow/operators/Emitters.kt @@ -19,10 +19,11 @@ import kotlin.jvm.* /** * Applies [transform] function to each value of the given flow. * - * The receiver of the [transform] is [FlowCollector] and thus `transform` is a - * generic function that may transform emitted element, skip it or emit it multiple times. + * The receiver of the `transform` is [FlowCollector] and thus `transform` is a + * flexible function that may transform emitted element, skip it or emit it multiple times. * - * This operator can be used as a building block for other operators, for example: + * This operator generalizes [filter] and [map] operators and + * can be used as a building block for other operators, for example: * * ``` * fun Flow.skipOddAndDuplicateEven(): Flow = transform { value -> diff --git a/kotlinx-coroutines-core/common/src/flow/operators/Limit.kt b/kotlinx-coroutines-core/common/src/flow/operators/Limit.kt index d30a2db206..1d7ffd1db6 100644 --- a/kotlinx-coroutines-core/common/src/flow/operators/Limit.kt +++ b/kotlinx-coroutines-core/common/src/flow/operators/Limit.kt @@ -7,8 +7,10 @@ package kotlinx.coroutines.flow +import kotlinx.coroutines.* import kotlinx.coroutines.flow.internal.* import kotlin.jvm.* +import kotlinx.coroutines.flow.flow as safeFlow import kotlinx.coroutines.flow.internal.unsafeFlow as flow /** @@ -51,6 +53,10 @@ public fun Flow.take(count: Int): Flow { var consumed = 0 try { collect { value -> + // Note: this for take is not written via collectWhile on purpose. + // It checks condition first and then makes a tail-call to either emit or emitAbort. + // This way normal execution does not require a state machine, only a termination (emitAbort). + // See "TakeBenchmark" for comparision of different approaches. if (++consumed < count) { return@collect emit(value) } else { @@ -70,14 +76,67 @@ private suspend fun FlowCollector.emitAbort(value: T) { /** * Returns a flow that contains first elements satisfying the given [predicate]. + * + * Note, that the resulting flow does not contain the element on which the [predicate] returned `false`. + * See [transformWhile] for a more flexible operator. */ public fun Flow.takeWhile(predicate: suspend (T) -> Boolean): Flow = flow { - try { - collect { value -> - if (predicate(value)) emit(value) - else throw AbortFlowException(this) + // This return is needed to work around a bug in JS BE: KT-39227 + return@flow collectWhile { value -> + if (predicate(value)) { + emit(value) + true + } else { + false } + } +} + +/** + * Applies [transform] function to each value of the given flow while this + * function returns `true`. + * + * The receiver of the `transformWhile` is [FlowCollector] and thus `transformWhile` is a + * flexible function that may transform emitted element, skip it or emit it multiple times. + * + * This operator generalizes [takeWhile] and can be used as a building block for other operators. + * For example, a flow of download progress messages can be completed when the + * download is done but emit this last message (unlike `takeWhile`): + * + * ``` + * fun Flow.completeWhenDone(): Flow = + * transformWhile { progress -> + * emit(progress) // always emit progress + * !progress.isDone() // continue while download is not done + * } + * } + * ``` + */ +@ExperimentalCoroutinesApi +public fun Flow.transformWhile( + @BuilderInference transform: suspend FlowCollector.(value: T) -> Boolean +): Flow = + safeFlow { // Note: safe flow is used here, because collector is exposed to transform on each operation + // This return is needed to work around a bug in JS BE: KT-39227 + return@safeFlow collectWhile { value -> + transform(value) + } + } + +// Internal building block for non-tailcalling flow-truncating operators +internal suspend inline fun Flow.collectWhile(crossinline predicate: suspend (value: T) -> Boolean) { + val collector = object : FlowCollector { + override suspend fun emit(value: T) { + // Note: we are checking predicate first, then throw. If the predicate does suspend (calls emit, for example) + // the the resulting code is never tail-suspending and produces a state-machine + if (!predicate(value)) { + throw AbortFlowException(this) + } + } + } + try { + collect(collector) } catch (e: AbortFlowException) { - e.checkOwnership(owner = this) + e.checkOwnership(collector) } } diff --git a/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt b/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt index d99ae52c7d..d36e1bbf7b 100644 --- a/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt +++ b/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt @@ -82,9 +82,9 @@ public suspend fun Flow.singleOrNull(): T? { */ public suspend fun Flow.first(): T { var result: Any? = NULL - collectUntil { + collectWhile { result = it - true + false } if (result === NULL) throw NoSuchElementException("Expected at least one element") return result as T @@ -96,12 +96,12 @@ public suspend fun Flow.first(): T { */ public suspend fun Flow.first(predicate: suspend (T) -> Boolean): T { var result: Any? = NULL - collectUntil { + collectWhile { if (predicate(it)) { result = it - true - } else { false + } else { + true } } if (result === NULL) throw NoSuchElementException("Expected at least one element matching the predicate $predicate") @@ -114,9 +114,9 @@ public suspend fun Flow.first(predicate: suspend (T) -> Boolean): T { */ public suspend fun Flow.firstOrNull(): T? { var result: T? = null - collectUntil { + collectWhile { result = it - true + false } return result } @@ -127,28 +127,13 @@ public suspend fun Flow.firstOrNull(): T? { */ public suspend fun Flow.firstOrNull(predicate: suspend (T) -> Boolean): T? { var result: T? = null - collectUntil { + collectWhile { if (predicate(it)) { result = it - true - } else { false + } else { + true } } return result } - -internal suspend inline fun Flow.collectUntil(crossinline block: suspend (value: T) -> Boolean) { - val collector = object : FlowCollector { - override suspend fun emit(value: T) { - if (block(value)) { - throw AbortFlowException(this) - } - } - } - try { - collect(collector) - } catch (e: AbortFlowException) { - e.checkOwnership(collector) - } -} diff --git a/kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt b/kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt index ce93f1fdb2..effc8744d0 100644 --- a/kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/FlowInvariantsTest.kt @@ -193,7 +193,7 @@ class FlowInvariantsTest : TestBase() { } @Test - fun testEmptyCoroutineContext() = runTest { + fun testEmptyCoroutineContextMap() = runTest { emptyContextTest { map { expect(it) @@ -213,7 +213,18 @@ class FlowInvariantsTest : TestBase() { } @Test - fun testEmptyCoroutineContextViolation() = runTest { + fun testEmptyCoroutineContextTransformWhile() = runTest { + emptyContextTest { + transformWhile { + expect(it) + emit(it + 1) + true + } + } + } + + @Test + fun testEmptyCoroutineContextViolationTransform() = runTest { try { emptyContextTest { transform { @@ -230,6 +241,25 @@ class FlowInvariantsTest : TestBase() { } } + @Test + fun testEmptyCoroutineContextViolationTransformWhile() = runTest { + try { + emptyContextTest { + transformWhile { + expect(it) + withContext(Dispatchers.Unconfined) { + emit(it + 1) + } + true + } + } + expectUnreached() + } catch (e: IllegalStateException) { + assertTrue(e.message!!.contains("Flow invariant is violated")) + finish(2) + } + } + private suspend fun emptyContextTest(block: Flow.() -> Flow) { suspend fun collector(): Int { var result: Int = -1 diff --git a/kotlinx-coroutines-core/common/test/flow/operators/TransformWhileTest.kt b/kotlinx-coroutines-core/common/test/flow/operators/TransformWhileTest.kt new file mode 100644 index 0000000000..df660103c3 --- /dev/null +++ b/kotlinx-coroutines-core/common/test/flow/operators/TransformWhileTest.kt @@ -0,0 +1,70 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.flow + +import kotlinx.coroutines.* +import kotlin.test.* + +class TransformWhileTest : TestBase() { + @Test + fun testSimple() = runTest { + val flow = (0..10).asFlow() + val expected = listOf("A", "B", "C", "D") + val actual = flow.transformWhile { value -> + when(value) { + 0 -> { emit("A"); true } + 1 -> true + 2 -> { emit("B"); emit("C"); true } + 3 -> { emit("D"); false } + else -> { expectUnreached(); false } + } + }.toList() + assertEquals(expected, actual) + } + + @Test + fun testCancelUpstream() = runTest { + var cancelled = false + val flow = flow { + coroutineScope { + launch(start = CoroutineStart.ATOMIC) { + hang { cancelled = true } + } + emit(1) + emit(2) + emit(3) + } + } + val transformed = flow.transformWhile { + emit(it) + it < 2 + } + assertEquals(listOf(1, 2), transformed.toList()) + assertTrue(cancelled) + } + + @Test + fun testExample() = runTest { + val source = listOf( + DownloadProgress(0), + DownloadProgress(50), + DownloadProgress(100), + DownloadProgress(147) + ) + val expected = source.subList(0, 3) + val actual = source.asFlow().completeWhenDone().toList() + assertEquals(expected, actual) + } + + private fun Flow.completeWhenDone(): Flow = + transformWhile { progress -> + emit(progress) // always emit progress + !progress.isDone() // continue while download is not done + } + + private data class DownloadProgress(val percent: Int) { + fun isDone() = percent >= 100 + } +} \ No newline at end of file