Skip to content

Commit b289d5f

Browse files
Add Rescore functionality.
Original Pull Request #1688 Closes #1686
1 parent eb816cc commit b289d5f

File tree

7 files changed

+402
-7
lines changed

7 files changed

+402
-7
lines changed

Diff for: src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java

+35-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.HashMap;
2525
import java.util.List;
2626
import java.util.Map;
27+
import java.util.Objects;
2728
import java.util.Optional;
2829

2930
import org.elasticsearch.action.DocWriteRequest;
@@ -37,7 +38,6 @@
3738
import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequestBuilder;
3839
import org.elasticsearch.action.admin.indices.refresh.RefreshRequest;
3940
import org.elasticsearch.action.admin.indices.settings.get.GetSettingsRequest;
40-
import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse;
4141
import org.elasticsearch.action.admin.indices.template.delete.DeleteIndexTemplateRequest;
4242
import org.elasticsearch.action.bulk.BulkRequest;
4343
import org.elasticsearch.action.bulk.BulkRequestBuilder;
@@ -65,7 +65,6 @@
6565
import org.elasticsearch.client.indices.PutIndexTemplateRequest;
6666
import org.elasticsearch.client.indices.PutMappingRequest;
6767
import org.elasticsearch.common.geo.GeoDistance;
68-
import org.elasticsearch.common.settings.Settings;
6968
import org.elasticsearch.common.unit.DistanceUnit;
7069
import org.elasticsearch.common.unit.TimeValue;
7170
import org.elasticsearch.index.VersionType;
@@ -83,6 +82,8 @@
8382
import org.elasticsearch.search.builder.SearchSourceBuilder;
8483
import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
8584
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
85+
import org.elasticsearch.search.rescore.QueryRescoreMode;
86+
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
8687
import org.elasticsearch.search.sort.FieldSortBuilder;
8788
import org.elasticsearch.search.sort.GeoDistanceSortBuilder;
8889
import org.elasticsearch.search.sort.ScoreSortBuilder;
@@ -106,6 +107,7 @@
106107
import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty;
107108
import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates;
108109
import org.springframework.data.elasticsearch.core.query.*;
110+
import org.springframework.data.elasticsearch.core.query.RescorerQuery.ScoreMode;
109111
import org.springframework.data.mapping.context.MappingContext;
110112
import org.springframework.lang.Nullable;
111113
import org.springframework.util.Assert;
@@ -119,6 +121,7 @@
119121
* @author Roman Puchkovskiy
120122
* @author Subhobrata Dey
121123
* @author Farid Faoudi
124+
* @author Peer Mueller
122125
* @since 4.0
123126
*/
124127
class RequestFactory {
@@ -1050,6 +1053,9 @@ private SearchRequest prepareSearchRequest(Query query, @Nullable Class<?> clazz
10501053
sourceBuilder.searchAfter(query.getSearchAfter().toArray());
10511054
}
10521055

1056+
query.getRescorerQueries().forEach(rescorer -> sourceBuilder.addRescorer(
1057+
getQueryRescorerBuilder(rescorer)));
1058+
10531059
request.source(sourceBuilder);
10541060
return request;
10551061
}
@@ -1136,6 +1142,9 @@ private SearchRequestBuilder prepareSearchRequestBuilder(Query query, Client cli
11361142
searchRequestBuilder.searchAfter(query.getSearchAfter().toArray());
11371143
}
11381144

1145+
query.getRescorerQueries().forEach(rescorer -> searchRequestBuilder.addRescorer(
1146+
getQueryRescorerBuilder(rescorer)));
1147+
11391148
return searchRequestBuilder;
11401149
}
11411150

@@ -1260,6 +1269,30 @@ private SortBuilder<?> getSortBuilder(Sort.Order order, @Nullable ElasticsearchP
12601269
}
12611270
}
12621271
}
1272+
1273+
private QueryRescorerBuilder getQueryRescorerBuilder(RescorerQuery rescorerQuery) {
1274+
1275+
QueryRescorerBuilder builder = new QueryRescorerBuilder(Objects.requireNonNull(getQuery(rescorerQuery.getQuery())));
1276+
1277+
if (rescorerQuery.getScoreMode() != ScoreMode.Default) {
1278+
builder.setScoreMode(QueryRescoreMode.valueOf(rescorerQuery.getScoreMode().name()));
1279+
}
1280+
1281+
if (rescorerQuery.getQueryWeight() != null) {
1282+
builder.setQueryWeight(rescorerQuery.getQueryWeight());
1283+
}
1284+
1285+
if (rescorerQuery.getRescoreQueryWeight() != null) {
1286+
builder.setRescoreQueryWeight(rescorerQuery.getRescoreQueryWeight());
1287+
}
1288+
1289+
if (rescorerQuery.getWindowSize() != null) {
1290+
builder.windowSize(rescorerQuery.getWindowSize());
1291+
}
1292+
1293+
return builder;
1294+
1295+
}
12631296
// endregion
12641297

12651298
// region update

Diff for: src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java

+18
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
* @author Sascha Woo
4242
* @author Farid Azaza
4343
* @author Peter-Josef Meisch
44+
* @author Peer Mueller
4445
*/
4546
abstract class AbstractQuery implements Query {
4647

@@ -63,6 +64,7 @@ abstract class AbstractQuery implements Query {
6364
@Nullable private TimeValue timeout;
6465
private boolean explain = false;
6566
@Nullable private List<Object> searchAfter;
67+
protected List<RescorerQuery> rescorerQueries = new ArrayList<>();
6668

6769
@Override
6870
@Nullable
@@ -295,4 +297,20 @@ public void setSearchAfter(@Nullable List<Object> searchAfter) {
295297
public List<Object> getSearchAfter() {
296298
return searchAfter;
297299
}
300+
301+
@Override
302+
public void addRescorerQuery(RescorerQuery rescorerQuery) {
303+
this.rescorerQueries.add(rescorerQuery);
304+
}
305+
306+
@Override
307+
public void setRescorerQueries(List<RescorerQuery> rescorerQueryList) {
308+
this.rescorerQueries.clear();
309+
this.rescorerQueries.addAll(rescorerQueryList);
310+
}
311+
312+
@Override
313+
public List<RescorerQuery> getRescorerQueries() {
314+
return rescorerQueries;
315+
}
298316
}

Diff for: src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java

+17-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import java.util.ArrayList;
2121
import java.util.Collection;
2222
import java.util.List;
23+
import java.util.stream.Collectors;
2324

2425
import org.elasticsearch.action.search.SearchType;
2526
import org.elasticsearch.action.support.IndicesOptions;
@@ -28,6 +29,7 @@
2829
import org.elasticsearch.search.aggregations.AbstractAggregationBuilder;
2930
import org.elasticsearch.search.collapse.CollapseBuilder;
3031
import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder;
32+
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
3133
import org.elasticsearch.search.sort.SortBuilder;
3234
import org.springframework.data.domain.Pageable;
3335
import org.springframework.lang.Nullable;
@@ -45,6 +47,7 @@
4547
* @author Martin Choraine
4648
* @author Farid Azaza
4749
* @author Peter-Josef Meisch
50+
* @author Peer Mueller
4851
*/
4952
public class NativeSearchQueryBuilder {
5053

@@ -70,6 +73,7 @@ public class NativeSearchQueryBuilder {
7073
@Nullable private Integer maxResults;
7174
@Nullable private Boolean trackTotalHits;
7275
@Nullable private TimeValue timeout;
76+
private final List<RescorerQuery> rescorerQueries = new ArrayList<>();
7377

7478
public NativeSearchQueryBuilder withQuery(QueryBuilder queryBuilder) {
7579
this.queryBuilder = queryBuilder;
@@ -183,12 +187,17 @@ public NativeSearchQueryBuilder withTrackTotalHits(Boolean trackTotalHits) {
183187
this.trackTotalHits = trackTotalHits;
184188
return this;
185189
}
186-
187-
public NativeSearchQueryBuilder withTimeout(TimeValue timeout) {
190+
191+
public NativeSearchQueryBuilder withTimeout(TimeValue timeout) {
188192
this.timeout = timeout;
189193
return this;
190194
}
191195

196+
public NativeSearchQueryBuilder withRescorerQuery(RescorerQuery rescorerQuery) {
197+
this.rescorerQueries.add(rescorerQuery);
198+
return this;
199+
}
200+
192201
public NativeSearchQuery build() {
193202

194203
NativeSearchQuery nativeSearchQuery = new NativeSearchQuery(queryBuilder, filterBuilder, sortBuilders,
@@ -250,11 +259,16 @@ public NativeSearchQuery build() {
250259
}
251260

252261
nativeSearchQuery.setTrackTotalHits(trackTotalHits);
253-
262+
254263
if (timeout != null) {
255264
nativeSearchQuery.setTimeout(timeout);
256265
}
257266

267+
if (!isEmpty(rescorerQueries)) {
268+
nativeSearchQuery.setRescorerQueries(
269+
rescorerQueries);
270+
}
271+
258272
return nativeSearchQuery;
259273
}
260274
}

Diff for: src/main/java/org/springframework/data/elasticsearch/core/query/Query.java

+23-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.time.Duration;
1919
import java.util.Collection;
20+
import java.util.Collections;
2021
import java.util.List;
2122
import java.util.Optional;
2223

@@ -41,6 +42,7 @@
4142
* @author Christoph Strobl
4243
* @author Farid Azaza
4344
* @author Peter-Josef Meisch
45+
* @author Peer Mueller
4446
*/
4547
public interface Query {
4648

@@ -297,7 +299,7 @@ default boolean getExplain() {
297299

298300
/**
299301
* Sets the setSearchAfter objects for this query.
300-
*
302+
*
301303
* @param searchAfter the setSearchAfter objects. These are obtained with {@link SearchHit#getSortValues()} from a
302304
* search result.
303305
* @since 4.2
@@ -310,4 +312,24 @@ default boolean getExplain() {
310312
*/
311313
@Nullable
312314
List<Object> getSearchAfter();
315+
316+
/**
317+
* Sets the {@link RescorerQuery}.
318+
*
319+
* @param rescorerQuery the query to add to the list of rescorer queries
320+
* @since 4.2
321+
*/
322+
void addRescorerQuery(RescorerQuery rescorerQuery);
323+
324+
/**
325+
* Sets the {@link RescorerQuery}.
326+
*
327+
* @param rescorerQueryList list of rescorer queries set
328+
* @since 4.2
329+
*/
330+
void setRescorerQueries(List<RescorerQuery> rescorerQueryList);
331+
332+
default List<RescorerQuery> getRescorerQueries() {
333+
return Collections.emptyList();
334+
}
313335
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
/*
2+
* Copyright 2020-2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.elasticsearch.core.query;
17+
18+
import org.elasticsearch.index.query.QueryBuilder;
19+
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
20+
import org.springframework.lang.Nullable;
21+
22+
/**
23+
* Implementation of RescorerQuery to be used for rescoring filtered search results.
24+
*
25+
* @author Peer Mueller
26+
* @since 4.2
27+
*/
28+
public class RescorerQuery {
29+
30+
private final Query query;
31+
private ScoreMode scoreMode = ScoreMode.Default;
32+
@Nullable private Integer windowSize;
33+
@Nullable private Float queryWeight;
34+
@Nullable private Float rescoreQueryWeight;
35+
36+
public RescorerQuery(Query query) {
37+
this.query = query;
38+
}
39+
40+
public Query getQuery() {
41+
return query;
42+
}
43+
44+
public ScoreMode getScoreMode() {
45+
return scoreMode;
46+
}
47+
48+
@Nullable
49+
public Integer getWindowSize() {
50+
return windowSize;
51+
}
52+
53+
@Nullable
54+
public Float getQueryWeight() {
55+
return queryWeight;
56+
}
57+
58+
@Nullable
59+
public Float getRescoreQueryWeight() {
60+
return rescoreQueryWeight;
61+
}
62+
63+
public RescorerQuery withScoreMode(ScoreMode scoreMode) {
64+
this.scoreMode = scoreMode;
65+
return this;
66+
}
67+
68+
public RescorerQuery withWindowSize(int windowSize) {
69+
this.windowSize = windowSize;
70+
return this;
71+
}
72+
73+
public RescorerQuery withQueryWeight(float queryWeight) {
74+
this.queryWeight = queryWeight;
75+
return this;
76+
}
77+
78+
public RescorerQuery withRescoreQueryWeight(float rescoreQueryWeight) {
79+
this.rescoreQueryWeight = rescoreQueryWeight;
80+
return this;
81+
}
82+
83+
84+
85+
public enum ScoreMode {
86+
Default,
87+
Avg,
88+
Max,
89+
Min,
90+
Total,
91+
Multiply
92+
}
93+
94+
}

0 commit comments

Comments
 (0)