Skip to content

Commit 9a48e32

Browse files
divya_jnu08mp911de
divya_jnu08
authored andcommitted
Aggregation query method should be able to return Slice and Stream.
Aggregation query methods can not return Slice and Stream. interface PersonRepository extends CrudReppsitory<Person, String> { @Aggregation("{ $group: { _id : $lastname, names : { $addToSet : ?0 } } }") Slice<PersonAggregate> groupByLastnameAnd(String property, Pageable page); @Aggregation("{ $group: { _id : $lastname, names : { $addToSet : $firstname } } }") Stream<PersonAggregate> groupByLastnameAndFirstnamesAsStream(); } Closes #3543. Original pull request: #3645.
1 parent ede6927 commit 9a48e32

File tree

3 files changed

+48
-16
lines changed

3 files changed

+48
-16
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java

+22
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,28 @@ static void appendLimitAndOffsetIfPresent(List<AggregationOperation> aggregation
145145

146146
aggregationPipeline.add(Aggregation.limit(pageable.getPageSize()));
147147
}
148+
149+
/**
150+
* Append {@code $skip} and {@code $limit} aggregation stage if {@link ConvertingParameterAccessor#getSort()} is
151+
* present.
152+
*
153+
* @param aggregationPipeline
154+
* @param accessor
155+
*/
156+
static void appendModifiedLimitAndOffsetIfPresent(List<AggregationOperation> aggregationPipeline,
157+
ConvertingParameterAccessor accessor) {
158+
159+
Pageable pageable = accessor.getPageable();
160+
if (pageable.isUnpaged()) {
161+
return;
162+
}
163+
164+
if (pageable.getOffset() > 0) {
165+
aggregationPipeline.add(Aggregation.skip(pageable.getOffset()));
166+
}
167+
168+
aggregationPipeline.add(Aggregation.limit(pageable.getPageSize()+1));
169+
}
148170

149171
/**
150172
* Extract a single entry from the given {@link Document}. <br />

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java

+19-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
import java.util.stream.Collectors;
2121

2222
import org.bson.Document;
23+
import org.springframework.data.domain.Pageable;
24+
import org.springframework.data.domain.SliceImpl;
2325
import org.springframework.data.mapping.model.SpELExpressionEvaluator;
2426
import org.springframework.data.mongodb.InvalidMongoDbApiUsageException;
2527
import org.springframework.data.mongodb.core.MongoOperations;
@@ -76,18 +78,17 @@ public StringBasedAggregation(MongoQueryMethod method, MongoOperations mongoOper
7678
protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProcessor,
7779
ConvertingParameterAccessor accessor, Class<?> typeToRead) {
7880

79-
if (method.isPageQuery() || method.isSliceQuery()) {
80-
throw new InvalidMongoDbApiUsageException(String.format(
81-
"Repository aggregation method '%s' does not support '%s' return type. Please use eg. 'List' instead.",
82-
method.getName(), method.getReturnType().getType().getSimpleName()));
83-
}
84-
8581
Class<?> sourceType = method.getDomainClass();
8682
Class<?> targetType = typeToRead;
8783

8884
List<AggregationOperation> pipeline = computePipeline(method, accessor);
8985
AggregationUtils.appendSortIfPresent(pipeline, accessor, typeToRead);
90-
AggregationUtils.appendLimitAndOffsetIfPresent(pipeline, accessor);
86+
87+
if (method.isSliceQuery()) {
88+
AggregationUtils.appendModifiedLimitAndOffsetIfPresent(pipeline, accessor);
89+
}else{
90+
AggregationUtils.appendLimitAndOffsetIfPresent(pipeline, accessor);
91+
}
9192

9293
boolean isSimpleReturnType = isSimpleReturnType(typeToRead);
9394
boolean isRawAggregationResult = ClassUtils.isAssignable(AggregationResults.class, typeToRead);
@@ -118,7 +119,17 @@ protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProces
118119

119120
return result.getMappedResults();
120121
}
121-
122+
123+
List mappedResults = result.getMappedResults();
124+
125+
if(method.isSliceQuery()) {
126+
127+
Pageable pageable = accessor.getPageable();
128+
int pageSize = pageable.getPageSize();
129+
boolean hasNext = mappedResults.size() > pageSize;
130+
return new SliceImpl<Object>(hasNext ? mappedResults.subList(0, pageSize) : mappedResults, pageable, hasNext);
131+
}
132+
122133
Object uniqueResult = result.getUniqueMappedResult();
123134

124135
return isSimpleReturnType

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StringBasedAggregationUnitTests.java

+7-8
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import org.mockito.junit.jupiter.MockitoExtension;
3737
import org.mockito.junit.jupiter.MockitoSettings;
3838
import org.mockito.quality.Strictness;
39+
import org.springframework.data.domain.Slice;
40+
import org.springframework.data.domain.SliceImpl;
3941
import org.springframework.data.domain.Page;
4042
import org.springframework.data.domain.PageRequest;
4143
import org.springframework.data.domain.Pageable;
@@ -220,13 +222,10 @@ public void aggregateWithCollationParameter() {
220222
}
221223

222224
@Test // DATAMONGO-2506
223-
public void aggregateRaisesErrorOnInvalidReturnType() {
224-
225-
StringBasedAggregation sba = createAggregationForMethod("invalidPageReturnType", Pageable.class);
226-
assertThatExceptionOfType(InvalidMongoDbApiUsageException.class) //
227-
.isThrownBy(() -> sba.execute(new Object[] { PageRequest.of(0, 1) })) //
228-
.withMessageContaining("invalidPageReturnType") //
229-
.withMessageContaining("Page");
225+
public void aggregationWithSliceReturnType() {
226+
StringBasedAggregation sba = createAggregationForMethod("aggregationWithSliceReturnType", Pageable.class);
227+
Object result = sba.execute(new Object[] { PageRequest.of(0, 1) });
228+
assertThat(result.getClass()).isEqualTo(SliceImpl.class);
230229
}
231230

232231
@Test // DATAMONGO-2557
@@ -319,7 +318,7 @@ private interface SampleRepository extends Repository<Person, Long> {
319318
PersonAggregate aggregateWithCollation(Collation collation);
320319

321320
@Aggregation(RAW_GROUP_BY_LASTNAME_STRING)
322-
Page<Person> invalidPageReturnType(Pageable page);
321+
Slice<Person> aggregationWithSliceReturnType(Pageable page);
323322

324323
@Aggregation(RAW_GROUP_BY_LASTNAME_STRING)
325324
String simpleReturnType();

0 commit comments

Comments
 (0)