From 6eebf6beebd19e7d89015a5d656a7a3cd7aafeaf Mon Sep 17 00:00:00 2001 From: Oleg Yukhnevich Date: Fri, 4 Dec 2020 22:03:27 +0300 Subject: [PATCH] update coroutines and some job/channel related logic --- .../src/jsMain/kotlin/Server.kt | 2 +- gradle.properties | 4 +- playground/src/commonMain/kotlin/streams.kt | 3 +- .../kotlin/{Cancelable.kt => Cancellable.kt} | 12 +++-- .../kotlin/io/rsocket/kotlin/Connection.kt | 2 +- .../kotlin/io/rsocket/kotlin/RSocket.kt | 2 +- .../rsocket/kotlin/RSocketRequestHandler.kt | 4 +- .../kotlin/core/RSocketConnectorBuilder.kt | 6 +-- .../io/rsocket/kotlin/core/RSocketServer.kt | 2 +- .../kotlin/core/ReconnectableRSocket.kt | 45 +++++++------------ .../kotlin/internal/CloseOperations.kt | 19 ++++++++ .../io/rsocket/kotlin/internal/Prioritizer.kt | 15 ++++--- .../kotlin/internal/RSocketRequester.kt | 2 +- .../kotlin/internal/RSocketResponder.kt | 2 +- .../rsocket/kotlin/internal/RSocketState.kt | 35 ++++++++++----- .../io/rsocket/kotlin/SetupRejectionTest.kt | 3 ++ .../rsocket/kotlin/keepalive/KeepAliveTest.kt | 2 +- .../io/rsocket/kotlin/test/TestConnection.kt | 2 +- .../io/rsocket/kotlin/test/TestRSocket.kt | 2 +- .../kotlin/transport/ktor/TcpConnection.kt | 19 ++++++-- .../transport/ktor/WebSocketConnection.kt | 2 +- .../rsocket/kotlin/NativeTcpTransportTest.kt | 2 +- .../kotlin/transport/local/LocalConnection.kt | 24 ++-------- .../kotlin/transport/local/LocalServer.kt | 30 ++++++++++--- 24 files changed, 141 insertions(+), 100 deletions(-) rename rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/{Cancelable.kt => Cancellable.kt} (63%) diff --git a/examples/nodejs-tcp-transport/src/jsMain/kotlin/Server.kt b/examples/nodejs-tcp-transport/src/jsMain/kotlin/Server.kt index 03b47c09a..42bd89265 100644 --- a/examples/nodejs-tcp-transport/src/jsMain/kotlin/Server.kt +++ b/examples/nodejs-tcp-transport/src/jsMain/kotlin/Server.kt @@ -57,7 +57,7 @@ fun NodeJsTcpServerTransport(port: Int, onStart: () -> Unit = {}): ServerTranspo // nodejs TCP transport connection - may not work in all cases, not tested properly @OptIn(ExperimentalCoroutinesApi::class, TransportApi::class) class NodeJsTcpConnection(private val socket: Socket) : Connection { - override val job: Job = Job() + override val job: CompletableJob = Job() private val sendChannel = Channel(8) private val receiveChannel = Channel(8) diff --git a/gradle.properties b/gradle.properties index 3b0862ce0..0e7e7993d 100644 --- a/gradle.properties +++ b/gradle.properties @@ -19,9 +19,9 @@ group=io.rsocket.kotlin version=0.12.0 #Versions -kotlinVersion=1.4.20 +kotlinVersion=1.4.21 ktorVersion=1.4.3 -kotlinxCoroutinesVersion=1.3.9-native-mt-2 +kotlinxCoroutinesVersion=1.4.2-native-mt kotlinxAtomicfuVersion=0.14.4 kotlinxSerializationVersion=1.0.1 kotlinxBenchmarkVersion=0.2.0-dev-20 diff --git a/playground/src/commonMain/kotlin/streams.kt b/playground/src/commonMain/kotlin/streams.kt index 0a5341ee5..8e0a93a99 100644 --- a/playground/src/commonMain/kotlin/streams.kt +++ b/playground/src/commonMain/kotlin/streams.kt @@ -18,12 +18,11 @@ import io.rsocket.kotlin.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* -import kotlin.coroutines.* @ExperimentalStreamsApi private suspend fun s() { val flow = flow { - val strategy = coroutineContext[RequestStrategy]!!.provide() + val strategy = currentCoroutineContext()[RequestStrategy]!!.provide() var i = strategy.firstRequest() println("INIT: $i") var r = 0 diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Cancelable.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Cancellable.kt similarity index 63% rename from rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Cancelable.kt rename to rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Cancellable.kt index 45529d6a0..9199a1805 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Cancelable.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Cancellable.kt @@ -18,12 +18,10 @@ package io.rsocket.kotlin import kotlinx.coroutines.* -interface Cancelable { - val job: Job +interface Cancellable { + val job: CompletableJob } -val Cancelable.isActive: Boolean get() = job.isActive -fun Cancelable.cancel(cause: CancellationException? = null): Unit = job.cancel(cause) -fun Cancelable.cancel(message: String, cause: Throwable? = null): Unit = job.cancel(message, cause) -suspend fun Cancelable.join(): Unit = job.join() -suspend fun Cancelable.cancelAndJoin(): Unit = job.cancelAndJoin() +val Cancellable.isActive: Boolean get() = job.isActive +suspend fun Cancellable.join(): Unit = job.join() +suspend fun Cancellable.cancelAndJoin(): Unit = job.cancelAndJoin() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt index 4871b8d8c..b62b852af 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt @@ -25,7 +25,7 @@ import io.rsocket.kotlin.frame.* * That interface isn't stable for inheritance. */ @TransportApi -interface Connection : Cancelable { +interface Connection : Cancellable { @DangerousInternalIoApi val pool: ObjectPool diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt index d85514781..d267cbf50 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt @@ -20,7 +20,7 @@ import io.ktor.utils.io.core.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.flow.* -interface RSocket : Cancelable { +interface RSocket : Cancellable { suspend fun metadataPush(metadata: ByteReadPacket) { metadata.release() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt index 43d195234..297fe561f 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt @@ -53,7 +53,7 @@ class RSocketRequestHandlerBuilder internal constructor() { requestChannel = block } - internal fun build(job: Job): RSocket = + internal fun build(job: CompletableJob): RSocket = RSocketRequestHandler(job, metadataPush, fireAndForget, requestResponse, requestStream, requestChannel) } @@ -65,7 +65,7 @@ fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandl } private class RSocketRequestHandler( - override val job: Job, + override val job: CompletableJob, private val metadataPush: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)? = null, private val fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null, private val requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null, diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt index fe86f4fd7..33d9345b9 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt @@ -100,10 +100,10 @@ public class RSocketConnectorBuilder internal constructor() { ) private companion object { - private val defaultAcceptor: ConnectionAcceptor = ConnectionAcceptor { EmptyRSocket } + private val defaultAcceptor: ConnectionAcceptor = ConnectionAcceptor { EmptyRSocket() } - private object EmptyRSocket : RSocket { - override val job: Job = NonCancellable + private class EmptyRSocket : RSocket { + override val job: CompletableJob = Job() } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt index 0b38a0d9d..f68d720ca 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt @@ -68,7 +68,7 @@ class RSocketServer internal constructor( private suspend fun Connection.failSetup(error: RSocketError.Setup): Nothing { sendFrame(ErrorFrame(0, error)) - cancel("Setup failed", error) + job.completeExceptionally(error) throw error } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt index 24c36154b..5398186a8 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt @@ -27,40 +27,40 @@ import kotlinx.coroutines.flow.* internal typealias ReconnectPredicate = suspend (cause: Throwable, attempt: Long) -> Boolean @Suppress("FunctionName") -@OptIn(ExperimentalCoroutinesApi::class, FlowPreview::class) +@OptIn(FlowPreview::class) internal suspend fun ReconnectableRSocket( logger: Logger, connect: suspend () -> RSocket, predicate: ReconnectPredicate, ): RSocket { - val state = MutableStateFlow(ReconnectState.Connecting) - - val job = + val job = Job() + val state = connect.asFlow() .map { ReconnectState.Connected(it) } //if connection established - state = connected .onStart { emit(ReconnectState.Connecting) } //init - state = connecting - .retryWhen { cause, attempt -> + .retryWhen { cause, attempt -> //reconnection logic logger.debug(cause) { "Connection establishment failed, attempt: $attempt. Trying to reconnect..." } predicate(cause, attempt) - } //reconnection logic - .catch { + } + .catch { //reconnection failed - state = failed logger.debug(it) { "Reconnection failed" } emit(ReconnectState.Failed(it)) - } //reconnection failed - state = failed - .mapNotNull { - state.value = it //set state //TODO replace with Flow.stateIn when coroutines 1.4.0-native-mt will be released + } + .transform { value -> + emit(value) //emit before any action, to pass value directly to state - when (it) { + when (value) { is ReconnectState.Connected -> { logger.debug { "Connection established" } - it.rSocket.join() //await for connection completion + value.rSocket.join() //await for connection completion logger.debug { "Connection closed. Reconnecting..." } } - is ReconnectState.Failed -> throw it.error //reconnect failed, cancel job - ReconnectState.Connecting -> null //skip, still waiting for new connection + is ReconnectState.Failed -> job.completeExceptionally(value.error) //reconnect failed, fail job + ReconnectState.Connecting -> Unit //skip, still waiting for new connection } } - .launchRestarting() //reconnect if old connection completed/failed + .restarting() //reconnect if old connection completed + .stateIn(CoroutineScope(Dispatchers.Unconfined + job)) //await first connection to fail fast if something state.mapNotNull { @@ -74,17 +74,7 @@ internal suspend fun ReconnectableRSocket( return ReconnectableRSocket(job, state) } -private fun Flow<*>.launchRestarting(): Job = GlobalScope.launch(Dispatchers.Unconfined) { - while (isActive) { - try { - collect() - } catch (e: Throwable) { - // KLUDGE: K/N - cancel("Reconnection failed", e) - break - } - } -} +private fun Flow.restarting(): Flow = flow { while (true) emitAll(this@restarting) } private sealed class ReconnectState { object Connecting : ReconnectState() @@ -92,9 +82,8 @@ private sealed class ReconnectState { data class Connected(val rSocket: RSocket) : ReconnectState() } -@OptIn(ExperimentalCoroutinesApi::class, FlowPreview::class) private class ReconnectableRSocket( - override val job: Job, + override val job: CompletableJob, private val state: StateFlow, ) : RSocket { diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt index 791eaac5e..dc0a863ef 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/CloseOperations.kt @@ -19,6 +19,7 @@ package io.rsocket.kotlin.internal import io.ktor.utils.io.core.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlin.native.concurrent.* internal inline fun Closeable.closeOnError(block: () -> T): T { try { @@ -33,9 +34,27 @@ internal fun ReceiveChannel<*>.cancelConsumed(cause: Throwable?) { cancel(cause?.let { it as? CancellationException ?: CancellationException("Channel was consumed, consumer had failed", it) }) } +//TODO Can be removed after fix of https://github.com/Kotlin/kotlinx.coroutines/issues/2435 internal fun ReceiveChannel.closeReceivedElements() { try { while (true) poll()?.close() ?: break } catch (e: Throwable) { } } + +@SharedImmutable +private val onUndeliveredCloseable: (Closeable) -> Unit = Closeable::close + +@Suppress("FunctionName") +internal fun SafeChannel(capacity: Int): Channel = Channel(capacity, onUndeliveredElement = onUndeliveredCloseable) + +//TODO check after fix of https://github.com/Kotlin/kotlinx.coroutines/issues/2435 +// and https://github.com/Kotlin/kotlinx.coroutines/issues/974 +internal fun SendChannel.safeOffer(element: E) { + try { + if (!offer(element)) element.close() + } catch (cause: Throwable) { + element.close() + throw cause + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt index 6d5708c1a..771f9a63a 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt @@ -17,19 +17,20 @@ package io.rsocket.kotlin.internal import io.rsocket.kotlin.frame.* +import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.selects.* internal class Prioritizer { - private val priorityChannel = Channel(Channel.UNLIMITED) - private val commonChannel = Channel(Channel.UNLIMITED) + private val priorityChannel = SafeChannel(Channel.UNLIMITED) + private val commonChannel = SafeChannel(Channel.UNLIMITED) fun send(frame: Frame) { - commonChannel.offer(frame) + commonChannel.safeOffer(frame) } fun sendPrioritized(frame: Frame) { - priorityChannel.offer(frame) + priorityChannel.safeOffer(frame) } suspend fun receive(): Frame { @@ -41,10 +42,10 @@ internal class Prioritizer { } } - fun close(throwable: Throwable?) { + fun cancel(error: CancellationException) { priorityChannel.closeReceivedElements() commonChannel.closeReceivedElements() - priorityChannel.close(throwable) - commonChannel.close(throwable) + priorityChannel.cancel(error) + commonChannel.cancel(error) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt index 3df9046bc..be7381c65 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt @@ -27,7 +27,7 @@ import kotlinx.coroutines.flow.* internal class RSocketRequester( private val state: RSocketState, private val streamId: StreamId, -) : RSocket, Cancelable by state { +) : RSocket, Cancellable by state { override suspend fun metadataPush(metadata: ByteReadPacket): Unit = metadata.closeOnError { checkAvailable() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt index 78f2ea127..29b1dd32b 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt @@ -24,7 +24,7 @@ import kotlinx.coroutines.* internal class RSocketResponder( private val state: RSocketState, private val requestHandler: RSocket, -) : Cancelable by state { +) : Cancellable by state { fun handleMetadataPush(frame: MetadataPushFrame) { state.launch { diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt index ae84bdb20..2d6590a11 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt @@ -33,7 +33,7 @@ import kotlinx.coroutines.flow.* internal class RSocketState( private val connection: Connection, keepAlive: KeepAlive, -) : Cancelable by connection { +) : Cancellable by connection { private val prioritizer = Prioritizer() private val requestScope = CoroutineScope(SupervisorJob(job)) private val scope = CoroutineScope(job) @@ -53,7 +53,7 @@ internal class RSocketState( } fun createReceiverFor(streamId: Int, initFrame: RequestFrame? = null): ReceiveChannel { - val receiver = Channel(Channel.UNLIMITED) + val receiver = SafeChannel(Channel.UNLIMITED) initFrame?.let(receiver::offer) //used only in RequestChannel on responder side receivers[streamId] = receiver return receiver @@ -71,7 +71,7 @@ internal class RSocketState( if (cause != null) send(CancelFrame(streamId)) receivers.remove(streamId)?.apply { closeReceivedElements() - close(cause) + cancelConsumed(cause) } } } @@ -120,7 +120,7 @@ internal class RSocketState( when (val streamId = frame.streamId) { 0 -> when (frame) { is ErrorFrame -> { - cancel("Zero stream error", frame.throwable) + job.completeExceptionally(frame.throwable) frame.release() //TODO } is KeepAliveFrame -> keepAliveHandler.receive(frame) @@ -146,7 +146,7 @@ internal class RSocketState( frame.release() } is RequestFrame -> when (frame.type) { - FrameType.Payload -> receivers[streamId]?.offer(frame) + FrameType.Payload -> receivers[streamId]?.safeOffer(frame) ?: frame.release() FrameType.RequestFnF -> responder.handleFireAndForget(frame) FrameType.RequestResponse -> responder.handlerRequestResponse(frame) FrameType.RequestStream -> responder.handleRequestStream(frame) @@ -164,20 +164,35 @@ internal class RSocketState( fun start(requestHandler: RSocket) { val responder = RSocketResponder(this, requestHandler) keepAliveHandler.startIn(scope) - requestHandler.job.invokeOnCompletion { cancel("Request handled stopped", it) } + requestHandler.job.invokeOnCompletion { + when (it) { + null -> job.complete() + is CancellationException -> job.cancel(it) + else -> job.completeExceptionally(it) + } + } job.invokeOnCompletion { error -> - requestHandler.cancel("Connection closed", error) + when (error) { + null -> requestHandler.job.complete() + is CancellationException -> requestHandler.job.cancel(error) + else -> requestHandler.job.completeExceptionally(error) + } + val cancelError = error as? CancellationException ?: CancellationException("Connection closed", error) receivers.values().forEach { it.closeReceivedElements() - it.close((error as? CancellationException)?.cause ?: error) + it.cancel(cancelError) } + senders.values().forEach { it.cancel(cancelError) } receivers.clear() limits.clear() senders.clear() - prioritizer.close(error) + prioritizer.cancel(cancelError) } scope.launch { - while (connection.isActive) connection.sendFrame(prioritizer.receive()) + while (connection.isActive) { + val frame = prioritizer.receive() + frame.closeOnError { connection.sendFrame(frame) } + } } scope.launch { while (connection.isActive) { diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt index 815d2da07..d4dc37bb5 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/SetupRejectionTest.kt @@ -56,6 +56,9 @@ class SetupRejectionTest : SuspendTest, TestWithLeakCheck { } val sender = sendingRSocket.await() assertFalse(sender.isActive) + val error = expectError() + assertTrue(error is RSocketError.Setup.Rejected) + assertEquals(errorMessage, error.message) } } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt index 6c37bd581..2c77c5b7e 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt @@ -87,7 +87,7 @@ class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { @Test fun noKeepAliveSentAfterRSocketCanceled() = test { - requester().cancel() + requester().job.cancel() connection.test { expectNoEventsIn(500) } diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt index 50ee14bf9..0dce89279 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt @@ -30,7 +30,7 @@ import kotlin.time.* class TestConnection : Connection, CoroutineScope { override val pool: ObjectPool = InUseTrackingPool - override val job: Job = Job() + override val job: CompletableJob = Job() override val coroutineContext: CoroutineContext = job + Dispatchers.Unconfined private val sendChannel = Channel(Channel.UNLIMITED) diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt index 189755a97..37e0ba6a4 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt @@ -23,7 +23,7 @@ import kotlinx.coroutines.* import kotlinx.coroutines.flow.* class TestRSocket : RSocket { - override val job: Job = Job() + override val job: CompletableJob = Job() override suspend fun metadataPush(metadata: ByteReadPacket): Unit = metadata.release() diff --git a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpConnection.kt b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpConnection.kt index 4257ba716..f7ab86c0a 100644 --- a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpConnection.kt +++ b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpConnection.kt @@ -25,16 +25,18 @@ import io.ktor.utils.io.core.internal.* import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.io.* import kotlinx.coroutines.* +import kotlinx.coroutines.CancellationException import kotlinx.coroutines.channels.* import kotlin.coroutines.* +import kotlin.native.concurrent.* @OptIn(KtorExperimentalAPI::class, TransportApi::class, DangerousInternalIoApi::class) internal class TcpConnection(private val socket: Socket) : Connection, CoroutineScope { - override val job: Job = Job(socket.socketContext) + override val job: CompletableJob = Job(socket.socketContext) override val coroutineContext: CoroutineContext = job + Dispatchers.Unconfined - private val sendChannel = Channel(8) - private val receiveChannel = Channel(8) + private val sendChannel = SafeChannel(8) + private val receiveChannel = SafeChannel(8) init { launch { @@ -58,9 +60,20 @@ internal class TcpConnection(private val socket: Socket) : Connection, Coroutine } } } + job.invokeOnCompletion { cause -> + val error = cause?.let { it as? CancellationException ?: CancellationException("Connection failed", it) } + sendChannel.cancel(error) + receiveChannel.cancel(error) + } } override suspend fun send(packet: ByteReadPacket): Unit = sendChannel.send(packet) override suspend fun receive(): ByteReadPacket = receiveChannel.receive() } + +@SharedImmutable +private val onUndeliveredCloseable: (Closeable) -> Unit = Closeable::close + +@Suppress("FunctionName") +private fun SafeChannel(capacity: Int): Channel = Channel(capacity, onUndeliveredElement = onUndeliveredCloseable) diff --git a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnection.kt b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnection.kt index b1b03ac1f..2ed4ec258 100644 --- a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnection.kt +++ b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnection.kt @@ -24,7 +24,7 @@ import kotlinx.coroutines.* @TransportApi public class WebSocketConnection(private val session: WebSocketSession) : Connection { - override val job: Job get() = session.coroutineContext[Job]!! + override val job: CompletableJob = Job(session.coroutineContext[Job]) override suspend fun send(packet: ByteReadPacket) { session.send(Frame.Binary(true, packet)) diff --git a/rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/NativeTcpTransportTest.kt b/rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/NativeTcpTransportTest.kt index f7faf4a25..64fcfe6d9 100644 --- a/rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/NativeTcpTransportTest.kt +++ b/rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/NativeTcpTransportTest.kt @@ -42,7 +42,7 @@ class NativeTcpTransportTest : TransportTest() { override suspend fun after() { serverJob.cancel() - client.cancel() + client.job.cancel() server.close() serverJob.join() client.join() diff --git a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalConnection.kt b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalConnection.kt index 712bbdf6e..a9310f850 100644 --- a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalConnection.kt +++ b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalConnection.kt @@ -25,21 +25,12 @@ import kotlinx.coroutines.channels.* @OptIn(DangerousInternalIoApi::class, TransportApi::class) internal class LocalConnection( - private val sender: Channel, - private val receiver: Channel, + private val sender: SendChannel, + private val receiver: ReceiveChannel, override val pool: ObjectPool, parentJob: Job? = null, -) : Connection, Cancelable { - override val job: Job = Job(parentJob) - - init { - job.invokeOnCompletion { - sender.closeReceivedElements() - receiver.closeReceivedElements() - sender.close(it) - receiver.close(it) - } - } +) : Connection, Cancellable { + override val job: CompletableJob = Job(parentJob) override suspend fun send(packet: ByteReadPacket) { sender.send(packet) @@ -49,10 +40,3 @@ internal class LocalConnection( return receiver.receive() } } - -private fun ReceiveChannel.closeReceivedElements() { - try { - while (true) poll()?.close() ?: break - } catch (e: Throwable) { - } -} diff --git a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt index 310ed0033..01bb69aff 100644 --- a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt +++ b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt @@ -25,6 +25,7 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlin.native.concurrent.* @Suppress("FunctionName") @OptIn(DangerousInternalIoApi::class) @@ -35,22 +36,41 @@ public class LocalServer internal constructor( parentJob: Job?, private val pool: ObjectPool, -) : Cancelable, ServerTransport, ClientTransport { +) : Cancellable, ServerTransport, ClientTransport { private val connections = Channel() - override val job: Job = SupervisorJob(parentJob) + override val job: CompletableJob = SupervisorJob(parentJob) override suspend fun connect(): Connection { - val clientChannel = Channel(Channel.UNLIMITED) - val serverChannel = Channel(Channel.UNLIMITED) + val clientChannel = SafeChannel(Channel.UNLIMITED) + val serverChannel = SafeChannel(Channel.UNLIMITED) val connectionJob = Job(job) + connectionJob.invokeOnCompletion { cause -> + val error = cause?.let { it as? CancellationException ?: CancellationException("Connection failed", it) } + clientChannel.closeReceivedElements() + serverChannel.closeReceivedElements() + clientChannel.cancel(error) + serverChannel.cancel(error) + } val clientConnection = LocalConnection(serverChannel, clientChannel, pool, connectionJob) val serverConnection = LocalConnection(clientChannel, serverChannel, pool, connectionJob) connections.send(serverConnection) return clientConnection } - @OptIn(ExperimentalCoroutinesApi::class) override fun start(accept: suspend (Connection) -> Unit): Job = GlobalScope.launch(job) { connections.consumeEach { launch(job) { accept(it) } } } } + +@SharedImmutable +private val onUndeliveredCloseable: (Closeable) -> Unit = Closeable::close + +@Suppress("FunctionName") +private fun SafeChannel(capacity: Int): Channel = Channel(capacity, onUndeliveredElement = onUndeliveredCloseable) + +private fun ReceiveChannel.closeReceivedElements() { + try { + while (true) poll()?.close() ?: break + } catch (e: Throwable) { + } +}