Skip to content

Commit 2e6059f

Browse files
committed
Add coroutines support to RSocket @MessageMapping
Closes gh-22780
1 parent 842e7e5 commit 2e6059f

File tree

7 files changed

+311
-3
lines changed

7 files changed

+311
-3
lines changed

spring-messaging/spring-messaging.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def rsocketVersion = "0.12.2-RC3-SNAPSHOT"
1212
dependencies {
1313
compile(project(":spring-beans"))
1414
compile(project(":spring-core"))
15+
compileOnly(project(":spring-core-coroutines"))
1516
optional(project(":spring-context"))
1617
optional(project(":spring-oxm"))
1718
optional("io.projectreactor.netty:reactor-netty")
@@ -35,6 +36,7 @@ dependencies {
3536
testCompile("org.jetbrains.kotlin:kotlin-reflect:${kotlinVersion}")
3637
testCompile("org.jetbrains.kotlin:kotlin-stdlib:${kotlinVersion}")
3738
testCompile("org.xmlunit:xmlunit-matchers:2.6.2")
39+
testCompile(project(":spring-core-coroutines"))
3840
testRuntime("com.sun.xml.bind:jaxb-core:2.3.0.1")
3941
testRuntime("com.sun.xml.bind:jaxb-impl:2.3.0.1")
4042
testRuntime("com.sun.activation:javax.activation:1.2.0")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.messaging.handler.annotation.support.reactive;
18+
19+
import reactor.core.publisher.Mono;
20+
21+
import org.springframework.core.MethodParameter;
22+
import org.springframework.messaging.Message;
23+
import org.springframework.messaging.handler.invocation.reactive.HandlerMethodArgumentResolver;
24+
25+
/**
26+
* No-op resolver for method arguments of type {@link kotlin.coroutines.Continuation}.
27+
*
28+
* @author Sebastien Deleuze
29+
* @since 5.2
30+
*/
31+
public class ContinuationHandlerMethodArgumentResolver implements HandlerMethodArgumentResolver {
32+
33+
@Override
34+
public boolean supportsParameter(MethodParameter parameter) {
35+
return "kotlin.coroutines.Continuation".equals(parameter.getParameterType().getName());
36+
}
37+
38+
@Override
39+
public Mono<Object> resolveArgument(MethodParameter parameter, Message<?> message) {
40+
return Mono.empty();
41+
}
42+
}

spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/reactive/MessageMappingMessageHandler.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
import org.springframework.context.ApplicationContext;
3535
import org.springframework.context.ConfigurableApplicationContext;
3636
import org.springframework.context.EmbeddedValueResolverAware;
37+
import org.springframework.core.KotlinDetector;
3738
import org.springframework.core.annotation.AnnotatedElementUtils;
3839
import org.springframework.core.codec.Decoder;
3940
import org.springframework.core.convert.ConversionService;
@@ -238,6 +239,11 @@ protected List<? extends HandlerMethodArgumentResolver> initArgumentResolvers()
238239
resolvers.add(new HeadersMethodArgumentResolver());
239240
resolvers.add(new DestinationVariableMethodArgumentResolver(this.conversionService));
240241

242+
// Type-based...
243+
if (KotlinDetector.isKotlinPresent()) {
244+
resolvers.add(new ContinuationHandlerMethodArgumentResolver());
245+
}
246+
241247
// Custom resolvers
242248
resolvers.addAll(getArgumentResolverConfigurer().getCustomResolvers());
243249

spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/AbstractEncoderMethodReturnValueHandler.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,20 @@
1616

1717
package org.springframework.messaging.handler.invocation.reactive;
1818

19+
import java.lang.reflect.Method;
1920
import java.util.Collections;
2021
import java.util.List;
2122
import java.util.Map;
2223

24+
import kotlin.reflect.KFunction;
25+
import kotlin.reflect.jvm.ReflectJvmMapping;
2326
import org.apache.commons.logging.Log;
2427
import org.apache.commons.logging.LogFactory;
2528
import org.reactivestreams.Publisher;
2629
import reactor.core.publisher.Flux;
2730
import reactor.core.publisher.Mono;
2831

32+
import org.springframework.core.KotlinDetector;
2933
import org.springframework.core.MethodParameter;
3034
import org.springframework.core.ReactiveAdapter;
3135
import org.springframework.core.ReactiveAdapterRegistry;
@@ -60,6 +64,8 @@ public abstract class AbstractEncoderMethodReturnValueHandler implements Handler
6064

6165
private static final ResolvableType OBJECT_RESOLVABLE_TYPE = ResolvableType.forClass(Object.class);
6266

67+
private static final String COROUTINES_FLOW_CLASS_NAME = "kotlinx.coroutines.flow.Flow";
68+
6369

6470
protected final Log logger = LogFactory.getLog(getClass());
6571

@@ -132,7 +138,11 @@ private Flux<DataBuffer> encodeContent(
132138
ResolvableType elementType;
133139
if (adapter != null) {
134140
publisher = adapter.toPublisher(content);
135-
ResolvableType genericType = returnValueType.getGeneric();
141+
boolean isUnwrapped = KotlinDetector.isKotlinReflectPresent() &&
142+
KotlinDetector.isKotlinType(returnType.getContainingClass()) &&
143+
KotlinDelegate.isSuspend(returnType.getMethod()) &&
144+
!COROUTINES_FLOW_CLASS_NAME.equals(returnValueType.toClass().getName());
145+
ResolvableType genericType = isUnwrapped ? returnValueType : returnValueType.getGeneric();
136146
elementType = getElementType(adapter, genericType);
137147
}
138148
else {
@@ -213,4 +223,16 @@ protected abstract Mono<Void> handleEncodedContent(
213223
*/
214224
protected abstract Mono<Void> handleNoContent(MethodParameter returnType, Message<?> message);
215225

226+
227+
/**
228+
* Inner class to avoid a hard dependency on Kotlin at runtime.
229+
*/
230+
private static class KotlinDelegate {
231+
232+
static private boolean isSuspend(Method method) {
233+
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
234+
return function != null && function.isSuspend();
235+
}
236+
}
237+
216238
}

spring-messaging/src/main/java/org/springframework/messaging/handler/invocation/reactive/InvocableHandlerMethod.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626

2727
import reactor.core.publisher.Mono;
2828

29+
import org.springframework.core.CoroutinesUtils;
2930
import org.springframework.core.DefaultParameterNameDiscoverer;
31+
import org.springframework.core.KotlinDetector;
3032
import org.springframework.core.MethodParameter;
3133
import org.springframework.core.ParameterNameDiscoverer;
3234
import org.springframework.core.ReactiveAdapter;
@@ -125,13 +127,20 @@ public void setReactiveAdapterRegistry(ReactiveAdapterRegistry registry) {
125127
* @param providedArgs optional list of argument values to match by type
126128
* @return a Mono with the result from the invocation.
127129
*/
130+
@SuppressWarnings("KotlinInternalInJava")
128131
public Mono<Object> invoke(Message<?> message, Object... providedArgs) {
129132

130133
return getMethodArgumentValues(message, providedArgs).flatMap(args -> {
131134
Object value;
132135
try {
133-
ReflectionUtils.makeAccessible(getBridgedMethod());
134-
value = getBridgedMethod().invoke(getBean(), args);
136+
Method method = getBridgedMethod();
137+
ReflectionUtils.makeAccessible(method);
138+
if (KotlinDetector.isKotlinReflectPresent() && KotlinDetector.isKotlinType(method.getDeclaringClass())) {
139+
value = CoroutinesUtils.invokeHandlerMethod(method, getBean(), args);
140+
}
141+
else {
142+
value = method.invoke(getBean(), args);
143+
}
135144
}
136145
catch (IllegalArgumentException ex) {
137146
assertTargetBean(getBridgedMethod(), getBean(), args);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
/*
2+
* Copyright 2002-2019 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.messaging.rsocket
18+
19+
import java.time.Duration
20+
21+
import io.netty.buffer.PooledByteBufAllocator
22+
import io.rsocket.RSocketFactory
23+
import io.rsocket.frame.decoder.PayloadDecoder
24+
import io.rsocket.transport.netty.server.CloseableChannel
25+
import io.rsocket.transport.netty.server.TcpServerTransport
26+
import kotlinx.coroutines.FlowPreview
27+
import kotlinx.coroutines.delay
28+
import kotlinx.coroutines.flow.Flow
29+
import kotlinx.coroutines.flow.flow
30+
import kotlinx.coroutines.flow.map
31+
import org.junit.AfterClass
32+
import org.junit.BeforeClass
33+
import org.junit.Test
34+
import reactor.core.publisher.Flux
35+
import reactor.core.publisher.ReplayProcessor
36+
import reactor.test.StepVerifier
37+
38+
import org.springframework.context.annotation.AnnotationConfigApplicationContext
39+
import org.springframework.context.annotation.Bean
40+
import org.springframework.context.annotation.Configuration
41+
import org.springframework.core.codec.CharSequenceEncoder
42+
import org.springframework.core.codec.StringDecoder
43+
import org.springframework.core.io.buffer.NettyDataBufferFactory
44+
import org.springframework.messaging.handler.annotation.MessageExceptionHandler
45+
import org.springframework.messaging.handler.annotation.MessageMapping
46+
import org.springframework.stereotype.Controller
47+
48+
/**
49+
* Coroutines server-side handling of RSocket requests.
50+
*
51+
* @author Sebastien Deleuze
52+
* @author Rossen Stoyanchev
53+
*/
54+
class RSocketClientToServerCoroutinesIntegrationTests {
55+
56+
@Test
57+
fun echoAsync() {
58+
val result = Flux.range(1, 3).concatMap { i -> requester.route("echo-async").data("Hello " + i!!).retrieveMono(String::class.java) }
59+
60+
StepVerifier.create(result)
61+
.expectNext("Hello 1 async").expectNext("Hello 2 async").expectNext("Hello 3 async")
62+
.expectComplete()
63+
.verify(Duration.ofSeconds(5))
64+
}
65+
66+
@Test
67+
fun echoStream() {
68+
val result = requester.route("echo-stream").data("Hello").retrieveFlux(String::class.java)
69+
70+
StepVerifier.create(result)
71+
.expectNext("Hello 0").expectNextCount(6).expectNext("Hello 7")
72+
.thenCancel()
73+
.verify(Duration.ofSeconds(5))
74+
}
75+
76+
@Test
77+
fun echoChannel() {
78+
val result = requester.route("echo-channel")
79+
.data(Flux.range(1, 10).map { i -> "Hello " + i!! }, String::class.java)
80+
.retrieveFlux(String::class.java)
81+
82+
StepVerifier.create(result)
83+
.expectNext("Hello 1 async").expectNextCount(8).expectNext("Hello 10 async")
84+
.thenCancel() // https://github.com/rsocket/rsocket-java/issues/613
85+
.verify(Duration.ofSeconds(5))
86+
}
87+
88+
@Test
89+
fun unitReturnValue() {
90+
val result = requester.route("unit-return-value").data("Hello").retrieveFlux(String::class.java)
91+
StepVerifier.create(result).expectComplete().verify(Duration.ofSeconds(5))
92+
}
93+
94+
@Test
95+
fun unitReturnValueFromExceptionHandler() {
96+
val result = requester.route("unit-return-value").data("bad").retrieveFlux(String::class.java)
97+
StepVerifier.create(result).expectComplete().verify(Duration.ofSeconds(5))
98+
}
99+
100+
@Test
101+
fun handleWithThrownException() {
102+
val result = requester.route("thrown-exception").data("a").retrieveMono(String::class.java)
103+
StepVerifier.create(result)
104+
.expectNext("Invalid input error handled")
105+
.expectComplete()
106+
.verify(Duration.ofSeconds(5))
107+
}
108+
109+
@FlowPreview
110+
@Controller
111+
class ServerController {
112+
113+
val fireForgetPayloads = ReplayProcessor.create<String>()
114+
115+
@MessageMapping("echo-async")
116+
suspend fun echoAsync(payload: String): String {
117+
delay(10)
118+
return "$payload async"
119+
}
120+
121+
@MessageMapping("echo-stream")
122+
fun echoStream(payload: String): Flow<String> {
123+
var i = 0
124+
return flow {
125+
while(true) {
126+
delay(10)
127+
emit("$payload ${i++}")
128+
}
129+
}
130+
}
131+
132+
@MessageMapping("echo-channel")
133+
fun echoChannel(payloads: Flow<String>) = payloads.map {
134+
delay(10)
135+
"$it async"
136+
}
137+
138+
@MessageMapping("thrown-exception")
139+
suspend fun handleAndThrow(payload: String): String {
140+
delay(10)
141+
throw IllegalArgumentException("Invalid input error")
142+
}
143+
144+
@MessageMapping("unit-return-value")
145+
suspend fun unitReturnValue(payload: String) =
146+
if (payload != "bad") delay(10) else throw IllegalStateException("bad")
147+
148+
@MessageExceptionHandler
149+
suspend fun handleException(ex: IllegalArgumentException): String {
150+
delay(10)
151+
return "${ex.message} handled"
152+
}
153+
154+
@MessageExceptionHandler
155+
suspend fun handleExceptionWithVoidReturnValue(ex: IllegalStateException) {
156+
delay(10)
157+
}
158+
}
159+
160+
161+
@Configuration
162+
open class ServerConfig {
163+
164+
@Bean
165+
open fun controller(): ServerController {
166+
return ServerController()
167+
}
168+
169+
@Bean
170+
open fun messageHandlerAcceptor(): MessageHandlerAcceptor {
171+
val acceptor = MessageHandlerAcceptor()
172+
acceptor.rSocketStrategies = rsocketStrategies()
173+
return acceptor
174+
}
175+
176+
@Bean
177+
open fun rsocketStrategies(): RSocketStrategies {
178+
return RSocketStrategies.builder()
179+
.decoder(StringDecoder.allMimeTypes())
180+
.encoder(CharSequenceEncoder.allMimeTypes())
181+
.dataBufferFactory(NettyDataBufferFactory(PooledByteBufAllocator.DEFAULT))
182+
.build()
183+
}
184+
}
185+
186+
companion object {
187+
188+
private lateinit var context: AnnotationConfigApplicationContext
189+
190+
private lateinit var server: CloseableChannel
191+
192+
private val interceptor = FireAndForgetCountingInterceptor()
193+
194+
private lateinit var requester: RSocketRequester
195+
196+
197+
@BeforeClass
198+
@JvmStatic
199+
fun setupOnce() {
200+
context = AnnotationConfigApplicationContext(ServerConfig::class.java)
201+
202+
server = RSocketFactory.receive()
203+
.addServerPlugin(interceptor)
204+
.frameDecoder(PayloadDecoder.ZERO_COPY)
205+
.acceptor(context.getBean(MessageHandlerAcceptor::class.java))
206+
.transport(TcpServerTransport.create("localhost", 7000))
207+
.start()
208+
.block()!!
209+
210+
requester = RSocketRequester.builder()
211+
.rsocketFactory { factory -> factory.frameDecoder(PayloadDecoder.ZERO_COPY) }
212+
.rsocketStrategies(context.getBean(RSocketStrategies::class.java))
213+
.connectTcp("localhost", 7000)
214+
.block()!!
215+
}
216+
217+
@AfterClass
218+
@JvmStatic
219+
fun tearDownOnce() {
220+
requester.rsocket().dispose()
221+
server.dispose()
222+
}
223+
}
224+
225+
}

0 commit comments

Comments
 (0)