diff --git a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java index 822b6cff7..3be0ee0ef 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java @@ -24,6 +24,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import org.elasticsearch.action.DocWriteRequest; @@ -37,7 +38,6 @@ import org.elasticsearch.action.admin.indices.mapping.put.PutMappingRequestBuilder; import org.elasticsearch.action.admin.indices.refresh.RefreshRequest; import org.elasticsearch.action.admin.indices.settings.get.GetSettingsRequest; -import org.elasticsearch.action.admin.indices.settings.get.GetSettingsResponse; import org.elasticsearch.action.admin.indices.template.delete.DeleteIndexTemplateRequest; import org.elasticsearch.action.bulk.BulkRequest; import org.elasticsearch.action.bulk.BulkRequestBuilder; @@ -65,7 +65,6 @@ import org.elasticsearch.client.indices.PutIndexTemplateRequest; import org.elasticsearch.client.indices.PutMappingRequest; import org.elasticsearch.common.geo.GeoDistance; -import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.DistanceUnit; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.index.VersionType; @@ -83,6 +82,8 @@ import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.rescore.QueryRescoreMode; +import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.GeoDistanceSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; @@ -106,6 +107,7 @@ import org.springframework.data.elasticsearch.core.mapping.ElasticsearchPersistentProperty; import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; import org.springframework.data.elasticsearch.core.query.*; +import org.springframework.data.elasticsearch.core.query.RescorerQuery.ScoreMode; import org.springframework.data.mapping.context.MappingContext; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -119,6 +121,7 @@ * @author Roman Puchkovskiy * @author Subhobrata Dey * @author Farid Faoudi + * @author Peer Mueller * @since 4.0 */ class RequestFactory { @@ -1050,6 +1053,9 @@ private SearchRequest prepareSearchRequest(Query query, @Nullable Class clazz sourceBuilder.searchAfter(query.getSearchAfter().toArray()); } + query.getRescorerQueries().forEach(rescorer -> sourceBuilder.addRescorer( + getQueryRescorerBuilder(rescorer))); + request.source(sourceBuilder); return request; } @@ -1136,6 +1142,9 @@ private SearchRequestBuilder prepareSearchRequestBuilder(Query query, Client cli searchRequestBuilder.searchAfter(query.getSearchAfter().toArray()); } + query.getRescorerQueries().forEach(rescorer -> searchRequestBuilder.addRescorer( + getQueryRescorerBuilder(rescorer))); + return searchRequestBuilder; } @@ -1260,6 +1269,30 @@ private SortBuilder getSortBuilder(Sort.Order order, @Nullable ElasticsearchP } } } + + private QueryRescorerBuilder getQueryRescorerBuilder(RescorerQuery rescorerQuery) { + + QueryRescorerBuilder builder = new QueryRescorerBuilder(Objects.requireNonNull(getQuery(rescorerQuery.getQuery()))); + + if (rescorerQuery.getScoreMode() != ScoreMode.Default) { + builder.setScoreMode(QueryRescoreMode.valueOf(rescorerQuery.getScoreMode().name())); + } + + if (rescorerQuery.getQueryWeight() != null) { + builder.setQueryWeight(rescorerQuery.getQueryWeight()); + } + + if (rescorerQuery.getRescoreQueryWeight() != null) { + builder.setRescoreQueryWeight(rescorerQuery.getRescoreQueryWeight()); + } + + if (rescorerQuery.getWindowSize() != null) { + builder.windowSize(rescorerQuery.getWindowSize()); + } + + return builder; + + } // endregion // region update diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java index a00f2dcee..252638ef5 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/AbstractQuery.java @@ -41,6 +41,7 @@ * @author Sascha Woo * @author Farid Azaza * @author Peter-Josef Meisch + * @author Peer Mueller */ abstract class AbstractQuery implements Query { @@ -63,6 +64,7 @@ abstract class AbstractQuery implements Query { @Nullable private TimeValue timeout; private boolean explain = false; @Nullable private List searchAfter; + protected List rescorerQueries = new ArrayList<>(); @Override @Nullable @@ -295,4 +297,20 @@ public void setSearchAfter(@Nullable List searchAfter) { public List getSearchAfter() { return searchAfter; } + + @Override + public void addRescorerQuery(RescorerQuery rescorerQuery) { + this.rescorerQueries.add(rescorerQuery); + } + + @Override + public void setRescorerQueries(List rescorerQueryList) { + this.rescorerQueries.clear(); + this.rescorerQueries.addAll(rescorerQueryList); + } + + @Override + public List getRescorerQueries() { + return rescorerQueries; + } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java index 45af3f1f5..7a74f41bf 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java @@ -20,6 +20,7 @@ import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.stream.Collectors; import org.elasticsearch.action.search.SearchType; import org.elasticsearch.action.support.IndicesOptions; @@ -28,6 +29,7 @@ import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.sort.SortBuilder; import org.springframework.data.domain.Pageable; import org.springframework.lang.Nullable; @@ -45,6 +47,7 @@ * @author Martin Choraine * @author Farid Azaza * @author Peter-Josef Meisch + * @author Peer Mueller */ public class NativeSearchQueryBuilder { @@ -70,6 +73,7 @@ public class NativeSearchQueryBuilder { @Nullable private Integer maxResults; @Nullable private Boolean trackTotalHits; @Nullable private TimeValue timeout; + private final List rescorerQueries = new ArrayList<>(); public NativeSearchQueryBuilder withQuery(QueryBuilder queryBuilder) { this.queryBuilder = queryBuilder; @@ -183,12 +187,17 @@ public NativeSearchQueryBuilder withTrackTotalHits(Boolean trackTotalHits) { this.trackTotalHits = trackTotalHits; return this; } - - public NativeSearchQueryBuilder withTimeout(TimeValue timeout) { + + public NativeSearchQueryBuilder withTimeout(TimeValue timeout) { this.timeout = timeout; return this; } + public NativeSearchQueryBuilder withRescorerQuery(RescorerQuery rescorerQuery) { + this.rescorerQueries.add(rescorerQuery); + return this; + } + public NativeSearchQuery build() { NativeSearchQuery nativeSearchQuery = new NativeSearchQuery(queryBuilder, filterBuilder, sortBuilders, @@ -250,11 +259,16 @@ public NativeSearchQuery build() { } nativeSearchQuery.setTrackTotalHits(trackTotalHits); - + if (timeout != null) { nativeSearchQuery.setTimeout(timeout); } + if (!isEmpty(rescorerQueries)) { + nativeSearchQuery.setRescorerQueries( + rescorerQueries); + } + return nativeSearchQuery; } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java b/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java index 4e93d0725..ee926d5dd 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/Query.java @@ -17,6 +17,7 @@ import java.time.Duration; import java.util.Collection; +import java.util.Collections; import java.util.List; import java.util.Optional; @@ -41,6 +42,7 @@ * @author Christoph Strobl * @author Farid Azaza * @author Peter-Josef Meisch + * @author Peer Mueller */ public interface Query { @@ -297,7 +299,7 @@ default boolean getExplain() { /** * Sets the setSearchAfter objects for this query. - * + * * @param searchAfter the setSearchAfter objects. These are obtained with {@link SearchHit#getSortValues()} from a * search result. * @since 4.2 @@ -310,4 +312,24 @@ default boolean getExplain() { */ @Nullable List getSearchAfter(); + + /** + * Sets the {@link RescorerQuery}. + * + * @param rescorerQuery the query to add to the list of rescorer queries + * @since 4.2 + */ + void addRescorerQuery(RescorerQuery rescorerQuery); + + /** + * Sets the {@link RescorerQuery}. + * + * @param rescorerQueryList list of rescorer queries set + * @since 4.2 + */ + void setRescorerQueries(List rescorerQueryList); + + default List getRescorerQueries() { + return Collections.emptyList(); + } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java new file mode 100644 index 000000000..b8be4b9a6 --- /dev/null +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/RescorerQuery.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020-2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.elasticsearch.core.query; + +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.search.rescore.QueryRescorerBuilder; +import org.springframework.lang.Nullable; + +/** + * Implementation of RescorerQuery to be used for rescoring filtered search results. + * + * @author Peer Mueller + * @since 4.2 + */ +public class RescorerQuery { + + private final Query query; + private ScoreMode scoreMode = ScoreMode.Default; + @Nullable private Integer windowSize; + @Nullable private Float queryWeight; + @Nullable private Float rescoreQueryWeight; + + public RescorerQuery(Query query) { + this.query = query; + } + + public Query getQuery() { + return query; + } + + public ScoreMode getScoreMode() { + return scoreMode; + } + + @Nullable + public Integer getWindowSize() { + return windowSize; + } + + @Nullable + public Float getQueryWeight() { + return queryWeight; + } + + @Nullable + public Float getRescoreQueryWeight() { + return rescoreQueryWeight; + } + + public RescorerQuery withScoreMode(ScoreMode scoreMode) { + this.scoreMode = scoreMode; + return this; + } + + public RescorerQuery withWindowSize(int windowSize) { + this.windowSize = windowSize; + return this; + } + + public RescorerQuery withQueryWeight(float queryWeight) { + this.queryWeight = queryWeight; + return this; + } + + public RescorerQuery withRescoreQueryWeight(float rescoreQueryWeight) { + this.rescoreQueryWeight = rescoreQueryWeight; + return this; + } + + + + public enum ScoreMode { + Default, + Avg, + Max, + Min, + Total, + Multiply + } + +} diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java index a83142205..915cab29b 100755 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchTemplateTests.java @@ -53,12 +53,20 @@ import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.action.update.UpdateRequest; import org.elasticsearch.cluster.metadata.AliasMetadata; +import org.elasticsearch.common.lucene.search.function.CombineFunction; +import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery; import org.elasticsearch.index.VersionType; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder; +import org.elasticsearch.index.query.functionscore.GaussDecayFunctionBuilder; import org.elasticsearch.join.query.ParentIdQueryBuilder; import org.elasticsearch.script.Script; import org.elasticsearch.script.ScriptType; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.rescore.QueryRescoreMode; +import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; @@ -92,6 +100,7 @@ import org.springframework.data.elasticsearch.core.join.JoinField; import org.springframework.data.elasticsearch.core.mapping.IndexCoordinates; import org.springframework.data.elasticsearch.core.query.*; +import org.springframework.data.elasticsearch.core.query.RescorerQuery.ScoreMode; import org.springframework.data.elasticsearch.junit.jupiter.SpringIntegrationTest; import org.springframework.data.util.StreamUtils; import org.springframework.lang.Nullable; @@ -121,6 +130,7 @@ * @author Roman Puchkovskiy * @author Subhobrata Dey * @author Farid Faoudi + * @author Peer Mueller */ @SpringIntegrationTest public abstract class ElasticsearchTemplateTests { @@ -3122,7 +3132,67 @@ void shouldReturnHighlightFieldsInSearchHit() { assertThat(highlightField.get(1)).contains("message"); } - @Test // DATAES-738 + @Test + void shouldRunRescoreQueryInSearchQuery() { + IndexCoordinates index = IndexCoordinates.of("test-index-rescore-entity-template"); + + // matches main query better + SampleEntity entity = SampleEntity.builder() // + .id("1") // + .message("some message") // + .rate(java.lang.Integer.MAX_VALUE) + .version(System.currentTimeMillis()) // + .build(); + + // high score from rescore query + SampleEntity entity2 = SampleEntity.builder() // + .id("2") // + .message("nothing") // + .rate(1) + .version(System.currentTimeMillis()) // + .build(); + + List indexQueries = getIndexQueries(Arrays.asList(entity, entity2)); + + operations.bulkIndex(indexQueries, index); + indexOperations.refresh(); + + NativeSearchQuery query = new NativeSearchQueryBuilder() // + .withQuery( + boolQuery().filter(existsQuery("rate")).should(termQuery("message", "message"))) // + .withRescorerQuery(new RescorerQuery( + new NativeSearchQueryBuilder().withQuery( + QueryBuilders.functionScoreQuery( + new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{ + new FilterFunctionBuilder( + new GaussDecayFunctionBuilder("rate", 0, 10, null, 0.5) + .setWeight(1f)), + new FilterFunctionBuilder( + new GaussDecayFunctionBuilder("rate", 0, 10, null, 0.5) + .setWeight(100f))} + ) + .scoreMode(FunctionScoreQuery.ScoreMode.SUM) + .maxBoost(80f) + .boostMode(CombineFunction.REPLACE) + ).build() + ) + .withScoreMode(ScoreMode.Max) + .withWindowSize(100)) + .build(); + + SearchHits searchHits = operations.search(query, SampleEntity.class, index); + + assertThat(searchHits).isNotNull(); + assertThat(searchHits.getSearchHits()).hasSize(2); + + SearchHit searchHit = searchHits.getSearchHit(0); + assertThat(searchHit.getContent().getMessage()).isEqualTo("nothing"); + //score capped to 80 + assertThat(searchHit.getScore()).isEqualTo(80f); + } + + @Test + // DATAES-738 void shouldSaveEntityWithIndexCoordinates() { String id = "42"; SampleEntity entity = new SampleEntity(); diff --git a/src/test/java/org/springframework/data/elasticsearch/core/RequestFactoryTests.java b/src/test/java/org/springframework/data/elasticsearch/core/RequestFactoryTests.java index 40a29a362..9aaf7cd31 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/RequestFactoryTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/RequestFactoryTests.java @@ -39,10 +39,18 @@ import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.client.Client; import org.elasticsearch.client.indices.PutIndexTemplateRequest; +import org.elasticsearch.common.lucene.search.function.CombineFunction; +import org.elasticsearch.common.lucene.search.function.FunctionScoreQuery; import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.ToXContent; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.common.xcontent.XContentType; +import org.elasticsearch.index.query.MatchPhraseQueryBuilder; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder; +import org.elasticsearch.index.query.functionscore.FunctionScoreQueryBuilder.FilterFunctionBuilder; +import org.elasticsearch.index.query.functionscore.GaussDecayFunctionBuilder; +import org.elasticsearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.json.JSONException; import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.DisplayName; @@ -70,12 +78,15 @@ import org.springframework.data.elasticsearch.core.query.IndexQueryBuilder; import org.springframework.data.elasticsearch.core.query.NativeSearchQueryBuilder; import org.springframework.data.elasticsearch.core.query.Query; +import org.springframework.data.elasticsearch.core.query.RescorerQuery; +import org.springframework.data.elasticsearch.core.query.RescorerQuery.ScoreMode; import org.springframework.data.elasticsearch.core.query.SeqNoPrimaryTerm; import org.springframework.lang.Nullable; /** * @author Peter-Josef Meisch * @author Roman Puchkovskiy + * @author Peer Mueller */ @ExtendWith(MockitoExtension.class) class RequestFactoryTests { @@ -511,6 +522,139 @@ private String requestToString(ToXContent request) throws IOException { return XContentHelper.toXContent(request, XContentType.JSON, true).utf8ToString(); } + @Test + void shouldBuildSearchWithRescorerQuery() throws JSONException { + CriteriaQuery query = new CriteriaQuery(new Criteria("lastName").is("Smith")); + RescorerQuery rescorerQuery = new RescorerQuery( new NativeSearchQueryBuilder() // + .withQuery( + QueryBuilders.functionScoreQuery(new FunctionScoreQueryBuilder.FilterFunctionBuilder[]{ + new FilterFunctionBuilder(QueryBuilders.existsQuery("someField"), + new GaussDecayFunctionBuilder("someField", 0, 100000.0, null, 0.683) + .setWeight(5.022317f)), + new FilterFunctionBuilder(QueryBuilders.existsQuery("anotherField"), + new GaussDecayFunctionBuilder("anotherField", "202102", "31536000s", null, 0.683) + .setWeight(4.170836f))}) + .scoreMode(FunctionScoreQuery.ScoreMode.SUM) + .maxBoost(50.0f) + .boostMode(CombineFunction.AVG) + .boost(1.5f)) + .build() + ) + .withWindowSize(50) + .withQueryWeight(2.0f) + .withRescoreQueryWeight(5.0f) + .withScoreMode(ScoreMode.Multiply); + + RescorerQuery anotherRescorerQuery = new RescorerQuery(new NativeSearchQueryBuilder() // + .withQuery( + QueryBuilders.matchPhraseQuery("message", "the quick brown").slop(2)) + .build() + ) + .withWindowSize(100) + .withQueryWeight(0.7f) + .withRescoreQueryWeight(1.2f); + + query.addRescorerQuery(rescorerQuery); + query.addRescorerQuery(anotherRescorerQuery); + + converter.updateQuery(query, Person.class); + + String expected = '{' + // + " \"query\": {" + // + " \"bool\": {" + // + " \"must\": [" + // + " {" + // + " \"query_string\": {" + // + " \"query\": \"Smith\"," + // + " \"fields\": [" + // + " \"last-name^1.0\"" + // + " ]" + // + " }" + // + " }" + // + " ]" + // + " }" + // + " }," + // + " \"rescore\": [{\n" + + " \"window_size\" : 100,\n" + + " \"query\" : {\n" + + " \"rescore_query\" : {\n" + + " \"match_phrase\" : {\n" + + " \"message\" : {\n" + + " \"query\" : \"the quick brown\",\n" + + " \"slop\" : 2\n" + + " }\n" + + " }\n" + + " },\n" + + " \"query_weight\" : 0.7,\n" + + " \"rescore_query_weight\" : 1.2\n" + + " }\n" + + " }," + + " {\n" + + " \"window_size\": 50,\n" + + " \"query\": {\n" + + " \"rescore_query\": {\n" + + " \"function_score\": {\n" + + " \"query\": {\n" + + " \"match_all\": {\n" + + " \"boost\": 1.0\n" + + " }\n" + + " },\n" + + " \"functions\": [\n" + + " {\n" + + " \"filter\": {\n" + + " \"exists\": {\n" + + " \"field\": \"someField\",\n" + + " \"boost\": 1.0\n" + + " }\n" + + " },\n" + + " \"weight\": 5.022317,\n" + + " \"gauss\": {\n" + + " \"someField\": {\n" + + " \"origin\": 0.0,\n" + + " \"scale\": 100000.0,\n" + + " \"decay\": 0.683\n" + + " },\n" + + " \"multi_value_mode\": \"MIN\"\n" + + " }\n" + + " },\n" + + " {\n" + + " \"filter\": {\n" + + " \"exists\": {\n" + + " \"field\": \"anotherField\",\n" + + " \"boost\": 1.0\n" + + " }\n" + + " },\n" + + " \"weight\": 4.170836,\n" + + " \"gauss\": {\n" + + " \"anotherField\": {\n" + + " \"origin\": \"202102\",\n" + + " \"scale\": \"31536000s\",\n" + + " \"decay\": 0.683\n" + + " },\n" + + " \"multi_value_mode\": \"MIN\"\n" + + " }\n" + + " }\n" + + " ],\n" + + " \"score_mode\": \"sum\",\n" + + " \"boost_mode\": \"avg\",\n" + + " \"max_boost\": 50.0,\n" + + " \"boost\": 1.5\n" + + " }\n" + + " },\n" + + " \"query_weight\": 2.0," + + " \"rescore_query_weight\": 5.0," + + " \"score_mode\": \"multiply\"" + + " }\n" + + " }\n" + + " ]\n" + + '}'; + + String searchRequest = requestFactory.searchRequest(query, Person.class, IndexCoordinates.of("persons")).source() + .toString(); + + assertEquals(expected, searchRequest, false); + } + @Data @Builder @NoArgsConstructor