Skip to content

Commit 3353ab4

Browse files
authored
Fix leaks on reconnection (#123)
1 parent a687b0e commit 3353ab4

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package io.rsocket.kotlin.core
1818

1919
import io.ktor.utils.io.core.*
2020
import io.rsocket.kotlin.*
21+
import io.rsocket.kotlin.internal.*
2122
import io.rsocket.kotlin.logging.*
2223
import io.rsocket.kotlin.payload.*
2324
import kotlinx.coroutines.*
@@ -97,30 +98,33 @@ private class ReconnectableRSocket(
9798
private val state: StateFlow<ReconnectState>,
9899
) : RSocket {
99100

100-
private val reconnectHandler = state.mapNotNull { it.handleState { null } }.take(1)
101+
private val reconnectHandler = state.mapNotNull { it.current() }.take(1)
101102

102-
//null pointer will never happen
103-
private suspend fun currentRSocket(): RSocket = state.value.handleState { reconnectHandler.first() }!!
103+
private suspend fun currentRSocket(closeable: Closeable): RSocket = closeable.closeOnError { currentRSocket() }
104104

105-
private inline fun ReconnectState.handleState(onReconnect: () -> RSocket?): RSocket? = when (this) {
106-
is ReconnectState.Connected -> when {
107-
rSocket.isActive -> rSocket //connection is ready to handle requests
108-
else -> onReconnect() //reconnection
109-
}
105+
private suspend fun currentRSocket(): RSocket = state.value.current() ?: reconnectHandler.first()
106+
107+
private fun ReconnectState.current(): RSocket? = when (this) {
108+
is ReconnectState.Connected -> rSocket.takeIf(RSocket::isActive) //connection is ready to handle requests
110109
is ReconnectState.Failed -> throw error //connection failed - fail requests
111-
ReconnectState.Connecting -> onReconnect() //reconnection
110+
ReconnectState.Connecting -> null //reconnection
112111
}
113112

114-
private suspend inline fun <T : Any> execSuspend(operation: RSocket.() -> T): T =
115-
currentRSocket().operation()
113+
override suspend fun metadataPush(metadata: ByteReadPacket): Unit =
114+
currentRSocket(metadata).metadataPush(metadata)
115+
116+
override suspend fun fireAndForget(payload: Payload): Unit =
117+
currentRSocket(payload).fireAndForget(payload)
116118

117-
private inline fun execFlow(crossinline operation: RSocket.() -> Flow<Payload>): Flow<Payload> =
118-
flow { emitAll(currentRSocket().operation()) }
119+
override suspend fun requestResponse(payload: Payload): Payload =
120+
currentRSocket(payload).requestResponse(payload)
119121

120-
override suspend fun metadataPush(metadata: ByteReadPacket): Unit = execSuspend { metadataPush(metadata) }
121-
override suspend fun fireAndForget(payload: Payload): Unit = execSuspend { fireAndForget(payload) }
122-
override suspend fun requestResponse(payload: Payload): Payload = execSuspend { requestResponse(payload) }
123-
override fun requestStream(payload: Payload): Flow<Payload> = execFlow { requestStream(payload) }
124-
override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = execFlow { requestChannel(payloads) }
122+
override fun requestStream(payload: Payload): Flow<Payload> = flow {
123+
emitAll(currentRSocket(payload).requestStream(payload))
124+
}
125+
126+
override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = flow {
127+
emitAll(currentRSocket().requestChannel(payloads))
128+
}
125129

126130
}

rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package io.rsocket.kotlin.core
1818

1919
import app.cash.turbine.*
20+
import io.ktor.utils.io.core.*
2021
import io.rsocket.kotlin.*
2122
import io.rsocket.kotlin.logging.*
2223
import io.rsocket.kotlin.payload.*
@@ -54,7 +55,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
5455
val connect: suspend () -> RSocket = {
5556
if (first.value) {
5657
first.value = false
57-
rrHandler(firstJob)
58+
handler(firstJob)
5859
} else {
5960
error("Failed to connect")
6061
}
@@ -89,7 +90,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
8990
first.value = false
9091
error("Failed to connect")
9192
} else {
92-
rrHandler(handlerJob)
93+
handler(handlerJob)
9394
}
9495
}
9596
val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt ->
@@ -114,7 +115,7 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
114115
error("Failed to connect")
115116
} else {
116117
delay(200) //emulate connection establishment
117-
rrHandler(Job())
118+
handler(Job())
118119
}
119120
}
120121
val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt ->
@@ -137,13 +138,13 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
137138
when {
138139
first.value -> {
139140
first.value = false
140-
rrHandler(firstJob) //first connection
141+
handler(firstJob) //first connection
141142
}
142143
fails.value < 5 -> {
143144
delay(100)
144145
error("Failed to connect")
145146
}
146-
else -> rrHandler(Job())
147+
else -> handler(Job())
147148
}
148149
}
149150
val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt ->
@@ -170,13 +171,13 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
170171
when {
171172
first.value -> {
172173
first.value = false
173-
streamHandler(firstJob) //first connection
174+
handler(firstJob) //first connection
174175
}
175176
fails.value < 5 -> {
176177
delay(100)
177178
error("Failed to connect")
178179
}
179-
else -> streamHandler(Job())
180+
else -> handler(Job())
180181
}
181182
}
182183
val rSocket = ReconnectableRSocket(logger, connect) { cause, attempt ->
@@ -206,8 +207,52 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck {
206207
assertEquals(5, fails.value)
207208
}
208209

209-
private fun rrHandler(job: Job): RSocket = RSocketRequestHandler(job) { requestResponse { it } }
210-
private fun streamHandler(job: Job): RSocket = RSocketRequestHandler(job) {
210+
@Test
211+
fun testNoLeakMetadataPush() = testNoLeaksInteraction { metadataPush(it.data) }
212+
213+
@Test
214+
fun testNoLeakFireAndForget() = testNoLeaksInteraction { fireAndForget(it) }
215+
216+
@Test
217+
fun testNoLeakRequestResponse() = testNoLeaksInteraction { requestResponse(it) }
218+
219+
@Test
220+
fun testNoLeakRequestStream() = testNoLeaksInteraction { requestStream(it).collect() }
221+
222+
private inline fun testNoLeaksInteraction(crossinline interaction: suspend RSocket.(payload: Payload) -> Unit) = test {
223+
val firstJob = Job()
224+
val connect: suspend () -> RSocket = {
225+
if (first.compareAndSet(true, false)) {
226+
handler(firstJob)
227+
} else {
228+
error("Failed to connect")
229+
}
230+
}
231+
val rSocket = ReconnectableRSocket(logger, connect) { _, attempt ->
232+
delay(100)
233+
attempt < 5
234+
}
235+
236+
rSocket.requestResponse(Payload.Empty) //first request to be sure, that connected
237+
firstJob.cancelAndJoin() //cancel
238+
239+
val p = payload("text")
240+
assertFails {
241+
rSocket.interaction(p) //test release on reconnecting
242+
}
243+
assertTrue(p.data.isEmpty)
244+
245+
val p2 = payload("text")
246+
assertFails {
247+
rSocket.interaction(p2) //test release on failed
248+
}
249+
assertTrue(p2.data.isEmpty)
250+
}
251+
252+
private fun handler(job: Job): RSocket = RSocketRequestHandler(job) {
253+
requestResponse { payload ->
254+
payload
255+
}
211256
requestStream {
212257
flow {
213258
repeat(5) {

0 commit comments

Comments
 (0)