Skip to content

Commit fe51d78

Browse files
mp911dechristophstrobl
authored andcommitted
Add ReactiveRedisMessageListenerContainer.receiveLater(…) to await subscriptions.
Original Pull Request: #2052
1 parent 328de94 commit fe51d78

File tree

2 files changed

+199
-19
lines changed

2 files changed

+199
-19
lines changed

src/main/java/org/springframework/data/redis/listener/ReactiveRedisMessageListenerContainer.java

+162
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
import java.util.Arrays;
2424
import java.util.Collection;
2525
import java.util.HashMap;
26+
import java.util.HashSet;
2627
import java.util.List;
2728
import java.util.Map;
2829
import java.util.Set;
2930
import java.util.concurrent.ConcurrentHashMap;
31+
import java.util.concurrent.atomic.AtomicBoolean;
3032
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
3133
import java.util.stream.Collectors;
3234
import java.util.stream.StreamSupport;
@@ -41,9 +43,11 @@
4143
import org.springframework.data.redis.connection.ReactiveSubscription.Message;
4244
import org.springframework.data.redis.connection.ReactiveSubscription.PatternMessage;
4345
import org.springframework.data.redis.connection.SubscriptionListener;
46+
import org.springframework.data.redis.connection.util.ByteArrayWrapper;
4447
import org.springframework.data.redis.serializer.RedisElementReader;
4548
import org.springframework.data.redis.serializer.RedisSerializationContext.SerializationPair;
4649
import org.springframework.data.redis.serializer.RedisSerializer;
50+
import org.springframework.data.redis.util.ByteUtils;
4751
import org.springframework.lang.Nullable;
4852
import org.springframework.util.Assert;
4953
import org.springframework.util.ObjectUtils;
@@ -157,6 +161,28 @@ public Flux<Message<String, String>> receive(ChannelTopic... channelTopics) {
157161
return receive(Arrays.asList(channelTopics), stringSerializationPair, stringSerializationPair);
158162
}
159163

164+
/**
165+
* Subscribe to one or more {@link ChannelTopic}s and receive a stream of {@link ChannelMessage} once the returned
166+
* {@link Mono} completes. Messages and channel names are treated as {@link String}. The message stream subscribes
167+
* lazily to the Redis channels and unsubscribes if the inner {@link org.reactivestreams.Subscription} is
168+
* {@link org.reactivestreams.Subscription#cancel() cancelled}.
169+
* <p/>
170+
* The returned {@link Mono} completes once the connection has been subscribed to the given {@link Topic topics}. Note
171+
* that cancelling the returned {@link Mono} can leave the connection in a subscribed state.
172+
*
173+
* @param channelTopics the channels to subscribe.
174+
* @return the message stream.
175+
* @throws InvalidDataAccessApiUsageException if {@code patternTopics} is empty.
176+
* @since 2.6
177+
*/
178+
public Mono<Flux<Message<String, String>>> receiveLater(ChannelTopic... channelTopics) {
179+
180+
Assert.notNull(channelTopics, "ChannelTopics must not be null!");
181+
Assert.noNullElements(channelTopics, "ChannelTopics must not contain null elements!");
182+
183+
return receiveLater(Arrays.asList(channelTopics), stringSerializationPair, stringSerializationPair);
184+
}
185+
160186
/**
161187
* Subscribe to one or more {@link PatternTopic}s and receive a stream of {@link PatternMessage}. Messages, pattern,
162188
* and channel names are treated as {@link String}. The message stream subscribes lazily to the Redis channels and
@@ -178,6 +204,30 @@ public Flux<PatternMessage<String, String, String>> receive(PatternTopic... patt
178204
.map(m -> (PatternMessage<String, String, String>) m);
179205
}
180206

207+
/**
208+
* Subscribe to one or more {@link PatternTopic}s and receive a stream of {@link PatternMessage} once the returned
209+
* {@link Mono} completes. Messages, pattern, and channel names are treated as {@link String}. The message stream
210+
* subscribes lazily to the Redis channels and unsubscribes if the inner {@link org.reactivestreams.Subscription} is
211+
* {@link org.reactivestreams.Subscription#cancel() cancelled}.
212+
* <p/>
213+
* The returned {@link Mono} completes once the connection has been subscribed to the given {@link Topic topics}. Note
214+
* that cancelling the returned {@link Mono} can leave the connection in a subscribed state.
215+
*
216+
* @param patternTopics the channels to subscribe.
217+
* @return the message stream.
218+
* @throws InvalidDataAccessApiUsageException if {@code patternTopics} is empty.
219+
* @since 2.6
220+
*/
221+
@SuppressWarnings("unchecked")
222+
public Mono<Flux<PatternMessage<String, String, String>>> receiveLater(PatternTopic... patternTopics) {
223+
224+
Assert.notNull(patternTopics, "PatternTopic must not be null!");
225+
Assert.noNullElements(patternTopics, "PatternTopic must not contain null elements!");
226+
227+
return receiveLater(Arrays.asList(patternTopics), stringSerializationPair, stringSerializationPair)
228+
.map(it -> it.map(m -> (PatternMessage<String, String, String>) m));
229+
}
230+
181231
/**
182232
* Subscribe to one or more {@link Topic}s and receive a stream of {@link ChannelMessage}. The stream may contain
183233
* {@link PatternMessage} if subscribed to patterns. Messages, and channel names are serialized/deserialized using the
@@ -281,6 +331,68 @@ private <C, B> Flux<Message<C, B>> doReceive(SerializationPair<C> channelSeriali
281331
.map(message -> readMessage(channelSerializer.getReader(), messageSerializer.getReader(), message));
282332
}
283333

334+
/**
335+
* Subscribe to one or more {@link Topic}s and receive a stream of {@link ChannelMessage}. The returned {@link Mono}
336+
* completes once the connection has been subscribed to the given {@link Topic topics}. Note that cancelling the
337+
* returned {@link Mono} can leave the connection in a subscribed state.
338+
*
339+
* @param topics the channels to subscribe.
340+
* @param channelSerializer serialization pair to decode the channel/pattern name.
341+
* @param messageSerializer serialization pair to decode the message body.
342+
* @return the message stream.
343+
* @throws InvalidDataAccessApiUsageException if {@code topics} is empty.
344+
* @since 2.6
345+
*/
346+
private <C, B> Mono<Flux<Message<C, B>>> receiveLater(Iterable<? extends Topic> topics,
347+
SerializationPair<C> channelSerializer, SerializationPair<B> messageSerializer) {
348+
349+
Assert.notNull(topics, "Topics must not be null!");
350+
Assert.notNull(channelSerializer, "Channel serializer must not be null!");
351+
Assert.notNull(messageSerializer, "Message serializer must not be null!");
352+
353+
verifyConnection();
354+
355+
ByteBuffer[] patterns = getTargets(topics, PatternTopic.class);
356+
ByteBuffer[] channels = getTargets(topics, ChannelTopic.class);
357+
358+
if (ObjectUtils.isEmpty(patterns) && ObjectUtils.isEmpty(channels)) {
359+
throw new InvalidDataAccessApiUsageException("No channels or patterns to subscribe to.");
360+
}
361+
362+
return Mono.defer(() -> {
363+
364+
SubscriptionReadyListener readyListener = SubscriptionReadyListener.create(topics, stringSerializationPair);
365+
366+
return doReceiveLater(channelSerializer, messageSerializer,
367+
getRequiredConnection().pubSubCommands().createSubscription(readyListener), patterns, channels)
368+
.delayUntil(it -> readyListener.getTrigger());
369+
});
370+
}
371+
372+
private <C, B> Mono<Flux<Message<C, B>>> doReceiveLater(SerializationPair<C> channelSerializer,
373+
SerializationPair<B> messageSerializer, Mono<ReactiveSubscription> subscription, ByteBuffer[] patterns,
374+
ByteBuffer[] channels) {
375+
376+
return subscription.flatMap(it -> {
377+
378+
Mono<Void> subscribe = subscribe(patterns, channels, it).doOnSuccess(v -> getSubscribers(it).registered());
379+
380+
Sinks.One<Message<ByteBuffer, ByteBuffer>> terminalSink = Sinks.one();
381+
382+
Flux<Message<C, B>> receiver = it.receive().doOnCancel(() -> {
383+
384+
Subscribers subscribers = getSubscribers(it);
385+
if (subscribers.unregister()) {
386+
subscriptions.remove(it);
387+
it.cancel().subscribe(v -> terminalSink.tryEmitEmpty(), terminalSink::tryEmitError);
388+
}
389+
}).mergeWith(terminalSink.asMono())
390+
.map(message -> readMessage(channelSerializer.getReader(), messageSerializer.getReader(), message));
391+
392+
return subscribe.then(Mono.just(receiver));
393+
});
394+
}
395+
284396
private static Mono<Void> subscribe(ByteBuffer[] patterns, ByteBuffer[] channels, ReactiveSubscription it) {
285397

286398
Assert.isTrue(!ObjectUtils.isEmpty(channels) || !ObjectUtils.isEmpty(patterns),
@@ -418,4 +530,54 @@ boolean unregister() {
418530
return false;
419531
}
420532
}
533+
534+
static class SubscriptionReadyListener extends AtomicBoolean implements SubscriptionListener {
535+
536+
private final Set<ByteArrayWrapper> toSubscribe;
537+
private final Sinks.Empty<Void> sink = Sinks.empty();
538+
539+
private SubscriptionReadyListener(Set<ByteArrayWrapper> topics) {
540+
this.toSubscribe = topics;
541+
}
542+
543+
public static SubscriptionReadyListener create(Iterable<? extends Topic> topics,
544+
SerializationPair<String> serializationPair) {
545+
546+
Set<ByteArrayWrapper> wrappers = new HashSet<>();
547+
548+
for (Topic topic : topics) {
549+
wrappers.add(new ByteArrayWrapper(ByteUtils.getBytes(serializationPair.getWriter().write(topic.getTopic()))));
550+
}
551+
552+
return new SubscriptionReadyListener(wrappers);
553+
}
554+
555+
@Override
556+
public void onChannelSubscribed(byte[] channel, long count) {
557+
removeRemaining(channel);
558+
}
559+
560+
@Override
561+
public void onPatternSubscribed(byte[] pattern, long count) {
562+
removeRemaining(pattern);
563+
}
564+
565+
private void removeRemaining(byte[] channel) {
566+
567+
boolean done;
568+
569+
synchronized (toSubscribe) {
570+
toSubscribe.remove(new ByteArrayWrapper(channel));
571+
done = toSubscribe.isEmpty();
572+
}
573+
574+
if (done && compareAndSet(false, true)) {
575+
sink.emitEmpty(Sinks.EmitFailureHandler.FAIL_FAST);
576+
}
577+
}
578+
579+
public Mono<Void> getTrigger() {
580+
return sink.asMono();
581+
}
582+
}
421583
}

src/test/java/org/springframework/data/redis/listener/ReactiveRedisMessageListenerContainerIntegrationTests.java

+37-19
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import reactor.core.Disposable;
2121
import reactor.test.StepVerifier;
2222

23+
import java.nio.ByteBuffer;
2324
import java.time.Duration;
2425
import java.util.Collection;
2526
import java.util.Collections;
@@ -28,6 +29,7 @@
2829
import java.util.concurrent.LinkedBlockingDeque;
2930
import java.util.concurrent.TimeUnit;
3031
import java.util.concurrent.atomic.AtomicReference;
32+
import java.util.function.Function;
3133
import java.util.function.Supplier;
3234

3335
import org.awaitility.Awaitility;
@@ -36,6 +38,7 @@
3638

3739
import org.springframework.data.redis.connection.Message;
3840
import org.springframework.data.redis.connection.MessageListener;
41+
import org.springframework.data.redis.connection.ReactiveRedisConnection;
3942
import org.springframework.data.redis.connection.ReactiveSubscription;
4043
import org.springframework.data.redis.connection.ReactiveSubscription.ChannelMessage;
4144
import org.springframework.data.redis.connection.ReactiveSubscription.PatternMessage;
@@ -62,6 +65,7 @@ public class ReactiveRedisMessageListenerContainerIntegrationTests {
6265

6366
private final LettuceConnectionFactory connectionFactory;
6467
private @Nullable RedisConnection connection;
68+
private @Nullable ReactiveRedisConnection reactiveConnection;
6569

6670
/**
6771
* @param connectionFactory
@@ -79,6 +83,7 @@ public static Collection<Object[]> testParams() {
7983
@BeforeEach
8084
void before() {
8185
connection = connectionFactory.getConnection();
86+
reactiveConnection = connectionFactory.getReactiveConnection();
8287
}
8388

8489
@AfterEach
@@ -87,16 +92,21 @@ void tearDown() {
8792
if (connection != null) {
8893
connection.close();
8994
}
95+
96+
if (reactiveConnection != null) {
97+
reactiveConnection.close();
98+
}
9099
}
91100

92-
@ParameterizedRedisTest // DATAREDIS-612
101+
@ParameterizedRedisTest // DATAREDIS-612, GH-1622
93102
void shouldReceiveChannelMessages() {
94103

95104
ReactiveRedisMessageListenerContainer container = new ReactiveRedisMessageListenerContainer(connectionFactory);
96105

97-
container.receive(ChannelTopic.of(CHANNEL1)).as(StepVerifier::create) //
98-
.then(awaitSubscription(container::getActiveSubscriptions))
99-
.then(() -> connection.publish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
106+
container.receiveLater(ChannelTopic.of(CHANNEL1)) //
107+
.doOnNext(it -> doPublish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
108+
.flatMapMany(Function.identity()) //
109+
.as(StepVerifier::create) //
100110
.assertNext(c -> {
101111

102112
assertThat(c.getChannel()).isEqualTo(CHANNEL1);
@@ -136,9 +146,10 @@ public void onChannelUnsubscribed(byte[] channel, long count) {
136146
}
137147
};
138148

139-
container.receive(Collections.singletonList(ChannelTopic.of(CHANNEL1)), listener).as(StepVerifier::create) //
149+
container.receive(Collections.singletonList(ChannelTopic.of(CHANNEL1)), listener) //
150+
.as(StepVerifier::create) //
140151
.then(awaitSubscription(container::getActiveSubscriptions))
141-
.then(() -> connection.publish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
152+
.then(() -> doPublish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
142153
.assertNext(c -> {
143154

144155
assertThat(c.getChannel()).isEqualTo(CHANNEL1);
@@ -154,14 +165,14 @@ public void onChannelUnsubscribed(byte[] channel, long count) {
154165
container.destroy();
155166
}
156167

157-
@ParameterizedRedisTest // DATAREDIS-612
168+
@ParameterizedRedisTest // DATAREDIS-612, GH-1622
158169
void shouldReceivePatternMessages() {
159170

160171
ReactiveRedisMessageListenerContainer container = new ReactiveRedisMessageListenerContainer(connectionFactory);
161172

162-
container.receive(PatternTopic.of(PATTERN1)).as(StepVerifier::create) //
163-
.then(awaitSubscription(container::getActiveSubscriptions))
164-
.then(() -> connection.publish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
173+
container.receiveLater(PatternTopic.of(PATTERN1)) //
174+
.doOnNext(it -> doPublish(CHANNEL1.getBytes(), MESSAGE.getBytes())).flatMapMany(Function.identity()) //
175+
.as(StepVerifier::create) //
165176
.assertNext(c -> {
166177

167178
assertThat(c.getPattern()).isEqualTo(PATTERN1);
@@ -206,7 +217,7 @@ public void onPatternUnsubscribed(byte[] pattern, long count) {
206217
.cast(PatternMessage.class) //
207218
.as(StepVerifier::create) //
208219
.then(awaitSubscription(container::getActiveSubscriptions))
209-
.then(() -> connection.publish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
220+
.then(() -> doPublish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
210221
.assertNext(c -> {
211222

212223
assertThat(c.getPattern()).isEqualTo(PATTERN1);
@@ -223,19 +234,22 @@ public void onPatternUnsubscribed(byte[] pattern, long count) {
223234
container.destroy();
224235
}
225236

226-
@ParameterizedRedisTest // DATAREDIS-612
227-
void shouldPublishAndReceiveMessage() throws InterruptedException {
237+
@ParameterizedRedisTest // DATAREDIS-612, GH-1622
238+
void shouldPublishAndReceiveMessage() throws Exception {
228239

229240
ReactiveRedisMessageListenerContainer container = new ReactiveRedisMessageListenerContainer(connectionFactory);
230241
ReactiveRedisTemplate<String, String> template = new ReactiveRedisTemplate<>(connectionFactory,
231242
RedisSerializationContext.string());
232243

233244
BlockingQueue<PatternMessage<String, String, String>> messages = new LinkedBlockingDeque<>();
234-
Disposable subscription = container.receive(PatternTopic.of(PATTERN1)).doOnNext(messages::add).subscribe();
245+
CompletableFuture<Void> subscribed = new CompletableFuture<>();
246+
Disposable subscription = container.receiveLater(PatternTopic.of(PATTERN1))
247+
.doOnNext(it -> subscribed.complete(null)).flatMapMany(Function.identity()).doOnNext(messages::add).subscribe();
235248

236-
StepVerifier.create(template.convertAndSend(CHANNEL1, MESSAGE), 0) //
237-
.then(awaitSubscription(container::getActiveSubscriptions)) //
238-
.thenRequest(1).expectNextCount(1) //
249+
subscribed.get(5, TimeUnit.SECONDS);
250+
251+
template.convertAndSend(CHANNEL1, MESSAGE).as(StepVerifier::create) //
252+
.expectNextCount(1) //
239253
.verifyComplete();
240254

241255
PatternMessage<String, String, String> message = messages.poll(1, TimeUnit.SECONDS);
@@ -257,7 +271,7 @@ void listenToChannelShouldReceiveChannelMessagesCorrectly() throws InterruptedEx
257271

258272
template.listenToChannel(CHANNEL1).as(StepVerifier::create) //
259273
.thenAwait(Duration.ofMillis(100)) // just make sure we the subscription completed
260-
.then(() -> connection.publish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
274+
.then(() -> doPublish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
261275
.assertNext(message -> {
262276

263277
assertThat(message).isInstanceOf(ChannelMessage.class);
@@ -276,7 +290,7 @@ void listenToPatternShouldReceiveMessagesCorrectly() {
276290

277291
template.listenToPattern(PATTERN1).as(StepVerifier::create) //
278292
.thenAwait(Duration.ofMillis(100)) // just make sure we the subscription completed
279-
.then(() -> connection.publish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
293+
.then(() -> doPublish(CHANNEL1.getBytes(), MESSAGE.getBytes())) //
280294
.assertNext(message -> {
281295

282296
assertThat(message).isInstanceOf(PatternMessage.class);
@@ -288,6 +302,10 @@ void listenToPatternShouldReceiveMessagesCorrectly() {
288302
.verify();
289303
}
290304

305+
private void doPublish(byte[] channel, byte[] message) {
306+
reactiveConnection.pubSubCommands().publish(ByteBuffer.wrap(channel), ByteBuffer.wrap(message)).subscribe();
307+
}
308+
291309
private static Runnable awaitSubscription(Supplier<Collection<ReactiveSubscription>> activeSubscriptions) {
292310

293311
return () -> {

0 commit comments

Comments
 (0)