Skip to content

Commit dff3d62

Browse files
committed
Implement SendChannel#invokeOnClose
Fixes #341
1 parent 0a4f03f commit dff3d62

File tree

21 files changed

+233
-17
lines changed

21 files changed

+233
-17
lines changed

binary-compatibility-validator/reference-public-api/kotlinx-coroutines-core.txt

+3
Original file line numberDiff line numberDiff line change
@@ -500,6 +500,7 @@ public abstract class kotlinx/coroutines/experimental/channels/AbstractSendChann
500500
protected final fun getClosedForSend ()Lkotlinx/coroutines/experimental/channels/Closed;
501501
public final fun getOnSend ()Lkotlinx/coroutines/experimental/selects/SelectClause2;
502502
protected final fun getQueue ()Lkotlinx/coroutines/experimental/internal/LockFreeLinkedListHead;
503+
public fun invokeOnClose (Lkotlin/jvm/functions/Function1;)V
503504
protected abstract fun isBufferAlwaysFull ()Z
504505
protected abstract fun isBufferFull ()Z
505506
public final fun isClosedForSend ()Z
@@ -730,6 +731,7 @@ public final class kotlinx/coroutines/experimental/channels/ConflatedBroadcastCh
730731
public fun getOnSend ()Lkotlinx/coroutines/experimental/selects/SelectClause2;
731732
public final fun getValue ()Ljava/lang/Object;
732733
public final fun getValueOrNull ()Ljava/lang/Object;
734+
public fun invokeOnClose (Lkotlin/jvm/functions/Function1;)V
733735
public fun isClosedForSend ()Z
734736
public fun isFull ()Z
735737
public fun offer (Ljava/lang/Object;)Z
@@ -828,6 +830,7 @@ public abstract interface class kotlinx/coroutines/experimental/channels/Send {
828830
public abstract interface class kotlinx/coroutines/experimental/channels/SendChannel {
829831
public abstract fun close (Ljava/lang/Throwable;)Z
830832
public abstract fun getOnSend ()Lkotlinx/coroutines/experimental/selects/SelectClause2;
833+
public abstract fun invokeOnClose (Lkotlin/jvm/functions/Function1;)V
831834
public abstract fun isClosedForSend ()Z
832835
public abstract fun isFull ()Z
833836
public abstract fun offer (Ljava/lang/Object;)Z

common/kotlinx-coroutines-core-common/src/channels/AbstractChannel.kt

+36-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
package kotlinx.coroutines.experimental.channels
66

7+
import kotlinx.atomicfu.*
78
import kotlinx.coroutines.experimental.*
89
import kotlinx.coroutines.experimental.internal.*
910
import kotlinx.coroutines.experimental.internalAnnotations.*
@@ -32,6 +33,9 @@ public abstract class AbstractSendChannel<E> : SendChannel<E> {
3233
*/
3334
protected abstract val isBufferFull: Boolean
3435

36+
// State transitions: null -> handler -> HANDLER_INVOKED
37+
private val onCloseHandler = atomic<Any?>(null)
38+
3539
// ------ internal functions for override by buffered channels ------
3640

3741
/**
@@ -247,11 +251,40 @@ public abstract class AbstractSendChannel<E> : SendChannel<E> {
247251
}
248252

249253
helpClose(closed)
254+
invokeOnCloseHandler(cause)
255+
// TODO We can get rid of afterClose
250256
onClosed(closed)
251257
afterClose(cause)
252258
return true
253259
}
254260

261+
private fun invokeOnCloseHandler(cause: Throwable?) {
262+
val handler = onCloseHandler.value
263+
if (handler !== null && handler !== HANDLER_INVOKED
264+
&& onCloseHandler.compareAndSet(handler, HANDLER_INVOKED)) {
265+
// CAS failed -> concurrent invokeOnClose() invoked handler
266+
(handler as Handler)(cause)
267+
}
268+
}
269+
270+
override fun invokeOnClose(handler: Handler) {
271+
// Intricate dance for concurrent invokeOnClose and close calls
272+
if (!onCloseHandler.compareAndSet(null, handler)) {
273+
val value = onCloseHandler.value
274+
if (value === HANDLER_INVOKED) {
275+
throw IllegalStateException("Another handler was already registered and successfully invoked")
276+
}
277+
278+
throw IllegalStateException("Another handler was already registered: $value")
279+
} else {
280+
val closedToken = closedForSend
281+
if (closedToken != null && onCloseHandler.compareAndSet(handler, HANDLER_INVOKED)) {
282+
// CAS failed -> close() call invoked handler
283+
(handler)(closedToken.closeCause)
284+
}
285+
}
286+
}
287+
255288
private fun helpClose(closed: Closed<*>) {
256289
/*
257290
* It's important to traverse list from right to left to avoid races with sender.
@@ -983,6 +1016,9 @@ public abstract class AbstractChannel<E> : AbstractSendChannel<E>(), Channel<E>
9831016
/** @suppress **This is unstable API and it is subject to change.** */
9841017
@JvmField internal val SEND_RESUMED = Symbol("SEND_RESUMED")
9851018

1019+
internal typealias Handler = (Throwable?) -> Unit
1020+
@JvmField internal val HANDLER_INVOKED = Any()
1021+
9861022
/**
9871023
* Represents sending waiter in the queue.
9881024
* @suppress **This is unstable API and it is subject to change.**
@@ -1043,4 +1079,3 @@ private abstract class Receive<in E> : LockFreeLinkedListNode(), ReceiveOrClosed
10431079
override val offerResult get() = OFFER_SUCCESS
10441080
abstract fun resumeReceiveClosed(closed: Closed<*>)
10451081
}
1046-

common/kotlinx-coroutines-core-common/src/channels/Channel.kt

+34
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,40 @@ public interface SendChannel<in E> {
8282
* receive on a failed channel throw the specified [cause] exception.
8383
*/
8484
public fun close(cause: Throwable? = null): Boolean
85+
86+
/**
87+
* Registers handler which is synchronously invoked once the channel is [closed][close]
88+
* or receiving side of this channel is [cancelled][ReceiveChannel.cancel].
89+
* Only one handler can be attached to the channel during channel's lifetime.
90+
* Handler is invoked when [isClosedForSend] starts to return `true`.
91+
* If channel is already closed, handler is invoked immediately.
92+
*
93+
* The meaning of `cause` that is passed to the handler:
94+
* * `null` if channel was closed or cancelled without corresponding argument
95+
* * close or cancel cause otherwise.
96+
*
97+
* Example of usage (exception handling is omitted):
98+
* ```
99+
* val events = Channel(UNLIMITED)
100+
* callbackBasedApi.registerCallback { event ->
101+
* events.offer(event)
102+
* }
103+
*
104+
* val uiUpdater = launch(UI, parent = UILifecycle) {
105+
* events.consume {}
106+
* events.cancel()
107+
* }
108+
*
109+
* events.invokeOnClose { callbackBasedApi.stop() }
110+
*
111+
* ```
112+
*
113+
* @throws UnsupportedOperationException if underlying channel doesn't support [invokeOnClose].
114+
* Implementation note: currently, [invokeOnClose] is unsupported only by Rx-like integrations
115+
*
116+
* @throws IllegalStateException if another handler was already registered
117+
*/
118+
public fun invokeOnClose(handler: (Throwable?) -> Unit)
85119
}
86120

87121
/**

common/kotlinx-coroutines-core-common/src/channels/ConflatedBroadcastChannel.kt

+29-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ public class ConflatedBroadcastChannel<E>() : BroadcastChannel<E> {
3838

3939
private val _state = atomic<Any>(INITIAL_STATE) // State | Closed
4040
private val _updating = atomic(0)
41+
// State transitions: null -> handler -> HANDLER_INVOKED
42+
private val onCloseHandler = atomic<Any?>(null)
4143

4244
private companion object {
4345
@JvmField
@@ -163,6 +165,7 @@ public class ConflatedBroadcastChannel<E>() : BroadcastChannel<E> {
163165
val update = if (cause == null) CLOSED else Closed(cause)
164166
if (_state.compareAndSet(state, update)) {
165167
(state as State<E>).subscribers?.forEach { it.close(cause) }
168+
invokeOnCloseHandler(cause)
166169
return true
167170
}
168171
}
@@ -171,6 +174,31 @@ public class ConflatedBroadcastChannel<E>() : BroadcastChannel<E> {
171174
}
172175
}
173176

177+
private fun invokeOnCloseHandler(cause: Throwable?) {
178+
val handler = onCloseHandler.value
179+
if (handler !== null && handler !== HANDLER_INVOKED
180+
&& onCloseHandler.compareAndSet(handler, HANDLER_INVOKED)) {
181+
(handler as Handler)(cause)
182+
}
183+
}
184+
185+
override fun invokeOnClose(handler: Handler) {
186+
// Intricate dance for concurrent invokeOnClose and close
187+
if (!onCloseHandler.compareAndSet(null, handler)) {
188+
val value = onCloseHandler.value
189+
if (value === HANDLER_INVOKED) {
190+
throw IllegalStateException("Another handler was already registered and successfully invoked")
191+
} else {
192+
throw IllegalStateException("Another handler was already registered: $value")
193+
}
194+
} else {
195+
val state = _state.value
196+
if (state is Closed && onCloseHandler.compareAndSet(handler, HANDLER_INVOKED)) {
197+
(handler)(state.closeCause)
198+
}
199+
}
200+
}
201+
174202
/**
175203
* Closes this broadcast channel. Same as [close].
176204
*/
@@ -249,4 +277,4 @@ public class ConflatedBroadcastChannel<E>() : BroadcastChannel<E> {
249277

250278
public override fun offerInternal(element: E): Any = super.offerInternal(element)
251279
}
252-
}
280+
}

common/kotlinx-coroutines-core-common/test/TestBase.common.kt

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ public expect open class TestBase constructor() {
1212
public fun expect(index: Int)
1313
public fun expectUnreached()
1414
public fun finish(index: Int)
15+
public fun reset() // Resets counter and finish flag. Workaround for parametrized tests absence in common
1516

1617
public fun runTest(
1718
expected: ((Throwable) -> Boolean)? = null,

common/kotlinx-coroutines-core-common/test/channels/BasicOperationsTest.kt

+39
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,45 @@ class BasicOperationsTest : TestBase() {
3131
TestChannelKind.values().forEach { kind -> testReceiveOrNullException(kind) }
3232
}
3333

34+
@Test
35+
fun testInvokeOnClose() = TestChannelKind.values().forEach { kind ->
36+
reset()
37+
val channel = kind.create()
38+
channel.invokeOnClose {
39+
if (it is AssertionError) {
40+
expect(3)
41+
}
42+
}
43+
expect(1)
44+
channel.offer(42)
45+
expect(2)
46+
channel.close(AssertionError())
47+
finish(4)
48+
}
49+
50+
@Test
51+
fun testInvokeOnClosed() = TestChannelKind.values().forEach { kind ->
52+
reset()
53+
expect(1)
54+
val channel = kind.create()
55+
channel.close()
56+
channel.invokeOnClose { expect(2) }
57+
assertFailsWith<IllegalStateException> { channel.invokeOnClose { expect(3) } }
58+
finish(3)
59+
}
60+
61+
@Test
62+
fun testMultipleInvokeOnClose() = TestChannelKind.values().forEach { kind ->
63+
reset()
64+
val channel = kind.create()
65+
channel.invokeOnClose { expect(3) }
66+
expect(1)
67+
assertFailsWith<IllegalStateException> { channel.invokeOnClose { expect(4) } }
68+
expect(2)
69+
channel.close()
70+
finish(4)
71+
}
72+
3473
private suspend fun testReceiveOrNull(kind: TestChannelKind) {
3574
val channel = kind.create()
3675
val d = async(coroutineContext) {

core/kotlinx-coroutines-core/test/TestBase.kt

+8-5
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,8 @@
44

55
package kotlinx.coroutines.experimental
66

7-
import org.junit.After
8-
import org.junit.Before
9-
import java.util.concurrent.atomic.AtomicBoolean
10-
import java.util.concurrent.atomic.AtomicInteger
11-
import java.util.concurrent.atomic.AtomicReference
7+
import org.junit.*
8+
import java.util.concurrent.atomic.*
129

1310
/**
1411
* Base class for tests, so that tests for predictable scheduling of actions in multiple coroutines sharing a single
@@ -88,6 +85,12 @@ public actual open class TestBase actual constructor() {
8885
check(!finished.getAndSet(true)) { "Should call 'finish(...)' at most once" }
8986
}
9087

88+
public actual fun reset() {
89+
check(actionIndex.get() == 0 || finished.get()) { "Expecting that 'finish(...)' was invoked, but it was not" }
90+
actionIndex.set(0)
91+
finished.set(false)
92+
}
93+
9194
private lateinit var threadsBefore: Set<Thread>
9295
private val SHUTDOWN_TIMEOUT = 10_000L // 10s at most to wait
9396

core/kotlinx-coroutines-core/test/channels/ChannelIsClosedLinearizabilityTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class ChannelIsClosedLinearizabilityTest : TestBase() {
1919
private lateinit var channel: Channel<Int>
2020

2121
@Reset
22-
fun reset() {
22+
fun resetChannel() {
2323
channel = Channel()
2424
}
2525

core/kotlinx-coroutines-core/test/channels/ChannelLinearizabilityTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class ChannelLinearizabilityTest : TestBase() {
2323
private lateinit var channel: Channel<Int>
2424

2525
@Reset
26-
fun reset() {
26+
fun resetChannel() {
2727
channel = Channel(capacity)
2828
}
2929

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright 2016-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package channels
6+
7+
import kotlinx.coroutines.experimental.*
8+
import kotlinx.coroutines.experimental.channels.*
9+
import org.junit.*
10+
import org.junit.Test
11+
import java.util.concurrent.*
12+
import java.util.concurrent.atomic.*
13+
import kotlin.test.*
14+
15+
class InvokeOnCloseStressTest : TestBase() {
16+
17+
private val iterations = 1000 * stressTestMultiplier
18+
19+
private val pool = newFixedThreadPoolContext(3, "InvokeOnCloseStressTest")
20+
21+
@After
22+
fun tearDown() {
23+
pool.close()
24+
}
25+
26+
@Test
27+
fun testInvokedExactlyOnce() = runBlocking {
28+
runStressTest(TestChannelKind.ARRAY_1)
29+
}
30+
31+
@Test
32+
fun testInvokedExactlyOnceBroadcast() = runBlocking {
33+
runStressTest(TestChannelKind.CONFLATED_BROADCAST)
34+
}
35+
36+
private suspend fun runStressTest(kind: TestChannelKind) {
37+
repeat(iterations) {
38+
val counter = AtomicInteger(0)
39+
val channel = kind.create()
40+
41+
val latch = CountDownLatch(1)
42+
val j1 = async(pool) {
43+
latch.await()
44+
channel.close()
45+
}
46+
47+
val j2 = async(pool) {
48+
latch.await()
49+
channel.invokeOnClose { counter.incrementAndGet() }
50+
}
51+
52+
val j3 = async(pool) {
53+
latch.await()
54+
channel.invokeOnClose { counter.incrementAndGet() }
55+
}
56+
57+
latch.countDown()
58+
joinAll(j1, j2, j3)
59+
assertEquals(1, counter.get())
60+
}
61+
}
62+
}

core/kotlinx-coroutines-core/test/internal/LockFreeListLinearizabilityTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class LockFreeListLinearizabilityTest : TestBase() {
1818
lateinit var q: LockFreeLinkedListHead
1919

2020
@Reset
21-
fun reset() {
21+
fun resetList() {
2222
q = LockFreeLinkedListHead()
2323
}
2424

core/kotlinx-coroutines-core/test/internal/LockFreeMPSCQueueLinearizabilityTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class LockFreeMPSCQueueLinearizabilityTest : TestBase() {
1717
private lateinit var q: LockFreeMPSCQueue<Int>
1818

1919
@Reset
20-
fun reset() {
20+
fun resetQueue() {
2121
q = LockFreeMPSCQueue()
2222
}
2323

core/kotlinx-coroutines-io/test/BufferReleaseLinearizabilityTest.kt

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class BufferReleaseLinearizabilityTest : TestBase() {
2020
private val lr = LinTesting()
2121

2222
@Reset
23-
fun reset() {
23+
fun resetChannel() {
2424
ch = ByteChannel(false)
2525
}
2626

core/kotlinx-coroutines-io/test/ByteChannelJoinLinearizabilityTest.kt

+1-2
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ class ByteChannelJoinLinearizabilityTest : TestBase() {
2222
private val lr = LinTesting()
2323

2424
@Reset
25-
fun reset() {
26-
// println("============== reset ====================")
25+
fun resetChannel() {
2726
from = ByteChannel(true)
2827
to = ByteChannel(true)
2928
}

0 commit comments

Comments
 (0)