Skip to content

Commit fd653b4

Browse files
committed
* Add suspend functions support for Messaging Gateway
* Add convenient `CoroutinesUtils` for Coroutines types and `Continuation` argument fulfilling via `Mono` * Treat `suspend fun` in the `GatewayProxyFactoryBean` as a `Mono` return * Convert `Mono` to the `Continuation` resuming in the end of gateway call
1 parent 03b0bfe commit fd653b4

File tree

7 files changed

+166
-32
lines changed

7 files changed

+166
-32
lines changed

spring-integration-core/src/main/java/org/springframework/integration/gateway/GatewayMethodInboundMessageMapper.java

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@
3838
import org.springframework.expression.EvaluationContext;
3939
import org.springframework.expression.Expression;
4040
import org.springframework.expression.spel.standard.SpelExpressionParser;
41-
import org.springframework.expression.spel.support.StandardEvaluationContext;
4241
import org.springframework.integration.expression.ExpressionUtils;
4342
import org.springframework.integration.mapping.InboundMessageMapper;
4443
import org.springframework.integration.mapping.MessageMappingException;
4544
import org.springframework.integration.support.AbstractIntegrationMessageBuilder;
4645
import org.springframework.integration.support.DefaultMessageBuilderFactory;
4746
import org.springframework.integration.support.MessageBuilderFactory;
47+
import org.springframework.integration.util.CoroutinesUtils;
4848
import org.springframework.integration.util.MessagingAnnotationUtils;
4949
import org.springframework.lang.Nullable;
5050
import org.springframework.messaging.Message;
@@ -216,14 +216,6 @@ private Map<String, Object> evaluateHeaders(EvaluationContext methodInvocationEv
216216
return evaluatedHeaders;
217217
}
218218

219-
// TODO Remove in the future release. The MethodArgsHolder as a root object covers this use-case.
220-
private StandardEvaluationContext createMethodInvocationEvaluationContext(Object[] arguments) {
221-
StandardEvaluationContext context = ExpressionUtils.createStandardEvaluationContext(this.beanFactory);
222-
context.setVariable("args", arguments);
223-
context.setVariable("gatewayMethod", this.method);
224-
return context;
225-
}
226-
227219
@Nullable
228220
private Object evaluatePayloadExpression(String expressionString, Object argumentValue) {
229221
Expression expression =
@@ -294,19 +286,23 @@ public class DefaultMethodArgsMessageMapper implements MethodArgsMessageMapper {
294286
public Message<?> toMessage(MethodArgsHolder holder, @Nullable Map<String, Object> headersToMap) {
295287
Object messageOrPayload = null;
296288
boolean foundPayloadAnnotation = false;
297-
Object[] arguments = holder.getArgs();
298-
EvaluationContext methodInvocationEvaluationContext = createMethodInvocationEvaluationContext(arguments);
299-
Map<String, Object> headersToPopulate =
300-
headersToMap != null
301-
? new HashMap<>(headersToMap)
302-
: new HashMap<>();
289+
EvaluationContext methodInvocationEvaluationContext =
290+
ExpressionUtils.createStandardEvaluationContext(GatewayMethodInboundMessageMapper.this.beanFactory);
303291
if (GatewayMethodInboundMessageMapper.this.payloadExpression != null) {
304292
messageOrPayload =
305293
GatewayMethodInboundMessageMapper.this.payloadExpression.getValue(
306294
methodInvocationEvaluationContext, holder);
307295
}
296+
Map<String, Object> headersToPopulate =
297+
headersToMap != null
298+
? new HashMap<>(headersToMap)
299+
: new HashMap<>();
300+
Object[] arguments = holder.getArgs();
308301
for (int i = 0; i < GatewayMethodInboundMessageMapper.this.parameterList.size(); i++) {
309302
Object argumentValue = arguments[i];
303+
if (CoroutinesUtils.isContinuation(argumentValue)) {
304+
continue;
305+
}
310306
MethodParameter methodParameter = GatewayMethodInboundMessageMapper.this.parameterList.get(i);
311307
Annotation annotation =
312308
MessagingAnnotationUtils.findMessagePartAnnotation(methodParameter.getParameterAnnotations(),

spring-integration-core/src/main/java/org/springframework/integration/gateway/GatewayProxyFactoryBean.java

Lines changed: 44 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@
4646
import org.springframework.beans.factory.BeanFactory;
4747
import org.springframework.beans.factory.BeanInitializationException;
4848
import org.springframework.beans.factory.FactoryBean;
49+
import org.springframework.core.KotlinDetector;
4950
import org.springframework.core.MethodParameter;
5051
import org.springframework.core.ResolvableType;
52+
import org.springframework.core.convert.ConversionService;
5153
import org.springframework.core.task.AsyncTaskExecutor;
5254
import org.springframework.core.task.SimpleAsyncTaskExecutor;
5355
import org.springframework.core.task.support.TaskExecutorAdapter;
@@ -67,6 +69,7 @@
6769
import org.springframework.integration.support.management.IntegrationManagement;
6870
import org.springframework.integration.support.management.TrackableComponent;
6971
import org.springframework.integration.support.management.metrics.MetricsCaptor;
72+
import org.springframework.integration.util.CoroutinesUtils;
7073
import org.springframework.lang.Nullable;
7174
import org.springframework.messaging.Message;
7275
import org.springframework.messaging.MessageChannel;
@@ -498,13 +501,14 @@ public Object getObject() {
498501
@Nullable
499502
@SuppressWarnings("deprecation")
500503
public Object invoke(final MethodInvocation invocation) throws Throwable { // NOSONAR
501-
final Class<?> returnType;
502-
MethodInvocationGateway gateway = this.gatewayMap.get(invocation.getMethod());
504+
Method method = invocation.getMethod();
505+
Class<?> returnType;
506+
MethodInvocationGateway gateway = this.gatewayMap.get(method);
503507
if (gateway != null) {
504508
returnType = gateway.returnType;
505509
}
506510
else {
507-
returnType = invocation.getMethod().getReturnType();
511+
returnType = method.getReturnType();
508512
}
509513
if (this.asyncExecutor != null && !Object.class.equals(returnType)) {
510514
Invoker invoker = new Invoker(invocation);
@@ -524,7 +528,7 @@ else if (Future.class.isAssignableFrom(returnType)) {
524528
+ returnType.getSimpleName());
525529
}
526530
}
527-
if (Mono.class.isAssignableFrom(returnType)) {
531+
if (Mono.class.isAssignableFrom(returnType) || KotlinDetector.isSuspendingFunction(method)) {
528532
return doInvoke(invocation, false);
529533
}
530534
else {
@@ -534,8 +538,7 @@ else if (Future.class.isAssignableFrom(returnType)) {
534538

535539
@Nullable
536540
protected Object doInvoke(MethodInvocation invocation, boolean runningOnCallerThread) throws Throwable { // NOSONAR
537-
Method method = invocation.getMethod();
538-
if (AopUtils.isToStringMethod(method)) {
541+
if (AopUtils.isToStringMethod(invocation.getMethod())) {
539542
return "gateway proxy for service interface [" + this.serviceInterface + "]";
540543
}
541544
try {
@@ -575,16 +578,29 @@ private Object invokeGatewayMethod(MethodInvocation invocation, boolean runningO
575578
else {
576579
response = sendOrSendAndReceive(invocation, gateway, shouldReturnMessage, !oneWay);
577580
}
578-
return response(gateway.returnType, shouldReturnMessage, response);
581+
582+
Object continuation = null;
583+
if (gateway.isSuspendingFunction) {
584+
for (Object argument : invocation.getArguments()) {
585+
if (CoroutinesUtils.KOTLIN_CONTINUATION_CLASS.isAssignableFrom(argument.getClass())) {
586+
continuation = argument;
587+
break;
588+
}
589+
}
590+
}
591+
592+
return response(gateway.returnType, shouldReturnMessage, response, continuation);
579593
}
580594

581595
@Nullable
582-
private Object response(Class<?> returnType, boolean shouldReturnMessage, @Nullable Object response) {
596+
private Object response(Class<?> returnType, boolean shouldReturnMessage,
597+
@Nullable Object response, @Nullable Object continuation) {
598+
583599
if (shouldReturnMessage) {
584600
return response;
585601
}
586602
else {
587-
return response != null ? convert(response, returnType) : null;
603+
return response != null ? convert(response, returnType, continuation) : null;
588604
}
589605
}
590606

@@ -627,7 +643,7 @@ private Object sendOrSendAndReceive(MethodInvocation invocation, MethodInvocatio
627643

628644
Object[] args = invocation.getArguments();
629645
if (shouldReply) {
630-
if (gateway.isMonoReturn) {
646+
if (gateway.isMonoReturn || gateway.isSuspendingFunction) {
631647
Mono<Message<?>> messageMono = gateway.sendAndReceiveMessageReactive(args);
632648
if (!shouldReturnMessage) {
633649
return messageMono.map(Message::getPayload);
@@ -641,7 +657,7 @@ private Object sendOrSendAndReceive(MethodInvocation invocation, MethodInvocatio
641657
}
642658
}
643659
else {
644-
if (gateway.isMonoReturn) {
660+
if (gateway.isMonoReturn || gateway.isSuspendingFunction) {
645661
return Mono.fromRunnable(() -> gateway.send(args));
646662
}
647663
else {
@@ -1015,15 +1031,26 @@ protected void doStop() {
10151031

10161032
@SuppressWarnings("unchecked")
10171033
@Nullable
1018-
private <T> T convert(Object source, Class<T> expectedReturnType) {
1034+
private <T> T convert(Object source, Class<T> expectedReturnType, @Nullable Object continuation) {
1035+
if (continuation != null) {
1036+
return CoroutinesUtils.monoAwaitSingleOrNull((Mono<T>) source, continuation);
1037+
}
10191038
if (Future.class.isAssignableFrom(expectedReturnType)) {
10201039
return (T) source;
10211040
}
10221041
if (Mono.class.isAssignableFrom(expectedReturnType)) {
10231042
return (T) source;
10241043
}
1025-
if (getConversionService() != null) {
1026-
return getConversionService().convert(source, expectedReturnType);
1044+
1045+
1046+
return doConvert(source, expectedReturnType);
1047+
}
1048+
1049+
@Nullable
1050+
private <T> T doConvert(Object source, Class<T> expectedReturnType) {
1051+
ConversionService conversionService = getConversionService();
1052+
if (conversionService != null) {
1053+
return conversionService.convert(source, expectedReturnType);
10271054
}
10281055
else {
10291056
return this.typeConverter.convertIfNecessary(source, expectedReturnType);
@@ -1050,6 +1077,8 @@ private static final class MethodInvocationGateway extends MessagingGatewaySuppo
10501077

10511078
private boolean pollable;
10521079

1080+
private boolean isSuspendingFunction;
1081+
10531082
MethodInvocationGateway(GatewayMethodInboundMessageMapper messageMapper) {
10541083
setRequestMapper(messageMapper);
10551084
}
@@ -1088,6 +1117,7 @@ void setupReturnType(Class<?> serviceInterface, Method method) {
10881117
this.expectMessage = hasReturnParameterizedWithMessage(resolvableType);
10891118
}
10901119
this.isVoidReturn = isVoidReturnType(resolvableType);
1120+
this.isSuspendingFunction = KotlinDetector.isSuspendingFunction(method);
10911121
}
10921122

10931123
private boolean hasReturnParameterizedWithMessage(ResolvableType resolvableType) {

spring-integration-core/src/main/java/org/springframework/integration/handler/AbstractMessageProducingHandler.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ protected boolean isAsync() {
133133
* {@link #shouldCopyRequestHeaders() shouldCopyRequestHeaaders} is true.
134134
* At least one pattern as "*" means do not copy headers at all.
135135
* @param headers the headers to not propagate from the inbound message.
136-
* @see org.springframework.util.PatternMatchUtils
137136
* @since 4.3.10
137+
* @see org.springframework.util.PatternMatchUtils
138138
*/
139139
@Override
140140
public void setNotPropagatedHeaders(String... headers) {

spring-integration-core/src/main/java/org/springframework/integration/handler/support/ContinuationHandlerMethodArgumentResolver.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.integration.handler.support;
1818

1919
import org.springframework.core.MethodParameter;
20+
import org.springframework.integration.util.CoroutinesUtils;
2021
import org.springframework.messaging.Message;
2122
import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver;
2223

@@ -33,7 +34,7 @@ public class ContinuationHandlerMethodArgumentResolver implements HandlerMethodA
3334

3435
@Override
3536
public boolean supportsParameter(MethodParameter parameter) {
36-
return "kotlin.coroutines.Continuation".equals(parameter.getParameterType().getName());
37+
return CoroutinesUtils.isContinuationType(parameter.getParameterType());
3738
}
3839

3940
@Override

spring-integration-core/src/main/java/org/springframework/integration/handler/support/MessagingMethodInvokerHelper.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
import org.springframework.integration.support.management.ManageableLifecycle;
8484
import org.springframework.integration.util.AbstractExpressionEvaluator;
8585
import org.springframework.integration.util.AnnotatedMethodFilter;
86+
import org.springframework.integration.util.CoroutinesUtils;
8687
import org.springframework.integration.util.FixedMethodFilter;
8788
import org.springframework.integration.util.MessagingAnnotationUtils;
8889
import org.springframework.integration.util.UniqueMethodFilter;
@@ -1176,7 +1177,7 @@ else if (Map.class.isAssignableFrom(parameterType)) {
11761177
populateMapParameterForExpression(sb, parameterType);
11771178
return true;
11781179
}
1179-
else if ("kotlin.coroutines.Continuation".equals(parameterType.getName())) {
1180+
else if (CoroutinesUtils.isContinuationType(parameterType)) {
11801181
sb.append("null");
11811182
}
11821183
else {
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright 2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.integration.util;
18+
19+
import org.springframework.core.KotlinDetector;
20+
import org.springframework.util.Assert;
21+
import org.springframework.util.ClassUtils;
22+
23+
import reactor.core.publisher.Mono;
24+
25+
/**
26+
* Additional utilities for working with Kotlin Coroutines.
27+
*
28+
* @author Artem Bilan
29+
*
30+
* @since 6.0
31+
*
32+
* @see org.springframework.core.CoroutinesUtils
33+
*/
34+
public final class CoroutinesUtils {
35+
36+
/**
37+
* The {@link kotlin.coroutines.Continuation} class object.
38+
*/
39+
public static final Class<?> KOTLIN_CONTINUATION_CLASS;
40+
41+
static {
42+
if (KotlinDetector.isKotlinPresent()) {
43+
Class<?> kotlinClass = null;
44+
try {
45+
kotlinClass = ClassUtils.forName("kotlin.coroutines.Continuation", ClassUtils.getDefaultClassLoader());
46+
}
47+
catch (ClassNotFoundException ex) {
48+
//Ignore: assume no Kotlin in classpath
49+
}
50+
finally {
51+
KOTLIN_CONTINUATION_CLASS = kotlinClass;
52+
}
53+
}
54+
else {
55+
KOTLIN_CONTINUATION_CLASS = null;
56+
}
57+
}
58+
59+
public static boolean isContinuationType(Class<?> candidate) {
60+
return KOTLIN_CONTINUATION_CLASS != null && KOTLIN_CONTINUATION_CLASS.isAssignableFrom(candidate);
61+
}
62+
63+
public static boolean isContinuation(Object candidate) {
64+
return KOTLIN_CONTINUATION_CLASS != null && KOTLIN_CONTINUATION_CLASS.isAssignableFrom(candidate.getClass());
65+
}
66+
67+
@SuppressWarnings("unchecked")
68+
public static <T> T monoAwaitSingleOrNull(Mono<T> source, Object continuation) {
69+
Assert.isAssignable(KOTLIN_CONTINUATION_CLASS, continuation.getClass());
70+
return (T) kotlinx.coroutines.reactor.MonoKt.awaitSingleOrNull(
71+
source, (kotlin.coroutines.Continuation<T>) continuation);
72+
}
73+
74+
private CoroutinesUtils() {
75+
}
76+
77+
}

spring-integration-core/src/test/kotlin/org/springframework/integration/function/FunctionsTests.kt

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.springframework.integration.function
1919
import assertk.assertThat
2020
import assertk.assertions.*
2121
import kotlinx.coroutines.flow.flow
22+
import kotlinx.coroutines.runBlocking
2223
import org.junit.jupiter.api.Test
2324
import org.springframework.beans.factory.annotation.Autowired
2425
import org.springframework.beans.factory.annotation.Qualifier
@@ -32,6 +33,7 @@ import org.springframework.integration.config.EnableIntegration
3233
import org.springframework.integration.dsl.integrationFlow
3334
import org.springframework.integration.endpoint.SourcePollingChannelAdapter
3435
import org.springframework.integration.gateway.GatewayProxyFactoryBean
36+
import org.springframework.integration.handler.ServiceActivatingHandler
3537
import org.springframework.messaging.Message
3638
import org.springframework.messaging.MessageChannel
3739
import org.springframework.messaging.PollableChannel
@@ -188,8 +190,25 @@ class FunctionsTests {
188190
stepVerifier.verify(Duration.ofSeconds(10))
189191
}
190192

193+
@Autowired
194+
private lateinit var suspendRequestChannel: DirectChannel
195+
196+
@Autowired
197+
private lateinit var suspendFunGateway: SuspendFunGateway
198+
199+
@Test
200+
fun `suspend gateway`() {
201+
suspendRequestChannel.subscribe(ServiceActivatingHandler { m -> m.payload.toString().uppercase() })
202+
203+
runBlocking {
204+
val reply = suspendFunGateway.suspendGateway("test suspend gateway")
205+
assertThat(reply).isEqualTo("TEST SUSPEND GATEWAY")
206+
}
207+
}
208+
191209
@Configuration
192210
@EnableIntegration
211+
@IntegrationComponentScan
193212
class Config {
194213

195214
@Bean
@@ -245,6 +264,16 @@ class FunctionsTests {
245264
}
246265
}
247266

267+
@Bean
268+
fun suspendRequestChannel() = DirectChannel()
269+
270+
}
271+
272+
@MessagingGateway(defaultRequestChannel = "suspendRequestChannel")
273+
interface SuspendFunGateway {
274+
275+
suspend fun suspendGateway(payload: String): String
276+
248277
}
249278

250279
interface MonoFunction : Function<String, Mono<Message<*>>>

0 commit comments

Comments
 (0)