Skip to content

Commit 90d03d9

Browse files
committed
Polishing.
Let appendLimitAndOffsetIfPresent accept unary operators for adjusting limit/offset values instead of appendModifiedLimitAndOffsetIfPresent. Apply simple type extraction for Slice. Add support for aggregation result streaming. Extend tests, add author tags, update docs. See #3543. Original pull request: #3645.
1 parent 9a48e32 commit 90d03d9

File tree

6 files changed

+185
-78
lines changed

6 files changed

+185
-78
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AggregationUtils.java

+13-16
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import java.time.Duration;
1919
import java.util.List;
2020
import java.util.Map;
21+
import java.util.function.IntUnaryOperator;
22+
import java.util.function.LongUnaryOperator;
2123

2224
import org.bson.Document;
2325
import org.springframework.data.domain.Pageable;
@@ -42,6 +44,7 @@
4244
*
4345
* @author Christoph Strobl
4446
* @author Mark Paluch
47+
* @author Divya Srivastava
4548
* @since 2.2
4649
*/
4750
abstract class AggregationUtils {
@@ -133,39 +136,33 @@ static void appendSortIfPresent(List<AggregationOperation> aggregationPipeline,
133136
*/
134137
static void appendLimitAndOffsetIfPresent(List<AggregationOperation> aggregationPipeline,
135138
ConvertingParameterAccessor accessor) {
136-
137-
Pageable pageable = accessor.getPageable();
138-
if (pageable.isUnpaged()) {
139-
return;
140-
}
141-
142-
if (pageable.getOffset() > 0) {
143-
aggregationPipeline.add(Aggregation.skip(pageable.getOffset()));
144-
}
145-
146-
aggregationPipeline.add(Aggregation.limit(pageable.getPageSize()));
139+
appendLimitAndOffsetIfPresent(aggregationPipeline, accessor, LongUnaryOperator.identity(),
140+
IntUnaryOperator.identity());
147141
}
148-
142+
149143
/**
150144
* Append {@code $skip} and {@code $limit} aggregation stage if {@link ConvertingParameterAccessor#getSort()} is
151145
* present.
152146
*
153147
* @param aggregationPipeline
154148
* @param accessor
149+
* @param offsetOperator
150+
* @param limitOperator
151+
* @since 3.3
155152
*/
156-
static void appendModifiedLimitAndOffsetIfPresent(List<AggregationOperation> aggregationPipeline,
157-
ConvertingParameterAccessor accessor) {
153+
static void appendLimitAndOffsetIfPresent(List<AggregationOperation> aggregationPipeline,
154+
ConvertingParameterAccessor accessor, LongUnaryOperator offsetOperator, IntUnaryOperator limitOperator) {
158155

159156
Pageable pageable = accessor.getPageable();
160157
if (pageable.isUnpaged()) {
161158
return;
162159
}
163160

164161
if (pageable.getOffset() > 0) {
165-
aggregationPipeline.add(Aggregation.skip(pageable.getOffset()));
162+
aggregationPipeline.add(Aggregation.skip(offsetOperator.applyAsLong(pageable.getOffset())));
166163
}
167164

168-
aggregationPipeline.add(Aggregation.limit(pageable.getPageSize()+1));
165+
aggregationPipeline.add(Aggregation.limit(limitOperator.applyAsInt(pageable.getPageSize())));
169166
}
170167

171168
/**

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/StringBasedAggregation.java

+51-21
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
import java.util.ArrayList;
1919
import java.util.List;
20-
import java.util.stream.Collectors;
20+
import java.util.function.LongUnaryOperator;
21+
import java.util.stream.Stream;
2122

2223
import org.bson.Document;
24+
2325
import org.springframework.data.domain.Pageable;
2426
import org.springframework.data.domain.SliceImpl;
2527
import org.springframework.data.mapping.model.SpELExpressionEvaluator;
@@ -42,7 +44,12 @@
4244
import org.springframework.util.ClassUtils;
4345

4446
/**
47+
* {@link AbstractMongoQuery} implementation to run string-based aggregations using
48+
* {@link org.springframework.data.mongodb.repository.Aggregation}.
49+
*
4550
* @author Christoph Strobl
51+
* @author Divya Srivastava
52+
* @author Mark Paluch
4653
* @since 2.2
4754
*/
4855
public class StringBasedAggregation extends AbstractMongoQuery {
@@ -64,6 +71,12 @@ public StringBasedAggregation(MongoQueryMethod method, MongoOperations mongoOper
6471
ExpressionParser expressionParser, QueryMethodEvaluationContextProvider evaluationContextProvider) {
6572
super(method, mongoOperations, expressionParser, evaluationContextProvider);
6673

74+
if (method.isPageQuery()) {
75+
throw new InvalidMongoDbApiUsageException(String.format(
76+
"Repository aggregation method '%s' does not support '%s' return type. Please use 'Slice' or 'List' instead.",
77+
method.getName(), method.getReturnType().getType().getSimpleName()));
78+
}
79+
6780
this.mongoOperations = mongoOperations;
6881
this.mongoConverter = mongoOperations.getConverter();
6982
this.expressionParser = expressionParser;
@@ -83,10 +96,11 @@ protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProces
8396

8497
List<AggregationOperation> pipeline = computePipeline(method, accessor);
8598
AggregationUtils.appendSortIfPresent(pipeline, accessor, typeToRead);
86-
99+
87100
if (method.isSliceQuery()) {
88-
AggregationUtils.appendModifiedLimitAndOffsetIfPresent(pipeline, accessor);
89-
}else{
101+
AggregationUtils.appendLimitAndOffsetIfPresent(pipeline, accessor, LongUnaryOperator.identity(),
102+
limit -> limit + 1);
103+
} else {
90104
AggregationUtils.appendLimitAndOffsetIfPresent(pipeline, accessor);
91105
}
92106

@@ -96,47 +110,63 @@ protected Object doExecute(MongoQueryMethod method, ResultProcessor resultProces
96110
if (isSimpleReturnType) {
97111
targetType = Document.class;
98112
} else if (isRawAggregationResult) {
113+
114+
// 🙈
99115
targetType = method.getReturnType().getRequiredActualType().getRequiredComponentType().getType();
100116
}
101117

102118
AggregationOptions options = computeOptions(method, accessor);
103119
TypedAggregation<?> aggregation = new TypedAggregation<>(sourceType, pipeline, options);
104120

105-
AggregationResults<?> result = mongoOperations.aggregate(aggregation, targetType);
121+
if (method.isStreamQuery()) {
122+
123+
Stream<?> stream = mongoOperations.aggregateStream(aggregation, targetType).stream();
124+
125+
if (isSimpleReturnType) {
126+
return stream.map(it -> AggregationUtils.extractSimpleTypeResult((Document) it, typeToRead, mongoConverter));
127+
}
128+
129+
return stream;
130+
}
131+
132+
AggregationResults<Object> result = (AggregationResults<Object>) mongoOperations.aggregate(aggregation, targetType);
106133

107134
if (isRawAggregationResult) {
108135
return result;
109136
}
110137

138+
List<Object> results = result.getMappedResults();
111139
if (method.isCollectionQuery()) {
140+
return isSimpleReturnType ? convertResults(typeToRead, results) : results;
141+
}
112142

113-
if (isSimpleReturnType) {
114-
115-
return result.getMappedResults().stream()
116-
.map(it -> AggregationUtils.extractSimpleTypeResult((Document) it, typeToRead, mongoConverter))
117-
.collect(Collectors.toList());
118-
}
143+
if (method.isSliceQuery()) {
119144

120-
return result.getMappedResults();
121-
}
122-
123-
List mappedResults = result.getMappedResults();
124-
125-
if(method.isSliceQuery()) {
126-
127145
Pageable pageable = accessor.getPageable();
128146
int pageSize = pageable.getPageSize();
129-
boolean hasNext = mappedResults.size() > pageSize;
130-
return new SliceImpl<Object>(hasNext ? mappedResults.subList(0, pageSize) : mappedResults, pageable, hasNext);
147+
List<Object> resultsToUse = isSimpleReturnType ? convertResults(typeToRead, results) : results;
148+
boolean hasNext = resultsToUse.size() > pageSize;
149+
return new SliceImpl<>(hasNext ? resultsToUse.subList(0, pageSize) : resultsToUse, pageable, hasNext);
131150
}
132-
151+
133152
Object uniqueResult = result.getUniqueMappedResult();
134153

135154
return isSimpleReturnType
136155
? AggregationUtils.extractSimpleTypeResult((Document) uniqueResult, typeToRead, mongoConverter)
137156
: uniqueResult;
138157
}
139158

159+
private List<Object> convertResults(Class<?> typeToRead, List<Object> mappedResults) {
160+
161+
List<Object> list = new ArrayList<>(mappedResults.size());
162+
for (Object it : mappedResults) {
163+
Object extractSimpleTypeResult = AggregationUtils.extractSimpleTypeResult((Document) it, typeToRead,
164+
mongoConverter);
165+
list.add(extractSimpleTypeResult);
166+
}
167+
return list;
168+
}
169+
140170
private boolean isSimpleReturnType(Class<?> targetType) {
141171
return MongoSimpleTypes.HOLDER.isSimpleType(targetType);
142172
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java

+19-7
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import org.springframework.data.domain.Example;
4444
import org.springframework.data.domain.Page;
4545
import org.springframework.data.domain.PageRequest;
46+
import org.springframework.data.domain.Pageable;
4647
import org.springframework.data.domain.Range;
4748
import org.springframework.data.domain.Slice;
4849
import org.springframework.data.domain.Sort;
@@ -1269,13 +1270,16 @@ void annotatedQueryShouldAllowPositionalParameterInFieldsProjectionWithDbRef() {
12691270
@Test // DATAMONGO-2153
12701271
void findListOfSingleValue() {
12711272

1272-
assertThat(repository.findAllLastnames()) //
1273-
.contains("Lessard") //
1274-
.contains("Keys") //
1275-
.contains("Tinsley") //
1276-
.contains("Beauford") //
1277-
.contains("Moore") //
1278-
.contains("Matthews"); //
1273+
assertThat(repository.findAllLastnames()).contains("Lessard", "Keys", "Tinsley", "Beauford", "Moore", "Matthews");
1274+
}
1275+
1276+
@Test // GH-3543
1277+
void findStreamOfSingleValue() {
1278+
1279+
try (Stream<String> lastnames = repository.findAllLastnamesAsStream()) {
1280+
assertThat(lastnames) //
1281+
.contains("Lessard", "Keys", "Tinsley", "Beauford", "Moore", "Matthews");
1282+
}
12791283
}
12801284

12811285
@Test // DATAMONGO-2153
@@ -1290,6 +1294,14 @@ void annotatedAggregationWithPlaceholderValue() {
12901294
.contains(new PersonAggregate("Matthews", Arrays.asList("Dave", "Oliver August")));
12911295
}
12921296

1297+
@Test // GH-3543
1298+
void annotatedAggregationWithPlaceholderAsSlice() {
1299+
1300+
Slice<PersonAggregate> slice = repository.groupByLastnameAndAsSlice("firstname", Pageable.ofSize(5));
1301+
assertThat(slice).hasSize(5);
1302+
assertThat(slice.hasNext()).isTrue();
1303+
}
1304+
12931305
@Test // DATAMONGO-2153
12941306
void annotatedAggregationWithSort() {
12951307

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java

+6
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,15 @@ Page<Person> findByCustomQueryLastnameAndAddressStreetInList(String lastname, Li
379379
@Aggregation("{ '$project': { '_id' : '$lastname' } }")
380380
List<String> findAllLastnames();
381381

382+
@Aggregation("{ '$project': { '_id' : '$lastname' } }")
383+
Stream<String> findAllLastnamesAsStream();
384+
382385
@Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }")
383386
List<PersonAggregate> groupByLastnameAnd(String property);
384387

388+
@Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }")
389+
Slice<PersonAggregate> groupByLastnameAndAsSlice(String property, Pageable pageable);
390+
385391
@Aggregation("{ '$group': { '_id' : '$lastname', names : { $addToSet : '$?0' } } }")
386392
List<PersonAggregate> groupByLastnameAnd(String property, Sort sort);
387393

0 commit comments

Comments
 (0)