Skip to content

Commit d404ee6

Browse files
author
olme04
authored
provides fragmentation and reassembly (#177)
Co-authored-by: olme04 <olme04>
1 parent 5526b27 commit d404ee6

26 files changed

+502
-78
lines changed

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnector.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import kotlinx.coroutines.*
2727
@OptIn(TransportApi::class, RSocketLoggingApi::class)
2828
public class RSocketConnector internal constructor(
2929
private val loggerFactory: LoggerFactory,
30+
private val maxFragmentSize: Int,
3031
private val interceptors: Interceptors,
3132
private val connectionConfigProvider: () -> ConnectionConfig,
3233
private val acceptor: ConnectionAcceptor,
@@ -61,6 +62,7 @@ public class RSocketConnector internal constructor(
6162
try {
6263
val requester = connection.connect(
6364
isServer = false,
65+
maxFragmentSize = maxFragmentSize,
6466
interceptors = interceptors,
6567
connectionConfig = connectionConfig,
6668
acceptor = acceptor

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketConnectorBuilder.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ import kotlinx.coroutines.*
2626
public class RSocketConnectorBuilder internal constructor() {
2727
@RSocketLoggingApi
2828
public var loggerFactory: LoggerFactory = DefaultLoggerFactory
29+
public var maxFragmentSize: Int = 0
30+
set(value) {
31+
require(value == 0 || value >= 64) {
32+
"maxFragmentSize should be zero (no fragmentation) or greater than or equal to 64, but was $value"
33+
}
34+
field = value
35+
}
2936

3037
private val connectionConfig: ConnectionConfigBuilder = ConnectionConfigBuilder()
3138
private val interceptors: InterceptorsBuilder = InterceptorsBuilder()
@@ -96,6 +103,7 @@ public class RSocketConnectorBuilder internal constructor() {
96103
@OptIn(RSocketLoggingApi::class)
97104
internal fun build(): RSocketConnector = RSocketConnector(
98105
loggerFactory,
106+
maxFragmentSize,
99107
interceptors.build(),
100108
connectionConfig.producer(),
101109
acceptor ?: defaultAcceptor,

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServer.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import kotlinx.coroutines.*
2727
@OptIn(TransportApi::class, RSocketLoggingApi::class)
2828
public class RSocketServer internal constructor(
2929
private val loggerFactory: LoggerFactory,
30+
private val maxFragmentSize: Int,
3031
private val interceptors: Interceptors,
3132
) {
3233

@@ -44,6 +45,7 @@ public class RSocketServer internal constructor(
4445
else -> try {
4546
connect(
4647
isServer = true,
48+
maxFragmentSize = maxFragmentSize,
4749
interceptors = interceptors,
4850
connectionConfig = ConnectionConfig(
4951
keepAlive = setupFrame.keepAlive,

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/RSocketServerBuilder.kt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ import io.rsocket.kotlin.logging.*
2222
public class RSocketServerBuilder internal constructor() {
2323
@RSocketLoggingApi
2424
public var loggerFactory: LoggerFactory = DefaultLoggerFactory
25+
public var maxFragmentSize: Int = 0
26+
set(value) {
27+
require(value == 0 || value >= 64) {
28+
"maxFragmentSize should be zero (no fragmentation) or greater than or equal to 64, but was $value"
29+
}
30+
field = value
31+
}
2532

2633
private val interceptors: InterceptorsBuilder = InterceptorsBuilder()
2734

@@ -30,7 +37,7 @@ public class RSocketServerBuilder internal constructor() {
3037
}
3138

3239
@OptIn(RSocketLoggingApi::class)
33-
internal fun build(): RSocketServer = RSocketServer(loggerFactory, interceptors.build())
40+
internal fun build(): RSocketServer = RSocketServer(loggerFactory, maxFragmentSize, interceptors.build())
3441
}
3542

3643
public fun RSocketServer(configure: RSocketServerBuilder.() -> Unit = {}): RSocketServer {

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Connect.kt

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@ import kotlinx.coroutines.*
2525
@OptIn(TransportApi::class)
2626
internal suspend inline fun Connection.connect(
2727
isServer: Boolean,
28+
maxFragmentSize: Int,
2829
interceptors: Interceptors,
2930
connectionConfig: ConnectionConfig,
3031
acceptor: ConnectionAcceptor
3132
): RSocket {
3233
val keepAliveHandler = KeepAliveHandler(connectionConfig.keepAlive)
3334
val prioritizer = Prioritizer()
34-
val streamsStorage = StreamsStorage(isServer)
35+
val frameSender = FrameSender(prioritizer, pool, maxFragmentSize)
36+
val streamsStorage = StreamsStorage(isServer, pool)
3537
val requestJob = SupervisorJob(job)
3638

3739
requestJob.invokeOnCompletion {
@@ -43,7 +45,7 @@ internal suspend inline fun Connection.connect(
4345
val requestScope = CoroutineScope(requestJob + Dispatchers.Unconfined + CoroutineExceptionHandler { _, _ -> })
4446
val connectionScope = CoroutineScope(job + Dispatchers.Unconfined + CoroutineExceptionHandler { _, _ -> })
4547

46-
val requester = interceptors.wrapRequester(RSocketRequester(job, prioritizer, streamsStorage, requestScope))
48+
val requester = interceptors.wrapRequester(RSocketRequester(job, frameSender, streamsStorage, requestScope, pool))
4749
val requestHandler = interceptors.wrapResponder(
4850
with(interceptors.wrapAcceptor(acceptor)) {
4951
ConnectionAcceptorContext(connectionConfig, requester).accept()
@@ -71,7 +73,7 @@ internal suspend inline fun Connection.connect(
7173

7274
// start frame handling
7375
connectionScope.launch {
74-
val rSocketResponder = RSocketResponder(prioritizer, requestHandler, requestScope)
76+
val rSocketResponder = RSocketResponder(frameSender, requestHandler, requestScope)
7577
while (isActive) {
7678
receiveFrame().closeOnError { frame ->
7779
when (frame.streamId) {
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright 2015-2020 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.rsocket.kotlin.internal
18+
19+
import io.ktor.utils.io.core.*
20+
import io.ktor.utils.io.core.internal.*
21+
import io.ktor.utils.io.pool.*
22+
import io.rsocket.kotlin.frame.*
23+
import io.rsocket.kotlin.frame.io.*
24+
import io.rsocket.kotlin.payload.*
25+
import kotlinx.coroutines.*
26+
import kotlin.math.*
27+
28+
private const val lengthSize = 3
29+
private const val headerSize = 6
30+
private const val fragmentOffset = lengthSize + headerSize
31+
private const val fragmentOffsetWithMetadata = fragmentOffset + lengthSize
32+
33+
internal class FrameSender(
34+
private val prioritizer: Prioritizer,
35+
private val pool: ObjectPool<ChunkBuffer>,
36+
private val maxFragmentSize: Int
37+
) {
38+
39+
suspend fun sendKeepAlive(respond: Boolean, lastPosition: Long, data: ByteReadPacket): Unit =
40+
prioritizer.send(KeepAliveFrame(respond, lastPosition, data))
41+
42+
suspend fun sendMetadataPush(metadata: ByteReadPacket): Unit = prioritizer.send(MetadataPushFrame(metadata))
43+
44+
suspend fun sendCancel(id: Int): Unit = withContext(NonCancellable) { prioritizer.send(CancelFrame(id)) }
45+
suspend fun sendError(id: Int, throwable: Throwable): Unit = withContext(NonCancellable) { prioritizer.send(ErrorFrame(id, throwable)) }
46+
suspend fun sendRequestN(id: Int, n: Int): Unit = prioritizer.send(RequestNFrame(id, n))
47+
48+
suspend fun sendRequestPayload(type: FrameType, streamId: Int, payload: Payload, initialRequest: Int = 0) {
49+
sendFragmented(type, streamId, payload, false, false, initialRequest)
50+
}
51+
52+
suspend fun sendNextPayload(streamId: Int, payload: Payload) {
53+
sendFragmented(FrameType.Payload, streamId, payload, false, true, 0)
54+
}
55+
56+
suspend fun sendNextCompletePayload(streamId: Int, payload: Payload) {
57+
sendFragmented(FrameType.Payload, streamId, payload, true, true, 0)
58+
}
59+
60+
suspend fun sendCompletePayload(streamId: Int) {
61+
sendFragmented(FrameType.Payload, streamId, Payload.Empty, true, false, 0)
62+
}
63+
64+
private suspend fun sendFragmented(
65+
type: FrameType,
66+
streamId: Int,
67+
payload: Payload,
68+
complete: Boolean,
69+
next: Boolean,
70+
initialRequest: Int
71+
) {
72+
//TODO release on fail ?
73+
if (!payload.isFragmentable(type.hasInitialRequest)) {
74+
prioritizer.send(RequestFrame(type, streamId, false, complete, next, initialRequest, payload))
75+
return
76+
}
77+
78+
val data = payload.data
79+
val metadata = payload.metadata
80+
81+
val fragmentSize = maxFragmentSize - fragmentOffset - (if (type.hasInitialRequest) Int.SIZE_BYTES else 0)
82+
83+
var first = true
84+
var remaining = fragmentSize
85+
if (metadata != null) remaining -= lengthSize
86+
87+
do {
88+
val metadataFragment = if (metadata != null && metadata.isNotEmpty) {
89+
if (!first) remaining -= lengthSize
90+
val length = min(metadata.remaining.toInt(), remaining)
91+
remaining -= length
92+
metadata.readPacket(pool, length)
93+
} else null
94+
95+
val dataFragment = if (remaining > 0 && data.isNotEmpty) {
96+
val length = min(data.remaining.toInt(), remaining)
97+
remaining -= length
98+
data.readPacket(pool, length)
99+
} else {
100+
ByteReadPacket.Empty
101+
}
102+
103+
val fType = if (first && type.isRequestType) type else FrameType.Payload
104+
val fragment = Payload(dataFragment, metadataFragment)
105+
val follows = metadata != null && metadata.isNotEmpty || data.isNotEmpty
106+
prioritizer.send(RequestFrame(fType, streamId, follows, (!follows && complete), !fType.isRequestType, initialRequest, fragment))
107+
first = false
108+
remaining = fragmentSize
109+
} while (follows)
110+
}
111+
112+
private fun Payload.isFragmentable(hasInitialRequest: Boolean) = when (maxFragmentSize) {
113+
0 -> false
114+
else -> when (val meta = metadata) {
115+
null -> data.remaining > maxFragmentSize - fragmentOffset - (if (hasInitialRequest) Int.SIZE_BYTES else 0)
116+
else -> data.remaining + meta.remaining > maxFragmentSize - fragmentOffsetWithMetadata - (if (hasInitialRequest) Int.SIZE_BYTES else 0)
117+
}
118+
}
119+
120+
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/Prioritizer.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ internal class Prioritizer {
3030
private val commonChannel = SafeChannel<Frame>(Channel.UNLIMITED)
3131

3232
suspend fun send(frame: Frame) {
33-
if (frame.type != FrameType.Cancel && frame.type != FrameType.Error) currentCoroutineContext().ensureActive()
33+
currentCoroutineContext().ensureActive()
3434
val channel = if (frame.streamId == 0) priorityChannel else commonChannel
3535
channel.send(frame)
3636
}

rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package io.rsocket.kotlin.internal
1818

1919
import io.ktor.utils.io.core.*
20+
import io.ktor.utils.io.core.internal.*
21+
import io.ktor.utils.io.pool.*
2022
import io.rsocket.kotlin.*
2123
import io.rsocket.kotlin.frame.*
2224
import io.rsocket.kotlin.internal.handler.*
@@ -28,83 +30,84 @@ import kotlinx.coroutines.flow.*
2830
@OptIn(ExperimentalStreamsApi::class)
2931
internal class RSocketRequester(
3032
connectionJob: Job,
31-
private val prioritizer: Prioritizer,
33+
private val sender: FrameSender,
3234
private val streamsStorage: StreamsStorage,
33-
private val requestScope: CoroutineScope
35+
private val requestScope: CoroutineScope,
36+
private val pool: ObjectPool<ChunkBuffer>
3437
) : RSocket {
3538
override val job: Job = connectionJob
3639

3740
override suspend fun metadataPush(metadata: ByteReadPacket) {
3841
ensureActiveOrRelease(metadata)
3942
metadata.closeOnError {
40-
prioritizer.send(MetadataPushFrame(metadata))
43+
sender.sendMetadataPush(metadata)
4144
}
4245
}
4346

4447
override suspend fun fireAndForget(payload: Payload) {
4548
ensureActiveOrRelease(payload)
4649

47-
val streamId = streamsStorage.nextId()
50+
val id = streamsStorage.nextId()
4851
try {
49-
prioritizer.send(RequestFireAndForgetFrame(streamId, payload))
52+
sender.sendRequestPayload(FrameType.RequestFnF, id, payload)
5053
} catch (cause: Throwable) {
5154
payload.release()
52-
if (job.isActive) prioritizer.send(CancelFrame(streamId)) //if cancelled during fragmentation
55+
if (job.isActive) sender.sendCancel(id) //if cancelled during fragmentation
5356
throw cause
5457
}
5558
}
5659

5760
override suspend fun requestResponse(payload: Payload): Payload {
5861
ensureActiveOrRelease(payload)
5962

60-
val streamId = streamsStorage.nextId()
63+
val id = streamsStorage.nextId()
6164

6265
val deferred = CompletableDeferred<Payload>()
63-
val handler = RequesterRequestResponseFrameHandler(streamId, streamsStorage, deferred)
64-
streamsStorage.save(streamId, handler)
66+
val handler = RequesterRequestResponseFrameHandler(id, streamsStorage, deferred, pool)
67+
streamsStorage.save(id, handler)
6568

66-
return handler.receiveOrCancel(streamId, payload) {
67-
prioritizer.send(RequestResponseFrame(streamId, payload))
69+
return handler.receiveOrCancel(id, payload) {
70+
sender.sendRequestPayload(FrameType.RequestResponse, id, payload)
6871
deferred.await()
6972
}
7073
}
7174

7275
override fun requestStream(payload: Payload): Flow<Payload> = requestFlow { strategy, initialRequest ->
7376
ensureActiveOrRelease(payload)
7477

75-
val streamId = streamsStorage.nextId()
78+
val id = streamsStorage.nextId()
7679

7780
val channel = SafeChannel<Payload>(Channel.UNLIMITED)
78-
val handler = RequesterRequestStreamFrameHandler(streamId, streamsStorage, channel)
79-
streamsStorage.save(streamId, handler)
81+
val handler = RequesterRequestStreamFrameHandler(id, streamsStorage, channel, pool)
82+
streamsStorage.save(id, handler)
8083

81-
handler.receiveOrCancel(streamId, payload) {
82-
prioritizer.send(RequestStreamFrame(streamId, initialRequest, payload))
83-
emitAllWithRequestN(channel, strategy) { prioritizer.send(RequestNFrame(streamId, it)) }
84+
handler.receiveOrCancel(id, payload) {
85+
sender.sendRequestPayload(FrameType.RequestStream, id, payload, initialRequest)
86+
emitAllWithRequestN(channel, strategy) { sender.sendRequestN(id, it) }
8487
}
8588
}
8689

8790
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> = requestFlow { strategy, initialRequest ->
8891
ensureActiveOrRelease(initPayload)
8992

90-
val streamId = streamsStorage.nextId()
93+
val id = streamsStorage.nextId()
9194

9295
val channel = SafeChannel<Payload>(Channel.UNLIMITED)
9396
val limiter = Limiter(0)
94-
val sender = Job(requestScope.coroutineContext.job)
95-
val handler = RequesterRequestChannelFrameHandler(streamId, streamsStorage, limiter, sender, channel)
96-
streamsStorage.save(streamId, handler)
97+
val payloadsJob = Job(requestScope.coroutineContext.job)
98+
val handler = RequesterRequestChannelFrameHandler(id, streamsStorage, limiter, payloadsJob, channel, pool)
99+
streamsStorage.save(id, handler)
97100

98-
handler.receiveOrCancel(streamId, initPayload) {
99-
prioritizer.send(RequestChannelFrame(streamId, initialRequest, initPayload))
101+
handler.receiveOrCancel(id, initPayload) {
102+
sender.sendRequestPayload(FrameType.RequestChannel, id, initPayload, initialRequest)
100103
//TODO lazy?
101-
requestScope.launch(sender) {
102-
handler.sendOrFail(streamId) {
103-
payloads.collectLimiting(limiter) { prioritizer.send(NextPayloadFrame(streamId, it)) }
104-
prioritizer.send(CompletePayloadFrame(streamId))
104+
requestScope.launch(payloadsJob) {
105+
handler.sendOrFail(id) {
106+
payloads.collectLimiting(limiter) { sender.sendNextPayload(id, it) }
107+
sender.sendCompletePayload(id)
105108
}
106109
}
107-
emitAllWithRequestN(channel, strategy) { prioritizer.send(RequestNFrame(streamId, it)) }
110+
emitAllWithRequestN(channel, strategy) { sender.sendRequestN(id, it) }
108111
}
109112
}
110113

@@ -114,7 +117,7 @@ internal class RSocketRequester(
114117
onSendComplete()
115118
} catch (cause: Throwable) {
116119
val isFailed = onSendFailed(cause)
117-
if (job.isActive && isFailed) prioritizer.send(ErrorFrame(id, cause))
120+
if (job.isActive && isFailed) sender.sendError(id, cause)
118121
throw cause
119122
}
120123
}
@@ -127,7 +130,7 @@ internal class RSocketRequester(
127130
} catch (cause: Throwable) {
128131
payload.release()
129132
val isCancelled = onReceiveCancelled(cause)
130-
if (job.isActive && isCancelled) prioritizer.send(CancelFrame(id))
133+
if (job.isActive && isCancelled) sender.sendCancel(id)
131134
throw cause
132135
}
133136
}

0 commit comments

Comments
 (0)