Skip to content

Restore context preservation invariant in flatMapMerge #1452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions benchmarks/src/jmh/kotlin/benchmarks/YieldRelativeCostBenchmark.kt

This file was deleted.

47 changes: 47 additions & 0 deletions benchmarks/src/jmh/kotlin/benchmarks/flow/FlatMapMergeBenchmark.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package benchmarks.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import org.openjdk.jmh.annotations.*
import java.util.concurrent.*

@Warmup(iterations = 7, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 7, time = 1, timeUnit = TimeUnit.SECONDS)
@Fork(value = 1)
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
open class FlatMapMergeBenchmark {

// Note: tests only absence of contention on downstream

@Param("10", "100", "1000")
private var iterations = 100

@Benchmark
fun flatMapUnsafe() = runBlocking {
benchmarks.flow.scrabble.flow {
repeat(iterations) { emit(it) }
}.flatMapMerge { value ->
flowOf(value)
}.collect {
if (it == -1) error("")
}
}

@Benchmark
fun flatMapSafe() = runBlocking {
kotlinx.coroutines.flow.flow {
repeat(iterations) { emit(it) }
}.flatMapMerge { value ->
flowOf(value)
}.collect {
if (it == -1) error("")
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
*/


package benchmarks.flow.misc
package benchmarks.flow

import benchmarks.flow.scrabble.flow
import io.reactivex.*
Expand Down Expand Up @@ -35,7 +35,7 @@ import java.util.concurrent.*
@BenchmarkMode(Mode.AverageTime)
@OutputTimeUnit(TimeUnit.MICROSECONDS)
@State(Scope.Benchmark)
open class Numbers {
open class NumbersBenchmark {

companion object {
private const val primes = 100
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package benchmarks.flow.misc
package benchmarks.flow

import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ public final class kotlinx/coroutines/flow/internal/SafeCollectorKt {
public static final fun unsafeFlow (Lkotlin/jvm/functions/Function2;)Lkotlinx/coroutines/flow/Flow;
}

public final class kotlinx/coroutines/flow/internal/SendingCollector : kotlinx/coroutines/flow/internal/ConcurrentFlowCollector {
public final class kotlinx/coroutines/flow/internal/SendingCollector : kotlinx/coroutines/flow/FlowCollector {
public fun <init> (Lkotlinx/coroutines/channels/SendChannel;)V
public fun emit (Ljava/lang/Object;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}
Expand Down
2 changes: 1 addition & 1 deletion kotlinx-coroutines-core/common/src/channels/Produce.kt
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ public fun <E> CoroutineScope.produce(
return coroutine
}

private class ProducerCoroutine<E>(
internal open class ProducerCoroutine<E>(
parentContext: CoroutineContext, channel: Channel<E>
) : ChannelCoroutine<E>(parentContext, channel, active = true), ProducerScope<E> {
override val isActive: Boolean
Expand Down
12 changes: 2 additions & 10 deletions kotlinx-coroutines-core/common/src/flow/internal/ChannelFlow.kt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public abstract class ChannelFlow<T>(
protected abstract suspend fun collectTo(scope: ProducerScope<T>)

// shared code to create a suspend lambda from collectTo function in one place
private val collectToFun: suspend (ProducerScope<T>) -> Unit
internal val collectToFun: suspend (ProducerScope<T>) -> Unit
get() = { collectTo(it) }

private val produceCapacity: Int
Expand Down Expand Up @@ -140,13 +140,11 @@ internal class ChannelFlowOperatorImpl<T>(
private fun <T> FlowCollector<T>.withUndispatchedContextCollector(emitContext: CoroutineContext): FlowCollector<T> = when (this) {
// SendingCollector & NopCollector do not care about the context at all and can be used as is
is SendingCollector, is NopCollector -> this
// Original collector is concurrent, so wrap into ConcurrentUndispatchedContextCollector (also concurrent)
is ConcurrentFlowCollector -> ConcurrentUndispatchedContextCollector(this, emitContext)
// Otherwise just wrap into UndispatchedContextCollector interface implementation
else -> UndispatchedContextCollector(this, emitContext)
}

private open class UndispatchedContextCollector<T>(
private class UndispatchedContextCollector<T>(
downstream: FlowCollector<T>,
private val emitContext: CoroutineContext
) : FlowCollector<T> {
Expand All @@ -157,12 +155,6 @@ private open class UndispatchedContextCollector<T>(
withContextUndispatched(emitContext, countOrElement, emitRef, value)
}

// named class for a combination of UndispatchedContextCollector & ConcurrentFlowCollector interface
private class ConcurrentUndispatchedContextCollector<T>(
downstream: ConcurrentFlowCollector<T>,
emitContext: CoroutineContext
) : UndispatchedContextCollector<T>(downstream, emitContext), ConcurrentFlowCollector<T>

// Efficiently computes block(value) in the newContext
private suspend fun <T, V> withContextUndispatched(
newContext: CoroutineContext,
Expand Down
81 changes: 0 additions & 81 deletions kotlinx-coroutines-core/common/src/flow/internal/Concurrent.kt

This file was deleted.

22 changes: 22 additions & 0 deletions kotlinx-coroutines-core/common/src/flow/internal/FlowCoroutine.kt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,18 @@ internal fun <R> scopedFlow(@BuilderInference block: suspend CoroutineScope.(Flo
flowScope { block(collector) }
}

internal fun <T> CoroutineScope.flowProduce(
context: CoroutineContext,
capacity: Int = 0,
@BuilderInference block: suspend ProducerScope<T>.() -> Unit
): ReceiveChannel<T> {
val channel = Channel<T>(capacity)
val newContext = newCoroutineContext(context)
val coroutine = FlowProduceCoroutine(newContext, channel)
coroutine.start(CoroutineStart.DEFAULT, coroutine, block)
return coroutine
}

private class FlowCoroutine<T>(
context: CoroutineContext,
uCont: Continuation<T>
Expand All @@ -61,3 +73,13 @@ private class FlowCoroutine<T>(
return cancelImpl(cause)
}
}

private class FlowProduceCoroutine<T>(
parentContext: CoroutineContext,
channel: Channel<T>
) : ProducerCoroutine<T>(parentContext, channel) {
public override fun childCancelled(cause: Throwable): Boolean {
if (cause is ChildCancelledException) return true
return cancelImpl(cause)
}
}
27 changes: 9 additions & 18 deletions kotlinx-coroutines-core/common/src/flow/internal/Merge.kt
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,21 @@ internal class ChannelFlowTransformLatest<T, R>(
}

internal class ChannelFlowMerge<T>(
flow: Flow<Flow<T>>,
private val flow: Flow<Flow<T>>,
private val concurrency: Int,
context: CoroutineContext = EmptyCoroutineContext,
capacity: Int = Channel.OPTIONAL_CHANNEL
) : ChannelFlowOperator<Flow<T>, T>(flow, context, capacity) {
capacity: Int = Channel.BUFFERED
) : ChannelFlow<T>(context, capacity) {
override fun create(context: CoroutineContext, capacity: Int): ChannelFlow<T> =
ChannelFlowMerge(flow, concurrency, context, capacity)

// The actual merge implementation with concurrency limit
private suspend fun mergeImpl(scope: CoroutineScope, collector: ConcurrentFlowCollector<T>) {
override fun produceImpl(scope: CoroutineScope): ReceiveChannel<T> {
return scope.flowProduce(context, capacity, block = collectToFun)
}

override suspend fun collectTo(scope: ProducerScope<T>) {
val semaphore = Semaphore(concurrency)
val collector = SendingCollector(scope)
val job: Job? = coroutineContext[Job]
flow.collect { inner ->
/*
Expand All @@ -68,19 +72,6 @@ internal class ChannelFlowMerge<T>(
}
}

// Fast path in ChannelFlowOperator calls this function (channel was not created yet)
override suspend fun flowCollect(collector: FlowCollector<T>) {
// this function should not have been invoked when channel was explicitly requested
assert { capacity == Channel.OPTIONAL_CHANNEL }
flowScope {
mergeImpl(this, collector.asConcurrentFlowCollector())
}
}

// Slow path when output channel is required (and was created)
override suspend fun collectTo(scope: ProducerScope<T>) =
mergeImpl(scope, SendingCollector(scope))

override fun additionalToStringProps(): String =
"concurrency=$concurrency, "
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@

package kotlinx.coroutines.flow.internal

internal object NopCollector : ConcurrentFlowCollector<Any?> {
import kotlinx.coroutines.flow.*

internal object NopCollector : FlowCollector<Any?> {
override suspend fun emit(value: Any?) {
// does nothing
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/*
* Copyright 2016-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.coroutines.flow.internal

import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*

/**
* Collection that sends to channel
* @suppress **This an internal API and should not be used from general code.**
*/
@InternalCoroutinesApi
public class SendingCollector<T>(
private val channel: SendChannel<T>
) : FlowCollector<T> {
override suspend fun emit(value: T) = channel.send(value)
}
Loading