diff --git a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api index db821a1e05..a6e5fd513e 100644 --- a/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api +++ b/kotlinx-coroutines-core/api/kotlinx-coroutines-core.api @@ -907,6 +907,8 @@ public final class kotlinx/coroutines/flow/FlowKt { public static final fun filterNotNull (Lkotlinx/coroutines/flow/Flow;)Lkotlinx/coroutines/flow/Flow; public static final fun first (Lkotlinx/coroutines/flow/Flow;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static final fun first (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun firstOrNull (Lkotlinx/coroutines/flow/Flow;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; + public static final fun firstOrNull (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public static final fun flatMap (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow; public static final fun flatMapConcat (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow; public static final fun flatMapLatest (Lkotlinx/coroutines/flow/Flow;Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow; diff --git a/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt b/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt index ccf8241f41..674f8322f2 100644 --- a/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt +++ b/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt @@ -77,7 +77,6 @@ public suspend fun Flow.singleOrNull(): T? { if (result != null) error("Expected only one element") result = value } - return result } @@ -120,3 +119,39 @@ public suspend fun Flow.first(predicate: suspend (T) -> Boolean): T { if (result === NULL) throw NoSuchElementException("Expected at least one element matching the predicate $predicate") return result as T } + +/** + * The terminal operator that returns the first element emitted by the flow and then cancels flow's collection. + * Returns `null` if the flow was empty. + */ +public suspend fun Flow.firstOrNull(): T? { + var result: T? = null + try { + collect { value -> + result = value + throw AbortFlowException(NopCollector) + } + } catch (e: AbortFlowException) { + // Do nothing + } + return result +} + +/** + * The terminal operator that returns the first element emitted by the flow matching the given [predicate] and then cancels flow's collection. + * Returns `null` if the flow did not contain an element matching the [predicate]. + */ +public suspend fun Flow.firstOrNull(predicate: suspend (T) -> Boolean): T? { + var result: T? = null + try { + collect { value -> + if (predicate(value)) { + result = value + throw AbortFlowException(NopCollector) + } + } + } catch (e: AbortFlowException) { + // Do nothing + } + return result +} diff --git a/kotlinx-coroutines-core/common/test/AsyncTest.kt b/kotlinx-coroutines-core/common/test/AsyncTest.kt index 6fd4ebbe04..3019ddeab1 100644 --- a/kotlinx-coroutines-core/common/test/AsyncTest.kt +++ b/kotlinx-coroutines-core/common/test/AsyncTest.kt @@ -210,12 +210,6 @@ class AsyncTest : TestBase() { finish(13) } - class BadClass { - override fun equals(other: Any?): Boolean = error("equals") - override fun hashCode(): Int = error("hashCode") - override fun toString(): String = error("toString") - } - @Test fun testDeferBadClass() = runTest { val bad = BadClass() diff --git a/kotlinx-coroutines-core/common/test/TestBase.common.kt b/kotlinx-coroutines-core/common/test/TestBase.common.kt index a6119ee8a6..0ba80ee509 100644 --- a/kotlinx-coroutines-core/common/test/TestBase.common.kt +++ b/kotlinx-coroutines-core/common/test/TestBase.common.kt @@ -80,3 +80,8 @@ public fun wrapperDispatcher(context: CoroutineContext): CoroutineContext { public suspend fun wrapperDispatcher(): CoroutineContext = wrapperDispatcher(coroutineContext) +class BadClass { + override fun equals(other: Any?): Boolean = error("equals") + override fun hashCode(): Int = error("hashCode") + override fun toString(): String = error("toString") +} diff --git a/kotlinx-coroutines-core/common/test/WithTimeoutOrNullTest.kt b/kotlinx-coroutines-core/common/test/WithTimeoutOrNullTest.kt index 3faf900cb9..40d2758daa 100644 --- a/kotlinx-coroutines-core/common/test/WithTimeoutOrNullTest.kt +++ b/kotlinx-coroutines-core/common/test/WithTimeoutOrNullTest.kt @@ -152,12 +152,6 @@ class WithTimeoutOrNullTest : TestBase() { assertSame(bad, result) } - class BadClass { - override fun equals(other: Any?): Boolean = error("Should not be called") - override fun hashCode(): Int = error("Should not be called") - override fun toString(): String = error("Should not be called") - } - @Test fun testNullOnTimeout() = runTest { expect(1) diff --git a/kotlinx-coroutines-core/common/test/WithTimeoutTest.kt b/kotlinx-coroutines-core/common/test/WithTimeoutTest.kt index ab61b9c8f3..8462c96953 100644 --- a/kotlinx-coroutines-core/common/test/WithTimeoutTest.kt +++ b/kotlinx-coroutines-core/common/test/WithTimeoutTest.kt @@ -107,12 +107,6 @@ class WithTimeoutTest : TestBase() { assertSame(bad, result) } - class BadClass { - override fun equals(other: Any?): Boolean = error("Should not be called") - override fun hashCode(): Int = error("Should not be called") - override fun toString(): String = error("Should not be called") - } - @Test fun testExceptionOnTimeout() = runTest { expect(1) diff --git a/kotlinx-coroutines-core/common/test/channels/RendezvousChannelTest.kt b/kotlinx-coroutines-core/common/test/channels/RendezvousChannelTest.kt index d036af9395..4d20d71596 100644 --- a/kotlinx-coroutines-core/common/test/channels/RendezvousChannelTest.kt +++ b/kotlinx-coroutines-core/common/test/channels/RendezvousChannelTest.kt @@ -241,12 +241,6 @@ class RendezvousChannelTest : TestBase() { finish(12) } - class BadClass { - override fun equals(other: Any?): Boolean = error("equals") - override fun hashCode(): Int = error("hashCode") - override fun toString(): String = error("toString") - } - @Test fun testProduceBadClass() = runTest { val bad = BadClass() diff --git a/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt b/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt index e84d4c7b77..f737a1d0de 100644 --- a/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt @@ -83,4 +83,81 @@ class FirstTest : TestBase() { assertEquals(1, flow.first()) finish(2) } + + @Test + fun testFirstOrNull() = runTest { + val flow = flowOf(1, 2, 3) + assertEquals(1, flow.firstOrNull()) + } + + @Test + fun testFirstOrNullWithPredicate() = runTest { + val flow = flowOf(1, 2, 3) + assertEquals(1, flow.firstOrNull { it > 0 }) + assertEquals(2, flow.firstOrNull { it > 1 }) + assertNull(flow.firstOrNull { it > 3 }) + } + + @Test + fun testFirstOrNullCancellation() = runTest { + val latch = Channel() + val flow = flow { + coroutineScope { + launch { + latch.send(Unit) + hang { expect(1) } + } + emit(1) + emit(2) + } + } + + + val result = flow.firstOrNull { + latch.receive() + true + } + assertEquals(1, result) + finish(2) + } + + @Test + fun testFirstOrNullWithEmptyFlow() = runTest { + assertNull(emptyFlow().firstOrNull()) + assertNull(emptyFlow().firstOrNull { true }) + } + + @Test + fun testFirstOrNullWhenErrorCancelsUpstream() = runTest { + val latch = Channel() + val flow = flow { + coroutineScope { + launch { + latch.send(Unit) + hang { expect(1) } + } + emit(1) + } + } + + assertFailsWith { + flow.firstOrNull { + latch.receive() + throw TestException() + } + } + + assertEquals(1, flow.firstOrNull()) + finish(2) + } + + @Test + fun testBadClass() = runTest { + val instance = BadClass() + val flow = flowOf(instance) + assertSame(instance, flow.first()) + assertSame(instance, flow.firstOrNull()) + assertSame(instance, flow.first { true }) + assertSame(instance, flow.firstOrNull { true }) + } } diff --git a/kotlinx-coroutines-core/common/test/flow/terminal/SingleTest.kt b/kotlinx-coroutines-core/common/test/flow/terminal/SingleTest.kt index f7205957d1..4e89b93bd7 100644 --- a/kotlinx-coroutines-core/common/test/flow/terminal/SingleTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/terminal/SingleTest.kt @@ -17,7 +17,6 @@ class SingleTest : TestBase() { assertEquals(239L, flow.single()) assertEquals(239L, flow.singleOrNull()) - } @Test @@ -63,4 +62,12 @@ class SingleTest : TestBase() { assertNull(flowOf(null).single()) assertFailsWith { flowOf().single() } } + + @Test + fun testBadClass() = runTest { + val instance = BadClass() + val flow = flowOf(instance) + assertSame(instance, flow.single()) + assertSame(instance, flow.singleOrNull()) + } }