diff --git a/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt b/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt index 98e665f601..d99ae52c7d 100644 --- a/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt +++ b/kotlinx-coroutines-core/common/src/flow/terminal/Reduce.kt @@ -8,9 +8,7 @@ package kotlinx.coroutines.flow -import kotlinx.coroutines.* import kotlinx.coroutines.flow.internal.* -import kotlinx.coroutines.flow.internal.unsafeFlow as flow import kotlin.jvm.* /** @@ -84,15 +82,10 @@ public suspend fun Flow.singleOrNull(): T? { */ public suspend fun Flow.first(): T { var result: Any? = NULL - try { - collect { value -> - result = value - throw AbortFlowException(NopCollector) - } - } catch (e: AbortFlowException) { - // Do nothing + collectUntil { + result = it + true } - if (result === NULL) throw NoSuchElementException("Expected at least one element") return result as T } @@ -103,17 +96,14 @@ public suspend fun Flow.first(): T { */ public suspend fun Flow.first(predicate: suspend (T) -> Boolean): T { var result: Any? = NULL - try { - collect { value -> - if (predicate(value)) { - result = value - throw AbortFlowException(NopCollector) - } + collectUntil { + if (predicate(it)) { + result = it + true + } else { + false } - } catch (e: AbortFlowException) { - // Do nothing } - if (result === NULL) throw NoSuchElementException("Expected at least one element matching the predicate $predicate") return result as T } @@ -124,13 +114,9 @@ public suspend fun Flow.first(predicate: suspend (T) -> Boolean): T { */ public suspend fun Flow.firstOrNull(): T? { var result: T? = null - try { - collect { value -> - result = value - throw AbortFlowException(NopCollector) - } - } catch (e: AbortFlowException) { - // Do nothing + collectUntil { + result = it + true } return result } @@ -141,15 +127,28 @@ public suspend fun Flow.firstOrNull(): T? { */ public suspend fun Flow.firstOrNull(predicate: suspend (T) -> Boolean): T? { var result: T? = null - try { - collect { value -> - if (predicate(value)) { - result = value - throw AbortFlowException(NopCollector) + collectUntil { + if (predicate(it)) { + result = it + true + } else { + false + } + } + return result +} + +internal suspend inline fun Flow.collectUntil(crossinline block: suspend (value: T) -> Boolean) { + val collector = object : FlowCollector { + override suspend fun emit(value: T) { + if (block(value)) { + throw AbortFlowException(this) } } + } + try { + collect(collector) } catch (e: AbortFlowException) { - // Do nothing + e.checkOwnership(collector) } - return result } diff --git a/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt b/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt index f737a1d0de..edb9f00fa6 100644 --- a/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt +++ b/kotlinx-coroutines-core/common/test/flow/terminal/FirstTest.kt @@ -6,6 +6,7 @@ package kotlinx.coroutines.flow import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlinx.coroutines.flow.internal.* import kotlin.test.* class FirstTest : TestBase() { @@ -160,4 +161,13 @@ class FirstTest : TestBase() { assertSame(instance, flow.first { true }) assertSame(instance, flow.firstOrNull { true }) } + + @Test + fun testAbortFlowException() = runTest { + val flow = flow { + throw AbortFlowException(NopCollector) // Emulate cancellation + } + + assertFailsWith { flow.first() } + } } diff --git a/kotlinx-coroutines-core/jvm/test/flow/FirstJvmTest.kt b/kotlinx-coroutines-core/jvm/test/flow/FirstJvmTest.kt new file mode 100644 index 0000000000..77ad0831f3 --- /dev/null +++ b/kotlinx-coroutines-core/jvm/test/flow/FirstJvmTest.kt @@ -0,0 +1,28 @@ +/* + * Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.coroutines.flow + +import kotlinx.coroutines.* +import org.junit.Test +import kotlin.test.* + +class FirstJvmTest : TestBase() { + + @Test + fun testTakeInterference() = runBlocking(Dispatchers.Default) { + /* + * This test tests a racy situation when outer channelFlow is being cancelled, + * inner flow starts atomically in "CANCELLING" state, sends one element and completes + * (=> cancels and drops element away), triggering NSEE in Flow.first operator + */ + val values = (0..10000).asFlow().flatMapMerge(Int.MAX_VALUE) { + channelFlow { + val value = channelFlow { send(1) }.first() + send(value) + } + }.take(1).toList() + assertEquals(listOf(1), values) + } +} \ No newline at end of file