Skip to content

Commit c9b8b1a

Browse files
authored
Adapt custom routing.
Original Pull Request #2474 Closes #2087
1 parent 9c80dc9 commit c9b8b1a

10 files changed

+132
-74
lines changed

Diff for: src/main/java/org/springframework/data/elasticsearch/client/elc/ElasticsearchTemplate.java

+11-7
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ public ByQueryResponse delete(Query query, Class<?> clazz, IndexCoordinates inde
174174

175175
Assert.notNull(query, "query must not be null");
176176

177-
DeleteByQueryRequest request = requestConverter.documentDeleteByQueryRequest(query, clazz, index,
178-
getRefreshPolicy());
177+
DeleteByQueryRequest request = requestConverter.documentDeleteByQueryRequest(query, routingResolver.getRouting(),
178+
clazz, index, getRefreshPolicy());
179179

180180
DeleteByQueryResponse response = execute(client -> client.deleteByQuery(request));
181181

@@ -309,7 +309,8 @@ public long count(Query query, @Nullable Class<?> clazz, IndexCoordinates index)
309309
Assert.notNull(query, "query must not be null");
310310
Assert.notNull(index, "index must not be null");
311311

312-
SearchRequest searchRequest = requestConverter.searchRequest(query, clazz, index, true);
312+
SearchRequest searchRequest = requestConverter.searchRequest(query, routingResolver.getRouting(), clazz, index,
313+
true);
313314

314315
SearchResponse<EntityAsMap> searchResponse = execute(client -> client.search(searchRequest, EntityAsMap.class));
315316

@@ -331,7 +332,8 @@ public <T> SearchHits<T> search(Query query, Class<T> clazz, IndexCoordinates in
331332
}
332333

333334
protected <T> SearchHits<T> doSearch(Query query, Class<T> clazz, IndexCoordinates index) {
334-
SearchRequest searchRequest = requestConverter.searchRequest(query, clazz, index, false);
335+
SearchRequest searchRequest = requestConverter.searchRequest(query, routingResolver.getRouting(), clazz, index,
336+
false);
335337
SearchResponse<EntityAsMap> searchResponse = execute(client -> client.search(searchRequest, EntityAsMap.class));
336338

337339
// noinspection DuplicatedCode
@@ -343,7 +345,7 @@ protected <T> SearchHits<T> doSearch(Query query, Class<T> clazz, IndexCoordinat
343345
}
344346

345347
protected <T> SearchHits<T> doSearch(SearchTemplateQuery query, Class<T> clazz, IndexCoordinates index) {
346-
var searchTemplateRequest = requestConverter.searchTemplate(query, index);
348+
var searchTemplateRequest = requestConverter.searchTemplate(query, routingResolver.getRouting(), index);
347349
var searchTemplateResponse = execute(client -> client.searchTemplate(searchTemplateRequest, EntityAsMap.class));
348350

349351
// noinspection DuplicatedCode
@@ -374,7 +376,8 @@ public <T> SearchScrollHits<T> searchScrollStart(long scrollTimeInMillis, Query
374376
Assert.notNull(query, "query must not be null");
375377
Assert.notNull(query.getPageable(), "pageable of query must not be null.");
376378

377-
SearchRequest request = requestConverter.searchRequest(query, clazz, index, false, scrollTimeInMillis);
379+
SearchRequest request = requestConverter.searchRequest(query, routingResolver.getRouting(), clazz, index, false,
380+
scrollTimeInMillis);
378381
SearchResponse<EntityAsMap> response = execute(client -> client.search(request, EntityAsMap.class));
379382

380383
return getSearchScrollHits(clazz, index, response);
@@ -492,7 +495,8 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
492495
@SuppressWarnings({ "unchecked", "rawtypes" })
493496
private List<SearchHits<?>> doMultiSearch(List<MultiSearchQueryParameter> multiSearchQueryParameters) {
494497

495-
MsearchRequest request = requestConverter.searchMsearchRequest(multiSearchQueryParameters);
498+
MsearchRequest request = requestConverter.searchMsearchRequest(multiSearchQueryParameters,
499+
routingResolver.getRouting());
496500

497501
MsearchResponse<EntityAsMap> msearchResponse = execute(client -> client.msearch(request, EntityAsMap.class));
498502
List<MultiSearchResponseItem<EntityAsMap>> responseItems = msearchResponse.responses();

Diff for: src/main/java/org/springframework/data/elasticsearch/client/elc/ReactiveElasticsearchTemplate.java

+14-10
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,7 @@ protected <T> Mono<Tuple2<T, IndexResponseMetaData>> doIndex(T entity, IndexCoor
111111
return Mono.just(entity) //
112112
.zipWith(//
113113
Mono.from(execute((ClientCallback<Publisher<IndexResponse>>) client -> client.index(indexRequest))) //
114-
.map(indexResponse -> new IndexResponseMetaData(
115-
indexResponse.id(), //
114+
.map(indexResponse -> new IndexResponseMetaData(indexResponse.id(), //
116115
indexResponse.index(), //
117116
indexResponse.seqNo(), //
118117
indexResponse.primaryTerm(), //
@@ -171,8 +170,8 @@ public Mono<ByQueryResponse> delete(Query query, Class<?> entityType, IndexCoord
171170

172171
Assert.notNull(query, "query must not be null");
173172

174-
DeleteByQueryRequest request = requestConverter.documentDeleteByQueryRequest(query, entityType, index,
175-
getRefreshPolicy());
173+
DeleteByQueryRequest request = requestConverter.documentDeleteByQueryRequest(query, routingResolver.getRouting(),
174+
entityType, index, getRefreshPolicy());
176175
return Mono
177176
.from(execute((ClientCallback<Publisher<DeleteByQueryResponse>>) client -> client.deleteByQuery(request)))
178177
.map(responseConverter::byQueryResponse);
@@ -391,7 +390,8 @@ private Flux<SearchDocument> doFindUnbounded(Query query, Class<?> clazz, IndexC
391390

392391
baseQuery.setPointInTime(new Query.PointInTime(psa.getPit(), pitKeepAlive));
393392
baseQuery.addSort(Sort.by("_shard_doc"));
394-
SearchRequest firstSearchRequest = requestConverter.searchRequest(baseQuery, clazz, index, false, true);
393+
SearchRequest firstSearchRequest = requestConverter.searchRequest(baseQuery, routingResolver.getRouting(),
394+
clazz, index, false, true);
395395

396396
return Mono.from(execute((ClientCallback<Publisher<ResponseBody<EntityAsMap>>>) client -> client
397397
.search(firstSearchRequest, EntityAsMap.class))).expand(entityAsMapSearchResponse -> {
@@ -404,7 +404,8 @@ private Flux<SearchDocument> doFindUnbounded(Query query, Class<?> clazz, IndexC
404404
List<Object> sortOptions = hits.get(hits.size() - 1).sort().stream().map(TypeUtils::toObject)
405405
.collect(Collectors.toList());
406406
baseQuery.setSearchAfter(sortOptions);
407-
SearchRequest followSearchRequest = requestConverter.searchRequest(baseQuery, clazz, index, false, true);
407+
SearchRequest followSearchRequest = requestConverter.searchRequest(baseQuery,
408+
routingResolver.getRouting(), clazz, index, false, true);
408409
return Mono.from(execute((ClientCallback<Publisher<ResponseBody<EntityAsMap>>>) client -> client
409410
.search(followSearchRequest, EntityAsMap.class)));
410411
});
@@ -460,7 +461,8 @@ protected Mono<Long> doCount(Query query, Class<?> entityType, IndexCoordinates
460461
Assert.notNull(query, "query must not be null");
461462
Assert.notNull(index, "index must not be null");
462463

463-
SearchRequest searchRequest = requestConverter.searchRequest(query, entityType, index, true);
464+
SearchRequest searchRequest = requestConverter.searchRequest(query, routingResolver.getRouting(), entityType, index,
465+
true);
464466

465467
return Mono
466468
.from(execute((ClientCallback<Publisher<ResponseBody<EntityAsMap>>>) client -> client.search(searchRequest,
@@ -470,7 +472,8 @@ protected Mono<Long> doCount(Query query, Class<?> entityType, IndexCoordinates
470472

471473
private Flux<SearchDocument> doFindBounded(Query query, Class<?> clazz, IndexCoordinates index) {
472474

473-
SearchRequest searchRequest = requestConverter.searchRequest(query, clazz, index, false, false);
475+
SearchRequest searchRequest = requestConverter.searchRequest(query, routingResolver.getRouting(), clazz, index,
476+
false, false);
474477

475478
return Mono
476479
.from(execute((ClientCallback<Publisher<ResponseBody<EntityAsMap>>>) client -> client.search(searchRequest,
@@ -481,7 +484,7 @@ private Flux<SearchDocument> doFindBounded(Query query, Class<?> clazz, IndexCoo
481484

482485
private Flux<SearchDocument> doSearch(SearchTemplateQuery query, Class<?> clazz, IndexCoordinates index) {
483486

484-
var request = requestConverter.searchTemplate(query, index);
487+
var request = requestConverter.searchTemplate(query, routingResolver.getRouting(), index);
485488

486489
return Mono
487490
.from(execute((ClientCallback<Publisher<SearchTemplateResponse<EntityAsMap>>>) client -> client
@@ -496,7 +499,8 @@ protected <T> Mono<SearchDocumentResponse> doFindForResponse(Query query, Class<
496499
Assert.notNull(query, "query must not be null");
497500
Assert.notNull(index, "index must not be null");
498501

499-
SearchRequest searchRequest = requestConverter.searchRequest(query, clazz, index, false);
502+
SearchRequest searchRequest = requestConverter.searchRequest(query, routingResolver.getRouting(), clazz, index,
503+
false);
500504

501505
// noinspection unchecked
502506
SearchDocumentCallback<T> callback = new ReadSearchDocumentCallback<>((Class<T>) clazz, index);

Diff for: src/main/java/org/springframework/data/elasticsearch/client/elc/RequestConverter.java

+43-20
Original file line numberDiff line numberDiff line change
@@ -837,8 +837,8 @@ public DeleteRequest documentDeleteRequest(String id, @Nullable String routing,
837837
});
838838
}
839839

840-
public DeleteByQueryRequest documentDeleteByQueryRequest(Query query, Class<?> clazz, IndexCoordinates index,
841-
@Nullable RefreshPolicy refreshPolicy) {
840+
public DeleteByQueryRequest documentDeleteByQueryRequest(Query query, @Nullable String routing, Class<?> clazz,
841+
IndexCoordinates index, @Nullable RefreshPolicy refreshPolicy) {
842842

843843
Assert.notNull(query, "query must not be null");
844844
Assert.notNull(index, "index must not be null");
@@ -857,6 +857,8 @@ public DeleteByQueryRequest documentDeleteByQueryRequest(Query query, Class<?> c
857857

858858
if (query.getRoute() != null) {
859859
b.routing(query.getRoute());
860+
} else if (StringUtils.hasText(routing)) {
861+
b.routing(routing);
860862
}
861863

862864
return b;
@@ -998,45 +1000,53 @@ public UpdateByQueryRequest documentUpdateByQueryRequest(UpdateQuery updateQuery
9981000

9991001
// region search
10001002

1001-
public <T> SearchRequest searchRequest(Query query, @Nullable Class<T> clazz, IndexCoordinates indexCoordinates,
1002-
boolean forCount) {
1003-
return searchRequest(query, clazz, indexCoordinates, forCount, false, null);
1003+
public <T> SearchRequest searchRequest(Query query, @Nullable String routing, @Nullable Class<T> clazz,
1004+
IndexCoordinates indexCoordinates, boolean forCount) {
1005+
return searchRequest(query, routing, clazz, indexCoordinates, forCount, false, null);
10041006
}
10051007

1006-
public <T> SearchRequest searchRequest(Query query, @Nullable Class<T> clazz, IndexCoordinates indexCoordinates,
1007-
boolean forCount, long scrollTimeInMillis) {
1008-
return searchRequest(query, clazz, indexCoordinates, forCount, true, scrollTimeInMillis);
1008+
public <T> SearchRequest searchRequest(Query query, @Nullable String routing, @Nullable Class<T> clazz,
1009+
IndexCoordinates indexCoordinates, boolean forCount, long scrollTimeInMillis) {
1010+
return searchRequest(query, routing, clazz, indexCoordinates, forCount, true, scrollTimeInMillis);
10091011
}
10101012

1011-
public <T> SearchRequest searchRequest(Query query, @Nullable Class<T> clazz, IndexCoordinates indexCoordinates,
1012-
boolean forCount, boolean forBatchedSearch) {
1013-
return searchRequest(query, clazz, indexCoordinates, forCount, forBatchedSearch, null);
1013+
public <T> SearchRequest searchRequest(Query query, @Nullable String routing, @Nullable Class<T> clazz,
1014+
IndexCoordinates indexCoordinates, boolean forCount, boolean forBatchedSearch) {
1015+
return searchRequest(query, routing, clazz, indexCoordinates, forCount, forBatchedSearch, null);
10141016
}
10151017

1016-
public <T> SearchRequest searchRequest(Query query, @Nullable Class<T> clazz, IndexCoordinates indexCoordinates,
1017-
boolean forCount, boolean forBatchedSearch, @Nullable Long scrollTimeInMillis) {
1018+
public <T> SearchRequest searchRequest(Query query, @Nullable String routing, @Nullable Class<T> clazz,
1019+
IndexCoordinates indexCoordinates, boolean forCount, boolean forBatchedSearch,
1020+
@Nullable Long scrollTimeInMillis) {
10181021

10191022
String[] indexNames = indexCoordinates.getIndexNames();
10201023
Assert.notNull(query, "query must not be null");
10211024
Assert.notNull(indexCoordinates, "indexCoordinates must not be null");
10221025

10231026
elasticsearchConverter.updateQuery(query, clazz);
10241027
SearchRequest.Builder builder = new SearchRequest.Builder();
1025-
prepareSearchRequest(query, clazz, indexCoordinates, builder, forCount, forBatchedSearch);
1028+
prepareSearchRequest(query, routing, clazz, indexCoordinates, builder, forCount, forBatchedSearch);
10261029

10271030
if (scrollTimeInMillis != null) {
10281031
builder.scroll(t -> t.time(scrollTimeInMillis + "ms"));
10291032
}
10301033

10311034
builder.query(getQuery(query, clazz));
10321035

1036+
if (StringUtils.hasText(query.getRoute())) {
1037+
builder.routing(query.getRoute());
1038+
}
1039+
if (StringUtils.hasText(routing)) {
1040+
builder.routing(routing);
1041+
}
1042+
10331043
addFilter(query, builder);
10341044

10351045
return builder.build();
10361046
}
10371047

10381048
public MsearchRequest searchMsearchRequest(
1039-
List<ElasticsearchTemplate.MultiSearchQueryParameter> multiSearchQueryParameters) {
1049+
List<ElasticsearchTemplate.MultiSearchQueryParameter> multiSearchQueryParameters, @Nullable String routing) {
10401050

10411051
// basically the same stuff as in prepareSearchRequest, but the new Elasticsearch has different builders for a
10421052
// normal search and msearch
@@ -1049,11 +1059,16 @@ public MsearchRequest searchMsearchRequest(
10491059
.header(h -> {
10501060
h //
10511061
.index(Arrays.asList(param.index().getIndexNames())) //
1052-
.routing(query.getRoute()) //
10531062
.searchType(searchType(query.getSearchType())) //
10541063
.requestCache(query.getRequestCache()) //
10551064
;
10561065

1066+
if (StringUtils.hasText(query.getRoute())) {
1067+
h.routing(query.getRoute());
1068+
} else if (StringUtils.hasText(routing)) {
1069+
h.routing(routing);
1070+
}
1071+
10571072
if (query.getPreference() != null) {
10581073
h.preference(query.getPreference());
10591074
}
@@ -1156,8 +1171,8 @@ public MsearchRequest searchMsearchRequest(
11561171
});
11571172
}
11581173

1159-
private <T> void prepareSearchRequest(Query query, @Nullable Class<T> clazz, IndexCoordinates indexCoordinates,
1160-
SearchRequest.Builder builder, boolean forCount, boolean forBatchedSearch) {
1174+
private <T> void prepareSearchRequest(Query query, @Nullable String routing, @Nullable Class<T> clazz,
1175+
IndexCoordinates indexCoordinates, SearchRequest.Builder builder, boolean forCount, boolean forBatchedSearch) {
11611176

11621177
String[] indexNames = indexCoordinates.getIndexNames();
11631178

@@ -1190,6 +1205,8 @@ private <T> void prepareSearchRequest(Query query, @Nullable Class<T> clazz, Ind
11901205

11911206
if (query.getRoute() != null) {
11921207
builder.routing(query.getRoute());
1208+
} else if (StringUtils.hasText(routing)) {
1209+
builder.routing(routing);
11931210
}
11941211

11951212
if (query.getPreference() != null) {
@@ -1559,7 +1576,8 @@ public ClosePointInTimeRequest searchClosePointInTime(String pit) {
15591576
return ClosePointInTimeRequest.of(cpit -> cpit.id(pit));
15601577
}
15611578

1562-
public SearchTemplateRequest searchTemplate(SearchTemplateQuery query, IndexCoordinates index) {
1579+
public SearchTemplateRequest searchTemplate(SearchTemplateQuery query, @Nullable String routing,
1580+
IndexCoordinates index) {
15631581

15641582
Assert.notNull(query, "query must not be null");
15651583

@@ -1570,10 +1588,15 @@ public SearchTemplateRequest searchTemplate(SearchTemplateQuery query, IndexCoor
15701588
.id(query.getId()) //
15711589
.index(Arrays.asList(index.getIndexNames())) //
15721590
.preference(query.getPreference()) //
1573-
.routing(query.getRoute()) //
15741591
.searchType(searchType(query.getSearchType())).source(query.getSource()) //
15751592
;
15761593

1594+
if (query.getRoute() != null) {
1595+
builder.routing(query.getRoute());
1596+
} else if (StringUtils.hasText(routing)) {
1597+
builder.routing(routing);
1598+
}
1599+
15771600
var expandWildcards = query.getExpandWildcards();
15781601
if (!expandWildcards.isEmpty()) {
15791602
builder.expandWildcards(expandWildcards(expandWildcards));

Diff for: src/main/java/org/springframework/data/elasticsearch/client/erhlc/ElasticsearchRestTemplate.java

+10-8
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ protected String doDelete(String id, @Nullable String routing, IndexCoordinates
261261

262262
@Override
263263
public ByQueryResponse delete(Query query, Class<?> clazz, IndexCoordinates index) {
264-
DeleteByQueryRequest deleteByQueryRequest = requestFactory.deleteByQueryRequest(query, clazz, index);
264+
DeleteByQueryRequest deleteByQueryRequest = requestFactory.deleteByQueryRequest(query, routingResolver.getRouting(),
265+
clazz, index);
265266
return ResponseConverter
266267
.byQueryResponseOf(execute(client -> client.deleteByQuery(deleteByQueryRequest, RequestOptions.DEFAULT)));
267268
}
@@ -398,7 +399,7 @@ public long count(Query query, @Nullable Class<?> clazz, IndexCoordinates index)
398399

399400
final Boolean trackTotalHits = query.getTrackTotalHits();
400401
query.setTrackTotalHits(true);
401-
SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index);
402+
SearchRequest searchRequest = requestFactory.searchRequest(query, routingResolver.getRouting(), clazz, index);
402403
query.setTrackTotalHits(trackTotalHits);
403404

404405
searchRequest.source().size(0);
@@ -409,7 +410,7 @@ public long count(Query query, @Nullable Class<?> clazz, IndexCoordinates index)
409410

410411
@Override
411412
public <T> SearchHits<T> search(Query query, Class<T> clazz, IndexCoordinates index) {
412-
SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index);
413+
SearchRequest searchRequest = requestFactory.searchRequest(query, routingResolver.getRouting(), clazz, index);
413414
SearchResponse response = execute(client -> client.search(searchRequest, RequestOptions.DEFAULT));
414415

415416
ReadDocumentCallback<T> documentCallback = new ReadDocumentCallback<>(elasticsearchConverter, clazz, index);
@@ -431,7 +432,7 @@ public <T> SearchScrollHits<T> searchScrollStart(long scrollTimeInMillis, Query
431432

432433
Assert.notNull(query.getPageable(), "pageable of query must not be null.");
433434

434-
SearchRequest searchRequest = requestFactory.searchRequest(query, clazz, index);
435+
SearchRequest searchRequest = requestFactory.searchRequest(query, routingResolver.getRouting(), clazz, index);
435436
searchRequest.scroll(TimeValue.timeValueMillis(scrollTimeInMillis));
436437

437438
SearchResponse response = execute(client -> client.search(searchRequest, RequestOptions.DEFAULT));
@@ -477,7 +478,7 @@ public SearchResponse suggest(SuggestBuilder suggestion, IndexCoordinates index)
477478
public <T> List<SearchHits<T>> multiSearch(List<? extends Query> queries, Class<T> clazz, IndexCoordinates index) {
478479
MultiSearchRequest request = new MultiSearchRequest();
479480
for (Query query : queries) {
480-
request.add(requestFactory.searchRequest(query, clazz, index));
481+
request.add(requestFactory.searchRequest(query, routingResolver.getRouting(), clazz, index));
481482
}
482483

483484
MultiSearchResponse.Item[] items = getMultiSearchResult(request);
@@ -504,7 +505,8 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
504505
Iterator<Class<?>> it = classes.iterator();
505506
for (Query query : queries) {
506507
Class<?> clazz = it.next();
507-
request.add(requestFactory.searchRequest(query, clazz, getIndexCoordinatesFor(clazz)));
508+
request
509+
.add(requestFactory.searchRequest(query, routingResolver.getRouting(), clazz, getIndexCoordinatesFor(clazz)));
508510
}
509511

510512
MultiSearchResponse.Item[] items = getMultiSearchResult(request);
@@ -538,7 +540,7 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
538540
MultiSearchRequest request = new MultiSearchRequest();
539541
Iterator<Class<?>> it = classes.iterator();
540542
for (Query query : queries) {
541-
request.add(requestFactory.searchRequest(query, it.next(), index));
543+
request.add(requestFactory.searchRequest(query, routingResolver.getRouting(), it.next(), index));
542544
}
543545

544546
MultiSearchResponse.Item[] items = getMultiSearchResult(request);
@@ -572,7 +574,7 @@ public List<SearchHits<?>> multiSearch(List<? extends Query> queries, List<Class
572574
Iterator<Class<?>> it = classes.iterator();
573575
Iterator<IndexCoordinates> indexesIt = indexes.iterator();
574576
for (Query query : queries) {
575-
request.add(requestFactory.searchRequest(query, it.next(), indexesIt.next()));
577+
request.add(requestFactory.searchRequest(query, routingResolver.getRouting(), it.next(), indexesIt.next()));
576578
}
577579

578580
MultiSearchResponse.Item[] items = getMultiSearchResult(request);

0 commit comments

Comments
 (0)