From 130cf63570216b73754f160979ab28d84c5057e7 Mon Sep 17 00:00:00 2001 From: pemueller Date: Thu, 11 Feb 2021 23:03:05 +0100 Subject: [PATCH 1/4] Add Rescore functionality Closes #1686 --- .../elasticsearch/core/RequestFactory.java | 7 ++ .../core/query/AbstractQuery.java | 54 +++++++++++---- .../core/query/NativeSearchQueryBuilder.java | 22 ++++++- .../data/elasticsearch/core/query/Query.java | 22 +++++++ .../core/query/RescorerQuery.java | 40 +++++++++++ .../core/ElasticsearchTemplateTests.java | 66 ++++++++++++++++++- 6 files changed, 195 insertions(+), 16 deletions(-) create mode 100644 src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java diff --git a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java index 2adfe70eb..c897a9d4a 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java @@ -120,6 +120,7 @@ * @author Roman Puchkovskiy * @author Subhobrata Dey * @author Farid Faoudi + * @author Peer Mueller * @since 4.0 */ class RequestFactory { @@ -1147,6 +1148,9 @@ private SearchRequest prepareSearchRequest(Query query, @Nullable Class clazz sourceBuilder.explain(query.getExplain()); + query.getRescorerQueries().stream().map(RescorerQuery::getRescorerBuilder) + .forEach(sourceBuilder::addRescorer); + request.source(sourceBuilder); return request; } @@ -1229,6 +1233,9 @@ private SearchRequestBuilder prepareSearchRequestBuilder(Query query, Client cli searchRequestBuilder.setExplain(query.getExplain()); + query.getRescorerQueries().stream().map(RescorerQuery::getRescorerBuilder) + .forEach(searchRequestBuilder::addRescorer); + return searchRequestBuilder; } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java index b2d9e0a9d..07370494a 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java @@ -41,27 +41,41 @@ * @author Sascha Woo * @author Farid Azaza * @author Peter-Josef Meisch + * @author Peer Mueller */ abstract class AbstractQuery implements Query { protected Pageable pageable = DEFAULT_PAGE; - @Nullable protected Sort sort; + @Nullable + protected Sort sort; protected List fields = new ArrayList<>(); - @Nullable protected SourceFilter sourceFilter; + @Nullable + protected SourceFilter sourceFilter; protected float minScore; - @Nullable protected Collection ids; - @Nullable protected String route; + @Nullable + protected Collection ids; + @Nullable + protected String route; protected SearchType searchType = SearchType.DFS_QUERY_THEN_FETCH; - @Nullable protected IndicesOptions indicesOptions; + @Nullable + protected IndicesOptions indicesOptions; protected boolean trackScores; - @Nullable protected String preference; - @Nullable protected Integer maxResults; - @Nullable protected HighlightQuery highlightQuery; - @Nullable private Boolean trackTotalHits; - @Nullable private Integer trackTotalHitsUpTo; - @Nullable private Duration scrollTime; - @Nullable private TimeValue timeout; + @Nullable + protected String preference; + @Nullable + protected Integer maxResults; + @Nullable + protected HighlightQuery highlightQuery; + @Nullable + private Boolean trackTotalHits; + @Nullable + private Integer trackTotalHitsUpTo; + @Nullable + private Duration scrollTime; + @Nullable + private TimeValue timeout; private boolean explain = false; + protected List rescorerQueries = new ArrayList<>(); @Override @Nullable @@ -283,4 +297,20 @@ public boolean getExplain() { public void setExplain(boolean explain) { this.explain = explain; } + + @Override + public void addRescorerQuery(RescorerQuery rescorerQuery) { + this.rescorerQueries.add(rescorerQuery); + } + + @Override + public void setRescorerQueries(List rescorerQueryList) { + this.rescorerQueries.clear(); + this.rescorerQueries.addAll(rescorerQueryList); + } + + @Override + public List getRescorerQueries() { + return rescorerQueries; + } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java index 45af3f1f5..13d96bd3b 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.stream.Collectors; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.action.support.IndicesOptions; @@ -28,6 +29,7 @@ import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.sort.SortBuilder; import org.springframework.data.domain.Pageable; import org.springframework.lang.Nullable; @@ -45,6 +47,7 @@ * @author Martin Choraine * @author Farid Azaza * @author Peter-Josef Meisch + * @author Peer Mueller */ public class NativeSearchQueryBuilder { @@ -70,6 +73,7 @@ public class NativeSearchQueryBuilder { @Nullable private Integer maxResults; @Nullable private Boolean trackTotalHits; @Nullable private TimeValue timeout; + private final List queryRescorerBuilders = new ArrayList<>(); public NativeSearchQueryBuilder withQuery(QueryBuilder queryBuilder) { this.queryBuilder = queryBuilder; @@ -183,12 +187,17 @@ public NativeSearchQueryBuilder withTrackTotalHits(Boolean trackTotalHits) { this.trackTotalHits = trackTotalHits; return this; } - - public NativeSearchQueryBuilder withTimeout(TimeValue timeout) { + + public NativeSearchQueryBuilder withTimeout(TimeValue timeout) { this.timeout = timeout; return this; } + public NativeSearchQueryBuilder withRescorerQuery(QueryRescorerBuilder queryRescorerBuilder) { + this.queryRescorerBuilders.add(queryRescorerBuilder); + return this; + } + public NativeSearchQuery build() { NativeSearchQuery nativeSearchQuery = new NativeSearchQuery(queryBuilder, filterBuilder, sortBuilders, @@ -250,11 +259,18 @@ public NativeSearchQuery build() { } nativeSearchQuery.setTrackTotalHits(trackTotalHits); - + if (timeout != null) { nativeSearchQuery.setTimeout(timeout); } + if (!isEmpty(queryRescorerBuilders)) { + nativeSearchQuery.setRescorerQueries( + queryRescorerBuilders.stream() + .map(RescorerQuery::new) + .collect(Collectors.toList())); + } + return nativeSearchQuery; } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java b/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java index 790870cce..5627db0bc 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java @@ -17,6 +17,7 @@ import java.time.Duration; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Optional; @@ -40,6 +41,7 @@ * @author Christoph Strobl * @author Farid Azaza * @author Peter-Josef Meisch + * @author Peer Mueller */ public interface Query { @@ -293,4 +295,24 @@ default boolean hasScrollTime() { default boolean getExplain() { return false; } + + /** + * Sets the {@link RescorerQuery}. + * + * @param rescorerQuery the query to add to the list of rescorer queries + * @since 4.2 + */ + void addRescorerQuery(RescorerQuery rescorerQuery); + + /** + * Sets the {@link RescorerQuery}. + * + * @param rescorerQueryList list of rescorer queries set + * @since 4.2 + */ + void setRescorerQueries(List rescorerQueryList); + + default List getRescorerQueries() { + return Collections.emptyList(); + } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java new file mode 100644 index 000000000..aad9195d0 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java @@ -0,0 +1,40 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core.query; + +import org.elasticsearch.search.rescore.QueryRescorerBuilder; + +/** + * Encapsulates an Elasticsearch {@link QueryRescorerBuilder} to prevent leaking of Elasticsearch + * classes into the query API. + * + * @author Peer Mueller + * @since 4.2 + */ +public class RescorerQuery { + + private final QueryRescorerBuilder queryRescorerBuilder; + + public RescorerQuery(QueryRescorerBuilder queryRescorerBuilder) { + this.queryRescorerBuilder = queryRescorerBuilder; + } + + public QueryRescorerBuilder getRescorerBuilder() { + return queryRescorerBuilder; + } + + +} diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index 59b8c1f25..002a6d0d1 100755 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -53,12 +53,19 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.metadata.AliasMetadata; +import org.elasticsearch.common.lucene.search.function.CombineFunction; +import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery; import org.elasticsearch.index.VersionType; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder; +import org.elasticsearch.index.query.functionscore.GaussDecayFunctionBuilder; import org.elasticsearch.join.query.ParentIdQueryBuilder; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptType; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.rescore.QueryRescoreMode; +import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; @@ -121,6 +128,7 @@ * @author Roman Puchkovskiy * @author Subhobrata Dey * @author Farid Faoudi + * @author Peer Mueller */ @SpringIntegrationTest public abstract class ElasticsearchTemplateTests { @@ -3108,7 +3116,63 @@ void shouldReturnHighlightFieldsInSearchHit() { assertThat(highlightField.get(1)).contains("message"); } - @Test // DATAES-738 + @Test + void shouldRunRescoreQueryInSearchQuery() { + IndexCoordinates index = IndexCoordinates.of("test-index-rescore-entity-template"); + + // matches main query better + SampleEntity entity = SampleEntity.builder() // + .id("1") // + .message("some message") // + .rate(java.lang.Integer.MAX_VALUE) + .version(System.currentTimeMillis()) // + .build(); + + // high score from rescore query + SampleEntity entity2 = SampleEntity.builder() // + .id("2") // + .message("nothing") // + .rate(1) + .version(System.currentTimeMillis()) // + .build(); + + List indexQueries = getIndexQueries(Arrays.asList(entity, entity2)); + + operations.bulkIndex(indexQueries, index); + indexOperations.refresh(); + + NativeSearchQuery query = new NativeSearchQueryBuilder() // + .withQuery( + boolQuery().filter(existsQuery("rate")).should(termQuery("message", "message"))) // + .withRescorerQuery(new QueryRescorerBuilder( + new FunctionScoreQueryBuilder( + new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{ + new FilterFunctionBuilder( + new GaussDecayFunctionBuilder("rate", 0, 10, null, 0.5) + .setWeight(1f)), + new FilterFunctionBuilder( + new GaussDecayFunctionBuilder("rate", 0, 10, null, 0.5) + .setWeight(100f))}) + .scoreMode(FunctionScoreQuery.ScoreMode.SUM) + .maxBoost(80f) + .boostMode(CombineFunction.REPLACE)) + .setScoreMode(QueryRescoreMode.Max) + .windowSize(100)) + .build(); + + SearchHits searchHits = operations.search(query, SampleEntity.class, index); + + assertThat(searchHits).isNotNull(); + assertThat(searchHits.getSearchHits()).hasSize(2); + + SearchHit searchHit = searchHits.getSearchHit(0); + assertThat(searchHit.getContent().getMessage()).isEqualTo("nothing"); + //score capped to 80 + assertThat(searchHit.getScore()).isEqualTo(80f); + } + + @Test + // DATAES-738 void shouldSaveEntityWithIndexCoordinates() { String id = "42"; SampleEntity entity = new SampleEntity(); From 79b2320aa085e22cbfc38eaf722029325e805161 Mon Sep 17 00:00:00 2001 From: pemueller Date: Thu, 11 Feb 2021 23:13:47 +0100 Subject: [PATCH 2/4] formatting issues --- .../core/query/AbstractQuery.java | 36 +++++++------------ 1 file changed, 12 insertions(+), 24 deletions(-) diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java index 07370494a..7997ea9ef 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java @@ -46,34 +46,22 @@ abstract class AbstractQuery implements Query { protected Pageable pageable = DEFAULT_PAGE; - @Nullable - protected Sort sort; + @Nullable protected Sort sort; protected List fields = new ArrayList<>(); - @Nullable - protected SourceFilter sourceFilter; + @Nullable protected SourceFilter sourceFilter; protected float minScore; - @Nullable - protected Collection ids; - @Nullable - protected String route; + @Nullable protected Collection ids; + @Nullable protected String route; protected SearchType searchType = SearchType.DFS_QUERY_THEN_FETCH; - @Nullable - protected IndicesOptions indicesOptions; + @Nullable protected IndicesOptions indicesOptions; protected boolean trackScores; - @Nullable - protected String preference; - @Nullable - protected Integer maxResults; - @Nullable - protected HighlightQuery highlightQuery; - @Nullable - private Boolean trackTotalHits; - @Nullable - private Integer trackTotalHitsUpTo; - @Nullable - private Duration scrollTime; - @Nullable - private TimeValue timeout; + @Nullable protected String preference; + @Nullable protected Integer maxResults; + @Nullable protected HighlightQuery highlightQuery; + @Nullable private Boolean trackTotalHits; + @Nullable private Integer trackTotalHitsUpTo; + @Nullable private Duration scrollTime; + @Nullable private TimeValue timeout; private boolean explain = false; protected List rescorerQueries = new ArrayList<>(); From e3035d3665f20a805f7e1af54b1414187233367b Mon Sep 17 00:00:00 2001 From: pemueller Date: Thu, 11 Feb 2021 23:14:53 +0100 Subject: [PATCH 3/4] formatting issues --- .../core/query/AbstractQuery.java | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java index 7997ea9ef..25fba5911 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java @@ -46,22 +46,22 @@ abstract class AbstractQuery implements Query { protected Pageable pageable = DEFAULT_PAGE; - @Nullable protected Sort sort; + @Nullable protected Sort sort; protected List fields = new ArrayList<>(); - @Nullable protected SourceFilter sourceFilter; + @Nullable protected SourceFilter sourceFilter; protected float minScore; - @Nullable protected Collection ids; - @Nullable protected String route; + @Nullable protected Collection ids; + @Nullable protected String route; protected SearchType searchType = SearchType.DFS_QUERY_THEN_FETCH; - @Nullable protected IndicesOptions indicesOptions; + @Nullable protected IndicesOptions indicesOptions; protected boolean trackScores; - @Nullable protected String preference; - @Nullable protected Integer maxResults; - @Nullable protected HighlightQuery highlightQuery; - @Nullable private Boolean trackTotalHits; - @Nullable private Integer trackTotalHitsUpTo; - @Nullable private Duration scrollTime; - @Nullable private TimeValue timeout; + @Nullable protected String preference; + @Nullable protected Integer maxResults; + @Nullable protected HighlightQuery highlightQuery; + @Nullable private Boolean trackTotalHits; + @Nullable private Integer trackTotalHitsUpTo; + @Nullable private Duration scrollTime; + @Nullable private TimeValue timeout; private boolean explain = false; protected List rescorerQueries = new ArrayList<>(); From 1dc3a16bfe4d8cfb8eb97a80f42cf5fb7be492e7 Mon Sep 17 00:00:00 2001 From: pemueller Date: Thu, 11 Mar 2021 02:32:29 +0100 Subject: [PATCH 4/4] converted RescorerQuery to a proper domain object to support spring data query implementations --- .../elasticsearch/core/RequestFactory.java | 38 ++++- .../core/query/NativeSearchQueryBuilder.java | 12 +- .../core/query/RescorerQuery.java | 68 ++++++++- .../core/ElasticsearchTemplateTests.java | 34 +++-- .../core/RequestFactoryTests.java | 144 ++++++++++++++++++ 5 files changed, 262 insertions(+), 34 deletions(-) diff --git a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java index d8dadf925..3be0ee0ef 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import org.elasticsearch.action.DocWriteRequest; @@ -37,7 +38,6 @@ import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequestBuilder; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.admin.indices.settings.get.GetSettingsRequest; -import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse; import org.elasticsearch.action.admin.indices.template.delete.DeleteIndexTemplateRequest; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkRequestBuilder; @@ -65,7 +65,6 @@ import org.elasticsearch.client.indices.PutIndexTemplateRequest; import org.elasticsearch.client.indices.PutMappingRequest; import org.elasticsearch.common.geo.GeoDistance; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.DistanceUnit; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.VersionType; @@ -83,6 +82,8 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.rescore.QueryRescoreMode; +import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.GeoDistanceSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; @@ -106,6 +107,7 @@ import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty; import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; import org.springframework.data.elasticsearch.core.query.*; +import org.springframework.data.elasticsearch.core.query.RescorerQuery.ScoreMode; import org.springframework.data.mapping.context.MappingContext; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -1051,8 +1053,8 @@ private SearchRequest prepareSearchRequest(Query query, @Nullable Class clazz sourceBuilder.searchAfter(query.getSearchAfter().toArray()); } - query.getRescorerQueries().stream().map(RescorerQuery::getRescorerBuilder) - .forEach(sourceBuilder::addRescorer); + query.getRescorerQueries().forEach(rescorer -> sourceBuilder.addRescorer( + getQueryRescorerBuilder(rescorer))); request.source(sourceBuilder); return request; @@ -1140,8 +1142,8 @@ private SearchRequestBuilder prepareSearchRequestBuilder(Query query, Client cli searchRequestBuilder.searchAfter(query.getSearchAfter().toArray()); } - query.getRescorerQueries().stream().map(RescorerQuery::getRescorerBuilder) - .forEach(searchRequestBuilder::addRescorer); + query.getRescorerQueries().forEach(rescorer -> searchRequestBuilder.addRescorer( + getQueryRescorerBuilder(rescorer))); return searchRequestBuilder; } @@ -1267,6 +1269,30 @@ private SortBuilder getSortBuilder(Sort.Order order, @Nullable ElasticsearchP } } } + + private QueryRescorerBuilder getQueryRescorerBuilder(RescorerQuery rescorerQuery) { + + QueryRescorerBuilder builder = new QueryRescorerBuilder(Objects.requireNonNull(getQuery(rescorerQuery.getQuery()))); + + if (rescorerQuery.getScoreMode() != ScoreMode.Default) { + builder.setScoreMode(QueryRescoreMode.valueOf(rescorerQuery.getScoreMode().name())); + } + + if (rescorerQuery.getQueryWeight() != null) { + builder.setQueryWeight(rescorerQuery.getQueryWeight()); + } + + if (rescorerQuery.getRescoreQueryWeight() != null) { + builder.setRescoreQueryWeight(rescorerQuery.getRescoreQueryWeight()); + } + + if (rescorerQuery.getWindowSize() != null) { + builder.windowSize(rescorerQuery.getWindowSize()); + } + + return builder; + + } // endregion // region update diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java index 13d96bd3b..7a74f41bf 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java @@ -73,7 +73,7 @@ public class NativeSearchQueryBuilder { @Nullable private Integer maxResults; @Nullable private Boolean trackTotalHits; @Nullable private TimeValue timeout; - private final List queryRescorerBuilders = new ArrayList<>(); + private final List rescorerQueries = new ArrayList<>(); public NativeSearchQueryBuilder withQuery(QueryBuilder queryBuilder) { this.queryBuilder = queryBuilder; @@ -193,8 +193,8 @@ public NativeSearchQueryBuilder withTimeout(TimeValue timeout) { return this; } - public NativeSearchQueryBuilder withRescorerQuery(QueryRescorerBuilder queryRescorerBuilder) { - this.queryRescorerBuilders.add(queryRescorerBuilder); + public NativeSearchQueryBuilder withRescorerQuery(RescorerQuery rescorerQuery) { + this.rescorerQueries.add(rescorerQuery); return this; } @@ -264,11 +264,9 @@ public NativeSearchQuery build() { nativeSearchQuery.setTimeout(timeout); } - if (!isEmpty(queryRescorerBuilders)) { + if (!isEmpty(rescorerQueries)) { nativeSearchQuery.setRescorerQueries( - queryRescorerBuilders.stream() - .map(RescorerQuery::new) - .collect(Collectors.toList())); + rescorerQueries); } return nativeSearchQuery; diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java index aad9195d0..b8be4b9a6 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java @@ -15,26 +15,80 @@ */ package org.springframework.data.elasticsearch.core.query; +import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.search.rescore.QueryRescorerBuilder; +import org.springframework.lang.Nullable; /** - * Encapsulates an Elasticsearch {@link QueryRescorerBuilder} to prevent leaking of Elasticsearch - * classes into the query API. + * Implementation of RescorerQuery to be used for rescoring filtered search results. * * @author Peer Mueller * @since 4.2 */ public class RescorerQuery { - private final QueryRescorerBuilder queryRescorerBuilder; + private final Query query; + private ScoreMode scoreMode = ScoreMode.Default; + @Nullable private Integer windowSize; + @Nullable private Float queryWeight; + @Nullable private Float rescoreQueryWeight; - public RescorerQuery(QueryRescorerBuilder queryRescorerBuilder) { - this.queryRescorerBuilder = queryRescorerBuilder; + public RescorerQuery(Query query) { + this.query = query; } - public QueryRescorerBuilder getRescorerBuilder() { - return queryRescorerBuilder; + public Query getQuery() { + return query; } + public ScoreMode getScoreMode() { + return scoreMode; + } + + @Nullable + public Integer getWindowSize() { + return windowSize; + } + + @Nullable + public Float getQueryWeight() { + return queryWeight; + } + + @Nullable + public Float getRescoreQueryWeight() { + return rescoreQueryWeight; + } + + public RescorerQuery withScoreMode(ScoreMode scoreMode) { + this.scoreMode = scoreMode; + return this; + } + + public RescorerQuery withWindowSize(int windowSize) { + this.windowSize = windowSize; + return this; + } + + public RescorerQuery withQueryWeight(float queryWeight) { + this.queryWeight = queryWeight; + return this; + } + + public RescorerQuery withRescoreQueryWeight(float rescoreQueryWeight) { + this.rescoreQueryWeight = rescoreQueryWeight; + return this; + } + + + + public enum ScoreMode { + Default, + Avg, + Max, + Min, + Total, + Multiply + } } diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index c65af7360..915cab29b 100755 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -56,6 +56,7 @@ import org.elasticsearch.common.lucene.search.function.CombineFunction; import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery; import org.elasticsearch.index.VersionType; +import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder; import org.elasticsearch.index.query.functionscore.GaussDecayFunctionBuilder; @@ -99,6 +100,7 @@ import org.springframework.data.elasticsearch.core.join.JoinField; import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; import org.springframework.data.elasticsearch.core.query.*; +import org.springframework.data.elasticsearch.core.query.RescorerQuery.ScoreMode; import org.springframework.data.elasticsearch.junit.jupiter.SpringIntegrationTest; import org.springframework.data.util.StreamUtils; import org.springframework.lang.Nullable; @@ -3158,20 +3160,24 @@ void shouldRunRescoreQueryInSearchQuery() { NativeSearchQuery query = new NativeSearchQueryBuilder() // .withQuery( boolQuery().filter(existsQuery("rate")).should(termQuery("message", "message"))) // - .withRescorerQuery(new QueryRescorerBuilder( - new FunctionScoreQueryBuilder( - new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{ - new FilterFunctionBuilder( - new GaussDecayFunctionBuilder("rate", 0, 10, null, 0.5) - .setWeight(1f)), - new FilterFunctionBuilder( - new GaussDecayFunctionBuilder("rate", 0, 10, null, 0.5) - .setWeight(100f))}) - .scoreMode(FunctionScoreQuery.ScoreMode.SUM) - .maxBoost(80f) - .boostMode(CombineFunction.REPLACE)) - .setScoreMode(QueryRescoreMode.Max) - .windowSize(100)) + .withRescorerQuery(new RescorerQuery( + new NativeSearchQueryBuilder().withQuery( + QueryBuilders.functionScoreQuery( + new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{ + new FilterFunctionBuilder( + new GaussDecayFunctionBuilder("rate", 0, 10, null, 0.5) + .setWeight(1f)), + new FilterFunctionBuilder( + new GaussDecayFunctionBuilder("rate", 0, 10, null, 0.5) + .setWeight(100f))} + ) + .scoreMode(FunctionScoreQuery.ScoreMode.SUM) + .maxBoost(80f) + .boostMode(CombineFunction.REPLACE) + ).build() + ) + .withScoreMode(ScoreMode.Max) + .withWindowSize(100)) .build(); SearchHits searchHits = operations.search(query, SampleEntity.class, index); diff --git a/src/test/java/org/springframework/data/elasticsearch/core/RequestFactoryTests.java b/src/test/java/org/springframework/data/elasticsearch/core/RequestFactoryTests.java index 40a29a362..9aaf7cd31 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/RequestFactoryTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/RequestFactoryTests.java @@ -39,10 +39,18 @@ import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.client.indices.PutIndexTemplateRequest; +import org.elasticsearch.common.lucene.search.function.CombineFunction; +import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchPhraseQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder; +import org.elasticsearch.index.query.functionscore.GaussDecayFunctionBuilder; +import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.json.JSONException; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.DisplayName; @@ -70,12 +78,15 @@ import org.springframework.data.elasticsearch.core.query.IndexQueryBuilder; import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilder; import org.springframework.data.elasticsearch.core.query.Query; +import org.springframework.data.elasticsearch.core.query.RescorerQuery; +import org.springframework.data.elasticsearch.core.query.RescorerQuery.ScoreMode; import org.springframework.data.elasticsearch.core.query.SeqNoPrimaryTerm; import org.springframework.lang.Nullable; /** * @author Peter-Josef Meisch * @author Roman Puchkovskiy + * @author Peer Mueller */ @ExtendWith(MockitoExtension.class) class RequestFactoryTests { @@ -511,6 +522,139 @@ private String requestToString(ToXContent request) throws IOException { return XContentHelper.toXContent(request, XContentType.JSON, true).utf8ToString(); } + @Test + void shouldBuildSearchWithRescorerQuery() throws JSONException { + CriteriaQuery query = new CriteriaQuery(new Criteria("lastName").is("Smith")); + RescorerQuery rescorerQuery = new RescorerQuery( new NativeSearchQueryBuilder() // + .withQuery( + QueryBuilders.functionScoreQuery(new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{ + new FilterFunctionBuilder(QueryBuilders.existsQuery("someField"), + new GaussDecayFunctionBuilder("someField", 0, 100000.0, null, 0.683) + .setWeight(5.022317f)), + new FilterFunctionBuilder(QueryBuilders.existsQuery("anotherField"), + new GaussDecayFunctionBuilder("anotherField", "202102", "31536000s", null, 0.683) + .setWeight(4.170836f))}) + .scoreMode(FunctionScoreQuery.ScoreMode.SUM) + .maxBoost(50.0f) + .boostMode(CombineFunction.AVG) + .boost(1.5f)) + .build() + ) + .withWindowSize(50) + .withQueryWeight(2.0f) + .withRescoreQueryWeight(5.0f) + .withScoreMode(ScoreMode.Multiply); + + RescorerQuery anotherRescorerQuery = new RescorerQuery(new NativeSearchQueryBuilder() // + .withQuery( + QueryBuilders.matchPhraseQuery("message", "the quick brown").slop(2)) + .build() + ) + .withWindowSize(100) + .withQueryWeight(0.7f) + .withRescoreQueryWeight(1.2f); + + query.addRescorerQuery(rescorerQuery); + query.addRescorerQuery(anotherRescorerQuery); + + converter.updateQuery(query, Person.class); + + String expected = '{' + // + " \"query\": {" + // + " \"bool\": {" + // + " \"must\": [" + // + " {" + // + " \"query_string\": {" + // + " \"query\": \"Smith\"," + // + " \"fields\": [" + // + " \"last-name^1.0\"" + // + " ]" + // + " }" + // + " }" + // + " ]" + // + " }" + // + " }," + // + " \"rescore\": [{\n" + + " \"window_size\" : 100,\n" + + " \"query\" : {\n" + + " \"rescore_query\" : {\n" + + " \"match_phrase\" : {\n" + + " \"message\" : {\n" + + " \"query\" : \"the quick brown\",\n" + + " \"slop\" : 2\n" + + " }\n" + + " }\n" + + " },\n" + + " \"query_weight\" : 0.7,\n" + + " \"rescore_query_weight\" : 1.2\n" + + " }\n" + + " }," + + " {\n" + + " \"window_size\": 50,\n" + + " \"query\": {\n" + + " \"rescore_query\": {\n" + + " \"function_score\": {\n" + + " \"query\": {\n" + + " \"match_all\": {\n" + + " \"boost\": 1.0\n" + + " }\n" + + " },\n" + + " \"functions\": [\n" + + " {\n" + + " \"filter\": {\n" + + " \"exists\": {\n" + + " \"field\": \"someField\",\n" + + " \"boost\": 1.0\n" + + " }\n" + + " },\n" + + " \"weight\": 5.022317,\n" + + " \"gauss\": {\n" + + " \"someField\": {\n" + + " \"origin\": 0.0,\n" + + " \"scale\": 100000.0,\n" + + " \"decay\": 0.683\n" + + " },\n" + + " \"multi_value_mode\": \"MIN\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"filter\": {\n" + + " \"exists\": {\n" + + " \"field\": \"anotherField\",\n" + + " \"boost\": 1.0\n" + + " }\n" + + " },\n" + + " \"weight\": 4.170836,\n" + + " \"gauss\": {\n" + + " \"anotherField\": {\n" + + " \"origin\": \"202102\",\n" + + " \"scale\": \"31536000s\",\n" + + " \"decay\": 0.683\n" + + " },\n" + + " \"multi_value_mode\": \"MIN\"\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"score_mode\": \"sum\",\n" + + " \"boost_mode\": \"avg\",\n" + + " \"max_boost\": 50.0,\n" + + " \"boost\": 1.5\n" + + " }\n" + + " },\n" + + " \"query_weight\": 2.0," + + " \"rescore_query_weight\": 5.0," + + " \"score_mode\": \"multiply\"" + + " }\n" + + " }\n" + + " ]\n" + + '}'; + + String searchRequest = requestFactory.searchRequest(query, Person.class, IndexCoordinates.of("persons")).source() + .toString(); + + assertEquals(expected, searchRequest, false); + } + @Data @Builder @NoArgsConstructor