diff --git a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java index 386784d62..20e1d4a7b 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/RequestFactory.java @@ -77,7 +77,6 @@ import org.elasticsearch.index.reindex.UpdateByQueryRequest; import org.elasticsearch.index.reindex.UpdateByQueryRequestBuilder; import org.elasticsearch.script.Script; -import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.fetch.subphase.FetchSourceContext; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; @@ -1119,9 +1118,11 @@ private void prepareNativeSearch(NativeSearchQuery query, SearchSourceBuilder so } if (!isEmpty(query.getAggregations())) { - for (AbstractAggregationBuilder aggregationBuilder : query.getAggregations()) { - sourceBuilder.aggregation(aggregationBuilder); - } + query.getAggregations().forEach(sourceBuilder::aggregation); + } + + if (!isEmpty(query.getPipelineAggregations())) { + query.getPipelineAggregations().forEach(sourceBuilder::aggregation); } } @@ -1144,9 +1145,11 @@ private void prepareNativeSearch(SearchRequestBuilder searchRequestBuilder, Nati } if (!isEmpty(nativeSearchQuery.getAggregations())) { - for (AbstractAggregationBuilder aggregationBuilder : nativeSearchQuery.getAggregations()) { - searchRequestBuilder.addAggregation(aggregationBuilder); - } + nativeSearchQuery.getAggregations().forEach(searchRequestBuilder::addAggregation); + } + + if (!isEmpty(nativeSearchQuery.getPipelineAggregations())) { + nativeSearchQuery.getPipelineAggregations().forEach(searchRequestBuilder::addAggregation); } } diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQuery.java b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQuery.java index 5201ea73b..63ec22aff 100644 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQuery.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQuery.java @@ -22,6 +22,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.script.mustache.SearchTemplateRequestBuilder; import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; import org.elasticsearch.search.sort.SortBuilder; @@ -48,6 +49,7 @@ public class NativeSearchQuery extends AbstractQuery { private final List scriptFields = new ArrayList<>(); @Nullable private CollapseBuilder collapseBuilder; @Nullable private List> aggregations; + @Nullable private List pipelineAggregations; @Nullable private HighlightBuilder highlightBuilder; @Nullable private HighlightBuilder.Field[] highlightFields; @Nullable private List indicesBoost; @@ -143,6 +145,11 @@ public List> getAggregations() { return aggregations; } + @Nullable + public List getPipelineAggregations() { + return pipelineAggregations; + } + public void addAggregation(AbstractAggregationBuilder aggregationBuilder) { if (aggregations == null) { @@ -156,6 +163,10 @@ public void setAggregations(List> aggregations) { this.aggregations = aggregations; } + public void setPipelineAggregations(List pipelineAggregationBuilders) { + this.pipelineAggregations = pipelineAggregationBuilders; + } + @Nullable public List getIndicesBoost() { return indicesBoost; diff --git a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java index 33abb43b1..93c84ca4e 100755 --- a/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java +++ b/src/main/java/org/springframework/data/elasticsearch/core/query/NativeSearchQueryBuilder.java @@ -27,6 +27,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.script.mustache.SearchTemplateRequestBuilder; import org.elasticsearch.search.aggregations.AbstractAggregationBuilder; +import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; import org.elasticsearch.search.sort.SortBuilder; @@ -55,6 +56,7 @@ public class NativeSearchQueryBuilder { private final List scriptFields = new ArrayList<>(); private final List> sortBuilders = new ArrayList<>(); private final List> aggregationBuilders = new ArrayList<>(); + private final List pipelineAggregationBuilders = new ArrayList<>(); @Nullable private HighlightBuilder highlightBuilder; @Nullable private HighlightBuilder.Field[] highlightFields; private Pageable pageable = Pageable.unpaged(); @@ -105,6 +107,14 @@ public NativeSearchQueryBuilder addAggregation(AbstractAggregationBuilder agg return this; } + /** + * @since 4.3 + */ + public NativeSearchQueryBuilder addAggregation(PipelineAggregationBuilder pipelineAggregationBuilder) { + this.pipelineAggregationBuilders.add(pipelineAggregationBuilder); + return this; + } + public NativeSearchQueryBuilder withHighlightBuilder(HighlightBuilder highlightBuilder) { this.highlightBuilder = highlightBuilder; return this; @@ -239,6 +249,10 @@ public NativeSearchQuery build() { nativeSearchQuery.setAggregations(aggregationBuilders); } + if (!isEmpty(pipelineAggregationBuilders)) { + nativeSearchQuery.setPipelineAggregations(pipelineAggregationBuilders); + } + if (minScore > 0) { nativeSearchQuery.setMinScore(minScore); } diff --git a/src/test/java/org/springframework/data/elasticsearch/core/aggregation/ElasticsearchTemplateAggregationTests.java b/src/test/java/org/springframework/data/elasticsearch/core/aggregation/ElasticsearchTemplateAggregationTests.java index 86bef45bc..e98d640dd 100644 --- a/src/test/java/org/springframework/data/elasticsearch/core/aggregation/ElasticsearchTemplateAggregationTests.java +++ b/src/test/java/org/springframework/data/elasticsearch/core/aggregation/ElasticsearchTemplateAggregationTests.java @@ -18,6 +18,7 @@ import static org.assertj.core.api.Assertions.*; import static org.elasticsearch.index.query.QueryBuilders.*; import static org.elasticsearch.search.aggregations.AggregationBuilders.*; +import static org.elasticsearch.search.aggregations.PipelineAggregatorBuilders.*; import static org.springframework.data.elasticsearch.annotations.FieldType.*; import static org.springframework.data.elasticsearch.annotations.FieldType.Integer; @@ -26,9 +27,14 @@ import java.util.List; import org.elasticsearch.action.search.SearchType; +import org.elasticsearch.search.aggregations.Aggregation; import org.elasticsearch.search.aggregations.Aggregations; +import org.elasticsearch.search.aggregations.pipeline.InternalStatsBucket; +import org.elasticsearch.search.aggregations.pipeline.ParsedStatsBucket; +import org.elasticsearch.search.aggregations.pipeline.StatsBucket; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Configuration; @@ -109,7 +115,7 @@ public void after() { indexOperations.delete(); } - @Test + @Test // DATAES-96 public void shouldReturnAggregatedResponseForGivenSearchQuery() { // given @@ -130,6 +136,56 @@ public void shouldReturnAggregatedResponseForGivenSearchQuery() { assertThat(searchHits.hasSearchHits()).isFalse(); } + @Test // #1255 + @DisplayName("should work with pipeline aggregations") + void shouldWorkWithPipelineAggregations() { + + IndexInitializer.init(operations.indexOps(PipelineAggsEntity.class)); + operations.save( // + new PipelineAggsEntity("1-1", "one"), // + new PipelineAggsEntity("2-1", "two"), // + new PipelineAggsEntity("2-2", "two"), // + new PipelineAggsEntity("3-1", "three"), // + new PipelineAggsEntity("3-2", "three"), // + new PipelineAggsEntity("3-3", "three") // + ); // + + NativeSearchQuery searchQuery = new NativeSearchQueryBuilder() // + .withQuery(matchAllQuery()) // + .withSearchType(SearchType.DEFAULT) // + .addAggregation(terms("keyword_aggs").field("keyword")) // + .addAggregation(statsBucket("keyword_bucket_stats", "keyword_aggs._count")) // + .withMaxResults(0) // + .build(); + + SearchHits searchHits = operations.search(searchQuery, PipelineAggsEntity.class); + + Aggregations aggregations = searchHits.getAggregations(); + assertThat(aggregations).isNotNull(); + assertThat(aggregations.asMap().get("keyword_aggs")).isNotNull(); + Aggregation keyword_bucket_stats = aggregations.asMap().get("keyword_bucket_stats"); + assertThat(keyword_bucket_stats).isInstanceOf(StatsBucket.class); + if (keyword_bucket_stats instanceof ParsedStatsBucket) { + // Rest client + ParsedStatsBucket statsBucket = (ParsedStatsBucket) keyword_bucket_stats; + assertThat(statsBucket.getMin()).isEqualTo(1.0); + assertThat(statsBucket.getMax()).isEqualTo(3.0); + assertThat(statsBucket.getAvg()).isEqualTo(2.0); + assertThat(statsBucket.getSum()).isEqualTo(6.0); + assertThat(statsBucket.getCount()).isEqualTo(3L); + } + if (keyword_bucket_stats instanceof InternalStatsBucket) { + // transport client + InternalStatsBucket statsBucket = (InternalStatsBucket) keyword_bucket_stats; + assertThat(statsBucket.getMin()).isEqualTo(1.0); + assertThat(statsBucket.getMax()).isEqualTo(3.0); + assertThat(statsBucket.getAvg()).isEqualTo(2.0); + assertThat(statsBucket.getSum()).isEqualTo(6.0); + assertThat(statsBucket.getCount()).isEqualTo(3L); + } + } + + // region entities @Document(indexName = "test-index-articles-core-aggregation") static class ArticleEntity { @@ -256,4 +312,34 @@ public IndexQuery buildIndex() { } } + @Document(indexName = "pipeline-aggs") + static class PipelineAggsEntity { + @Id private String id; + @Field(type = Keyword) private String keyword; + + public PipelineAggsEntity() {} + + public PipelineAggsEntity(String id, String keyword) { + this.id = id; + this.keyword = keyword; + } + + public String getId() { + return id; + } + + public void setId(String id) { + this.id = id; + } + + public String getKeyword() { + return keyword; + } + + public void setKeyword(String keyword) { + this.keyword = keyword; + } + } + // endregion + }