Skip to content

Commit cc55f60

Browse files
committed
~linearizable combine
1 parent d196082 commit cc55f60

File tree

3 files changed

+74
-24
lines changed

3 files changed

+74
-24
lines changed

kotlinx-coroutines-core/common/src/flow/internal/Combine.kt

+58-18
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@ import kotlinx.coroutines.flow.*
1111
import kotlinx.coroutines.internal.*
1212
import kotlin.coroutines.*
1313
import kotlin.coroutines.intrinsics.*
14+
import kotlin.jvm.*
15+
import kotlin.math.*
1416

15-
internal fun getNull(): Symbol = NULL // Workaround for JS BE bug
17+
private class Update(@JvmField val index: Int, @JvmField val value: Any?)
1618

1719
@PublishedApi
1820
internal suspend fun <R, T> FlowCollector<R>.combineInternal(
@@ -22,27 +24,17 @@ internal suspend fun <R, T> FlowCollector<R>.combineInternal(
2224
): Unit = flowScope { // flow scope so any cancellation within the source flow will cancel the whole scope
2325
val size = flows.size
2426
if (size == 0) return@flowScope // bail-out for empty input
25-
val latestValues = Array<Any?>(size) { getNull() }
27+
val latestValues = Array<Any?>(size) { UNINITIALIZED }
2628
val isClosed = Array(size) { false }
27-
val resultChannel = Channel<Array<T>>(Channel.CONFLATED)
29+
val resultChannel = Channel<Update>(flows.size)
2830
val nonClosed = LocalAtomicInt(size)
29-
val remainingAbsentValues = LocalAtomicInt(size)
31+
var remainingAbsentValues = size
3032
for (i in 0 until size) {
3133
// Coroutine per flow that keeps track of its value and sends result to downstream
3234
launch {
3335
try {
3436
flows[i].collect { value ->
35-
val previous = latestValues[i]
36-
latestValues[i] = value
37-
if (previous === getNull()) remainingAbsentValues.decrementAndGet()
38-
if (remainingAbsentValues.value == 0) {
39-
val results = arrayFactory()
40-
for (index in 0 until size) {
41-
results[index] = getNull().unbox(latestValues[index])
42-
}
43-
// NB: here actually "stale" array can overwrite a fresh one and break linearizability
44-
resultChannel.send(results as Array<T>)
45-
}
37+
resultChannel.send(Update(i, value))
4638
yield() // Emulate fairness for backward compatibility
4739
}
4840
} finally {
@@ -55,8 +47,56 @@ internal suspend fun <R, T> FlowCollector<R>.combineInternal(
5547
}
5648
}
5749

50+
// val lastReceivedEpoch = IntArray(size)
51+
// var currentEpoch = 0
52+
// while (!resultChannel.isClosedForReceive) {
53+
// ++currentEpoch
54+
// var shouldSuspend = true
55+
// // Start batch
56+
// var elementsReceived = 0
57+
// while (true) {
58+
// // The very first receive in epoch should be suspending
59+
// val element = if (shouldSuspend) {
60+
// shouldSuspend = false
61+
// resultChannel.receiveOrNull()
62+
// } else {
63+
// resultChannel.poll()
64+
// }
65+
// if (element === null) break // End batch processing, nothing to receive
66+
// ++elementsReceived
67+
// val index = element.index
68+
// // Update valued
69+
// val previous = latestValues[index]
70+
// latestValues[index] = element.value
71+
// if (previous === UNINITIALIZED) --remainingAbsentValues
72+
// // Check epoch
73+
// // Received the second value from the same flow in the same epoch -- bail out
74+
// if (lastReceivedEpoch[index] == currentEpoch) break
75+
// lastReceivedEpoch[index] = currentEpoch
76+
// }
77+
//
78+
// // Process batch result
79+
// if (remainingAbsentValues == 0 && elementsReceived != 0) {
80+
// val results = arrayFactory()
81+
// for (i in 0 until size) {
82+
// results[i] = latestValues[i] as T?
83+
// }
84+
// transform(results as Array<T>)
85+
// }
86+
// }
87+
5888
resultChannel.consumeEach {
59-
transform(it)
89+
val index = it.index
90+
val previous = latestValues[index]
91+
latestValues[index] = it.value
92+
if (previous === UNINITIALIZED) --remainingAbsentValues
93+
if (remainingAbsentValues == 0) {
94+
val results = arrayFactory()
95+
for (i in 0 until size) {
96+
results[i] = latestValues[i] as T?
97+
}
98+
transform(results as Array<T>)
99+
}
60100
}
61101
}
62102

@@ -101,7 +141,7 @@ internal fun <T1, T2, R> zipImpl(flow: Flow<T1>, flow2: Flow<T2>, transform: sus
101141
flow.collect { value ->
102142
withContextUndispatched(newContext, cnt) {
103143
val otherValue = second.receiveOrNull() ?: throw AbortFlowException(this@unsafeFlow)
104-
emit(transform(getNull().unbox(value), getNull().unbox(otherValue)))
144+
emit(transform(NULL.unbox(value), NULL.unbox(otherValue)))
105145
}
106146
}
107147
}
@@ -129,6 +169,6 @@ private suspend fun withContextUndispatched(
129169
// Channel has any type due to onReceiveOrNull. This will be fixed after receiveOrClosed
130170
private fun CoroutineScope.asChannel(flow: Flow<*>): ReceiveChannel<Any> = produce {
131171
flow.collect { value ->
132-
return@collect channel.send(value ?: getNull())
172+
return@collect channel.send(value ?: NULL)
133173
}
134174
}

kotlinx-coroutines-core/common/src/flow/internal/NullSurrogate.kt

+9
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,20 @@ import kotlin.native.concurrent.*
1111
/**
1212
* This value is used a a surrogate `null` value when needed.
1313
* It should never leak to the outside world.
14+
* Its usage typically are paired with [Symbol.unbox] usages.
1415
*/
1516
@JvmField
1617
@SharedImmutable
1718
internal val NULL = Symbol("NULL")
1819

20+
/**
21+
* Symbol to indicate that the value is not yet initialized.
22+
* It should never leak to the outside world.
23+
*/
24+
@JvmField
25+
@SharedImmutable
26+
internal val UNINITIALIZED = Symbol("UNINITIALIZED")
27+
1928
/*
2029
* Symbol used to indicate that the flow is complete.
2130
* It should never leak to the outside world.

kotlinx-coroutines-core/common/test/flow/operators/CombineTest.kt

+7-6
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
* Copyright 2016-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
33
*/
44

5-
package kotlinx.coroutines.flow
5+
package kotlinx.coroutines.flow.operators
66

77
import kotlinx.coroutines.*
8+
import kotlinx.coroutines.flow.*
89
import kotlin.test.*
910
import kotlinx.coroutines.flow.combine as combineOriginal
1011
import kotlinx.coroutines.flow.combineTransform as combineTransformOriginal
@@ -208,19 +209,19 @@ abstract class CombineTestBase : TestBase() {
208209
}
209210
val f2 = flow {
210211
emit(1)
211-
hang { expect(3) }
212+
expectUnreached()
212213
}
213214

214-
val flow = f1.combineLatest(f2, { _, _ -> 1 }).onEach { expectUnreached() }
215+
val flow = f1.combineLatest(f2, { _, _ -> 1 }).onEach { expect(2) }
215216
assertFailsWith<CancellationException>(flow)
216-
finish(2)
217+
finish(3)
217218
}
218219

219220
@Test
220221
fun testCancellationExceptionDownstream() = runTest {
221222
val f1 = flow {
222223
emit(1)
223-
expect(1)
224+
expect(2)
224225
hang { expect(5) }
225226
}
226227
val f2 = flow {
@@ -230,7 +231,7 @@ abstract class CombineTestBase : TestBase() {
230231
}
231232

232233
val flow = f1.combineLatest(f2, { _, _ -> 1 }).onEach {
233-
expect(2)
234+
expect(1)
234235
yield()
235236
expect(4)
236237
throw CancellationException("")

0 commit comments

Comments
 (0)