Skip to content

Commit 775bf66

Browse files
authored
Refactor DefaultReactiveElasticsearchClient to do request customization with the WebClient. (#1795)
Original Pull Request #1795 Closes #1794
1 parent f8fbf77 commit 775bf66

File tree

7 files changed

+66
-41
lines changed

7 files changed

+66
-41
lines changed

Diff for: src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultReactiveElasticsearchClient.java

+14-12
Original file line numberDiff line numberDiff line change
@@ -281,15 +281,23 @@ private static WebClientProvider getWebClientProvider(ClientConfiguration client
281281
scheme = "https";
282282
}
283283

284-
ReactorClientHttpConnector connector = new ReactorClientHttpConnector(httpClient);
285-
WebClientProvider provider = WebClientProvider.create(scheme, connector);
284+
WebClientProvider provider = WebClientProvider.create(scheme, new ReactorClientHttpConnector(httpClient));
286285

287286
if (clientConfiguration.getPathPrefix() != null) {
288287
provider = provider.withPathPrefix(clientConfiguration.getPathPrefix());
289288
}
290289

291-
provider = provider.withDefaultHeaders(clientConfiguration.getDefaultHeaders()) //
292-
.withWebClientConfigurer(clientConfiguration.getWebClientConfigurer());
290+
provider = provider //
291+
.withDefaultHeaders(clientConfiguration.getDefaultHeaders()) //
292+
.withWebClientConfigurer(clientConfiguration.getWebClientConfigurer()) //
293+
.withRequestConfigurer(requestHeadersSpec -> requestHeadersSpec.headers(httpHeaders -> {
294+
HttpHeaders suppliedHeaders = clientConfiguration.getHeadersSupplier().get();
295+
296+
if (suppliedHeaders != null && suppliedHeaders != HttpHeaders.EMPTY) {
297+
httpHeaders.addAll(suppliedHeaders);
298+
}
299+
}));
300+
293301
return provider;
294302
}
295303

@@ -584,12 +592,6 @@ private RequestBodySpec sendRequest(WebClient webClient, String logId, Request r
584592
request.getOptions().getHeaders().forEach(it -> theHeaders.add(it.getName(), it.getValue()));
585593
}
586594
}
587-
588-
// plus the ones from the supplier
589-
HttpHeaders suppliedHeaders = headersSupplier.get();
590-
if (suppliedHeaders != null && suppliedHeaders != HttpHeaders.EMPTY) {
591-
theHeaders.addAll(suppliedHeaders);
592-
}
593595
});
594596

595597
if (request.getEntity() != null) {
@@ -599,8 +601,8 @@ private RequestBodySpec sendRequest(WebClient webClient, String logId, Request r
599601
ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters(),
600602
body::get);
601603

602-
requestBodySpec.contentType(MediaType.valueOf(request.getEntity().getContentType().getValue()));
603-
requestBodySpec.body(Mono.fromSupplier(body), String.class);
604+
requestBodySpec.contentType(MediaType.valueOf(request.getEntity().getContentType().getValue()))
605+
.body(Mono.fromSupplier(body), String.class);
604606
} else {
605607
ClientLogger.logRequest(logId, request.getMethod().toUpperCase(), request.getEndpoint(), request.getParameters());
606608
}

Diff for: src/main/java/org/springframework/data/elasticsearch/client/reactive/DefaultWebClientProvider.java

+28-8
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ class DefaultWebClientProvider implements WebClientProvider {
4848
private final HttpHeaders headers;
4949
private final @Nullable String pathPrefix;
5050
private final Function<WebClient, WebClient> webClientConfigurer;
51+
private final Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer;
5152

5253
/**
5354
* Create new {@link DefaultWebClientProvider} with empty {@link HttpHeaders} and no-op {@literal error listener}.
@@ -56,7 +57,7 @@ class DefaultWebClientProvider implements WebClientProvider {
5657
* @param connector can be {@literal null}.
5758
*/
5859
DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector) {
59-
this(scheme, connector, e -> {}, HttpHeaders.EMPTY, null, Function.identity());
60+
this(scheme, connector, e -> {}, HttpHeaders.EMPTY, null, Function.identity(), requestHeadersSpec -> {});
6061
}
6162

6263
/**
@@ -66,18 +67,21 @@ class DefaultWebClientProvider implements WebClientProvider {
6667
* @param connector can be {@literal null}.
6768
* @param errorListener must not be {@literal null}.
6869
* @param headers must not be {@literal null}.
69-
* @param pathPrefix can be {@literal null}
70+
* @param pathPrefix can be {@literal null}.
7071
* @param webClientConfigurer must not be {@literal null}.
72+
* @param requestConfigurer must not be {@literal null}.
7173
*/
7274
private DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector connector,
7375
Consumer<Throwable> errorListener, HttpHeaders headers, @Nullable String pathPrefix,
74-
Function<WebClient, WebClient> webClientConfigurer) {
76+
Function<WebClient, WebClient> webClientConfigurer, Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer) {
7577

7678
Assert.notNull(scheme, "Scheme must not be null! A common scheme would be 'http'.");
7779
Assert.notNull(errorListener, "errorListener must not be null! You may want use a no-op one 'e -> {}' instead.");
7880
Assert.notNull(headers, "headers must not be null! Think about using 'HttpHeaders.EMPTY' as an alternative.");
7981
Assert.notNull(webClientConfigurer,
8082
"webClientConfigurer must not be null! You may want use a no-op one 'Function.identity()' instead.");
83+
Assert.notNull(requestConfigurer,
84+
"requestConfigurer must not be null! You may want use a no-op one 'r -> {}' instead.\"");
8185

8286
this.cachedClients = new ConcurrentHashMap<>();
8387
this.scheme = scheme;
@@ -86,6 +90,7 @@ private DefaultWebClientProvider(String scheme, @Nullable ClientHttpConnector co
8690
this.headers = headers;
8791
this.pathPrefix = pathPrefix;
8892
this.webClientConfigurer = webClientConfigurer;
93+
this.requestConfigurer = requestConfigurer;
8994
}
9095

9196
@Override
@@ -106,6 +111,7 @@ public Consumer<Throwable> getErrorListener() {
106111
return this.errorListener;
107112
}
108113

114+
@Nullable
109115
@Override
110116
public String getPathPrefix() {
111117
return pathPrefix;
@@ -120,7 +126,17 @@ public WebClientProvider withDefaultHeaders(HttpHeaders headers) {
120126
merged.addAll(this.headers);
121127
merged.addAll(headers);
122128

123-
return new DefaultWebClientProvider(scheme, connector, errorListener, merged, pathPrefix, webClientConfigurer);
129+
return new DefaultWebClientProvider(scheme, connector, errorListener, merged, pathPrefix, webClientConfigurer,
130+
requestConfigurer);
131+
}
132+
133+
@Override
134+
public WebClientProvider withRequestConfigurer(Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer) {
135+
136+
Assert.notNull(requestConfigurer, "requestConfigurer must not be null.");
137+
138+
return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer,
139+
requestConfigurer);
124140
}
125141

126142
@Override
@@ -129,26 +145,30 @@ public WebClientProvider withErrorListener(Consumer<Throwable> errorListener) {
129145
Assert.notNull(errorListener, "Error listener must not be null.");
130146

131147
Consumer<Throwable> listener = this.errorListener.andThen(errorListener);
132-
return new DefaultWebClientProvider(scheme, this.connector, listener, headers, pathPrefix, webClientConfigurer);
148+
return new DefaultWebClientProvider(scheme, this.connector, listener, headers, pathPrefix, webClientConfigurer,
149+
requestConfigurer);
133150
}
134151

135152
@Override
136153
public WebClientProvider withPathPrefix(String pathPrefix) {
137154
Assert.notNull(pathPrefix, "pathPrefix must not be null.");
138155

139156
return new DefaultWebClientProvider(this.scheme, this.connector, this.errorListener, this.headers, pathPrefix,
140-
webClientConfigurer);
157+
webClientConfigurer, requestConfigurer);
141158
}
142159

143160
@Override
144161
public WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer) {
145-
return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer);
162+
return new DefaultWebClientProvider(scheme, connector, errorListener, headers, pathPrefix, webClientConfigurer,
163+
requestConfigurer);
146164

147165
}
148166

149167
protected WebClient createWebClientForSocketAddress(InetSocketAddress socketAddress) {
150168

151-
Builder builder = WebClient.builder().defaultHeaders(it -> it.addAll(getDefaultHeaders()));
169+
Builder builder = WebClient.builder() //
170+
.defaultHeaders(it -> it.addAll(getDefaultHeaders())) //
171+
.defaultRequest(requestConfigurer);
152172

153173
if (connector != null) {
154174
builder = builder.clientConnector(connector);

Diff for: src/main/java/org/springframework/data/elasticsearch/client/reactive/HostProvider.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,9 @@ static HostProvider<?> provider(WebClientProvider clientProvider, Supplier<HttpH
5454
Assert.notEmpty(endpoints, "Please provide at least one endpoint to connect to.");
5555

5656
if (endpoints.length == 1) {
57-
return new SingleNodeHostProvider(clientProvider, headersSupplier, endpoints[0]);
57+
return new SingleNodeHostProvider(clientProvider, endpoints[0]);
5858
} else {
59-
return new MultiNodeHostProvider(clientProvider, headersSupplier, endpoints);
59+
return new MultiNodeHostProvider(clientProvider, endpoints);
6060
}
6161
}
6262

Diff for: src/main/java/org/springframework/data/elasticsearch/client/reactive/MultiNodeHostProvider.java

+1-7
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,12 @@
2828
import java.util.List;
2929
import java.util.Map;
3030
import java.util.concurrent.ConcurrentHashMap;
31-
import java.util.function.Supplier;
3231

3332
import org.slf4j.Logger;
3433
import org.slf4j.LoggerFactory;
3534
import org.springframework.data.elasticsearch.client.ElasticsearchHost;
3635
import org.springframework.data.elasticsearch.client.ElasticsearchHost.State;
3736
import org.springframework.data.elasticsearch.client.NoReachableHostException;
38-
import org.springframework.http.HttpHeaders;
3937
import org.springframework.lang.Nullable;
4038
import org.springframework.web.reactive.function.client.ClientResponse;
4139
import org.springframework.web.reactive.function.client.WebClient;
@@ -53,14 +51,11 @@ class MultiNodeHostProvider implements HostProvider<MultiNodeHostProvider> {
5351
private final static Logger LOG = LoggerFactory.getLogger(MultiNodeHostProvider.class);
5452

5553
private final WebClientProvider clientProvider;
56-
private final Supplier<HttpHeaders> headersSupplier;
5754
private final Map<InetSocketAddress, ElasticsearchHost> hosts;
5855

59-
MultiNodeHostProvider(WebClientProvider clientProvider, Supplier<HttpHeaders> headersSupplier,
60-
InetSocketAddress... endpoints) {
56+
MultiNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress... endpoints) {
6157

6258
this.clientProvider = clientProvider;
63-
this.headersSupplier = headersSupplier;
6459
this.hosts = new ConcurrentHashMap<>();
6560
for (InetSocketAddress endpoint : endpoints) {
6661
this.hosts.put(endpoint, new ElasticsearchHost(endpoint, State.UNKNOWN));
@@ -166,7 +161,6 @@ private Flux<Tuple2<InetSocketAddress, State>> checkNodes(@Nullable State state)
166161

167162
Mono<ClientResponse> clientResponseMono = createWebClient(host) //
168163
.head().uri("/") //
169-
.headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) //
170164
.exchangeToMono(Mono::just) //
171165
.timeout(Duration.ofSeconds(1)) //
172166
.doOnError(throwable -> {

Diff for: src/main/java/org/springframework/data/elasticsearch/client/reactive/SingleNodeHostProvider.java

+1-7
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919

2020
import java.net.InetSocketAddress;
2121
import java.util.Collections;
22-
import java.util.function.Supplier;
2322

2423
import org.springframework.data.elasticsearch.client.ElasticsearchHost;
2524
import org.springframework.data.elasticsearch.client.ElasticsearchHost.State;
2625
import org.springframework.data.elasticsearch.client.NoReachableHostException;
27-
import org.springframework.http.HttpHeaders;
2826
import org.springframework.web.reactive.function.client.WebClient;
2927

3028
/**
@@ -38,15 +36,12 @@
3836
class SingleNodeHostProvider implements HostProvider<SingleNodeHostProvider> {
3937

4038
private final WebClientProvider clientProvider;
41-
private final Supplier<HttpHeaders> headersSupplier;
4239
private final InetSocketAddress endpoint;
4340
private volatile ElasticsearchHost state;
4441

45-
SingleNodeHostProvider(WebClientProvider clientProvider, Supplier<HttpHeaders> headersSupplier,
46-
InetSocketAddress endpoint) {
42+
SingleNodeHostProvider(WebClientProvider clientProvider, InetSocketAddress endpoint) {
4743

4844
this.clientProvider = clientProvider;
49-
this.headersSupplier = headersSupplier;
5045
this.endpoint = endpoint;
5146
this.state = new ElasticsearchHost(this.endpoint, State.UNKNOWN);
5247
}
@@ -60,7 +55,6 @@ public Mono<ClusterInformation> clusterInfo() {
6055

6156
return createWebClient(endpoint) //
6257
.head().uri("/") //
63-
.headers(httpHeaders -> httpHeaders.addAll(headersSupplier.get())) //
6458
.exchangeToMono(it -> {
6559
if (it.statusCode().isError()) {
6660
state = ElasticsearchHost.offline(endpoint);

Diff for: src/main/java/org/springframework/data/elasticsearch/client/reactive/WebClientProvider.java

+13-3
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ static WebClientProvider create(String scheme, @Nullable ClientHttpConnector con
101101

102102
/**
103103
* Obtain the {@link String pathPrefix} to be used.
104-
*
104+
*
105105
* @return the pathPrefix if set.
106106
* @since 4.0
107107
*/
@@ -126,7 +126,7 @@ static WebClientProvider create(String scheme, @Nullable ClientHttpConnector con
126126

127127
/**
128128
* Create a new instance of {@link WebClientProvider} where HTTP requests are called with the given path prefix.
129-
*
129+
*
130130
* @param pathPrefix Path prefix to add to requests
131131
* @return new instance of {@link WebClientProvider}
132132
* @since 4.0
@@ -136,10 +136,20 @@ static WebClientProvider create(String scheme, @Nullable ClientHttpConnector con
136136
/**
137137
* Create a new instance of {@link WebClientProvider} calling the given {@link Function} to configure the
138138
* {@link WebClient}.
139-
*
139+
*
140140
* @param webClientConfigurer configuration function
141141
* @return new instance of {@link WebClientProvider}
142142
* @since 4.0
143143
*/
144144
WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient> webClientConfigurer);
145+
146+
/**
147+
* Create a new instance of {@link WebClientProvider} calling the given {@link Consumer} to configure the requests of
148+
* this {@link WebClient}.
149+
*
150+
* @param requestConfigurer request configuration callback
151+
* @return new instance of {@link WebClientProvider}
152+
* @since 4.3
153+
*/
154+
WebClientProvider withRequestConfigurer(Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer);
145155
}

Diff for: src/test/java/org/springframework/data/elasticsearch/client/reactive/ReactiveMockClientTestsUtils.java

+7-2
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ public static <T extends HostProvider<T>> MockDelegatingElasticsearchHostProvide
8383

8484
if (hosts.length == 1) {
8585
// noinspection unchecked
86-
delegate = (T) new SingleNodeHostProvider(clientProvider, HttpHeaders::new, getInetSocketAddress(hosts[0])) {};
86+
delegate = (T) new SingleNodeHostProvider(clientProvider, getInetSocketAddress(hosts[0])) {};
8787
} else {
8888
// noinspection unchecked
89-
delegate = (T) new MultiNodeHostProvider(clientProvider, HttpHeaders::new, Arrays.stream(hosts)
89+
delegate = (T) new MultiNodeHostProvider(clientProvider, Arrays.stream(hosts)
9090
.map(ReactiveMockClientTestsUtils::getInetSocketAddress).toArray(InetSocketAddress[]::new)) {};
9191
}
9292

@@ -297,6 +297,11 @@ public WebClientProvider withWebClientConfigurer(Function<WebClient, WebClient>
297297
throw new UnsupportedOperationException("not implemented");
298298
}
299299

300+
@Override
301+
public WebClientProvider withRequestConfigurer(Consumer<WebClient.RequestHeadersSpec<?>> requestConfigurer) {
302+
throw new UnsupportedOperationException("not implemented");
303+
}
304+
300305
public Send when(String host) {
301306
InetSocketAddress inetSocketAddress = getInetSocketAddress(host);
302307
return new CallbackImpl(get(host), headersUriSpecMap.get(inetSocketAddress),

0 commit comments

Comments
 (0)