Skip to content

Commit ba559c2

Browse files
committed
DATAMONGO-1986 - Polishing.
Refactor duplicated code into AggregationUtil. Original pull request: #564.
1 parent 5f3ad68 commit ba559c2

File tree

4 files changed

+152
-123
lines changed

4 files changed

+152
-123
lines changed
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
/*
2+
* Copyright 2018 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.mongodb.core;
17+
18+
import lombok.AllArgsConstructor;
19+
20+
import java.util.List;
21+
import java.util.Optional;
22+
import java.util.stream.Collectors;
23+
24+
import org.bson.Document;
25+
import org.springframework.data.mapping.context.MappingContext;
26+
import org.springframework.data.mongodb.core.aggregation.Aggregation;
27+
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
28+
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
29+
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
30+
import org.springframework.data.mongodb.core.convert.QueryMapper;
31+
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
32+
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
33+
import org.springframework.lang.Nullable;
34+
import org.springframework.util.ObjectUtils;
35+
36+
/**
37+
* Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and
38+
* create type-bound {@link AggregationOperationContext}.
39+
*
40+
* @author Christoph Strobl
41+
* @author Mark Paluch
42+
* @since 2.0.8
43+
*/
44+
@AllArgsConstructor
45+
class AggregationUtil {
46+
47+
QueryMapper queryMapper;
48+
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;
49+
50+
/**
51+
* Prepare the {@link AggregationOperationContext} for a given aggregation by either returning the context itself it
52+
* is not {@literal null}, create a {@link TypeBasedAggregationOperationContext} if the aggregation contains type
53+
* information (is a {@link TypedAggregation}) or use the {@link Aggregation#DEFAULT_CONTEXT}.
54+
*
55+
* @param aggregation must not be {@literal null}.
56+
* @param context can be {@literal null}.
57+
* @return the root {@link AggregationOperationContext} to use.
58+
*/
59+
AggregationOperationContext prepareAggregationContext(Aggregation aggregation,
60+
@Nullable AggregationOperationContext context) {
61+
62+
if (context != null) {
63+
return context;
64+
}
65+
66+
if (aggregation instanceof TypedAggregation) {
67+
return new TypeBasedAggregationOperationContext(((TypedAggregation) aggregation).getInputType(), mappingContext,
68+
queryMapper);
69+
}
70+
71+
return Aggregation.DEFAULT_CONTEXT;
72+
}
73+
74+
/**
75+
* Extract and map the aggregation pipeline into a {@link List} of {@link Document}.
76+
*
77+
* @param aggregation
78+
* @param context
79+
* @return
80+
*/
81+
Document createPipeline(String collectionName, Aggregation aggregation, AggregationOperationContext context) {
82+
83+
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
84+
return aggregation.toDocument(collectionName, context);
85+
}
86+
87+
Document command = aggregation.toDocument(collectionName, context);
88+
command.put("pipeline", mapAggregationPipeline(command.get("pipeline", List.class)));
89+
90+
return command;
91+
}
92+
93+
/**
94+
* Extract the command and map the aggregation pipeline.
95+
*
96+
* @param aggregation
97+
* @param context
98+
* @return
99+
*/
100+
Document createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) {
101+
102+
Document command = aggregation.toDocument(collection, context);
103+
104+
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
105+
return command;
106+
}
107+
108+
command.put("pipeline", mapAggregationPipeline(command.get("pipeline", List.class)));
109+
110+
return command;
111+
}
112+
113+
private List<Document> mapAggregationPipeline(List<Document> pipeline) {
114+
115+
return pipeline.stream().map(val -> queryMapper.getMappedObject(val, Optional.empty()))
116+
.collect(Collectors.toList());
117+
}
118+
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

Lines changed: 21 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import java.util.*;
2828
import java.util.Map.Entry;
2929
import java.util.concurrent.TimeUnit;
30-
import java.util.stream.Collectors;
3130

3231
import org.bson.Document;
3332
import org.bson.conversions.Bson;
@@ -1934,7 +1933,8 @@ protected <O> AggregationResults<O> aggregate(Aggregation aggregation, String co
19341933
Assert.notNull(aggregation, "Aggregation pipeline must not be null!");
19351934
Assert.notNull(outputType, "Output type must not be null!");
19361935

1937-
Document commandResult = new BatchAggregationLoader(this, readPreference, Integer.MAX_VALUE)
1936+
Document commandResult = new BatchAggregationLoader(this, queryMapper, mappingContext, readPreference,
1937+
Integer.MAX_VALUE)
19381938
.aggregate(collectionName, aggregation, context);
19391939

19401940
return new AggregationResults<>(returnPotentiallyMappedResults(outputType, commandResult, collectionName),
@@ -1957,9 +1957,9 @@ private <O> List<O> returnPotentiallyMappedResults(Class<O> outputType, Document
19571957
return Collections.emptyList();
19581958
}
19591959

1960-
DocumentCallback<O> callback = new UnwrapAndReadDocumentCallback<O>(mongoConverter, outputType, collectionName);
1960+
DocumentCallback<O> callback = new UnwrapAndReadDocumentCallback<>(mongoConverter, outputType, collectionName);
19611961

1962-
List<O> mappedResults = new ArrayList<O>();
1962+
List<O> mappedResults = new ArrayList<>();
19631963
for (Document document : resultSet) {
19641964
mappedResults.add(callback.doWith(document));
19651965
}
@@ -1974,17 +1974,18 @@ protected <O> CloseableIterator<O> aggregateStream(Aggregation aggregation, Stri
19741974
Assert.notNull(aggregation, "Aggregation pipeline must not be null!");
19751975
Assert.notNull(outputType, "Output type must not be null!");
19761976

1977-
AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context;
1977+
AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext);
1978+
AggregationOperationContext rootContext = aggregationUtil.prepareAggregationContext(aggregation, context);
19781979

1979-
Document command = aggregation.toDocument(collectionName, rootContext);
1980+
Document command = aggregationUtil.createCommand(collectionName, aggregation, rootContext);
19801981

19811982
assertNotExplain(command);
19821983

19831984
if (LOGGER.isDebugEnabled()) {
19841985
LOGGER.debug("Streaming aggregation: {}", serializeToJsonSafely(command));
19851986
}
19861987

1987-
ReadDocumentCallback<O> readCallback = new ReadDocumentCallback<O>(mongoConverter, outputType, collectionName);
1988+
ReadDocumentCallback<O> readCallback = new ReadDocumentCallback<>(mongoConverter, outputType, collectionName);
19881989

19891990
return execute(collectionName, new CollectionCallback<CloseableIterator<O>>() {
19901991

@@ -2008,7 +2009,7 @@ public CloseableIterator<O> doInCollection(MongoCollection<Document> collection)
20082009
cursor = cursor.collation(options.getCollation().map(Collation::toMongoCollation).get());
20092010
}
20102011

2011-
return new CloseableIterableCursorAdapter<O>(cursor.iterator(), exceptionTranslator, readCallback);
2012+
return new CloseableIterableCursorAdapter<>(cursor.iterator(), exceptionTranslator, readCallback);
20122013
}
20132014
});
20142015
}
@@ -2577,72 +2578,6 @@ private Document addFieldsForProjection(Document fields, Class<?> domainType, Cl
25772578
return fields;
25782579
}
25792580

2580-
/**
2581-
* Prepare the {@link AggregationOperationContext} for a given aggregation by either returning the context itself it
2582-
* is not {@literal null}, create a {@link TypeBasedAggregationOperationContext} if the aggregation contains type
2583-
* information (is a {@link TypedAggregation}) or use the {@link Aggregation#DEFAULT_CONTEXT}.
2584-
*
2585-
* @param aggregation must not be {@literal null}.
2586-
* @param context can be {@literal null}.
2587-
* @return the root {@link AggregationOperationContext} to use.
2588-
*/
2589-
private AggregationOperationContext prepareAggregationContext(Aggregation aggregation,
2590-
@Nullable AggregationOperationContext context) {
2591-
2592-
if (context != null) {
2593-
return context;
2594-
}
2595-
2596-
if (aggregation instanceof TypedAggregation) {
2597-
return new TypeBasedAggregationOperationContext(((TypedAggregation) aggregation).getInputType(), mappingContext,
2598-
queryMapper);
2599-
}
2600-
2601-
return Aggregation.DEFAULT_CONTEXT;
2602-
}
2603-
2604-
/**
2605-
* Extract and map the aggregation pipeline.
2606-
*
2607-
* @param aggregation
2608-
* @param context
2609-
* @return
2610-
*/
2611-
private Document aggregationToPipeline(String inputCollectionName, Aggregation aggregation, AggregationOperationContext context) {
2612-
2613-
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
2614-
return aggregation.toDocument(inputCollectionName, context);
2615-
}
2616-
2617-
return queryMapper.getMappedObject(aggregation.toDocument(inputCollectionName, context), Optional.empty());
2618-
}
2619-
2620-
/**
2621-
* Extract the command and map the aggregation pipeline.
2622-
*
2623-
* @param aggregation
2624-
* @param context
2625-
* @return
2626-
*/
2627-
private Document aggregationToCommand(String collection, Aggregation aggregation,
2628-
AggregationOperationContext context) {
2629-
2630-
Document command = aggregation.toDocument(collection, context);
2631-
2632-
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
2633-
return command;
2634-
}
2635-
2636-
command.put("pipeline", mapAggregationPipeline(command.get("pipeline", List.class)));
2637-
2638-
return command;
2639-
}
2640-
2641-
private List<Document> mapAggregationPipeline(List<Document> pipeline) {
2642-
2643-
return pipeline.stream().map(val -> queryMapper.getMappedObject(val, Optional.empty()))
2644-
.collect(Collectors.toList());
2645-
}
26462581

26472582
/**
26482583
* Tries to convert the given {@link RuntimeException} into a {@link DataAccessException} but returns the original
@@ -3157,12 +3092,18 @@ static class BatchAggregationLoader {
31573092
private static final String OK = "ok";
31583093

31593094
private final MongoTemplate template;
3095+
private final QueryMapper queryMapper;
3096+
private final MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;
31603097
private final ReadPreference readPreference;
31613098
private final int batchSize;
31623099

3163-
BatchAggregationLoader(MongoTemplate template, ReadPreference readPreference, int batchSize) {
3100+
BatchAggregationLoader(MongoTemplate template, QueryMapper queryMapper,
3101+
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext,
3102+
ReadPreference readPreference, int batchSize) {
31643103

31653104
this.template = template;
3105+
this.queryMapper = queryMapper;
3106+
this.mappingContext = mappingContext;
31663107
this.readPreference = readPreference;
31673108
this.batchSize = batchSize;
31683109
}
@@ -3185,11 +3126,13 @@ Document aggregate(String collectionName, Aggregation aggregation, AggregationOp
31853126
* Pre process the aggregation command sent to the server by adding {@code cursor} options to match execution on
31863127
* different server versions.
31873128
*/
3188-
private static Document prepareAggregationCommand(String collectionName, Aggregation aggregation,
3129+
private Document prepareAggregationCommand(String collectionName, Aggregation aggregation,
31893130
@Nullable AggregationOperationContext context, int batchSize) {
31903131

3191-
AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context;
3192-
Document command = aggregation.toDocument(collectionName, rootContext);
3132+
AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext);
3133+
3134+
AggregationOperationContext rootContext = aggregationUtil.prepareAggregationContext(aggregation, context);
3135+
Document command = aggregationUtil.createCommand(collectionName, aggregation, rootContext);
31933136

31943137
if (!aggregation.getOptions().isExplain()) {
31953138
command.put(CURSOR_FIELD, new Document(BATCH_SIZE_FIELD, batchSize));

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -733,8 +733,9 @@ protected <O> Flux<O> aggregate(Aggregation aggregation, String collectionName,
733733
Assert.hasText(collectionName, "Collection name must not be null or empty!");
734734
Assert.notNull(outputType, "Output type must not be null!");
735735

736-
AggregationOperationContext rootContext = prepareAggregationContext(aggregation, context);
737-
Document command = aggregationToPipeline(collectionName, aggregation, rootContext);
736+
AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext);
737+
AggregationOperationContext rootContext = aggregationUtil.prepareAggregationContext(aggregation, context);
738+
Document command = aggregationUtil.createPipeline(collectionName, aggregation, rootContext);
738739
AggregationOptions options = AggregationOptions.fromDocument(command);
739740

740741
Assert.isTrue(!options.isExplain(), "Cannot use explain option with streaming!");
@@ -2197,46 +2198,6 @@ private Function<Throwable, Throwable> translateException() {
21972198
};
21982199
}
21992200

2200-
/**
2201-
* Prepare the {@link AggregationOperationContext} for a given aggregation by either returning the context itself it
2202-
* is not {@literal null}, create a {@link TypeBasedAggregationOperationContext} if the aggregation contains type
2203-
* information (is a {@link TypedAggregation}) or use the {@link Aggregation#DEFAULT_CONTEXT}.
2204-
*
2205-
* @param aggregation must not be {@literal null}.
2206-
* @param context can be {@literal null}.
2207-
* @return the root {@link AggregationOperationContext} to use.
2208-
*/
2209-
private AggregationOperationContext prepareAggregationContext(Aggregation aggregation,
2210-
@Nullable AggregationOperationContext context) {
2211-
2212-
if (context != null) {
2213-
return context;
2214-
}
2215-
2216-
if (aggregation instanceof TypedAggregation) {
2217-
return new TypeBasedAggregationOperationContext(((TypedAggregation) aggregation).getInputType(), mappingContext,
2218-
queryMapper);
2219-
}
2220-
2221-
return Aggregation.DEFAULT_CONTEXT;
2222-
}
2223-
2224-
/**
2225-
* Extract and map the aggregation pipeline.
2226-
*
2227-
* @param aggregation
2228-
* @param context
2229-
* @return
2230-
*/
2231-
private Document aggregationToPipeline(String inputCollectionName, Aggregation aggregation, AggregationOperationContext context) {
2232-
2233-
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
2234-
return aggregation.toDocument(inputCollectionName, context);
2235-
}
2236-
2237-
return queryMapper.getMappedObject(aggregation.toDocument(inputCollectionName, context), Optional.empty());
2238-
}
2239-
22402201
/**
22412202
* Tries to convert the given {@link RuntimeException} into a {@link DataAccessException} but returns the original
22422203
* exception if the conversation failed. Thus allows safe re-throwing of the return value.

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,16 @@
3131
import org.springframework.data.mongodb.core.MongoTemplate.BatchAggregationLoader;
3232
import org.springframework.data.mongodb.core.aggregation.Aggregation;
3333
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
34+
import org.springframework.data.mongodb.core.convert.DbRefResolver;
35+
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
36+
import org.springframework.data.mongodb.core.convert.QueryMapper;
37+
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
3438

3539
import com.mongodb.ReadPreference;
3640

3741
/**
3842
* Unit tests for {@link BatchAggregationLoader}.
39-
*
43+
*
4044
* @author Christoph Strobl
4145
* @author Mark Paluch
4246
*/
@@ -47,6 +51,7 @@ public class BatchAggregationLoaderUnitTests {
4751
project().and("firstName").as("name"));
4852

4953
@Mock MongoTemplate template;
54+
@Mock DbRefResolver dbRefResolver;
5055
@Mock Document aggregationResult;
5156
@Mock Document getMoreResult;
5257

@@ -60,7 +65,9 @@ public class BatchAggregationLoaderUnitTests {
6065

6166
@Before
6267
public void setUp() {
63-
loader = new BatchAggregationLoader(template, ReadPreference.primary(), 10);
68+
MongoMappingContext context = new MongoMappingContext();
69+
loader = new BatchAggregationLoader(template, new QueryMapper(new MappingMongoConverter(dbRefResolver, context)),
70+
context, ReadPreference.primary(), 10);
6471
}
6572

6673
@Test // DATAMONGO-1824
@@ -84,7 +91,7 @@ public void shouldLoadJustOneBatchWhenAlreadyDoneWithFirst() {
8491
when(aggregationResult.get("cursor")).thenReturn(cursorWithoutMore);
8592

8693
Document result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT);
87-
94+
8895
assertThat((List) result.get("result")).contains(luke);
8996

9097
verify(template).executeCommand(any(Document.class), any(ReadPreference.class));

0 commit comments

Comments
 (0)