diff --git a/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQuery.java b/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQuery.java index e5d4a3296..c690d6d42 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQuery.java @@ -46,7 +46,6 @@ public class NativeQuery extends BaseQuery { @Nullable private Suggester suggester; @Nullable private FieldCollapse fieldCollapse; private List scriptedFields = Collections.emptyList(); - private List rescorerQueries = Collections.emptyList(); public NativeQuery(NativeQueryBuilder builder) { super(builder); @@ -56,7 +55,6 @@ public NativeQuery(NativeQueryBuilder builder) { this.suggester = builder.getSuggester(); this.fieldCollapse = builder.getFieldCollapse(); this.scriptedFields = builder.getScriptedFields(); - this.rescorerQueries = builder.getRescorerQueries(); } public NativeQuery(@Nullable Query query) { @@ -94,9 +92,4 @@ public FieldCollapse getFieldCollapse() { public List getScriptedFields() { return scriptedFields; } - - @Override - public List getRescorerQueries() { - return rescorerQueries; - } } diff --git a/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQueryBuilder.java b/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQueryBuilder.java index de207c740..fa0173568 100644 --- a/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQueryBuilder.java +++ b/src/main/java/org/springframework/data/elasticsearch/client/elc/NativeQueryBuilder.java @@ -45,8 +45,6 @@ public class NativeQueryBuilder extends BaseQueryBuilder scriptedFields = new ArrayList<>(); - private List rescorerQueries = new ArrayList<>(); - public NativeQueryBuilder() {} @Nullable @@ -77,10 +75,6 @@ public List getScriptedFields() { return scriptedFields; } - public List getRescorerQueries() { - return rescorerQueries; - } - public NativeQueryBuilder withQuery(Query query) { Assert.notNull(query, "query must not be null"); @@ -135,14 +129,6 @@ public NativeQueryBuilder withScriptedField(ScriptedField scriptedField) { return this; } - public NativeQueryBuilder withResorerQuery(RescorerQuery resorerQuery) { - - Assert.notNull(resorerQuery, "resorerQuery must not be null"); - - this.rescorerQueries.add(resorerQuery); - return this; - } - public NativeQuery build() { return new NativeQuery(this); } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/BaseQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/BaseQuery.java index 0a34a3197..009a79a4e 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/BaseQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/BaseQuery.java @@ -47,8 +47,8 @@ */ public class BaseQuery implements Query { - protected Pageable pageable = DEFAULT_PAGE; @Nullable protected Sort sort; + protected Pageable pageable = DEFAULT_PAGE; protected List fields = new ArrayList<>(); @Nullable protected List storedFields; @Nullable protected SourceFilter sourceFilter; @@ -67,11 +67,11 @@ public class BaseQuery implements Query { @Nullable protected Duration timeout; private boolean explain = false; @Nullable protected List searchAfter; + @Nullable protected List indicesBoost; protected List rescorerQueries = new ArrayList<>(); @Nullable protected Boolean requestCache; protected List idsWithRouting = Collections.emptyList(); protected final List runtimeFields = new ArrayList<>(); - @Nullable protected List indicesBoost; public BaseQuery() {} @@ -79,17 +79,28 @@ public > BaseQuery(BaseQue this.sort = builder.getSort(); // do a setPageable after setting the sort, because the pageable may contain an additional sort this.setPageable(builder.getPageable() != null ? builder.getPageable() : DEFAULT_PAGE); - this.ids = builder.getIds(); - this.trackScores = builder.getTrackScores(); - this.maxResults = builder.getMaxResults(); - this.indicesOptions = builder.getIndicesOptions(); - this.minScore = builder.getMinScore(); - this.preference = builder.getPreference(); - this.sourceFilter = builder.getSourceFilter(); this.fields = builder.getFields(); - this.highlightQuery = builder.highlightQuery; + this.storedFields = builder.getStoredFields(); + this.sourceFilter = builder.getSourceFilter(); + this.minScore = builder.getMinScore(); + this.ids = builder.getIds().isEmpty() ? null : builder.getIds(); this.route = builder.getRoute(); + this.searchType = builder.getSearchType(); + this.indicesOptions = builder.getIndicesOptions(); + this.trackScores = builder.getTrackScores(); + this.preference = builder.getPreference(); + this.maxResults = builder.getMaxResults(); + this.highlightQuery = builder.getHighlightQuery(); + this.trackTotalHits = builder.getTrackTotalHits(); + this.trackTotalHitsUpTo = builder.getTrackTotalHitsUpTo(); + this.scrollTime = builder.getScrollTime(); + this.timeout = builder.getTimeout(); + this.explain = builder.getExplain(); + this.searchAfter = builder.getSearchAfter(); this.indicesBoost = builder.getIndicesBoost(); + this.rescorerQueries = builder.getRescorerQueries(); + this.requestCache = builder.getRequestCache(); + this.idsWithRouting = builder.getIdsWithRouting(); } @Override diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/BaseQueryBuilder.java b/src/main/java/org/springframework/data/elasticsearch/core/query/BaseQueryBuilder.java index 29799064e..5a351e99c 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/BaseQueryBuilder.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/BaseQueryBuilder.java @@ -15,6 +15,7 @@ */ package org.springframework.data.elasticsearch.core.query; +import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -23,7 +24,9 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; +import org.springframework.data.elasticsearch.core.RuntimeField; import org.springframework.lang.Nullable; +import org.springframework.util.Assert; /** * base class for query builders. The different implementations of {@link Query} should derive from this class and then @@ -34,28 +37,51 @@ */ public abstract class BaseQueryBuilder> { - @Nullable private Pageable pageable; @Nullable private Sort sort; - @Nullable private Integer maxResults; - @Nullable private Collection ids; - private boolean trackScores; - @Nullable protected IndicesOptions indicesOptions; + @Nullable private Pageable pageable; + private final List fields = new ArrayList<>(); + @Nullable private List storedFields; + @Nullable private SourceFilter sourceFilter; private float minScore; + private final Collection ids = new ArrayList<>(); + @Nullable private String route; + protected Query.SearchType searchType = Query.SearchType.QUERY_THEN_FETCH; + @Nullable protected IndicesOptions indicesOptions; + private boolean trackScores; @Nullable private String preference; - @Nullable private SourceFilter sourceFilter; - private List fields = new ArrayList<>(); + @Nullable private Integer maxResults; @Nullable protected HighlightQuery highlightQuery; - @Nullable private String route; + @Nullable private Boolean trackTotalHits; + @Nullable protected Integer trackTotalHitsUpTo; + @Nullable protected Duration scrollTime; + @Nullable protected Duration timeout; + boolean explain = false; + @Nullable protected List searchAfter; + @Nullable private List indicesBoost; + protected final List rescorerQueries = new ArrayList<>(); + + @Nullable protected Boolean requestCache; + protected final List idsWithRouting = new ArrayList<>(); + protected final List runtimeFields = new ArrayList<>(); + + @Nullable + public Sort getSort() { + return sort; + } @Nullable public Pageable getPageable() { return pageable; } + public List getFields() { + return fields; + } + @Nullable - public Sort getSort() { - return sort; + public List getStoredFields() { + return storedFields; } @Nullable @@ -91,10 +117,6 @@ public SourceFilter getSourceFilter() { return sourceFilter; } - public List getFields() { - return fields; - } - @Nullable public HighlightQuery getHighlightQuery() { return highlightQuery; @@ -110,6 +132,56 @@ public List getIndicesBoost() { return indicesBoost; } + public Query.SearchType getSearchType() { + return searchType; + } + + @Nullable + public Boolean getTrackTotalHits() { + return trackTotalHits; + } + + @Nullable + public Integer getTrackTotalHitsUpTo() { + return trackTotalHitsUpTo; + } + + @Nullable + public Duration getScrollTime() { + return scrollTime; + } + + @Nullable + public Duration getTimeout() { + return timeout; + } + + public boolean getExplain() { + return explain; + } + + @Nullable + public List getSearchAfter() { + return searchAfter; + } + + @Nullable + public Boolean getRequestCache() { + return requestCache; + } + + public List getIdsWithRouting() { + return idsWithRouting; + } + + public List getRuntimeFields() { + return runtimeFields; + } + + public List getRescorerQueries() { + return rescorerQueries; + } + public SELF withPageable(Pageable pageable) { this.pageable = pageable; return self(); @@ -130,12 +202,18 @@ public SELF withMaxResults(Integer maxResults) { } public SELF withIds(String... ids) { - this.ids = Arrays.asList(ids); + + this.ids.clear(); + this.ids.addAll(Arrays.asList(ids)); return self(); } public SELF withIds(Collection ids) { - this.ids = ids; + + Assert.notNull(ids, "ids must not be null"); + + this.ids.clear(); + this.ids.addAll(ids); return self(); } @@ -165,11 +243,17 @@ public SELF withSourceFilter(SourceFilter sourceFilter) { } public SELF withFields(String... fields) { + + this.fields.clear(); Collections.addAll(this.fields, fields); return self(); } public SELF withFields(Collection fields) { + + Assert.notNull(fields, "fields must not be null"); + + this.fields.clear(); this.fields.addAll(fields); return self(); } @@ -184,16 +268,96 @@ public SELF withRoute(String route) { return self(); } - public SELF withIndicesBoost(List indicesBoost) { + public SELF withIndicesBoost(@Nullable List indicesBoost) { this.indicesBoost = indicesBoost; return self(); } + public SELF withStoredFields(@Nullable List storedFields) { + this.storedFields = storedFields; + return self(); + } + public SELF withIndicesBoost(IndexBoost... indicesBoost) { this.indicesBoost = Arrays.asList(indicesBoost); return self(); } + public SELF withSearchType(Query.SearchType searchType) { + this.searchType = searchType; + return self(); + } + + public SELF withTrackTotalHits(@Nullable Boolean trackTotalHits) { + this.trackTotalHits = trackTotalHits; + return self(); + } + + public SELF withTrackTotalHitsUpTo(@Nullable Integer trackTotalHitsUpTo) { + this.trackTotalHitsUpTo = trackTotalHitsUpTo; + return self(); + } + + public SELF withTimeout(@Nullable Duration timeout) { + this.timeout = timeout; + return self(); + } + + public SELF withScrollTime(@Nullable Duration scrollTime) { + this.scrollTime = scrollTime; + return self(); + } + + public SELF withExplain(boolean explain) { + this.explain = explain; + return self(); + } + + public SELF withSearchAfter(@Nullable List searchAfter) { + this.searchAfter = searchAfter; + return self(); + } + + public SELF withRequestCache(@Nullable Boolean requestCache) { + this.requestCache = requestCache; + return self(); + } + + public SELF withIdsWithRouting(List idsWithRouting) { + + Assert.notNull(idsWithRouting, "idsWithRouting must not be null"); + + this.idsWithRouting.clear(); + this.idsWithRouting.addAll(idsWithRouting); + return self(); + } + + public SELF withRuntimeFields(List runtimeFields) { + + Assert.notNull(runtimeFields, "runtimeFields must not be null"); + + this.runtimeFields.clear(); + this.runtimeFields.addAll(runtimeFields); + return self(); + } + + public SELF withRescorerQueries(List rescorerQueries) { + + Assert.notNull(rescorerQueries, "rescorerQueries must not be null"); + + this.rescorerQueries.clear(); + this.rescorerQueries.addAll(rescorerQueries); + return self(); + } + + public SELF withRescorerQuery(RescorerQuery rescorerQuery) { + + Assert.notNull(rescorerQuery, "rescorerQuery must not be null"); + + this.rescorerQueries.add(rescorerQuery); + return self(); + } + public abstract Q build(); private SELF self() { diff --git a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchELCIntegrationTests.java b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchELCIntegrationTests.java index 89253b901..a57522ee3 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchELCIntegrationTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/ElasticsearchELCIntegrationTests.java @@ -169,7 +169,7 @@ protected Query getQueryWithRescorer() { .filter(f -> f.exists(e -> e.field("rate"))) // .should(s -> s.term(t -> t.field("message").value("message"))) // )) // - .withResorerQuery( // + .withRescorerQuery( // new RescorerQuery(NativeQuery.builder() // .withQuery(q -> q // .functionScore(fs -> fs //