diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java index b3cef1f6d9..a5a89cf9ce 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java @@ -145,6 +145,28 @@ static void appendLimitAndOffsetIfPresent(List aggregation aggregationPipeline.add(Aggregation.limit(pageable.getPageSize())); } + + /** + * Append {@code $skip} and {@code $limit} aggregation stage if {@link ConvertingParameterAccessor#getSort()} is + * present. + * + * @param aggregationPipeline + * @param accessor + */ + static void appendModifiedLimitAndOffsetIfPresent(List aggregationPipeline, + ConvertingParameterAccessor accessor) { + + Pageable pageable = accessor.getPageable(); + if (pageable.isUnpaged()) { + return; + } + + if (pageable.getOffset() > 0) { + aggregationPipeline.add(Aggregation.skip(pageable.getOffset())); + } + + aggregationPipeline.add(Aggregation.limit(pageable.getPageSize()+1)); + } /** * Extract a single entry from the given {@link Document}.
diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java index 4deb7d0d52..a4f8ed94bd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java @@ -20,6 +20,8 @@ import java.util.stream.Collectors; import org.bson.Document; +import org.springframework.data.domain.Pageable; +import org.springframework.data.domain.SliceImpl; import org.springframework.data.mapping.model.SpELExpressionEvaluator; import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.core.MongoOperations; @@ -76,18 +78,17 @@ public StringBasedAggregation(MongoQueryMethod method, MongoOperations mongoOper protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProcessor, ConvertingParameterAccessor accessor, Class typeToRead) { - if (method.isPageQuery() || method.isSliceQuery()) { - throw new InvalidMongoDbApiUsageException(String.format( - "Repository aggregation method '%s' does not support '%s' return type. Please use eg. 'List' instead.", - method.getName(), method.getReturnType().getType().getSimpleName())); - } - Class sourceType = method.getDomainClass(); Class targetType = typeToRead; List pipeline = computePipeline(method, accessor); AggregationUtils.appendSortIfPresent(pipeline, accessor, typeToRead); - AggregationUtils.appendLimitAndOffsetIfPresent(pipeline, accessor); + + if (method.isSliceQuery()) { + AggregationUtils.appendModifiedLimitAndOffsetIfPresent(pipeline, accessor); + }else{ + AggregationUtils.appendLimitAndOffsetIfPresent(pipeline, accessor); + } boolean isSimpleReturnType = isSimpleReturnType(typeToRead); boolean isRawAggregationResult = ClassUtils.isAssignable(AggregationResults.class, typeToRead); @@ -118,7 +119,17 @@ protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProces return result.getMappedResults(); } - + + List mappedResults = result.getMappedResults(); + + if(method.isSliceQuery()) { + + Pageable pageable = accessor.getPageable(); + int pageSize = pageable.getPageSize(); + boolean hasNext = mappedResults.size() > pageSize; + return new SliceImpl(hasNext ? mappedResults.subList(0, pageSize) : mappedResults, pageable, hasNext); + } + Object uniqueResult = result.getUniqueMappedResult(); return isSimpleReturnType diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java index 00506229ea..9a5f058490 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java @@ -36,6 +36,8 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.mockito.junit.jupiter.MockitoSettings; import org.mockito.quality.Strictness; +import org.springframework.data.domain.Slice; +import org.springframework.data.domain.SliceImpl; import org.springframework.data.domain.Page; import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; @@ -220,13 +222,10 @@ public void aggregateWithCollationParameter() { } @Test // DATAMONGO-2506 - public void aggregateRaisesErrorOnInvalidReturnType() { - - StringBasedAggregation sba = createAggregationForMethod("invalidPageReturnType", Pageable.class); - assertThatExceptionOfType(InvalidMongoDbApiUsageException.class) // - .isThrownBy(() -> sba.execute(new Object[] { PageRequest.of(0, 1) })) // - .withMessageContaining("invalidPageReturnType") // - .withMessageContaining("Page"); + public void aggregationWithSliceReturnType() { + StringBasedAggregation sba = createAggregationForMethod("aggregationWithSliceReturnType", Pageable.class); + Object result = sba.execute(new Object[] { PageRequest.of(0, 1) }); + assertThat(result.getClass()).isEqualTo(SliceImpl.class); } @Test // DATAMONGO-2557 @@ -319,7 +318,7 @@ private interface SampleRepository extends Repository { PersonAggregate aggregateWithCollation(Collation collation); @Aggregation(RAW_GROUP_BY_LASTNAME_STRING) - Page invalidPageReturnType(Pageable page); + Slice aggregationWithSliceReturnType(Pageable page); @Aggregation(RAW_GROUP_BY_LASTNAME_STRING) String simpleReturnType();