Skip to content

Commit 40ce6c3

Browse files
authored
ensures LoadbalancedRSocket select new rsocket upon re-subscription RC
Co-authored-by: alex079 <>
1 parent cdecc51 commit 40ce6c3

File tree

3 files changed

+95
-4
lines changed

3 files changed

+95
-4
lines changed

rsocket-core/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dependencies {
3434
testImplementation 'org.assertj:assertj-core'
3535
testImplementation 'org.junit.jupiter:junit-jupiter-api'
3636
testImplementation 'org.junit.jupiter:junit-jupiter-params'
37-
testImplementation 'org.mockito:mockito-core'
37+
testImplementation 'org.mockito:mockito-junit-jupiter'
3838
testImplementation 'org.awaitility:awaitility'
3939

4040
testRuntimeOnly 'ch.qos.logback:logback-classic'

rsocket-core/src/main/java/io/rsocket/loadbalance/LoadbalanceRSocketClient.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
import reactor.util.annotation.Nullable;
2828

2929
/**
30-
* An implementation of {@link RSocketClient backed by a pool of {@code RSocket} instances and using a {@link
31-
* LoadbalanceStrategy} to select the {@code RSocket} to use for a given request.
30+
* An implementation of {@link RSocketClient} backed by a pool of {@code RSocket} instances and
31+
* using a {@link LoadbalanceStrategy} to select the {@code RSocket} to use for a given request.
3232
*
3333
* @since 1.1
3434
*/
@@ -73,7 +73,7 @@ public Flux<Payload> requestStream(Mono<Payload> payloadMono) {
7373

7474
@Override
7575
public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
76-
return rSocketPool.select().requestChannel(payloads);
76+
return source().flatMapMany(rSocket -> rSocket.requestChannel(payloads));
7777
}
7878

7979
@Override
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package io.rsocket.loadbalance;
2+
3+
import static java.util.Collections.singletonList;
4+
import static org.assertj.core.api.Assertions.assertThat;
5+
import static org.mockito.Mockito.verify;
6+
import static org.mockito.Mockito.verifyNoMoreInteractions;
7+
import static org.mockito.Mockito.when;
8+
9+
import io.rsocket.Payload;
10+
import io.rsocket.RSocket;
11+
import io.rsocket.core.RSocketClient;
12+
import io.rsocket.core.RSocketConnector;
13+
import io.rsocket.transport.ClientTransport;
14+
import io.rsocket.util.DefaultPayload;
15+
import java.time.Duration;
16+
import java.util.concurrent.atomic.AtomicInteger;
17+
import org.junit.jupiter.api.Test;
18+
import org.junit.jupiter.api.extension.ExtendWith;
19+
import org.mockito.Mock;
20+
import org.mockito.junit.jupiter.MockitoExtension;
21+
import org.reactivestreams.Publisher;
22+
import reactor.core.publisher.Flux;
23+
import reactor.core.publisher.Mono;
24+
import reactor.test.StepVerifier;
25+
26+
@ExtendWith(MockitoExtension.class)
27+
class LoadbalanceRSocketClientTest {
28+
29+
@Mock private ClientTransport clientTransport;
30+
@Mock private RSocketConnector rSocketConnector;
31+
32+
public static final Duration SHORT_DURATION = Duration.ofMillis(25);
33+
public static final Duration LONG_DURATION = Duration.ofMillis(75);
34+
35+
private static final Publisher<Payload> SOURCE =
36+
Flux.interval(SHORT_DURATION).map(String::valueOf).map(DefaultPayload::create);
37+
38+
private static final Mono<RSocket> PROGRESSING_HANDLER =
39+
Mono.just(
40+
new RSocket() {
41+
private final AtomicInteger i = new AtomicInteger();
42+
43+
@Override
44+
public Flux<Payload> requestChannel(Publisher<Payload> payloads) {
45+
return Flux.from(payloads)
46+
.delayElements(SHORT_DURATION)
47+
.map(Payload::getDataUtf8)
48+
.map(DefaultPayload::create)
49+
.take(i.incrementAndGet());
50+
}
51+
});
52+
53+
@Test
54+
void testChannelReconnection() {
55+
when(rSocketConnector.connect(clientTransport)).thenReturn(PROGRESSING_HANDLER);
56+
57+
RSocketClient client =
58+
LoadbalanceRSocketClient.create(
59+
rSocketConnector,
60+
Mono.just(singletonList(LoadbalanceTarget.from("key", clientTransport))));
61+
62+
Publisher<String> result =
63+
client
64+
.requestChannel(SOURCE)
65+
.repeatWhen(longFlux -> longFlux.delayElements(LONG_DURATION).take(5))
66+
.map(Payload::getDataUtf8)
67+
.log();
68+
69+
StepVerifier.create(result)
70+
.expectSubscription()
71+
.assertNext(s -> assertThat(s).isEqualTo("0"))
72+
.assertNext(s -> assertThat(s).isEqualTo("0"))
73+
.assertNext(s -> assertThat(s).isEqualTo("1"))
74+
.assertNext(s -> assertThat(s).isEqualTo("0"))
75+
.assertNext(s -> assertThat(s).isEqualTo("1"))
76+
.assertNext(s -> assertThat(s).isEqualTo("2"))
77+
.assertNext(s -> assertThat(s).isEqualTo("0"))
78+
.assertNext(s -> assertThat(s).isEqualTo("1"))
79+
.assertNext(s -> assertThat(s).isEqualTo("2"))
80+
.assertNext(s -> assertThat(s).isEqualTo("3"))
81+
.assertNext(s -> assertThat(s).isEqualTo("0"))
82+
.assertNext(s -> assertThat(s).isEqualTo("1"))
83+
.assertNext(s -> assertThat(s).isEqualTo("2"))
84+
.assertNext(s -> assertThat(s).isEqualTo("3"))
85+
.assertNext(s -> assertThat(s).isEqualTo("4"))
86+
.verifyComplete();
87+
88+
verify(rSocketConnector).connect(clientTransport);
89+
verifyNoMoreInteractions(rSocketConnector, clientTransport);
90+
}
91+
}

0 commit comments

Comments
 (0)