Skip to content

GH-3902: Add Kotlin Coroutines Support #3905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ ext {
junit4Version = '4.13.2'
junitJupiterVersion = '5.9.0'
jythonVersion = '2.7.3'
kotlinCoroutinesVersion = '1.6.4'
kryoVersion = '5.3.0'
lettuceVersion = '6.2.0.RELEASE'
log4jVersion = '2.19.0'
Expand Down Expand Up @@ -168,6 +169,7 @@ allprojects {
mavenBom "org.apache.camel:camel-bom:$camelVersion"
mavenBom "org.testcontainers:testcontainers-bom:$testcontainersVersion"
mavenBom "org.apache.groovy:groovy-bom:$groovyVersion"
mavenBom "org.jetbrains.kotlinx:kotlinx-coroutines-bom:$kotlinCoroutinesVersion"
}

}
Expand Down Expand Up @@ -541,7 +543,7 @@ project('spring-integration-core') {
}
optionalApi "io.github.resilience4j:resilience4j-ratelimiter:$resilience4jVersion"
optionalApi "org.apache.avro:avro:$avroVersion"
optionalApi 'org.jetbrains.kotlin:kotlin-stdlib-jdk8'
optionalApi 'org.jetbrains.kotlinx:kotlinx-coroutines-reactor'

testImplementation "org.aspectj:aspectjweaver:$aspectjVersion"
testImplementation "org.hamcrest:hamcrest-core:$hamcrestVersion"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
import org.springframework.integration.context.IntegrationContextUtils;
import org.springframework.integration.context.IntegrationProperties;
import org.springframework.integration.handler.LoggingHandler;
import org.springframework.integration.handler.support.IntegrationMessageHandlerMethodFactory;
import org.springframework.integration.json.JsonPathUtils;
import org.springframework.integration.support.DefaultMessageBuilderFactory;
import org.springframework.integration.support.SmartLifecycleRoleController;
Expand Down Expand Up @@ -462,10 +463,10 @@ private void registerListMessageHandlerMethodFactory() {
}

private static BeanDefinitionBuilder createMessageHandlerMethodFactoryBeanDefinition(boolean listCapable) {
return BeanDefinitionBuilder.genericBeanDefinition(MessageHandlerMethodFactoryCreatingFactoryBean.class,
() -> new MessageHandlerMethodFactoryCreatingFactoryBean(listCapable))
return BeanDefinitionBuilder.genericBeanDefinition(IntegrationMessageHandlerMethodFactory.class,
() -> new IntegrationMessageHandlerMethodFactory(listCapable))
.addConstructorArgValue(listCapable)
.addPropertyReference("argumentResolverMessageConverter",
.addPropertyReference("messageConverter",
IntegrationContextUtils.ARGUMENT_RESOLVER_MESSAGE_CONVERTER_BEAN_NAME);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.springframework.integration.support.AbstractIntegrationMessageBuilder;
import org.springframework.integration.support.DefaultMessageBuilderFactory;
import org.springframework.integration.support.MessageBuilderFactory;
import org.springframework.integration.util.CoroutinesUtils;
import org.springframework.integration.util.MessagingAnnotationUtils;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
Expand Down Expand Up @@ -289,6 +290,9 @@ public Message<?> toMessage(MethodArgsHolder holder, @Nullable Map<String, Objec
for (int i = 0; i < GatewayMethodInboundMessageMapper.this.parameterList.size(); i++) {
Object argumentValue = arguments[i];
MethodParameter methodParameter = GatewayMethodInboundMessageMapper.this.parameterList.get(i);
if (CoroutinesUtils.isContinuationType(methodParameter.getParameterType())) {
continue;
}
Annotation annotation =
MessagingAnnotationUtils.findMessagePartAnnotation(methodParameter.getParameterAnnotations(),
false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanInitializationException;
import org.springframework.beans.factory.FactoryBean;
import org.springframework.core.KotlinDetector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.core.convert.ConversionService;
import org.springframework.core.task.AsyncTaskExecutor;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.core.task.support.TaskExecutorAdapter;
Expand All @@ -67,6 +69,7 @@
import org.springframework.integration.support.management.IntegrationManagement;
import org.springframework.integration.support.management.TrackableComponent;
import org.springframework.integration.support.management.metrics.MetricsCaptor;
import org.springframework.integration.util.CoroutinesUtils;
import org.springframework.lang.Nullable;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
Expand Down Expand Up @@ -498,13 +501,14 @@ public Object getObject() {
@Nullable
@SuppressWarnings("deprecation")
public Object invoke(final MethodInvocation invocation) throws Throwable { // NOSONAR
final Class<?> returnType;
MethodInvocationGateway gateway = this.gatewayMap.get(invocation.getMethod());
Method method = invocation.getMethod();
Class<?> returnType;
MethodInvocationGateway gateway = this.gatewayMap.get(method);
if (gateway != null) {
returnType = gateway.returnType;
}
else {
returnType = invocation.getMethod().getReturnType();
returnType = method.getReturnType();
}
if (this.asyncExecutor != null && !Object.class.equals(returnType)) {
Invoker invoker = new Invoker(invocation);
Expand All @@ -524,7 +528,7 @@ else if (Future.class.isAssignableFrom(returnType)) {
+ returnType.getSimpleName());
}
}
if (Mono.class.isAssignableFrom(returnType)) {
if (Mono.class.isAssignableFrom(returnType) || KotlinDetector.isSuspendingFunction(method)) {
return doInvoke(invocation, false);
}
else {
Expand All @@ -534,8 +538,7 @@ else if (Future.class.isAssignableFrom(returnType)) {

@Nullable
protected Object doInvoke(MethodInvocation invocation, boolean runningOnCallerThread) throws Throwable { // NOSONAR
Method method = invocation.getMethod();
if (AopUtils.isToStringMethod(method)) {
if (AopUtils.isToStringMethod(invocation.getMethod())) {
return "gateway proxy for service interface [" + this.serviceInterface + "]";
}
try {
Expand Down Expand Up @@ -575,16 +578,29 @@ private Object invokeGatewayMethod(MethodInvocation invocation, boolean runningO
else {
response = sendOrSendAndReceive(invocation, gateway, shouldReturnMessage, !oneWay);
}
return response(gateway.returnType, shouldReturnMessage, response);

Object continuation = null;
if (gateway.isSuspendingFunction) {
for (Object argument : invocation.getArguments()) {
if (argument != null && CoroutinesUtils.isContinuation(argument)) {
continuation = argument;
break;
}
}
}

return response(gateway.returnType, shouldReturnMessage, response, continuation);
}

@Nullable
private Object response(Class<?> returnType, boolean shouldReturnMessage, @Nullable Object response) {
private Object response(Class<?> returnType, boolean shouldReturnMessage,
@Nullable Object response, @Nullable Object continuation) {

if (shouldReturnMessage) {
return response;
}
else {
return response != null ? convert(response, returnType) : null;
return response != null ? convert(response, returnType, continuation) : null;
}
}

Expand Down Expand Up @@ -627,7 +643,7 @@ private Object sendOrSendAndReceive(MethodInvocation invocation, MethodInvocatio

Object[] args = invocation.getArguments();
if (shouldReply) {
if (gateway.isMonoReturn) {
if (gateway.isMonoReturn || gateway.isSuspendingFunction) {
Mono<Message<?>> messageMono = gateway.sendAndReceiveMessageReactive(args);
if (!shouldReturnMessage) {
return messageMono.map(Message::getPayload);
Expand All @@ -641,7 +657,7 @@ private Object sendOrSendAndReceive(MethodInvocation invocation, MethodInvocatio
}
}
else {
if (gateway.isMonoReturn) {
if (gateway.isMonoReturn || gateway.isSuspendingFunction) {
return Mono.fromRunnable(() -> gateway.send(args));
}
else {
Expand Down Expand Up @@ -1013,17 +1029,28 @@ protected void doStop() {
this.gatewayMap.values().forEach(MethodInvocationGateway::stop);
}

@SuppressWarnings("unchecked")
@Nullable
private <T> T convert(Object source, Class<T> expectedReturnType) {
@SuppressWarnings("unchecked")
private <T> T convert(Object source, Class<T> expectedReturnType, @Nullable Object continuation) {
if (continuation != null) {
return CoroutinesUtils.monoAwaitSingleOrNull((Mono<T>) source, continuation);
}
if (Future.class.isAssignableFrom(expectedReturnType)) {
return (T) source;
}
if (Mono.class.isAssignableFrom(expectedReturnType)) {
return (T) source;
}
if (getConversionService() != null) {
return getConversionService().convert(source, expectedReturnType);


return doConvert(source, expectedReturnType);
}

@Nullable
private <T> T doConvert(Object source, Class<T> expectedReturnType) {
ConversionService conversionService = getConversionService();
if (conversionService != null) {
return conversionService.convert(source, expectedReturnType);
}
else {
return this.typeConverter.convertIfNecessary(source, expectedReturnType);
Expand All @@ -1050,6 +1077,8 @@ private static final class MethodInvocationGateway extends MessagingGatewaySuppo

private boolean pollable;

private boolean isSuspendingFunction;

MethodInvocationGateway(GatewayMethodInboundMessageMapper messageMapper) {
setRequestMapper(messageMapper);
}
Expand Down Expand Up @@ -1088,6 +1117,7 @@ void setupReturnType(Class<?> serviceInterface, Method method) {
this.expectMessage = hasReturnParameterizedWithMessage(resolvableType);
}
this.isVoidReturn = isVoidReturnType(resolvableType);
this.isSuspendingFunction = KotlinDetector.isSuspendingFunction(method);
}

private boolean hasReturnParameterizedWithMessage(ResolvableType resolvableType) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@
import org.reactivestreams.Publisher;

import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.core.ReactiveAdapter;
import org.springframework.core.ReactiveAdapterRegistry;
import org.springframework.core.convert.ConversionService;
import org.springframework.integration.IntegrationMessageHeaderAccessor;
import org.springframework.integration.channel.ReactiveStreamsSubscribableChannel;
import org.springframework.integration.context.IntegrationContextUtils;
Expand Down Expand Up @@ -188,8 +190,7 @@ public Collection<String> getNotPropagatedHeaders() {
/**
* Add header patterns ("xxx*", "*xxx", "*xxx*" or "xxx*yyy")
* that will NOT be copied from the inbound message if
* {@link #shouldCopyRequestHeaders()} is true, instead of overwriting the existing
* set.
* {@link #shouldCopyRequestHeaders()} is true, instead of overwriting the existing set.
* @param headers the headers to not propagate from the inbound message.
* @since 4.3.10
* @see #setNotPropagatedHeaders(String...)
Expand Down Expand Up @@ -308,28 +309,68 @@ private void doProduceOutput(Message<?> requestMessage, MessageHeaders requestHe
replyChannel = getOutputChannel();
}

if (this.async && (reply instanceof org.springframework.util.concurrent.ListenableFuture<?>
|| reply instanceof CompletableFuture<?>
|| reply instanceof Publisher<?>)) {
ReactiveAdapter reactiveAdapter = null;

if (reply instanceof Publisher<?> &&
replyChannel instanceof ReactiveStreamsSubscribableChannel) {
if (this.async &&
(reply instanceof org.springframework.util.concurrent.ListenableFuture<?>
|| reply instanceof CompletableFuture<?>
|| (reactiveAdapter = ReactiveAdapterRegistry.getSharedInstance().getAdapter(null, reply)) != null)) {

((ReactiveStreamsSubscribableChannel) replyChannel)
if (replyChannel instanceof ReactiveStreamsSubscribableChannel reactiveStreamsSubscribableChannel) {
Publisher<?> reactiveReply = toPublisherReply(reply, reactiveAdapter);
reactiveStreamsSubscribableChannel
.subscribeTo(
Flux.from((Publisher<?>) reply)
Flux.from(reactiveReply)
.doOnError((ex) -> sendErrorMessage(requestMessage, ex))
.map(result -> createOutputMessage(result, requestHeaders)));
}
else {
asyncNonReactiveReply(requestMessage, reply, replyChannel);
CompletableFuture<?> futureReply = toFutureReply(reply, reactiveAdapter);
futureReply.whenComplete(new ReplyFutureCallback(requestMessage, replyChannel));
}
}
else {
sendOutput(createOutputMessage(reply, requestHeaders), replyChannel, false);
}
}

private static Publisher<?> toPublisherReply(Object reply, @Nullable ReactiveAdapter reactiveAdapter) {
if (reactiveAdapter != null) {
return reactiveAdapter.toPublisher(reply);
}
else {
return Mono.fromFuture(toCompletableFuture(reply));
}
}

private static CompletableFuture<?> toFutureReply(Object reply, @Nullable ReactiveAdapter reactiveAdapter) {
if (reactiveAdapter != null) {
Mono<?> reactiveReply;
Publisher<?> publisher = reactiveAdapter.toPublisher(reply);
if (reactiveAdapter.isMultiValue()) {
reactiveReply = Mono.just(publisher);
}
else {
reactiveReply = Mono.from(publisher);
}

return reactiveReply.publishOn(Schedulers.boundedElastic()).toFuture();
}
else {
return toCompletableFuture(reply);
}
}

@SuppressWarnings("deprecation")
private static CompletableFuture<?> toCompletableFuture(Object reply) {
if (reply instanceof CompletableFuture<?>) {
return (CompletableFuture<?>) reply;
}
else {
return ((org.springframework.util.concurrent.ListenableFuture<?>) reply).completable();
}
}

private AbstractIntegrationMessageBuilder<?> addRoutingSlipHeader(Object reply, List<?> routingSlip,
AtomicInteger routingSlipIndex) {

Expand All @@ -352,30 +393,6 @@ else if (reply instanceof AbstractIntegrationMessageBuilder) {
return builder;
}

@SuppressWarnings("deprecation")
private void asyncNonReactiveReply(Message<?> requestMessage, Object reply, @Nullable Object replyChannel) {
CompletableFuture<?> future;
if (reply instanceof CompletableFuture<?>) {
future = (CompletableFuture<?>) reply;
}
else if (reply instanceof org.springframework.util.concurrent.ListenableFuture<?>) {
future = ((org.springframework.util.concurrent.ListenableFuture<?>) reply).completable();
}
else {
Mono<?> reactiveReply;
ReactiveAdapter adapter = ReactiveAdapterRegistry.getSharedInstance().getAdapter(null, reply);
if (adapter != null && adapter.isMultiValue()) {
reactiveReply = Mono.just(reply);
}
else {
reactiveReply = Mono.from((Publisher<?>) reply);
}

future = reactiveReply.publishOn(Schedulers.boundedElastic()).toFuture();
}
future.whenComplete(new ReplyFutureCallback(requestMessage, replyChannel));
}

private Object getOutputChannelFromRoutingSlip(Object reply, Message<?> requestMessage, List<?> routingSlip,
AtomicInteger routingSlipIndex) {

Expand Down Expand Up @@ -444,7 +461,7 @@ else if (output instanceof AbstractIntegrationMessageBuilder) {
* <code>null</code>, and it must be an instance of either String or {@link MessageChannel}.
* @param output the output object to send
* @param replyChannelArg the 'replyChannel' value from the original request
* @param useArgChannel - use the replyChannel argument (must not be null), not
* @param useArgChannel use the replyChannel argument (must not be null), not
* the configured output channel.
*/
protected void sendOutput(Object output, @Nullable Object replyChannelArg, boolean useArgChannel) {
Expand Down Expand Up @@ -522,6 +539,22 @@ protected Object resolveErrorChannel(final MessageHeaders requestHeaders) {
return errorChannel;
}

protected void setupMessageProcessor(MessageProcessor<?> processor) {
if (processor instanceof AbstractMessageProcessor<?> abstractMessageProcessor) {
ConversionService conversionService = getConversionService();
if (conversionService != null) {
abstractMessageProcessor.setConversionService(conversionService);
}
}
BeanFactory beanFactory = getBeanFactory();
if (processor instanceof BeanFactoryAware beanFactoryAware && beanFactory != null) {
beanFactoryAware.setBeanFactory(beanFactory);
}
if (!this.async && processor instanceof MethodInvokingMessageProcessor<?> methodInvokingMessageProcessor) {
this.async = methodInvokingMessageProcessor.isAsync();
}
}

private final class ReplyFutureCallback implements BiConsumer<Object, Throwable> {

private final Message<?> requestMessage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ public boolean isRunning() {
return this.delegate.isRunning();
}

public boolean isAsync() {
return this.delegate.isAsync();
}

@Override
@Nullable
@SuppressWarnings("unchecked")
Expand Down
Loading