diff --git a/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt b/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt index 8e2949e7f..821cada52 100644 --- a/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt +++ b/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt @@ -25,12 +25,12 @@ import kotlinx.coroutines.* import kotlinx.coroutines.flow.* import kotlin.random.* -@OptIn(ExperimentalStreamsApi::class) +@OptIn(ExperimentalStreamsApi::class, DelicateCoroutinesApi::class) class RSocketKotlinBenchmark : RSocketBenchmark() { private val requestStrategy = PrefetchStrategy(64, 0) + private val benchJob = Job() lateinit var client: RSocket - lateinit var server: Job lateinit var payload: Payload lateinit var payloadsFlow: Flow @@ -40,9 +40,7 @@ class RSocketKotlinBenchmark : RSocketBenchmark() { override fun setup() { payload = createPayload(payloadSize) payloadsFlow = flow { repeat(5000) { emit(payloadCopy()) } } - - val localServer = LocalServer() - server = RSocketServer().bind(localServer) { + val server = RSocketServer().bindIn(CoroutineScope(benchJob + Dispatchers.Unconfined), LocalServerTransport()) { RSocketRequestHandler { requestResponse { it.release() @@ -59,14 +57,14 @@ class RSocketKotlinBenchmark : RSocketBenchmark() { } } client = runBlocking { - RSocketConnector().connect(localServer) + RSocketConnector().connect(server) } } override fun cleanup() { runBlocking { - client.job.runCatching { cancelAndJoin() } - server.runCatching { cancelAndJoin() } + client.coroutineContext.job.cancelAndJoin() + benchJob.cancelAndJoin() } } diff --git a/examples/interactions/src/jvmMain/kotlin/ReconnectExample.kt b/examples/interactions/src/jvmMain/kotlin/ReconnectExample.kt index 34175927a..920e9396e 100644 --- a/examples/interactions/src/jvmMain/kotlin/ReconnectExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/ReconnectExample.kt @@ -25,8 +25,7 @@ import kotlinx.coroutines.flow.* @TransportApi fun main(): Unit = runBlocking { - val server = LocalServer() - RSocketServer().bind(server) { + val server = RSocketServer().bindIn(this, LocalServerTransport()) { RSocketRequestHandler { requestStream { requestPayload -> val data = requestPayload.data.readText() diff --git a/examples/interactions/src/jvmMain/kotlin/ReconnectOnConnectFailExample.kt b/examples/interactions/src/jvmMain/kotlin/ReconnectOnConnectFailExample.kt index 4bedc12f3..05423acaa 100644 --- a/examples/interactions/src/jvmMain/kotlin/ReconnectOnConnectFailExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/ReconnectOnConnectFailExample.kt @@ -23,8 +23,7 @@ import kotlinx.coroutines.flow.* @TransportApi fun main(): Unit = runBlocking { - val server = LocalServer() - RSocketServer().bind(server) { + val server = RSocketServer().bindIn(this, LocalServerTransport()) { RSocketRequestHandler { requestStream { requestPayload -> val data = requestPayload.data.readText() diff --git a/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt b/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt index ee17b6700..eb18dbc3b 100644 --- a/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt @@ -22,8 +22,7 @@ import kotlinx.coroutines.flow.* fun main(): Unit = runBlocking { - val server = LocalServer() - RSocketServer().bind(server) { + val server = RSocketServer().bindIn(this, LocalServerTransport()) { RSocketRequestHandler { requestChannel { init, request -> println("Init with: ${init.data.readText()}") diff --git a/examples/interactions/src/jvmMain/kotlin/RequestResponseErrorExample.kt b/examples/interactions/src/jvmMain/kotlin/RequestResponseErrorExample.kt index 187412305..c3789fade 100644 --- a/examples/interactions/src/jvmMain/kotlin/RequestResponseErrorExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/RequestResponseErrorExample.kt @@ -20,8 +20,7 @@ import io.rsocket.kotlin.transport.local.* import kotlinx.coroutines.* fun main(): Unit = runBlocking { - val server = LocalServer() - RSocketServer().bind(server) { + val server = RSocketServer().bindIn(this, LocalServerTransport()) { RSocketRequestHandler { requestResponse { val data = it.data.readText() diff --git a/examples/interactions/src/jvmMain/kotlin/RequestResponseExample.kt b/examples/interactions/src/jvmMain/kotlin/RequestResponseExample.kt index 500d90910..27a194c43 100644 --- a/examples/interactions/src/jvmMain/kotlin/RequestResponseExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/RequestResponseExample.kt @@ -20,8 +20,7 @@ import io.rsocket.kotlin.transport.local.* import kotlinx.coroutines.* fun main(): Unit = runBlocking { - val server = LocalServer() - RSocketServer().bind(server) { + val server = RSocketServer().bindIn(this, LocalServerTransport()) { RSocketRequestHandler { requestResponse { val data = it.data.readText() diff --git a/examples/interactions/src/jvmMain/kotlin/RequestStreamExample.kt b/examples/interactions/src/jvmMain/kotlin/RequestStreamExample.kt index 9b0966214..8a36ee9a4 100644 --- a/examples/interactions/src/jvmMain/kotlin/RequestStreamExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/RequestStreamExample.kt @@ -21,8 +21,7 @@ import kotlinx.coroutines.* import kotlinx.coroutines.flow.* fun main(): Unit = runBlocking { - val server = LocalServer() - RSocketServer().bind(server) { + val server = RSocketServer().bindIn(this, LocalServerTransport()) { RSocketRequestHandler { requestStream { val data = it.data.readText() diff --git a/examples/interactions/src/jvmMain/kotlin/ServerRequestExample.kt b/examples/interactions/src/jvmMain/kotlin/ServerRequestExample.kt index 374151576..d085a2df9 100644 --- a/examples/interactions/src/jvmMain/kotlin/ServerRequestExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/ServerRequestExample.kt @@ -20,8 +20,7 @@ import io.rsocket.kotlin.transport.local.* import kotlinx.coroutines.* fun main(): Unit = runBlocking { - val server = LocalServer() - RSocketServer().bind(server) { + val server = RSocketServer().bindIn(this, LocalServerTransport()) { RSocketRequestHandler { requestResponse { val clientRequest = it.data.readText() diff --git a/examples/interactions/src/jvmMain/kotlin/ServerSetupExample.kt b/examples/interactions/src/jvmMain/kotlin/ServerSetupExample.kt index 7cd1d6ba1..73060b42e 100644 --- a/examples/interactions/src/jvmMain/kotlin/ServerSetupExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/ServerSetupExample.kt @@ -23,9 +23,7 @@ import kotlinx.coroutines.flow.* fun main(): Unit = runBlocking { - - val server = LocalServer() - RSocketServer().bind(server) { + val server = RSocketServer().bindIn(this, LocalServerTransport()) { val data = config.setupPayload.metadata?.readText() ?: error("Empty metadata") RSocketRequestHandler { when (data) { @@ -43,8 +41,8 @@ fun main(): Unit = runBlocking { suspend fun client1() { val rSocketClient = RSocketConnector().connect(server) - rSocketClient.job.join() - println("Client 1 canceled: ${rSocketClient.job.isCancelled}") + rSocketClient.coroutineContext.job.join() + println("Client 1 canceled: ${rSocketClient.coroutineContext.job.isCancelled}") try { rSocketClient.requestResponse(Payload.Empty) } catch (e: Throwable) { diff --git a/examples/multiplatform-chat/src/clientMain/kotlin/Api.kt b/examples/multiplatform-chat/src/clientMain/kotlin/Api.kt index 0f4b72f9a..b8a440e02 100644 --- a/examples/multiplatform-chat/src/clientMain/kotlin/Api.kt +++ b/examples/multiplatform-chat/src/clientMain/kotlin/Api.kt @@ -16,13 +16,12 @@ import io.ktor.client.* import io.ktor.client.features.websocket.* -import io.ktor.network.selector.* -import io.ktor.util.* import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.transport.ktor.* import io.rsocket.kotlin.transport.ktor.client.* +import kotlinx.coroutines.* class Api(rSocket: RSocket) { private val proto = ConfiguredProtoBuf @@ -42,9 +41,10 @@ suspend fun connectToApiUsingWS(name: String): Api { return Api(client.rSocket(port = 9000)) } -@OptIn(InternalAPI::class) suspend fun connectToApiUsingTCP(name: String): Api { - val transport = TcpClientTransport(SelectorManager(), "0.0.0.0", 8000) + val transport = TcpClientTransport("0.0.0.0", 8000, CoroutineExceptionHandler { coroutineContext, throwable -> + println("FAIL: $coroutineContext, $throwable") + }) return Api(connector(name).connect(transport)) } diff --git a/examples/multiplatform-chat/src/serverJvmMain/kotlin/App.kt b/examples/multiplatform-chat/src/serverJvmMain/kotlin/App.kt index bbf7fc031..8a44f4cbd 100644 --- a/examples/multiplatform-chat/src/serverJvmMain/kotlin/App.kt +++ b/examples/multiplatform-chat/src/serverJvmMain/kotlin/App.kt @@ -15,11 +15,9 @@ */ import io.ktor.application.* -import io.ktor.network.selector.* import io.ktor.routing.* import io.ktor.server.cio.* import io.ktor.server.engine.* -import io.ktor.util.* import io.ktor.websocket.* import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* @@ -31,7 +29,7 @@ import kotlinx.coroutines.* import kotlinx.coroutines.flow.* import kotlinx.serialization.* -@OptIn(ExperimentalSerializationApi::class, ExperimentalMetadataApi::class, InternalAPI::class) +@OptIn(ExperimentalSerializationApi::class, ExperimentalMetadataApi::class, DelicateCoroutinesApi::class) fun main() { val proto = ConfiguredProtoBuf val users = Users() @@ -97,7 +95,7 @@ fun main() { } //start TCP server - rSocketServer.bind(TcpServerTransport(ActorSelectorManager(Dispatchers.IO), port = 9000), acceptor) + rSocketServer.bind(TcpServerTransport(port = 8000), acceptor) //start WS server embeddedServer(CIO, port = 9000) { diff --git a/playground/src/commonMain/kotlin/TCP.kt b/playground/src/commonMain/kotlin/TCP.kt index 724266b32..2b6db3344 100644 --- a/playground/src/commonMain/kotlin/TCP.kt +++ b/playground/src/commonMain/kotlin/TCP.kt @@ -14,21 +14,18 @@ * limitations under the License. */ -import io.ktor.network.selector.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.transport.ktor.* -import kotlin.coroutines.* - -suspend fun runTcpClient(dispatcher: CoroutineContext) { - val transport = TcpClientTransport(SelectorManager(dispatcher), "0.0.0.0", 4444) +suspend fun runTcpClient() { + val transport = TcpClientTransport("0.0.0.0", 4444) RSocketConnector().connect(transport).doSomething() } //to test nodejs tcp server -suspend fun testNodeJsServer(dispatcher: CoroutineContext) { - val transport = TcpClientTransport(SelectorManager(dispatcher), "127.0.0.1", 9000) +suspend fun testNodeJsServer() { + val transport = TcpClientTransport("127.0.0.1", 9000) val client = RSocketConnector().connect(transport) val response = client.requestResponse(buildPayload { data("Hello from JVM") }) diff --git a/playground/src/jvmMain/kotlin/TcpClientApp.kt b/playground/src/jvmMain/kotlin/TcpClientApp.kt index 83f4b4f5f..f1189def1 100644 --- a/playground/src/jvmMain/kotlin/TcpClientApp.kt +++ b/playground/src/jvmMain/kotlin/TcpClientApp.kt @@ -14,8 +14,6 @@ * limitations under the License. */ -import kotlinx.coroutines.* +suspend fun main(): Unit = runTcpClient() -suspend fun main(): Unit = runTcpClient(Dispatchers.IO) - -//suspend fun main(): Unit = testNodeJsServer(Dispatchers.IO) +//suspend fun main(): Unit = testNodeJsServer() diff --git a/playground/src/jvmMain/kotlin/TcpServerApp.kt b/playground/src/jvmMain/kotlin/TcpServerApp.kt index 2e05bde42..ec7ad96a0 100644 --- a/playground/src/jvmMain/kotlin/TcpServerApp.kt +++ b/playground/src/jvmMain/kotlin/TcpServerApp.kt @@ -14,15 +14,14 @@ * limitations under the License. */ -import io.ktor.network.selector.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.transport.ktor.* import kotlinx.coroutines.* import kotlin.coroutines.* suspend fun runTcpServer(dispatcher: CoroutineContext) { - val transport = TcpServerTransport(SelectorManager(dispatcher), "0.0.0.0", 4444) - RSocketServer().bind(transport, rSocketAcceptor).join() + val transport = TcpServerTransport("0.0.0.0", 4444) + RSocketServer().bindIn(CoroutineScope(dispatcher), transport, rSocketAcceptor).handlerJob.join() } suspend fun main(): Unit = runTcpServer(Dispatchers.IO) diff --git a/playground/src/nativeMain/kotlin/TcpApp.kt b/playground/src/nativeMain/kotlin/TcpApp.kt index 33bb93d9d..184844af7 100644 --- a/playground/src/nativeMain/kotlin/TcpApp.kt +++ b/playground/src/nativeMain/kotlin/TcpApp.kt @@ -14,14 +14,11 @@ * limitations under the License. */ -import io.ktor.util.* import kotlinx.coroutines.* -import kotlin.coroutines.* -@OptIn(InternalAPI::class) fun main() { runBlocking { - runTcpClient(EmptyCoroutineContext) + runTcpClient() } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt index 9ae322d85..dd3ed05d4 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Connection.kt @@ -27,9 +27,7 @@ import kotlinx.coroutines.* * That interface isn't stable for inheritance. */ @TransportApi -public interface Connection { - public val job: Job - +public interface Connection : CoroutineScope { public val pool: ObjectPool get() = ChunkBuffer.Pool public suspend fun send(packet: ByteReadPacket) @@ -37,7 +35,7 @@ public interface Connection { } @OptIn(TransportApi::class) -internal suspend fun Connection.receiveFrame(): Frame = receive().readFrame(pool) +internal suspend inline fun Connection.receiveFrame(block: (frame: Frame) -> T): T = receive().readFrame(pool).closeOnError(block) @OptIn(TransportApi::class) internal suspend fun Connection.sendFrame(frame: Frame) { diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt index 81ac2854b..78873dc7e 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt @@ -21,8 +21,7 @@ import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* -public interface RSocket { - public val job: Job +public interface RSocket : CoroutineScope { public suspend fun metadataPush(metadata: ByteReadPacket) { metadata.release() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt index de79a1da4..01e8aaf09 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt @@ -20,6 +20,7 @@ import io.ktor.utils.io.core.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* +import kotlin.coroutines.* public class RSocketRequestHandlerBuilder internal constructor() { private var metadataPush: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)? = null @@ -53,19 +54,29 @@ public class RSocketRequestHandlerBuilder internal constructor() { requestChannel = block } - internal fun build(job: Job): RSocket = - RSocketRequestHandler(job, metadataPush, fireAndForget, requestResponse, requestStream, requestChannel) + internal fun build(parentContext: CoroutineContext): RSocket = + RSocketRequestHandler( + parentContext + Job(parentContext[Job]), + metadataPush, + fireAndForget, + requestResponse, + requestStream, + requestChannel + ) } @Suppress("FunctionName") -public fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandlerBuilder.() -> Unit): RSocket { +public fun RSocketRequestHandler( + parentContext: CoroutineContext = EmptyCoroutineContext, + configure: RSocketRequestHandlerBuilder.() -> Unit +): RSocket { val builder = RSocketRequestHandlerBuilder() builder.configure() - return builder.build(Job(parentJob)) + return builder.build(parentContext) } private class RSocketRequestHandler( - override val job: Job, + override val coroutineContext: CoroutineContext, private val metadataPush: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)? = null, private val fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null, private val requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null, diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt index 55e89179c..cdc62cf08 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt @@ -35,11 +35,13 @@ public class RSocketConnector internal constructor( ) { public suspend fun connect(transport: ClientTransport): RSocket = when (reconnectPredicate) { - null -> connectOnce(transport) - else -> ReconnectableRSocket( - logger = loggerFactory.logger("io.rsocket.kotlin.connection"), - connect = { connectOnce(transport) }, - predicate = reconnectPredicate + //TODO current coroutineContext job is overriden by transport coroutineContext jov + null -> withContext(transport.coroutineContext) { connectOnce(transport) } + else -> connectWithReconnect( + transport.coroutineContext, + loggerFactory.logger("io.rsocket.kotlin.connection"), + { connectOnce(transport) }, + reconnectPredicate, ) } @@ -48,7 +50,7 @@ public class RSocketConnector internal constructor( val connectionConfig = try { connectionConfigProvider() } catch (cause: Throwable) { - connection.job.cancel("Connection config provider failed", cause) + connection.cancel("Connection config provider failed", cause) throw cause } val setupFrame = SetupFrame( @@ -60,7 +62,8 @@ public class RSocketConnector internal constructor( payload = connectionConfig.setupPayload.copy() //copy needed, as it can be used in acceptor ) try { - val requester = connection.connect( + val requester = connect( + connection = connection, isServer = false, maxFragmentSize = maxFragmentSize, interceptors = interceptors, @@ -72,7 +75,7 @@ public class RSocketConnector internal constructor( } catch (cause: Throwable) { connectionConfig.setupPayload.release() setupFrame.release() - connection.job.cancel("Connection establishment failed", cause) + connection.cancel("Connection establishment failed", cause) throw cause } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt index 3eacc2f27..38a56c45e 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt @@ -22,6 +22,7 @@ import io.rsocket.kotlin.keepalive.* import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* +import kotlin.coroutines.* public class RSocketConnectorBuilder internal constructor() { @RSocketLoggingApi @@ -117,7 +118,7 @@ public class RSocketConnectorBuilder internal constructor() { } private class EmptyRSocket : RSocket { - override val job: Job = Job() + override val coroutineContext: CoroutineContext = Job() } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt index 1c7092c43..f7dc1638b 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt @@ -31,12 +31,23 @@ public class RSocketServer internal constructor( private val interceptors: Interceptors, ) { + @DelicateCoroutinesApi public fun bind( transport: ServerTransport, acceptor: ConnectionAcceptor, - ): T = transport.start { it.wrapConnection().bind(acceptor).join() } + ): T = bindIn(GlobalScope, transport, acceptor) - private suspend fun Connection.bind(acceptor: ConnectionAcceptor): Job = receiveFrame().closeOnError { setupFrame -> + public fun bindIn( + scope: CoroutineScope, + transport: ServerTransport, + acceptor: ConnectionAcceptor, + ): T = with(transport) { + scope.start { + it.wrapConnection().bind(acceptor).join() + } + } + + private suspend fun Connection.bind(acceptor: ConnectionAcceptor): Job = receiveFrame { setupFrame -> when { setupFrame !is SetupFrame -> failSetup(RSocketError.Setup.Invalid("Invalid setup frame: ${setupFrame.type}")) setupFrame.version != Version.Current -> failSetup(RSocketError.Setup.Unsupported("Unsupported version: ${setupFrame.version}")) @@ -44,6 +55,7 @@ public class RSocketServer internal constructor( setupFrame.resumeToken != null -> failSetup(RSocketError.Setup.Unsupported("Resume is not supported")) else -> try { connect( + connection = this, isServer = true, maxFragmentSize = maxFragmentSize, interceptors = interceptors, @@ -54,16 +66,17 @@ public class RSocketServer internal constructor( ), acceptor = acceptor ) - job + coroutineContext.job } catch (e: Throwable) { failSetup(RSocketError.Setup.Rejected(e.message ?: "Rejected by server acceptor")) } } } + @Suppress("SuspendFunctionOnCoroutineScope") private suspend fun Connection.failSetup(error: RSocketError.Setup): Nothing { sendFrame(ErrorFrame(0, error)) - job.cancel("Connection establishment failed", error) + cancel("Connection establishment failed", error) throw error } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt index ff1ec5b69..88bee3c4b 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt @@ -16,25 +16,27 @@ package io.rsocket.kotlin.internal -import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.frame.* import kotlinx.coroutines.* @OptIn(TransportApi::class) -internal suspend inline fun Connection.connect( +internal suspend inline fun connect( + connection: Connection, isServer: Boolean, maxFragmentSize: Int, interceptors: Interceptors, connectionConfig: ConnectionConfig, acceptor: ConnectionAcceptor ): RSocket { - val keepAliveHandler = KeepAliveHandler(connectionConfig.keepAlive) val prioritizer = Prioritizer() - val frameSender = FrameSender(prioritizer, pool, maxFragmentSize) - val streamsStorage = StreamsStorage(isServer, pool) - val requestJob = SupervisorJob(job) + val frameSender = FrameSender(prioritizer, connection.pool, maxFragmentSize) + val streamsStorage = StreamsStorage(isServer, connection.pool) + val keepAliveHandler = KeepAliveHandler(connectionConfig.keepAlive, frameSender) + + val requestJob = SupervisorJob(connection.coroutineContext[Job]) + val requestContext = connection.coroutineContext + requestJob requestJob.invokeOnCompletion { prioritizer.close(it) @@ -42,53 +44,42 @@ internal suspend inline fun Connection.connect( connectionConfig.setupPayload.release() } - val requestScope = CoroutineScope(requestJob + Dispatchers.Unconfined + CoroutineExceptionHandler { _, _ -> }) - val connectionScope = CoroutineScope(job + Dispatchers.Unconfined + CoroutineExceptionHandler { _, _ -> }) - - val requester = interceptors.wrapRequester(RSocketRequester(job, frameSender, streamsStorage, requestScope, pool)) + val requester = interceptors.wrapRequester( + RSocketRequester(requestContext + CoroutineName("rSocket-requester"), frameSender, streamsStorage, connection.pool) + ) val requestHandler = interceptors.wrapResponder( with(interceptors.wrapAcceptor(acceptor)) { ConnectionAcceptorContext(connectionConfig, requester).accept() } ) + val responder = RSocketResponder(requestContext + CoroutineName("rSocket-responder"), frameSender, requestHandler) // link completing of connection and requestHandler - job.invokeOnCompletion { requestHandler.job.cancel("Connection closed", it) } - requestHandler.job.invokeOnCompletion { if (it != null) job.cancel("Request handler failed", it) } + connection.coroutineContext[Job]?.invokeOnCompletion { requestHandler.cancel("Connection closed", it) } + requestHandler.coroutineContext[Job]?.invokeOnCompletion { if (it != null) connection.cancel("Request handler failed", it) } // start keepalive ticks - connectionScope.launch { - while (isActive) { - keepAliveHandler.tick() - prioritizer.send(KeepAliveFrame(true, 0, ByteReadPacket.Empty)) - } + (connection + CoroutineName("rSocket-connection-keep-alive")).launch { + while (isActive) keepAliveHandler.tick() } // start sending frames to connection - connectionScope.launch { - while (isActive) { - sendFrame(prioritizer.receive()) - } + (connection + CoroutineName("rSocket-connection-send")).launch { + while (isActive) connection.sendFrame(prioritizer.receive()) } // start frame handling - connectionScope.launch { - val rSocketResponder = RSocketResponder(frameSender, requestHandler, requestScope) - while (isActive) { - receiveFrame().closeOnError { frame -> - when (frame.streamId) { - 0 -> when (frame) { - is MetadataPushFrame -> rSocketResponder.handleMetadataPush(frame.metadata) - is ErrorFrame -> job.cancel("Error frame received on 0 stream", frame.throwable) - is KeepAliveFrame -> { - keepAliveHandler.mark() - if (frame.respond) prioritizer.send(KeepAliveFrame(false, 0, frame.data)) else Unit - } - is LeaseFrame -> frame.release().also { error("lease isn't implemented") } - else -> frame.release() - } - else -> streamsStorage.handleFrame(frame, rSocketResponder) + (connection + CoroutineName("rSocket-connection-receive")).launch { + while (isActive) connection.receiveFrame { frame -> + when (frame.streamId) { + 0 -> when (frame) { + is MetadataPushFrame -> responder.handleMetadataPush(frame.metadata) + is ErrorFrame -> connection.cancel("Error frame received on 0 stream", frame.throwable) + is KeepAliveFrame -> keepAliveHandler.mark(frame) + is LeaseFrame -> frame.release().also { error("lease isn't implemented") } + else -> frame.release() } + else -> streamsStorage.handleFrame(frame, responder) } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt index cb428f94c..fad9b95d5 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/KeepAliveHandler.kt @@ -16,23 +16,30 @@ package io.rsocket.kotlin.internal +import io.ktor.utils.io.core.* import io.rsocket.kotlin.* +import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.keepalive.* import kotlinx.atomicfu.* import kotlinx.coroutines.* -internal class KeepAliveHandler(private val keepAlive: KeepAlive) { +internal class KeepAliveHandler( + private val keepAlive: KeepAlive, + private val sender: FrameSender +) { private val lastMark = atomic(currentMillis()) // mark initial timestamp for keepalive - fun mark() { + suspend fun mark(frame: KeepAliveFrame) { lastMark.value = currentMillis() + if (frame.respond) sender.sendKeepAlive(false, 0, frame.data) } - // return boolean because of native suspend fun tick() { delay(keepAlive.intervalMillis.toLong()) - if (currentMillis() - lastMark.value < keepAlive.maxLifetimeMillis) return - throw RSocketError.ConnectionError("No keep-alive for ${keepAlive.maxLifetimeMillis} ms") + if (currentMillis() - lastMark.value >= keepAlive.maxLifetimeMillis) + throw RSocketError.ConnectionError("No keep-alive for ${keepAlive.maxLifetimeMillis} ms") + + sender.sendKeepAlive(true, 0, ByteReadPacket.Empty) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt index 0eeaad3dc..dbf01562d 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Limiter.kt @@ -18,6 +18,7 @@ package io.rsocket.kotlin.internal import io.rsocket.kotlin.payload.* import kotlinx.atomicfu.* +import kotlinx.atomicfu.locks.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* import kotlin.coroutines.* @@ -31,15 +32,17 @@ internal suspend inline fun Flow.collectLimiting(limiter: Limiter, cros } } -//TODO revisit 2 atomics -internal class Limiter(initial: Int) { +//TODO revisit 2 atomics and sync object +internal class Limiter(initial: Int) : SynchronizedObject() { private val requests = atomic(initial) private val awaiter = atomic?>(null) fun updateRequests(n: Int) { if (n <= 0) return - requests += n - awaiter.getAndSet(null)?.takeIf(CancellableContinuation::isActive)?.resume(Unit) + synchronized(this) { + requests += n + awaiter.getAndSet(null)?.takeIf(CancellableContinuation::isActive)?.resume(Unit) + } } suspend fun useRequest() { @@ -47,8 +50,10 @@ internal class Limiter(initial: Int) { currentCoroutineContext().ensureActive() } else { suspendCancellableCoroutine { - awaiter.value = it - if (requests.value >= 0 && it.isActive) it.resume(Unit) + synchronized(this) { + awaiter.value = it + if (requests.value >= 0 && it.isActive) it.resume(Unit) + } } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt index 933db562b..b14649664 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt @@ -26,16 +26,16 @@ import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* +import kotlin.coroutines.* +//TODO may be need to move all calls on transport dispatcher @OptIn(ExperimentalStreamsApi::class) internal class RSocketRequester( - connectionJob: Job, + override val coroutineContext: CoroutineContext, private val sender: FrameSender, private val streamsStorage: StreamsStorage, - private val requestScope: CoroutineScope, private val pool: ObjectPool ) : RSocket { - override val job: Job = connectionJob override suspend fun metadataPush(metadata: ByteReadPacket) { ensureActiveOrRelease(metadata) @@ -52,7 +52,7 @@ internal class RSocketRequester( sender.sendRequestPayload(FrameType.RequestFnF, id, payload) } catch (cause: Throwable) { payload.release() - if (job.isActive) sender.sendCancel(id) //if cancelled during fragmentation + if (isActive) sender.sendCancel(id) //if cancelled during fragmentation throw cause } } @@ -94,14 +94,14 @@ internal class RSocketRequester( val channel = SafeChannel(Channel.UNLIMITED) val limiter = Limiter(0) - val payloadsJob = Job(requestScope.coroutineContext.job) + val payloadsJob = Job(this@RSocketRequester.coroutineContext.job) val handler = RequesterRequestChannelFrameHandler(id, streamsStorage, limiter, payloadsJob, channel, pool) streamsStorage.save(id, handler) handler.receiveOrCancel(id, initPayload) { sender.sendRequestPayload(FrameType.RequestChannel, id, initPayload, initialRequest) //TODO lazy? - requestScope.launch(payloadsJob) { + launch(payloadsJob) { handler.sendOrFail(id) { payloads.collectLimiting(limiter) { sender.sendNextPayload(id, it) } sender.sendCompletePayload(id) @@ -117,7 +117,7 @@ internal class RSocketRequester( onSendComplete() } catch (cause: Throwable) { val isFailed = onSendFailed(cause) - if (job.isActive && isFailed) sender.sendError(id, cause) + if (isActive && isFailed) sender.sendError(id, cause) throw cause } } @@ -130,14 +130,14 @@ internal class RSocketRequester( } catch (cause: Throwable) { payload.release() val isCancelled = onReceiveCancelled(cause) - if (job.isActive && isCancelled) sender.sendCancel(id) + if (isActive && isCancelled) sender.sendCancel(id) throw cause } } private fun ensureActiveOrRelease(closeable: Closeable) { - if (job.isActive) return + if (isActive) return closeable.close() - job.ensureActive() + ensureActive() } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt index ca25c468e..b24623323 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt @@ -21,26 +21,20 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.internal.handler.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* +import kotlin.coroutines.* @OptIn(ExperimentalStreamsApi::class) internal class RSocketResponder( + override val coroutineContext: CoroutineContext, private val sender: FrameSender, - private val requestHandler: RSocket, - private val requestScope: CoroutineScope, -) { + private val requestHandler: RSocket +) : CoroutineScope { - private fun Job.closeOnCompletion(closeable: Closeable): Job { - invokeOnCompletion { - closeable.close() - } - return this - } - - fun handleMetadataPush(metadata: ByteReadPacket): Job = requestScope.launch { + fun handleMetadataPush(metadata: ByteReadPacket): Job = launch { requestHandler.metadataPush(metadata) }.closeOnCompletion(metadata) - fun handleFireAndForget(payload: Payload, handler: ResponderFireAndForgetFrameHandler): Job = requestScope.launch { + fun handleFireAndForget(payload: Payload, handler: ResponderFireAndForgetFrameHandler): Job = launch { try { requestHandler.fireAndForget(payload) } finally { @@ -48,21 +42,21 @@ internal class RSocketResponder( } }.closeOnCompletion(payload) - fun handleRequestResponse(payload: Payload, id: Int, handler: ResponderRequestResponseFrameHandler): Job = requestScope.launch { + fun handleRequestResponse(payload: Payload, id: Int, handler: ResponderRequestResponseFrameHandler): Job = launch { handler.sendOrFail(id, payload) { val response = requestHandler.requestResponse(payload) sender.sendNextCompletePayload(id, response) } }.closeOnCompletion(payload) - fun handleRequestStream(payload: Payload, id: Int, handler: ResponderRequestStreamFrameHandler): Job = requestScope.launch { + fun handleRequestStream(payload: Payload, id: Int, handler: ResponderRequestStreamFrameHandler): Job = launch { handler.sendOrFail(id, payload) { requestHandler.requestStream(payload).collectLimiting(handler.limiter) { sender.sendNextPayload(id, it) } sender.sendCompletePayload(id) } }.closeOnCompletion(payload) - fun handleRequestChannel(payload: Payload, id: Int, handler: ResponderRequestChannelFrameHandler): Job = requestScope.launch { + fun handleRequestChannel(payload: Payload, id: Int, handler: ResponderRequestChannelFrameHandler): Job = launch { val payloads = requestFlow { strategy, initialRequest -> handler.receiveOrCancel(id) { sender.sendRequestN(id, initialRequest) @@ -94,9 +88,16 @@ internal class RSocketResponder( onReceiveComplete() } catch (cause: Throwable) { val isCancelled = onReceiveCancelled(cause) - if (requestScope.isActive && isCancelled) sender.sendCancel(id) + if (isActive && isCancelled) sender.sendCancel(id) throw cause } } + private fun Job.closeOnCompletion(closeable: Closeable): Job { + invokeOnCompletion { + closeable.close() + } + return this + } + } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/ReconnectableRSocket.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/ReconnectableRSocket.kt index 7ae189a0f..bd1444917 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/ReconnectableRSocket.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/ReconnectableRSocket.kt @@ -22,17 +22,19 @@ import io.rsocket.kotlin.logging.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* +import kotlin.coroutines.* internal typealias ReconnectPredicate = suspend (cause: Throwable, attempt: Long) -> Boolean @OptIn(RSocketLoggingApi::class) -@Suppress("FunctionName") -internal suspend fun ReconnectableRSocket( +internal suspend fun connectWithReconnect( + coroutineContext: CoroutineContext, logger: Logger, connect: suspend () -> RSocket, predicate: ReconnectPredicate, ): RSocket { - val job = Job() + val child = Job(coroutineContext[Job]) + val childContext = coroutineContext + child val state = flow { emit(ReconnectState.Connecting) //init - state = connecting val rSocket = connect() @@ -49,18 +51,23 @@ internal suspend fun ReconnectableRSocket( when (value) { is ReconnectState.Connected -> { logger.debug { "Connection established" } - value.rSocket.job.join() //await for connection completion + value.rSocket.coroutineContext.job.join() //await for connection completion logger.debug { "Connection closed. Reconnecting..." } } - is ReconnectState.Failed -> job.completeExceptionally(value.error) //reconnect failed, fail job + is ReconnectState.Failed -> child.cancel("Reconnect failed", value.error) //reconnect failed, fail job ReconnectState.Connecting -> Unit //skip, still waiting for new connection } }.restarting() //reconnect if old connection completed - .stateIn(CoroutineScope(Dispatchers.Unconfined + job)) + .stateIn(CoroutineScope(childContext), SharingStarted.Eagerly, ReconnectState.Connecting) - return ReconnectableRSocket(job, state).apply { + return ReconnectableRSocket(childContext, state).apply { //await first connection to fail fast if something - currentRSocket() + try { + currentRSocket() + } catch (error: Throwable) { + child.cancel() //if during connecting, cancelled from user side + throw error + } } } @@ -73,7 +80,7 @@ private sealed class ReconnectState { } private class ReconnectableRSocket( - override val job: Job, + override val coroutineContext: CoroutineContext, private val state: StateFlow, ) : RSocket { @@ -82,7 +89,7 @@ private class ReconnectableRSocket( private suspend fun currentRSocket(closeable: Closeable): RSocket = closeable.closeOnError { currentRSocket() } private fun ReconnectState.current(): RSocket? = when (this) { - is ReconnectState.Connected -> rSocket.takeIf { it.job.isActive } //connection is ready to handle requests + is ReconnectState.Connected -> rSocket.takeIf(RSocket::isActive) //connection is ready to handle requests is ReconnectState.Failed -> throw error //connection failed - fail requests ReconnectState.Connecting -> null //reconnection } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt index 6ca1a3fc6..01b574791 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/StreamsStorage.kt @@ -66,7 +66,7 @@ internal class StreamsStorage(private val isServer: Boolean, private val pool: O FrameType.RequestChannel -> ResponderRequestChannelFrameHandler(id, this, responder, initialRequest, pool) else -> error("Wrong request frame type") // should never happen } - handlers[id] = handler + save(id, handler) handler.handleRequest(frame) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/ClientTransport.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/ClientTransport.kt index 499336aea..4c439a6b6 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/ClientTransport.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/ClientTransport.kt @@ -17,8 +17,20 @@ package io.rsocket.kotlin.transport import io.rsocket.kotlin.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +public fun interface ClientTransport : CoroutineScope { + override val coroutineContext: CoroutineContext get() = EmptyCoroutineContext -public fun interface ClientTransport { @TransportApi public suspend fun connect(): Connection } + +@OptIn(TransportApi::class) +public fun ClientTransport(coroutineContext: CoroutineContext, transport: ClientTransport): ClientTransport = object : ClientTransport { + override val coroutineContext: CoroutineContext get() = coroutineContext + + @TransportApi + override suspend fun connect(): Connection = transport.connect() +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/ServerTransport.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/ServerTransport.kt index 167cd845f..7f4d19290 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/ServerTransport.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/transport/ServerTransport.kt @@ -17,8 +17,9 @@ package io.rsocket.kotlin.transport import io.rsocket.kotlin.* +import kotlinx.coroutines.* public fun interface ServerTransport { @TransportApi - public fun start(accept: suspend (Connection) -> Unit): T + public fun CoroutineScope.start(accept: suspend CoroutineScope.(Connection) -> Unit): T } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt index b46f7e5e3..57b555c61 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/ConnectionEstablishmentTest.kt @@ -64,10 +64,10 @@ class ConnectionEstablishmentTest : SuspendTest, TestWithLeakCheck { assertEquals(errorMessage, frame.throwable.message) } val sender = sendingRSocket.await() - assertFalse(sender.job.isActive) + assertFalse(sender.isActive) expectNoEventsIn(100) } - val error = connection.job.getCancellationException().cause + val error = connection.coroutineContext.job.getCancellationException().cause assertTrue(error is RSocketError.Setup.Rejected) assertEquals(errorMessage, error.message) } @@ -88,7 +88,6 @@ class ConnectionEstablishmentTest : SuspendTest, TestWithLeakCheck { } }.connect { connection } } - println(p.data) assertTrue(p.data.isEmpty) } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt index 0abe3f450..616681ce5 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt @@ -40,10 +40,12 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } private suspend fun start(handler: RSocket? = null): RSocket { - val localServer = LocalServer(testJob) - RSocketServer { + val localServer = RSocketServer { loggerFactory = LoggerFactory { PrintLogger.withLevel(LoggingLevel.DEBUG).logger("SERVER |$it") } - }.bind(localServer) { + }.bindIn( + CoroutineScope(Dispatchers.Unconfined + testJob + CoroutineExceptionHandler { c, e -> println("$c -> $e") }), + LocalServerTransport(InUseTrackingPool) + ) { handler ?: RSocketRequestHandler { requestResponse { it } requestStream { @@ -52,7 +54,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } requestChannel { init, payloads -> init.release() - payloads.onEach { it.release() }.launchIn(CoroutineScope(job)) + payloads.onEach { it.release() }.launchIn(this) flow { repeat(10) { emitOrClose(payload("server got -> [$it]")) } } } } @@ -340,7 +342,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { val responderDeferred = CompletableDeferred>() val requester = start(RSocketRequestHandler { requestChannel { init, payloads -> - responderDeferred.complete(payloads.onStart { emit(init) }.produceIn(CoroutineScope(job))) + responderDeferred.complete(payloads.onStart { emit(init) }.produceIn(this)) responderSendChannel.consumeAsFlow() } @@ -348,7 +350,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { val requesterReceiveChannel = requester .requestChannel(payload("initData", "initMetadata"), requesterSendChannel.consumeAsFlow()) - .produceIn(CoroutineScope(requester.job)) + .produceIn(requester) val responderReceiveChannel = responderDeferred.await() diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt index f2431fd39..c5d6c8097 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt @@ -35,12 +35,17 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { private val first = atomic(true) private val logger = DefaultLoggerFactory.logger("io.rsocket.kotlin.connection") + private suspend fun connectWithReconnect( + connect: suspend () -> RSocket, + predicate: ReconnectPredicate, + ): RSocket = connectWithReconnect(Dispatchers.Unconfined, logger, connect, predicate) + @Test fun testConnectFail() = test { val connect: suspend () -> RSocket = { error("Failed to connect") } assertFailsWith(IllegalStateException::class, "Failed to connect") { - ReconnectableRSocket(logger, connect) { cause, attempt -> + connectWithReconnect(connect) { cause, attempt -> fails.incrementAndGet() assertTrue(cause is IllegalStateException) assertEquals("Failed to connect", cause.message) @@ -61,7 +66,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { error("Failed to connect") } } - val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt -> + val rSocket = connectWithReconnect(connect) { cause, attempt -> fails.incrementAndGet() assertTrue(cause is IllegalStateException) assertEquals("Failed to connect", cause.message) @@ -70,7 +75,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { assertEquals(Payload.Empty, rSocket.requestResponse(Payload.Empty)) - assertTrue(rSocket.job.isActive) + assertTrue(rSocket.isActive) assertEquals(0, fails.value) firstJob.cancelAndJoin() @@ -79,7 +84,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { rSocket.requestResponse(Payload.Empty) } - assertFalse(rSocket.job.isActive) + assertFalse(rSocket.isActive) assertEquals(6, fails.value) } @@ -94,7 +99,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { handler(handlerJob) } } - val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt -> + val rSocket = connectWithReconnect(connect) { cause, attempt -> fails.incrementAndGet() assertTrue(cause is IllegalStateException) assertEquals("Failed to connect", cause.message) @@ -104,7 +109,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { assertEquals(Payload.Empty, rSocket.requestResponse(Payload.Empty)) assertTrue(handlerJob.isActive) - assertTrue(rSocket.job.isActive) + assertTrue(rSocket.isActive) assertEquals(1, fails.value) } @@ -119,7 +124,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { handler(Job()) } } - val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt -> + val rSocket = connectWithReconnect(connect) { cause, attempt -> fails.incrementAndGet() assertTrue(cause is IllegalStateException) assertEquals("Failed to connect", cause.message) @@ -128,7 +133,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { assertEquals(Payload.Empty, rSocket.requestResponse(Payload.Empty)) - assertTrue(rSocket.job.isActive) + assertTrue(rSocket.isActive) assertEquals(5, fails.value) } @@ -148,7 +153,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { else -> handler(Job()) } } - val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt -> + val rSocket = connectWithReconnect(connect) { cause, attempt -> fails.incrementAndGet() assertTrue(cause is IllegalStateException) assertEquals("Failed to connect", cause.message) @@ -161,7 +166,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { assertEquals(Payload.Empty, rSocket.requestResponse(Payload.Empty)) - assertTrue(rSocket.job.isActive) + assertTrue(rSocket.isActive) assertEquals(5, fails.value) } @@ -181,7 +186,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { else -> handler(Job()) } } - val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt -> + val rSocket = connectWithReconnect(connect) { cause, attempt -> fails.incrementAndGet() assertTrue(cause is IllegalStateException) assertEquals("Failed to connect", cause.message) @@ -204,7 +209,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { expectComplete() } - assertTrue(rSocket.job.isActive) + assertTrue(rSocket.isActive) assertEquals(5, fails.value) } @@ -232,7 +237,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { error("Failed to connect") } } - val rSocket = ReconnectableRSocket(logger, connect) { _, attempt -> + val rSocket = connectWithReconnect(connect) { _, attempt -> delay(100) attempt < 5 } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt index ce260b3a5..7a24b7b06 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt @@ -34,7 +34,8 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { override suspend fun before() { super.before() - requester = connection.connect( + requester = connect( + connection = connection, isServer = false, maxFragmentSize = 0, interceptors = InterceptorsBuilder().build(), @@ -42,16 +43,15 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { keepAlive = KeepAlive(Duration.seconds(1000), Duration.seconds(1000)), payloadMimeType = DefaultPayloadMimeType, setupPayload = Payload.Empty - ), - acceptor = { RSocketRequestHandler { } } - ) + ) + ) { RSocketRequestHandler { } } } @Test fun testInvalidFrameOnStream0() = test { connection.sendToReceiver(NextPayloadFrame(0, payload("data", "metadata"))) //should be just released delay(100) - assertTrue(requester.job.isActive) + assertTrue(requester.isActive) } @Test @@ -258,8 +258,8 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { val errorMessage = "error" connection.sendToReceiver(ErrorFrame(0, RSocketError.Setup.Rejected(errorMessage))) delay(100) - assertFalse(requester.job.isActive) - val error = requester.job.getCancellationException().cause + assertFalse(requester.isActive) + val error = requester.coroutineContext.job.getCancellationException().cause assertTrue(error is RSocketError.Setup.Rejected) assertEquals(errorMessage, error.message) } @@ -385,7 +385,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { delay(200) connection.test { expectFrame { assertTrue(it is RequestFrame) } - connection.job.cancel() + connection.cancel() expectComplete() } } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt index 2bc4e44c4..f918627c3 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/keepalive/KeepAliveTest.kt @@ -31,13 +31,13 @@ class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { private suspend fun requester( keepAlive: KeepAlive = KeepAlive(Duration.milliseconds(100), Duration.seconds(1)) - ): RSocket = connection.connect( + ): RSocket = connect( + connection = connection, isServer = false, maxFragmentSize = 0, interceptors = InterceptorsBuilder().build(), - connectionConfig = ConnectionConfig(keepAlive, DefaultPayloadMimeType, Payload.Empty), - acceptor = { RSocketRequestHandler { } } - ) + connectionConfig = ConnectionConfig(keepAlive, DefaultPayloadMimeType, Payload.Empty) + ) { RSocketRequestHandler { } } @Test fun requesterSendKeepAlive() = test { @@ -62,7 +62,7 @@ class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { } } delay(Duration.seconds(1.5)) - assertTrue(rSocket.job.isActive) + assertTrue(rSocket.isActive) connection.test { repeat(50) { expectItem() @@ -92,7 +92,7 @@ class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { @Test fun noKeepAliveSentAfterRSocketCanceled() = test { - requester().job.cancel() + requester().cancel() connection.test { expectNoEventsIn(500) } @@ -102,9 +102,9 @@ class KeepAliveTest : TestWithConnection(), TestWithLeakCheck { fun rSocketCanceledOnMissingKeepAliveTicks() = test { val rSocket = requester() connection.test { - while (rSocket.job.isActive) kotlin.runCatching { expectItem() } + while (rSocket.isActive) kotlin.runCatching { expectItem() } } - assertTrue(rSocket.job.getCancellationException().cause is RSocketError.ConnectionError) + assertTrue(rSocket.coroutineContext.job.getCancellationException().cause is RSocketError.ConnectionError) } } diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt index 5188a8501..7e92b12b4 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestConnection.kt @@ -22,24 +22,25 @@ import io.ktor.utils.io.core.internal.* import io.ktor.utils.io.pool.* import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* +import io.rsocket.kotlin.internal.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* import kotlin.coroutines.* import kotlin.time.* -class TestConnection : Connection, CoroutineScope { +class TestConnection : Connection { override val pool: ObjectPool = InUseTrackingPool - override val job: Job = Job() - override val coroutineContext: CoroutineContext = job + Dispatchers.Unconfined + override val coroutineContext: CoroutineContext = + Job() + Dispatchers.Unconfined + CoroutineExceptionHandler { c, e -> println("$c -> $e") } private val sendChannel = Channel(Channel.UNLIMITED) private val receiveChannel = Channel(Channel.UNLIMITED) init { - job.invokeOnCompletion { + coroutineContext.job.invokeOnCompletion { sendChannel.close(it) - receiveChannel.cancel(it?.let { it as? CancellationException ?: CancellationException("Connection completed") }) + @Suppress("INVISIBLE_MEMBER") receiveChannel.fullClose(it) } } @@ -51,16 +52,17 @@ class TestConnection : Connection, CoroutineScope { return receiveChannel.receive() } - @Suppress("INVISIBLE_MEMBER") //for toPacket suspend fun sendToReceiver(vararg frames: Frame) { - frames.forEach { receiveChannel.send(it.toPacket(InUseTrackingPool)) } + frames.forEach { + val packet = @Suppress("INVISIBLE_MEMBER") it.toPacket(InUseTrackingPool) + receiveChannel.send(packet) + } } - @Suppress("INVISIBLE_MEMBER") //for readFrame - private fun sentAsFlow(): Flow = sendChannel.receiveAsFlow().map { it.readFrame(InUseTrackingPool) } - suspend fun test(validate: suspend FlowTurbine.() -> Unit) { - sentAsFlow().test(validate = validate) + sendChannel.consumeAsFlow().map { + @Suppress("INVISIBLE_MEMBER") it.readFrame(InUseTrackingPool) + }.test(validate = validate) } } diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt index 1b2a3317d..1ba65f142 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt @@ -21,9 +21,10 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* +import kotlin.coroutines.* class TestRSocket : RSocket { - override val job: Job = Job() + override val coroutineContext: CoroutineContext = Job() override suspend fun metadataPush(metadata: ByteReadPacket): Unit = metadata.release() diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestWithConnection.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestWithConnection.kt index 35d03d2d7..ca169e793 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestWithConnection.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestWithConnection.kt @@ -22,6 +22,6 @@ abstract class TestWithConnection : SuspendTest { val connection: TestConnection = TestConnection() override suspend fun after() { - connection.job.cancelAndJoin() + connection.coroutineContext.job.cancelAndJoin() } } diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt index bc3cae5af..88f45fbf2 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt @@ -32,7 +32,7 @@ abstract class TransportTest : SuspendTest, TestWithLeakCheck { lateinit var client: RSocket //should be assigned in `before` override suspend fun after() { - client.job.cancelAndJoin() + client.coroutineContext.job.cancelAndJoin() } @Test diff --git a/rsocket-transport-ktor/rsocket-transport-ktor-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/client/WebSocketClientTransport.kt b/rsocket-transport-ktor/rsocket-transport-ktor-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/client/WebSocketClientTransport.kt index 0613f47fc..4a0209058 100644 --- a/rsocket-transport-ktor/rsocket-transport-ktor-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/client/WebSocketClientTransport.kt +++ b/rsocket-transport-ktor/rsocket-transport-ktor-client/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/client/WebSocketClientTransport.kt @@ -26,14 +26,14 @@ import io.ktor.http.* import io.rsocket.kotlin.* import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.ktor.* +import kotlinx.coroutines.* public fun WebSocketClientTransport( httpClient: HttpClient, request: HttpRequestBuilder.() -> Unit, -): ClientTransport = ClientTransport { +): ClientTransport = ClientTransport(httpClient.coroutineContext + SupervisorJob(httpClient.coroutineContext[Job])) { val session = httpClient.webSocketSession(request) - @Suppress("INVISIBLE_MEMBER") - WebSocketConnection(session) + @Suppress("INVISIBLE_MEMBER") WebSocketConnection(session) } public fun WebSocketClientTransport( diff --git a/rsocket-transport-ktor/rsocket-transport-ktor-server/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/server/WebSocketServerTransport.kt b/rsocket-transport-ktor/rsocket-transport-ktor-server/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/server/WebSocketServerTransport.kt index a10312307..d89fe7ad8 100644 --- a/rsocket-transport-ktor/rsocket-transport-ktor-server/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/server/WebSocketServerTransport.kt +++ b/rsocket-transport-ktor/rsocket-transport-ktor-server/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/server/WebSocketServerTransport.kt @@ -29,13 +29,11 @@ internal fun Route.serverTransport( ): ServerTransport = ServerTransport { acceptor -> when (path) { null -> webSocket(protocol) { - @Suppress("INVISIBLE_MEMBER") - val connection = WebSocketConnection(this) + val connection = @Suppress("INVISIBLE_MEMBER") WebSocketConnection(this) acceptor(connection) } else -> webSocket(path, protocol) { - @Suppress("INVISIBLE_MEMBER") - val connection = WebSocketConnection(this) + val connection = @Suppress("INVISIBLE_MEMBER") WebSocketConnection(this) acceptor(connection) } } diff --git a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpClientTransport.kt b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpClientTransport.kt index 9b4e8a18e..70ef6742e 100644 --- a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpClientTransport.kt +++ b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpClientTransport.kt @@ -21,23 +21,41 @@ package io.rsocket.kotlin.transport.ktor import io.ktor.network.selector.* import io.ktor.network.sockets.* +import io.ktor.util.* import io.ktor.util.network.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* import io.rsocket.kotlin.* import io.rsocket.kotlin.transport.* +import kotlinx.coroutines.* +import kotlin.coroutines.* + +//TODO user should close ClientTransport manually if there is no job provided in context + +//this dispatcher will be used, if no dispatcher were provided by user in client and server +internal expect val defaultDispatcher: CoroutineDispatcher public fun TcpClientTransport( - selector: SelectorManager, hostname: String, port: Int, + context: CoroutineContext = EmptyCoroutineContext, + pool: ObjectPool = ChunkBuffer.Pool, intercept: (Socket) -> Socket = { it }, //f.e. for tls, which is currently supported by ktor only on JVM - configure: SocketOptions.TCPClientSocketOptions.() -> Unit = {}, -): ClientTransport = TcpClientTransport(selector, NetworkAddress(hostname, port), intercept, configure) + configure: SocketOptions.TCPClientSocketOptions.() -> Unit = {} +): ClientTransport = TcpClientTransport(NetworkAddress(hostname, port), context, pool, intercept, configure) public fun TcpClientTransport( - selector: SelectorManager, remoteAddress: NetworkAddress, + context: CoroutineContext = EmptyCoroutineContext, + pool: ObjectPool = ChunkBuffer.Pool, intercept: (Socket) -> Socket = { it }, //f.e. for tls, which is currently supported by ktor only on JVM - configure: SocketOptions.TCPClientSocketOptions.() -> Unit = {}, -): ClientTransport = ClientTransport { - val socket = aSocket(selector).tcp().connect(remoteAddress, configure) - TcpConnection(intercept(socket)) + configure: SocketOptions.TCPClientSocketOptions.() -> Unit = {} +): ClientTransport { + val transportJob = SupervisorJob(context[Job]) + val transportContext = defaultDispatcher + context + transportJob + CoroutineName("rSocket-tcp-client") + val selector = @OptIn(InternalAPI::class) SelectorManager(transportContext) + Job(transportJob).invokeOnCompletion { selector.close() } + return ClientTransport(transportContext) { + val socket = aSocket(selector).tcp().connect(remoteAddress, configure) + TcpConnection(intercept(socket), transportContext + Job(transportJob), pool) + } } diff --git a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpConnection.kt b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpConnection.kt index 1bc143aa4..139b2dae6 100644 --- a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpConnection.kt +++ b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpConnection.kt @@ -20,41 +20,38 @@ import io.ktor.network.sockets.* import io.ktor.util.cio.* import io.ktor.utils.io.* import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* import io.rsocket.kotlin.* import io.rsocket.kotlin.Connection import io.rsocket.kotlin.frame.io.* import io.rsocket.kotlin.internal.* import kotlinx.coroutines.* -import kotlinx.coroutines.CancellationException import kotlin.coroutines.* -import kotlin.native.concurrent.* - -@SharedImmutable -internal val ignoreExceptionHandler = CoroutineExceptionHandler { _, _ -> } @OptIn(TransportApi::class) -internal class TcpConnection(private val socket: Socket) : Connection, CoroutineScope { - override val job: Job = socket.socketContext - override val coroutineContext: CoroutineContext = job + Dispatchers.Unconfined + ignoreExceptionHandler - - @Suppress("INVISIBLE_MEMBER") - private val sendChannel = SafeChannel(8) +internal class TcpConnection( + socket: Socket, + override val coroutineContext: CoroutineContext, + override val pool: ObjectPool +) : Connection { + private val socketConnection = socket.connection() - @Suppress("INVISIBLE_MEMBER") - private val receiveChannel = SafeChannel(8) + private val sendChannel = @Suppress("INVISIBLE_MEMBER") SafeChannel(8) + private val receiveChannel = @Suppress("INVISIBLE_MEMBER") SafeChannel(8) init { launch { - socket.openWriteChannel(autoFlush = true).use { + socketConnection.output.use { while (isActive) { val packet = sendChannel.receive() val length = packet.remaining.toInt() try { writePacket { - @Suppress("INVISIBLE_MEMBER") - writeLength(length) + @Suppress("INVISIBLE_MEMBER") writeLength(length) writePacket(packet) } + flush() } catch (e: Throwable) { packet.close() throw e @@ -63,10 +60,9 @@ internal class TcpConnection(private val socket: Socket) : Connection, Coroutine } } launch { - socket.openReadChannel().apply { + socketConnection.input.apply { while (isActive) { - @Suppress("INVISIBLE_MEMBER") - val length = readPacket(3).readLength() + val length = @Suppress("INVISIBLE_MEMBER") readPacket(3).readLength() val packet = readPacket(length) try { receiveChannel.send(packet) @@ -77,14 +73,15 @@ internal class TcpConnection(private val socket: Socket) : Connection, Coroutine } } } - job.invokeOnCompletion { cause -> - val error = cause?.let { it as? CancellationException ?: CancellationException("Connection failed", it) } - sendChannel.cancel(error) - receiveChannel.cancel(error) + coroutineContext.job.invokeOnCompletion { + @Suppress("INVISIBLE_MEMBER") sendChannel.fullClose(it) + @Suppress("INVISIBLE_MEMBER") receiveChannel.fullClose(it) + socketConnection.input.cancel(it) + socketConnection.output.close(it) + socketConnection.socket.close() } } override suspend fun send(packet: ByteReadPacket): Unit = sendChannel.send(packet) - override suspend fun receive(): ByteReadPacket = receiveChannel.receive() } diff --git a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpServerTransport.kt b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpServerTransport.kt index 65467f297..6026fc20c 100644 --- a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpServerTransport.kt +++ b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/TcpServerTransport.kt @@ -20,31 +20,49 @@ package io.rsocket.kotlin.transport.ktor import io.ktor.network.selector.* import io.ktor.network.sockets.* +import io.ktor.util.* import io.ktor.util.network.* +import io.ktor.utils.io.core.* +import io.ktor.utils.io.core.internal.* +import io.ktor.utils.io.pool.* import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* +public class TcpServer internal constructor( + public val handlerJob: Job, + public val serverSocket: Deferred +) + public fun TcpServerTransport( - selector: SelectorManager, hostname: String = "0.0.0.0", port: Int = 0, + pool: ObjectPool = ChunkBuffer.Pool, configure: SocketOptions.AcceptorOptions.() -> Unit = {}, -): ServerTransport = TcpServerTransport(selector, NetworkAddress(hostname, port), configure) +): ServerTransport = TcpServerTransport(NetworkAddress(hostname, port), pool, configure) -@OptIn(DelicateCoroutinesApi::class) //TODO ? public fun TcpServerTransport( - selector: SelectorManager, localAddress: NetworkAddress? = null, + pool: ObjectPool = ChunkBuffer.Pool, configure: SocketOptions.AcceptorOptions.() -> Unit = {}, -): ServerTransport = ServerTransport { accept -> - val serverSocket = aSocket(selector).tcp().bind(localAddress, configure) - GlobalScope.launch(serverSocket.socketContext + Dispatchers.Unconfined + ignoreExceptionHandler, CoroutineStart.UNDISPATCHED) { - supervisorScope { - while (isActive) { - val clientSocket = serverSocket.accept() - val connection = TcpConnection(clientSocket) - launch(start = CoroutineStart.UNDISPATCHED) { accept(connection) } +): ServerTransport = ServerTransport { accept -> + val serverSocketDeferred = CompletableDeferred() + val handlerJob = launch(defaultDispatcher + coroutineContext) { + @OptIn(InternalAPI::class) SelectorManager(coroutineContext).use { selector -> + aSocket(selector).tcp().bind(localAddress, configure).use { serverSocket -> + serverSocketDeferred.complete(serverSocket) + val connectionScope = + CoroutineScope(coroutineContext + SupervisorJob(coroutineContext[Job]) + CoroutineName("rSocket-tcp-server")) + while (isActive) { + val clientSocket = serverSocket.accept() + connectionScope.launch { + accept(TcpConnection(clientSocket, coroutineContext, pool)) + }.invokeOnCompletion { + clientSocket.close() + } + } } } } - serverSocket.socketContext + handlerJob.invokeOnCompletion { it?.let(serverSocketDeferred::completeExceptionally) } + TcpServer(handlerJob, serverSocketDeferred) } + diff --git a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnection.kt b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnection.kt index ff8f22dbd..460b656fc 100644 --- a/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnection.kt +++ b/rsocket-transport-ktor/src/commonMain/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnection.kt @@ -22,10 +22,7 @@ import io.rsocket.kotlin.* import kotlinx.coroutines.* @TransportApi -internal class WebSocketConnection(private val session: WebSocketSession) : Connection { - - override val job: Job = session.coroutineContext.job - +internal class WebSocketConnection(private val session: WebSocketSession) : Connection, CoroutineScope by session { override suspend fun send(packet: ByteReadPacket) { session.send(packet.readBytes()) } diff --git a/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/PortProvider.kt b/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/PortProvider.kt new file mode 100644 index 000000000..33f9de457 --- /dev/null +++ b/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/PortProvider.kt @@ -0,0 +1,9 @@ +package io.rsocket.kotlin.transport.ktor + +import kotlinx.atomicfu.* +import kotlin.random.* + +object PortProvider { + private val port = atomic(Random.nextInt(20, 90) * 100) + fun next(): Int = port.incrementAndGet() +} diff --git a/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/TcpServerTest.kt b/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/TcpServerTest.kt index 81b618406..399b152c7 100644 --- a/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/TcpServerTest.kt +++ b/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/TcpServerTest.kt @@ -16,40 +16,33 @@ package io.rsocket.kotlin.transport.ktor -import io.ktor.network.selector.* +import io.ktor.util.network.* import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* import io.rsocket.kotlin.test.* -import kotlinx.atomicfu.* import kotlinx.coroutines.* -import kotlin.random.* import kotlin.test.* -abstract class TcpServerTest( - private val clientSelector: SelectorManager, - private val serverSelector: SelectorManager -) : SuspendTest, TestWithLeakCheck { - private val currentPort = port.incrementAndGet() - private val serverTransport = TcpServerTransport(serverSelector, port = currentPort) - private val clientTransport = TcpClientTransport(clientSelector, "0.0.0.0", port = currentPort) - - private lateinit var server: Job +abstract class TcpServerTest : SuspendTest, TestWithLeakCheck { + private val testJob = Job() + private val testContext = testJob + CoroutineExceptionHandler { c, e -> println("$c -> $e") } + private val address = NetworkAddress("0.0.0.0", PortProvider.next()) + private val serverTransport = TcpServerTransport(address, InUseTrackingPool) + private val clientTransport = TcpClientTransport(address, testContext, InUseTrackingPool) override suspend fun after() { - server.cancelAndJoin() - clientSelector.close() - serverSelector.close() + testJob.cancelAndJoin() } @Test fun testFailedConnection() = test { - server = RSocketServer().bind(serverTransport) { + val server = RSocketServer().bindIn(CoroutineScope(testContext), serverTransport) { if (config.setupPayload.data.readText() == "ok") { RSocketRequestHandler { requestResponse { it } } } else error("FAILED") - } + }.also { it.serverSocket.await() } suspend fun newClient(text: String) = RSocketConnector { connectionConfig { @@ -72,25 +65,26 @@ abstract class TcpServerTest( client3.requestResponse(payload("ok")).release() client1.requestResponse(payload("ok")).release() - assertTrue(client1.job.isActive) - assertFalse(client2.job.isActive) - assertTrue(client3.job.isActive) + assertTrue(client1.isActive) + assertFalse(client2.isActive) + assertTrue(client3.isActive) - assertTrue(server.isActive) + assertTrue(server.serverSocket.await().socketContext.isActive) + assertTrue(server.handlerJob.isActive) - client1.job.cancelAndJoin() - client2.job.cancelAndJoin() - client3.job.cancelAndJoin() + client1.coroutineContext.job.cancelAndJoin() + client2.coroutineContext.job.cancelAndJoin() + client3.coroutineContext.job.cancelAndJoin() } @Test fun testFailedHandler() = test { val handlers = mutableListOf() - server = RSocketServer().bind(serverTransport) { + val server = RSocketServer().bindIn(CoroutineScope(testContext), serverTransport) { RSocketRequestHandler { requestResponse { it } }.also { handlers += it } - } + }.also { it.serverSocket.await() } suspend fun newClient() = RSocketConnector().connect(clientTransport) @@ -102,7 +96,7 @@ abstract class TcpServerTest( client2.requestResponse(payload("1")).release() - handlers[1].job.apply { + handlers[1].coroutineContext.job.apply { cancel("FAILED") join() } @@ -119,18 +113,15 @@ abstract class TcpServerTest( client1.requestResponse(payload("1")).release() - assertTrue(client1.job.isActive) - assertFalse(client2.job.isActive) - assertTrue(client3.job.isActive) + assertTrue(client1.isActive) + assertFalse(client2.isActive) + assertTrue(client3.isActive) - assertTrue(server.isActive) - - client1.job.cancelAndJoin() - client2.job.cancelAndJoin() - client3.job.cancelAndJoin() - } + assertTrue(server.serverSocket.await().socketContext.isActive) + assertTrue(server.handlerJob.isActive) - companion object { - private val port = atomic(Random.nextInt(20, 90) * 100) + client1.coroutineContext.job.cancelAndJoin() + client2.coroutineContext.job.cancelAndJoin() + client3.coroutineContext.job.cancelAndJoin() } } diff --git a/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/TcpTransportTest.kt b/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/TcpTransportTest.kt index a835b0944..5d87df2db 100644 --- a/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/TcpTransportTest.kt +++ b/rsocket-transport-ktor/src/commonTest/kotlin/io/rsocket/kotlin/transport/ktor/TcpTransportTest.kt @@ -16,38 +16,22 @@ package io.rsocket.kotlin.transport.ktor -import io.ktor.network.selector.* import io.ktor.util.network.* import io.rsocket.kotlin.test.* -import io.rsocket.kotlin.transport.* -import kotlinx.atomicfu.* import kotlinx.coroutines.* -import kotlin.random.* -abstract class TcpTransportTest( - private val clientSelector: SelectorManager, - protected val serverSelector: SelectorManager -) : TransportTest() { - private lateinit var server: Job - - //on Native uses default transport - //on JVM uses little different, because of bug in KTOR - abstract fun serverTransport(address: NetworkAddress): ServerTransport +abstract class TcpTransportTest : TransportTest() { + private val testJob = Job() override suspend fun before() { - val address = NetworkAddress("0.0.0.0", port.incrementAndGet()) - server = SERVER.bind(serverTransport(address), ACCEPTOR) - client = CONNECTOR.connect(TcpClientTransport(clientSelector, address)) + val address = NetworkAddress("0.0.0.0", PortProvider.next()) + val context = testJob + CoroutineExceptionHandler { c, e -> println("$c -> $e") } + SERVER.bindIn(CoroutineScope(context), TcpServerTransport(address, InUseTrackingPool), ACCEPTOR).serverSocket.await() + client = CONNECTOR.connect(TcpClientTransport(address, context, InUseTrackingPool)) } override suspend fun after() { super.after() - server.cancelAndJoin() - clientSelector.close() - serverSelector.close() - } - - companion object { - private val port = atomic(Random.nextInt(20, 90) * 100) + testJob.cancelAndJoin() } } diff --git a/rsocket-transport-ktor/src/jsMain/kotlin/io/rsocket/kotlin/transport/ktor/defaultDispatcher.kt b/rsocket-transport-ktor/src/jsMain/kotlin/io/rsocket/kotlin/transport/ktor/defaultDispatcher.kt new file mode 100644 index 000000000..360633462 --- /dev/null +++ b/rsocket-transport-ktor/src/jsMain/kotlin/io/rsocket/kotlin/transport/ktor/defaultDispatcher.kt @@ -0,0 +1,5 @@ +package io.rsocket.kotlin.transport.ktor + +import kotlinx.coroutines.* + +internal actual val defaultDispatcher: CoroutineDispatcher get() = Dispatchers.Default diff --git a/rsocket-transport-ktor/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/defaultDispatcher.kt b/rsocket-transport-ktor/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/defaultDispatcher.kt new file mode 100644 index 000000000..42b0e8fc5 --- /dev/null +++ b/rsocket-transport-ktor/src/jvmMain/kotlin/io/rsocket/kotlin/transport/ktor/defaultDispatcher.kt @@ -0,0 +1,5 @@ +package io.rsocket.kotlin.transport.ktor + +import kotlinx.coroutines.* + +internal actual val defaultDispatcher: CoroutineDispatcher get() = Dispatchers.IO diff --git a/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/JvmTcpTransportTest.kt b/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/JvmTcpTransportTest.kt index 514a030c3..21bda15d2 100644 --- a/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/JvmTcpTransportTest.kt +++ b/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/JvmTcpTransportTest.kt @@ -16,35 +16,5 @@ package io.rsocket.kotlin.transport.ktor -import io.ktor.network.selector.* -import io.ktor.network.sockets.* -import io.ktor.util.network.* -import io.rsocket.kotlin.Connection -import io.rsocket.kotlin.transport.* -import kotlinx.atomicfu.* -import kotlinx.coroutines.* - -class JvmTcpTransportTest : TcpTransportTest(ActorSelectorManager(Dispatchers.IO), ActorSelectorManager(Dispatchers.IO)) { - //hack because of https://youtrack.jetbrains.com/issue/KTOR-2881 - private var serverConnection: Connection? by atomic(null) - override fun serverTransport(address: NetworkAddress): ServerTransport = ServerTransport { accept -> - val serverSocket = aSocket(serverSelector).tcp().bind(address) - GlobalScope.launch( - serverSocket.socketContext + Dispatchers.Unconfined + ignoreExceptionHandler, - CoroutineStart.UNDISPATCHED - ) { - val clientSocket = serverSocket.accept() - serverConnection = TcpConnection(clientSocket) - accept(serverConnection!!) - } - serverSocket.socketContext - } - - override suspend fun after() { - //we need to cancel server connection job manually on JVM - serverConnection?.job?.cancelAndJoin() - super.after() - } -} - -class JvmTcpServerTest : TcpServerTest(ActorSelectorManager(Dispatchers.IO), ActorSelectorManager(Dispatchers.IO)) +class JvmTcpTransportTest : TcpTransportTest() +class JvmTcpServerTest : TcpServerTest() diff --git a/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnectionTest.kt b/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnectionTest.kt index a830abc03..0328f9058 100644 --- a/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnectionTest.kt +++ b/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketConnectionTest.kt @@ -29,7 +29,6 @@ import io.rsocket.kotlin.transport.ktor.client.* import io.rsocket.kotlin.transport.ktor.server.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* -import kotlin.random.* import kotlin.test.* import io.ktor.client.engine.cio.CIO as ClientCIO import io.ktor.client.features.websocket.WebSockets as ClientWebSockets @@ -39,7 +38,7 @@ import io.rsocket.kotlin.transport.ktor.client.RSocketSupport as ClientRSocketSu import io.rsocket.kotlin.transport.ktor.server.RSocketSupport as ServerRSocketSupport class WebSocketConnectionTest : SuspendTest, TestWithLeakCheck { - private val port = Random.nextInt(20, 90) * 100 + private val port = PortProvider.next() private val client = HttpClient(ClientCIO) { install(ClientWebSockets) install(ClientRSocketSupport) { @@ -69,7 +68,7 @@ class WebSocketConnectionTest : SuspendTest, TestWithLeakCheck { } } } - }.also { responderJob = it.job } + }.also { responderJob = it.coroutineContext.job } } } } @@ -86,7 +85,7 @@ class WebSocketConnectionTest : SuspendTest, TestWithLeakCheck { @Test fun testWorks() = test { val rSocket = client.rSocket(port = port) - val requesterJob = rSocket.job + val requesterJob = rSocket.coroutineContext.job rSocket .requestStream(Payload.Empty) diff --git a/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketTransportTest.kt b/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketTransportTest.kt index afc3daf35..758f950f4 100644 --- a/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketTransportTest.kt +++ b/rsocket-transport-ktor/src/jvmTest/kotlin/io/rsocket/kotlin/transport/ktor/WebSocketTransportTest.kt @@ -21,14 +21,10 @@ import io.ktor.client.* import io.ktor.client.engine.* import io.ktor.routing.* import io.ktor.server.engine.* -import io.ktor.websocket.* -import io.rsocket.kotlin.* import io.rsocket.kotlin.test.* -import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.ktor.client.* -import kotlinx.atomicfu.* +import io.rsocket.kotlin.transport.ktor.server.* import kotlinx.coroutines.* -import kotlin.random.* import io.ktor.client.features.websocket.WebSockets as ClientWebSockets import io.ktor.websocket.WebSockets as ServerWebSockets import io.rsocket.kotlin.transport.ktor.client.RSocketSupport as ClientRSocketSupport @@ -38,45 +34,34 @@ abstract class WebSocketTransportTest( clientEngine: HttpClientEngineFactory<*>, serverEngine: ApplicationEngineFactory<*, *>, ) : TransportTest() { - - private var serverConnection: Connection? by atomic(null) + private val port = PortProvider.next() + private val testJob = Job() private val httpClient = HttpClient(clientEngine) { install(ClientWebSockets) install(ClientRSocketSupport) { connector = CONNECTOR } } - private val currentPort = port.incrementAndGet() - - private val server = embeddedServer(serverEngine, currentPort) { + private val server = (GlobalScope + testJob).embeddedServer(serverEngine, port) { install(ServerWebSockets) install(ServerRSocketSupport) { server = SERVER } - install(Routing) { - //hack to really await completion of server connection and expect no leaks - val serverTransport = ServerTransport { acceptor -> - webSocket { - serverConnection = WebSocketConnection(this) - acceptor(serverConnection!!) - } - } - SERVER.bind(serverTransport, ACCEPTOR) - } + install(Routing) { rSocket(acceptor = ACCEPTOR) } } override suspend fun before() { super.before() server.start() - client = trySeveralTimes { httpClient.rSocket(port = currentPort) } + client = trySeveralTimes { httpClient.rSocket(port = port) } } override suspend fun after() { super.after() - server.stop(0, 1000) + server.stop(200, 1000) + testJob.cancelAndJoin() httpClient.close() - httpClient.coroutineContext.job.join() - serverConnection?.job?.cancelAndJoin() + httpClient.coroutineContext.job.cancelAndJoin() } private suspend inline fun trySeveralTimes(block: () -> R): R { @@ -91,8 +76,4 @@ abstract class WebSocketTransportTest( } throw error } - - companion object { - private val port = atomic(Random.nextInt(20, 90) * 100) - } } diff --git a/rsocket-transport-ktor/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/defaultDispatcher.kt b/rsocket-transport-ktor/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/defaultDispatcher.kt new file mode 100644 index 000000000..1ec04304a --- /dev/null +++ b/rsocket-transport-ktor/src/nativeMain/kotlin/io/rsocket/kotlin/transport/ktor/defaultDispatcher.kt @@ -0,0 +1,5 @@ +package io.rsocket.kotlin.transport.ktor + +import kotlinx.coroutines.* + +internal actual val defaultDispatcher: CoroutineDispatcher get() = Dispatchers.Unconfined diff --git a/rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/NativeTcpTransportTest.kt b/rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/transport/ktor/NativeTcpTransportTest.kt similarity index 55% rename from rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/NativeTcpTransportTest.kt rename to rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/transport/ktor/NativeTcpTransportTest.kt index 1da63a23f..a370916e7 100644 --- a/rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/NativeTcpTransportTest.kt +++ b/rsocket-transport-ktor/src/nativeTest/kotlin/io/rsocket/kotlin/transport/ktor/NativeTcpTransportTest.kt @@ -14,17 +14,7 @@ * limitations under the License. */ -package io.rsocket.kotlin +package io.rsocket.kotlin.transport.ktor -import io.ktor.network.selector.* -import io.ktor.util.network.* -import io.rsocket.kotlin.transport.* -import io.rsocket.kotlin.transport.ktor.* -import kotlinx.coroutines.* - -class NativeTcpTransportTest : TcpTransportTest(SelectorManager(), SelectorManager()) { - override fun serverTransport(address: NetworkAddress): ServerTransport = - TcpServerTransport(serverSelector, address) -} - -class NativeTcpServerTest : TcpServerTest(SelectorManager(), SelectorManager()) +class NativeTcpTransportTest : TcpTransportTest() +class NativeTcpServerTest : TcpServerTest() diff --git a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalConnection.kt b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalConnection.kt index 72b44156b..58972a217 100644 --- a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalConnection.kt +++ b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalConnection.kt @@ -20,15 +20,15 @@ import io.ktor.utils.io.core.* import io.ktor.utils.io.core.internal.* import io.ktor.utils.io.pool.* import io.rsocket.kotlin.* -import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlin.coroutines.* @OptIn(TransportApi::class) internal class LocalConnection( private val sender: SendChannel, private val receiver: ReceiveChannel, override val pool: ObjectPool, - override val job: Job + override val coroutineContext: CoroutineContext ) : Connection { override suspend fun send(packet: ByteReadPacket) { diff --git a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt index f3a1c1573..4e889709a 100644 --- a/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt +++ b/rsocket-transport-local/src/commonMain/kotlin/io/rsocket/kotlin/transport/local/LocalServer.kt @@ -15,6 +15,7 @@ */ @file:OptIn(TransportApi::class) +@file:Suppress("FunctionName") package io.rsocket.kotlin.transport.local @@ -26,39 +27,41 @@ import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.transport.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* +import kotlin.coroutines.* -public class LocalServer( - parentJob: Job? = null, - private val pool: ObjectPool = ChunkBuffer.Pool, -) : ServerTransport, ClientTransport { - public val job: Job = SupervisorJob(parentJob) - private val connections = Channel() +public fun LocalServerTransport( + pool: ObjectPool = ChunkBuffer.Pool +): ServerTransport = ServerTransport { accept -> + val connections = Channel() + val handlerJob = launch { + supervisorScope { + connections.consumeEach { connection -> + launch { accept(connection) } + } + } + } + LocalServer(pool, connections, coroutineContext + SupervisorJob(handlerJob)) +} +public class LocalServer internal constructor( + private val pool: ObjectPool, + private val connections: Channel, + override val coroutineContext: CoroutineContext +) : ClientTransport { override suspend fun connect(): Connection { - @Suppress("INVISIBLE_MEMBER") - val clientChannel = SafeChannel(Channel.UNLIMITED) - - @Suppress("INVISIBLE_MEMBER") - val serverChannel = SafeChannel(Channel.UNLIMITED) - val connectionJob = Job(job) + val clientChannel = @Suppress("INVISIBLE_MEMBER") SafeChannel(Channel.UNLIMITED) + val serverChannel = @Suppress("INVISIBLE_MEMBER") SafeChannel(Channel.UNLIMITED) + val connectionJob = Job(coroutineContext[Job]) connectionJob.invokeOnCompletion { - val error = CancellationException("Connection failed", it) - clientChannel.cancel(error) - serverChannel.cancel(error) + @Suppress("INVISIBLE_MEMBER") clientChannel.fullClose(it) + @Suppress("INVISIBLE_MEMBER") serverChannel.fullClose(it) } - val clientConnection = LocalConnection(serverChannel, clientChannel, pool, connectionJob) - val serverConnection = LocalConnection(clientChannel, serverChannel, pool, connectionJob) + val connectionContext = coroutineContext + connectionJob + val clientConnection = + LocalConnection(serverChannel, clientChannel, pool, connectionContext + CoroutineName("rSocket-local-client")) + val serverConnection = + LocalConnection(clientChannel, serverChannel, pool, connectionContext + CoroutineName("rSocket-local-server")) connections.send(serverConnection) return clientConnection } - - @OptIn(DelicateCoroutinesApi::class) - override fun start(accept: suspend (Connection) -> Unit): Job = - GlobalScope.launch(job + Dispatchers.Unconfined, CoroutineStart.UNDISPATCHED) { - supervisorScope { - connections.consumeEach { connection -> - launch(start = CoroutineStart.UNDISPATCHED) { accept(connection) } - } - } - } } diff --git a/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt b/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt index 7a47141a2..2be32e46b 100644 --- a/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt +++ b/rsocket-transport-local/src/commonTest/kotlin/io/rsocket/kotlin/transport/local/LocalTransportTest.kt @@ -20,14 +20,15 @@ import io.rsocket.kotlin.test.* import kotlinx.coroutines.* class LocalTransportTest : TransportTest() { - - private val testJob: Job = Job() + private val testJob = Job() override suspend fun before() { - super.before() - val localServer = LocalServer(testJob, InUseTrackingPool) - SERVER.bind(localServer, ACCEPTOR) - client = CONNECTOR.connect(localServer) + val server = SERVER.bindIn( + CoroutineScope(testJob + CoroutineExceptionHandler { c, e -> println("$c -> $e") }), + LocalServerTransport(InUseTrackingPool), + ACCEPTOR + ) + client = CONNECTOR.connect(server) } override suspend fun after() { diff --git a/settings.gradle.kts b/settings.gradle.kts index 3d13730bc..be056ed4b 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -41,11 +41,6 @@ pluginManagement { dependencyResolutionManagement { repositories { mavenCentral() - jcenter { - content { - includeModule("org.jetbrains.kotlinx", "kotlinx-nodejs") - } - } } } @@ -80,7 +75,8 @@ fun includeExample(name: String) { include("examples:$name") } -includeExample("nodejs-tcp-transport") +//TODO ignore for now, as `kotlinx-nodejs` isn't maintained now +//includeExample("nodejs-tcp-transport") includeExample("interactions") includeExample("multiplatform-chat")