Skip to content

Commit a86fc32

Browse files
committed
* Add kotlinx.coroutines.flow.Flow support
The `Flow` is essentially a multi-value reactive `Publisher`, so use `ReactiveAdapterRegistry` to convert any custom reactive streams result to `Flux` and `Mono` which we already support as reply types
1 parent 0e0ea00 commit a86fc32

File tree

3 files changed

+109
-71
lines changed

3 files changed

+109
-71
lines changed

spring-integration-core/src/main/java/org/springframework/integration/handler/AbstractMessageProducingHandler.java

Lines changed: 52 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ public final void setAsync(boolean async) {
120120

121121
/**
122122
* @return true if this handler supports async replies.
123-
* @since 4.3
124123
* @see #setAsync(boolean)
124+
* @since 4.3
125125
*/
126126
protected boolean isAsync() {
127127
return this.async;
@@ -133,8 +133,8 @@ protected boolean isAsync() {
133133
* {@link #shouldCopyRequestHeaders() shouldCopyRequestHeaaders} is true.
134134
* At least one pattern as "*" means do not copy headers at all.
135135
* @param headers the headers to not propagate from the inbound message.
136-
* @since 4.3.10
137136
* @see org.springframework.util.PatternMatchUtils
137+
* @since 4.3.10
138138
*/
139139
@Override
140140
public void setNotPropagatedHeaders(String... headers) {
@@ -190,8 +190,7 @@ public Collection<String> getNotPropagatedHeaders() {
190190
/**
191191
* Add header patterns ("xxx*", "*xxx", "*xxx*" or "xxx*yyy")
192192
* that will NOT be copied from the inbound message if
193-
* {@link #shouldCopyRequestHeaders()} is true, instead of overwriting the existing
194-
* set.
193+
* {@link #shouldCopyRequestHeaders()} is true, instead of overwriting the existing set.
195194
* @param headers the headers to not propagate from the inbound message.
196195
* @since 4.3.10
197196
* @see #setNotPropagatedHeaders(String...)
@@ -310,28 +309,68 @@ private void doProduceOutput(Message<?> requestMessage, MessageHeaders requestHe
310309
replyChannel = getOutputChannel();
311310
}
312311

313-
if (this.async && (reply instanceof org.springframework.util.concurrent.ListenableFuture<?>
314-
|| reply instanceof CompletableFuture<?>
315-
|| reply instanceof Publisher<?>)) {
312+
ReactiveAdapter reactiveAdapter = null;
316313

317-
if (reply instanceof Publisher<?> &&
318-
replyChannel instanceof ReactiveStreamsSubscribableChannel) {
314+
if (this.async &&
315+
(reply instanceof org.springframework.util.concurrent.ListenableFuture<?>
316+
|| reply instanceof CompletableFuture<?>
317+
|| (reactiveAdapter = ReactiveAdapterRegistry.getSharedInstance().getAdapter(null, reply)) != null)) {
319318

320-
((ReactiveStreamsSubscribableChannel) replyChannel)
319+
if (replyChannel instanceof ReactiveStreamsSubscribableChannel reactiveStreamsSubscribableChannel) {
320+
Publisher<?> reactiveReply = toPublisherReply(reply, reactiveAdapter);
321+
reactiveStreamsSubscribableChannel
321322
.subscribeTo(
322-
Flux.from((Publisher<?>) reply)
323+
Flux.from(reactiveReply)
323324
.doOnError((ex) -> sendErrorMessage(requestMessage, ex))
324325
.map(result -> createOutputMessage(result, requestHeaders)));
325326
}
326327
else {
327-
asyncNonReactiveReply(requestMessage, reply, replyChannel);
328+
CompletableFuture<?> futureReply = toFutureReply(reply, reactiveAdapter);
329+
futureReply.whenComplete(new ReplyFutureCallback(requestMessage, replyChannel));
328330
}
329331
}
330332
else {
331333
sendOutput(createOutputMessage(reply, requestHeaders), replyChannel, false);
332334
}
333335
}
334336

337+
private static Publisher<?> toPublisherReply(Object reply, @Nullable ReactiveAdapter reactiveAdapter) {
338+
if (reactiveAdapter != null) {
339+
return reactiveAdapter.toPublisher(reply);
340+
}
341+
else {
342+
return Mono.fromFuture(toCompletableFuture(reply));
343+
}
344+
}
345+
346+
private static CompletableFuture<?> toFutureReply(Object reply, @Nullable ReactiveAdapter reactiveAdapter) {
347+
if (reactiveAdapter != null) {
348+
Mono<?> reactiveReply;
349+
Publisher<?> publisher = reactiveAdapter.toPublisher(reply);
350+
if (reactiveAdapter.isMultiValue()) {
351+
reactiveReply = Mono.just(publisher);
352+
}
353+
else {
354+
reactiveReply = Mono.from(publisher);
355+
}
356+
357+
return reactiveReply.publishOn(Schedulers.boundedElastic()).toFuture();
358+
}
359+
else {
360+
return toCompletableFuture(reply);
361+
}
362+
}
363+
364+
@SuppressWarnings("deprecation")
365+
private static CompletableFuture<?> toCompletableFuture(Object reply) {
366+
if (reply instanceof CompletableFuture<?>) {
367+
return (CompletableFuture<?>) reply;
368+
}
369+
else {
370+
return ((org.springframework.util.concurrent.ListenableFuture<?>) reply).completable();
371+
}
372+
}
373+
335374
private AbstractIntegrationMessageBuilder<?> addRoutingSlipHeader(Object reply, List<?> routingSlip,
336375
AtomicInteger routingSlipIndex) {
337376

@@ -354,30 +393,6 @@ else if (reply instanceof AbstractIntegrationMessageBuilder) {
354393
return builder;
355394
}
356395

357-
@SuppressWarnings("deprecation")
358-
private void asyncNonReactiveReply(Message<?> requestMessage, Object reply, @Nullable Object replyChannel) {
359-
CompletableFuture<?> future;
360-
if (reply instanceof CompletableFuture<?>) {
361-
future = (CompletableFuture<?>) reply;
362-
}
363-
else if (reply instanceof org.springframework.util.concurrent.ListenableFuture<?>) {
364-
future = ((org.springframework.util.concurrent.ListenableFuture<?>) reply).completable();
365-
}
366-
else {
367-
Mono<?> reactiveReply;
368-
ReactiveAdapter adapter = ReactiveAdapterRegistry.getSharedInstance().getAdapter(null, reply);
369-
if (adapter != null && adapter.isMultiValue()) {
370-
reactiveReply = Mono.just(reply);
371-
}
372-
else {
373-
reactiveReply = Mono.from((Publisher<?>) reply);
374-
}
375-
376-
future = reactiveReply.publishOn(Schedulers.boundedElastic()).toFuture();
377-
}
378-
future.whenComplete(new ReplyFutureCallback(requestMessage, replyChannel));
379-
}
380-
381396
private Object getOutputChannelFromRoutingSlip(Object reply, Message<?> requestMessage, List<?> routingSlip,
382397
AtomicInteger routingSlipIndex) {
383398

@@ -446,7 +461,7 @@ else if (output instanceof AbstractIntegrationMessageBuilder) {
446461
* <code>null</code>, and it must be an instance of either String or {@link MessageChannel}.
447462
* @param output the output object to send
448463
* @param replyChannelArg the 'replyChannel' value from the original request
449-
* @param useArgChannel - use the replyChannel argument (must not be null), not
464+
* @param useArgChannel use the replyChannel argument (must not be null), not
450465
* the configured output channel.
451466
*/
452467
protected void sendOutput(Object output, @Nullable Object replyChannelArg, boolean useArgChannel) {

spring-integration-core/src/main/java/org/springframework/integration/handler/support/MessagingMethodInvokerHelper.java

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.springframework.core.LocalVariableTableParameterNameDiscoverer;
5656
import org.springframework.core.MethodParameter;
5757
import org.springframework.core.ParameterNameDiscoverer;
58+
import org.springframework.core.ReactiveAdapterRegistry;
5859
import org.springframework.core.annotation.AnnotationAttributes;
5960
import org.springframework.core.annotation.AnnotationUtils;
6061
import org.springframework.core.convert.ConversionFailedException;
@@ -329,7 +330,6 @@ else if (targetObject instanceof Consumer) {
329330
/**
330331
* A {@code boolean} flag to use SpEL Expression evaluation or {@link InvocableHandlerMethod}
331332
* for target method invocation.
332-
*
333333
* @param useSpelInvoker to use SpEL Expression evaluation or not.
334334
* @since 5.0
335335
*/
@@ -520,7 +520,6 @@ private boolean isProvidedMessageHandlerFactoryBean() {
520520
* This should not be needed in production but we have many tests
521521
* that don't run in an application context.
522522
*/
523-
524523
private void initializeHandler(HandlerMethod candidate) {
525524
ExpressionParser parser;
526525
if (candidate.useSpelInvoker == null) {
@@ -847,7 +846,7 @@ private void populateHandlerMethod(Map<Class<?>, HandlerMethod> candidateMethods
847846
if (handlerMethod1.isMessageMethod()) {
848847
if (fallbackMessageMethods.containsKey(targetParameterType)) {
849848
// we need to check for duplicate type matches,
850-
// but only if we end up falling back
849+
// but only if we end up falling back,
851850
// and we'll only keep track of the first one
852851
ambiguousFallbackMessageGenericType.compareAndSet(null, targetParameterType);
853852
}
@@ -887,7 +886,6 @@ private void findSingleSpecificMethodOnInterfacesIfProxy(Map<Class<?>, HandlerMe
887886
Map<Class<?>, HandlerMethod> candidateMethods) {
888887
if (AopUtils.isAopProxy(this.targetObject)) {
889888
final AtomicReference<Method> targetMethod = new AtomicReference<>();
890-
final AtomicReference<Class<?>> targetClass = new AtomicReference<>();
891889
Class<?>[] interfaces = ((Advised) this.targetObject).getProxiedInterfaces();
892890
for (Class<?> clazz : interfaces) {
893891
ReflectionUtils.doWithMethods(clazz, method1 -> {
@@ -897,7 +895,6 @@ private void findSingleSpecificMethodOnInterfacesIfProxy(Map<Class<?>, HandlerMe
897895
}
898896
else {
899897
targetMethod.set(method1);
900-
targetClass.set(clazz);
901898
}
902899
}, method12 -> method12.getName().equals(this.methodName));
903900
}
@@ -1029,7 +1026,8 @@ public boolean isAsync() {
10291026
Method methodToCheck = this.handlerMethodsList.get(0).values().iterator().next().method;
10301027
return Publisher.class.isAssignableFrom(methodToCheck.getReturnType())
10311028
|| CompletableFuture.class.isAssignableFrom(methodToCheck.getReturnType())
1032-
|| KotlinDetector.isSuspendingFunction(methodToCheck);
1029+
|| KotlinDetector.isSuspendingFunction(methodToCheck)
1030+
|| ReactiveAdapterRegistry.getSharedInstance().getAdapter(methodToCheck.getReturnType()) != null;
10331031
}
10341032
return false;
10351033
}
@@ -1066,7 +1064,7 @@ private static class HandlerMethod {
10661064

10671065
// The number of times InvocableHandlerMethod was attempted and failed - enables us to eventually
10681066
// give up trying to call it when it just doesn't seem to be possible.
1069-
// Switching to spelOnly afterwards forever.
1067+
// Switching to 'spelOnly' afterwards forever.
10701068
private volatile int failedAttempts = 0;
10711069

10721070
HandlerMethod(Method method, boolean canProcessMessageList) {
@@ -1359,7 +1357,6 @@ public static class ParametersWrapper {
13591357

13601358
/**
13611359
* SpEL Function to retrieve a required header.
1362-
*
13631360
* @param headers the headers.
13641361
* @param header the header name
13651362
* @return the header

spring-integration-core/src/test/kotlin/org/springframework/integration/function/FunctionsTests.kt

Lines changed: 52 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,14 @@
1717
package org.springframework.integration.function
1818

1919
import assertk.assertThat
20-
import assertk.assertions.containsAll
21-
import assertk.assertions.isEqualTo
22-
import assertk.assertions.isNotNull
23-
import assertk.assertions.isTrue
24-
import assertk.assertions.size
20+
import assertk.assertions.*
21+
import kotlinx.coroutines.flow.flow
2522
import org.junit.jupiter.api.Test
2623
import org.springframework.beans.factory.annotation.Autowired
2724
import org.springframework.beans.factory.annotation.Qualifier
2825
import org.springframework.context.annotation.Bean
2926
import org.springframework.context.annotation.Configuration
30-
import org.springframework.integration.annotation.EndpointId
31-
import org.springframework.integration.annotation.InboundChannelAdapter
32-
import org.springframework.integration.annotation.Poller
33-
import org.springframework.integration.annotation.ServiceActivator
34-
import org.springframework.integration.annotation.Transformer
27+
import org.springframework.integration.annotation.*
3528
import org.springframework.integration.channel.DirectChannel
3629
import org.springframework.integration.channel.FluxMessageChannel
3730
import org.springframework.integration.channel.QueueChannel
@@ -91,8 +84,8 @@ class FunctionsTests {
9184
val replyChannel = QueueChannel()
9285

9386
val message = MessageBuilder.withPayload("foo")
94-
.setReplyChannel(replyChannel)
95-
.build()
87+
.setReplyChannel(replyChannel)
88+
.build()
9689

9790
this.functionServiceChannel.send(message)
9891

@@ -101,8 +94,8 @@ class FunctionsTests {
10194
val payload = receive?.payload
10295

10396
assertThat(payload)
104-
.isNotNull()
105-
.isEqualTo("FOO")
97+
.isNotNull()
98+
.isEqualTo("FOO")
10699
}
107100

108101
@Test
@@ -142,8 +135,8 @@ class FunctionsTests {
142135
val mono = this.monoFunction.apply("test")
143136

144137
StepVerifier.create(mono.map(Message<*>::getPayload).cast(String::class.java))
145-
.expectNext("TEST")
146-
.verifyComplete()
138+
.expectNext("TEST")
139+
.verifyComplete()
147140

148141
val gateways = this.monoFunctionGateway.gateways
149142
assertThat(gateways).size().isEqualTo(3)
@@ -167,7 +160,30 @@ class FunctionsTests {
167160
suspendServiceChannel.send(
168161
MessageBuilder.withPayload(testPayload)
169162
.setReplyChannel(replyChannel)
170-
.build())
163+
.build()
164+
)
165+
166+
stepVerifier.verify(Duration.ofSeconds(10))
167+
}
168+
169+
@Autowired
170+
private lateinit var flowServiceChannel: MessageChannel
171+
172+
@Test
173+
fun `verify flow function`() {
174+
val replyChannel = FluxMessageChannel()
175+
val testPayload = "test flow"
176+
val stepVerifier =
177+
StepVerifier.create(Flux.from(replyChannel).map(Message<*>::getPayload).cast(String::class.java))
178+
.expectNext("$testPayload #1", "$testPayload #2", "$testPayload #3")
179+
.thenCancel()
180+
.verifyLater()
181+
182+
flowServiceChannel.send(
183+
MessageBuilder.withPayload(testPayload)
184+
.setReplyChannel(replyChannel)
185+
.build()
186+
)
171187

172188
stepVerifier.verify(Duration.ofSeconds(10))
173189
}
@@ -195,30 +211,40 @@ class FunctionsTests {
195211
fun counterChannel() = DirectChannel()
196212

197213
@Bean
198-
@InboundChannelAdapter(value = "counterChannel", autoStartup = "false",
199-
poller = Poller(fixedRate = "10", maxMessagesPerPoll = "1"))
214+
@InboundChannelAdapter(
215+
value = "counterChannel", autoStartup = "false",
216+
poller = Poller(fixedRate = "10", maxMessagesPerPoll = "1")
217+
)
200218
@EndpointId("kotlinSupplierChannelAdapter")
201219
fun kotlinSupplier(): () -> String {
202220
return { "baz" }
203221
}
204222

205223
@Bean
206224
fun flowFromSupplier() =
207-
integrationFlow({ "" }, { poller { it.fixedDelay(10).maxMessagesPerPoll(1) } }) {
208-
transform<String> { "blank" }
209-
channel { queue("fromSupplierQueue") }
210-
}
225+
integrationFlow({ "" }, { poller { it.fixedDelay(10).maxMessagesPerPoll(1) } }) {
226+
transform<String> { "blank" }
227+
channel { queue("fromSupplierQueue") }
228+
}
211229

212230
@Bean
213231
fun monoFunctionGateway() =
214-
integrationFlow<MonoFunction>({ proxyDefaultMethods(true) }) {
215-
handle<String>({ p, _ -> Mono.just(p).map(String::uppercase) }) { async(true) }
216-
}
232+
integrationFlow<MonoFunction>({ proxyDefaultMethods(true) }) {
233+
handle<String>({ p, _ -> Mono.just(p).map(String::uppercase) }) { async(true) }
234+
}
217235

218236

219237
@ServiceActivator(inputChannel = "suspendServiceChannel")
220238
suspend fun suspendServiceFunction(payload: String) = payload.uppercase()
221239

240+
@ServiceActivator(inputChannel = "flowServiceChannel")
241+
fun flowServiceFunction(payload: String) =
242+
flow {
243+
for (i in 1..3) {
244+
emit("$payload #$i")
245+
}
246+
}
247+
222248
}
223249

224250
interface MonoFunction : Function<String, Mono<Message<*>>>

0 commit comments

Comments
 (0)