Skip to content

Commit 41897c7

Browse files
committed
DATAMONGO-1986 - Polishing.
Refactor duplicated code into AggregationUtil. Original pull request: #564.
1 parent f3397e9 commit 41897c7

File tree

6 files changed

+151
-11
lines changed

6 files changed

+151
-11
lines changed
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/*
2+
* Copyright 2018 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.mongodb.core;
17+
18+
import lombok.AllArgsConstructor;
19+
20+
import java.util.List;
21+
22+
import org.springframework.data.mapping.context.MappingContext;
23+
import org.springframework.data.mongodb.core.aggregation.Aggregation;
24+
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
25+
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
26+
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
27+
import org.springframework.data.mongodb.core.convert.QueryMapper;
28+
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
29+
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
30+
import org.springframework.util.ObjectUtils;
31+
32+
import com.mongodb.BasicDBList;
33+
import com.mongodb.DBObject;
34+
35+
/**
36+
* Utility methods to map {@link org.springframework.data.mongodb.core.aggregation.Aggregation} pipeline definitions and
37+
* create type-bound {@link AggregationOperationContext}.
38+
*
39+
* @author Christoph Strobl
40+
* @author Mark Paluch
41+
* @since 1.10.13
42+
*/
43+
@AllArgsConstructor
44+
class AggregationUtil {
45+
46+
QueryMapper queryMapper;
47+
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;
48+
49+
/**
50+
* Prepare the {@link AggregationOperationContext} for a given aggregation by either returning the context itself it
51+
* is not {@literal null}, create a {@link TypeBasedAggregationOperationContext} if the aggregation contains type
52+
* information (is a {@link TypedAggregation}) or use the {@link Aggregation#DEFAULT_CONTEXT}.
53+
*
54+
* @param aggregation must not be {@literal null}.
55+
* @param context can be {@literal null}.
56+
* @return the root {@link AggregationOperationContext} to use.
57+
*/
58+
AggregationOperationContext prepareAggregationContext(Aggregation aggregation, AggregationOperationContext context) {
59+
60+
if (context != null) {
61+
return context;
62+
}
63+
64+
if (aggregation instanceof TypedAggregation) {
65+
return new TypeBasedAggregationOperationContext(((TypedAggregation) aggregation).getInputType(), mappingContext,
66+
queryMapper);
67+
}
68+
69+
return Aggregation.DEFAULT_CONTEXT;
70+
}
71+
72+
/**
73+
* Extract and map the aggregation pipeline into a {@link List} of {@link Document}.
74+
*
75+
* @param aggregation
76+
* @param context
77+
* @return
78+
*/
79+
DBObject createPipeline(String collectionName, Aggregation aggregation, AggregationOperationContext context) {
80+
81+
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
82+
return aggregation.toDbObject(collectionName, context);
83+
}
84+
85+
DBObject command = aggregation.toDbObject(collectionName, context);
86+
command.put("pipeline", mapAggregationPipeline((List) command.get("pipeline")));
87+
88+
return command;
89+
}
90+
91+
/**
92+
* Extract the command and map the aggregation pipeline.
93+
*
94+
* @param aggregation
95+
* @param context
96+
* @return
97+
*/
98+
DBObject createCommand(String collection, Aggregation aggregation, AggregationOperationContext context) {
99+
100+
DBObject command = aggregation.toDbObject(collection, context);
101+
102+
if (!ObjectUtils.nullSafeEquals(context, Aggregation.DEFAULT_CONTEXT)) {
103+
return command;
104+
}
105+
106+
command.put("pipeline", mapAggregationPipeline((List) command.get("pipeline")));
107+
108+
return command;
109+
}
110+
111+
private BasicDBList mapAggregationPipeline(List<DBObject> pipeline) {
112+
113+
BasicDBList mapped = new BasicDBList();
114+
115+
for (DBObject stage : pipeline) {
116+
mapped.add(queryMapper.getMappedObject(stage, null));
117+
}
118+
119+
return mapped;
120+
}
121+
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,8 @@ protected <O> AggregationResults<O> aggregate(Aggregation aggregation, String co
15541554
Assert.notNull(aggregation, "Aggregation pipeline must not be null!");
15551555
Assert.notNull(outputType, "Output type must not be null!");
15561556

1557-
DBObject commandResult = new BatchAggregationLoader(this, readPreference, Integer.MAX_VALUE)
1557+
DBObject commandResult = new BatchAggregationLoader(this, queryMapper, mappingContext, readPreference,
1558+
Integer.MAX_VALUE)
15581559
.aggregate(collectionName, aggregation, context);
15591560

15601561
return new AggregationResults<O>(returnPotentiallyMappedResults(outputType, commandResult, collectionName),
@@ -2555,12 +2556,18 @@ static class BatchAggregationLoader {
25552556
private static final String OK = "ok";
25562557

25572558
private final MongoTemplate template;
2559+
private final QueryMapper queryMapper;
2560+
private final MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;
25582561
private final ReadPreference readPreference;
25592562
private final int batchSize;
25602563

2561-
BatchAggregationLoader(MongoTemplate template, ReadPreference readPreference, int batchSize) {
2564+
BatchAggregationLoader(MongoTemplate template, QueryMapper queryMapper,
2565+
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext,
2566+
ReadPreference readPreference, int batchSize) {
25622567

25632568
this.template = template;
2569+
this.queryMapper = queryMapper;
2570+
this.mappingContext = mappingContext;
25642571
this.readPreference = readPreference;
25652572
this.batchSize = batchSize;
25662573
}
@@ -2583,11 +2590,13 @@ DBObject aggregate(String collectionName, Aggregation aggregation, AggregationOp
25832590
* Pre process the aggregation command sent to the server by adding {@code cursor} options to match execution on
25842591
* different server versions.
25852592
*/
2586-
private static DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation,
2587-
AggregationOperationContext context, int batchSize) {
2593+
private DBObject prepareAggregationCommand(String collectionName, Aggregation aggregation,
2594+
AggregationOperationContext context, int batchSize) {
25882595

2589-
AggregationOperationContext rootContext = context == null ? Aggregation.DEFAULT_CONTEXT : context;
2590-
DBObject command = aggregation.toDbObject(collectionName, rootContext);
2596+
AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext);
2597+
2598+
AggregationOperationContext rootContext = aggregationUtil.prepareAggregationContext(aggregation, context);
2599+
DBObject command = aggregationUtil.createCommand(collectionName, aggregation, rootContext);
25912600

25922601
if (!aggregation.getOptions().isExplain()) {
25932602
command.put(CURSOR_FIELD, new BasicDBObject(BATCH_SIZE_FIELD, batchSize));

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/BatchAggregationLoaderUnitTests.java

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
import org.springframework.data.mongodb.core.MongoTemplate.BatchAggregationLoader;
3232
import org.springframework.data.mongodb.core.aggregation.Aggregation;
3333
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
34+
import org.springframework.data.mongodb.core.convert.DbRefResolver;
35+
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
36+
import org.springframework.data.mongodb.core.convert.QueryMapper;
37+
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
3438

3539
import com.mongodb.BasicDBObject;
3640
import com.mongodb.CommandResult;
@@ -39,7 +43,7 @@
3943

4044
/**
4145
* Unit tests for {@link BatchAggregationLoader}.
42-
*
46+
*
4347
* @author Christoph Strobl
4448
* @author Mark Paluch
4549
*/
@@ -50,6 +54,7 @@ public class BatchAggregationLoaderUnitTests {
5054
project().and("firstName").as("name"));
5155

5256
@Mock MongoTemplate template;
57+
@Mock DbRefResolver dbRefResolver;
5358
@Mock CommandResult aggregationResult;
5459
@Mock CommandResult getMoreResult;
5560

@@ -65,7 +70,9 @@ public class BatchAggregationLoaderUnitTests {
6570

6671
@Before
6772
public void setUp() {
68-
loader = new BatchAggregationLoader(template, ReadPreference.primary(), 10);
73+
MongoMappingContext context = new MongoMappingContext();
74+
loader = new BatchAggregationLoader(template, new QueryMapper(new MappingMongoConverter(dbRefResolver, context)),
75+
context, ReadPreference.primary(), 10);
6976
}
7077

7178
@Test // DATAMONGO-1824
@@ -89,6 +96,7 @@ public void shouldLoadJustOneBatchWhenAlreadyDoneWithFirst() {
8996
when(aggregationResult.get("cursor")).thenReturn(cursorWithoutMore);
9097

9198
DBObject result = loader.aggregate("person", AGGREGATION, Aggregation.DEFAULT_CONTEXT);
99+
92100
assertThat((List<Object>) result.get("result"),
93101
IsCollectionContaining.<Object> hasItem(luke));
94102

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/TestEntities.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ public Venue maplewoodNJ() {
8484

8585
public List<Venue> newYork() {
8686

87-
List<Venue> venues = new ArrayList<>();
87+
List<Venue> venues = new ArrayList<Venue>();
8888

8989
venues.add(pennStation());
9090
venues.add(tenGenOffice());

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AggregationTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1741,7 +1741,7 @@ public void runMatchOperationCriteriaThroughQueryMapperForTypedAggregation() {
17411741
.within(new Box(new Point(-73.99756, 40.73083), new Point(-73.988135, 40.741404)))),
17421742
project("id", "location", "name"));
17431743

1744-
AggregationResults<Document> groupResults = mongoTemplate.aggregate(aggregation, "newyork", Document.class);
1744+
AggregationResults<DBObject> groupResults = mongoTemplate.aggregate(aggregation, "newyork", DBObject.class);
17451745

17461746
assertThat(groupResults.getMappedResults().size(), is(4));
17471747
}
@@ -1756,7 +1756,8 @@ public void runMatchOperationCriteriaThroughQueryMapperForUntypedAggregation() {
17561756
.within(new Box(new Point(-73.99756, 40.73083), new Point(-73.988135, 40.741404)))),
17571757
project("id", "location", "name"));
17581758

1759-
AggregationResults<Document> groupResults = mongoTemplate.aggregate(aggregation, "newyork", Document.class);
1759+
AggregationResults<DBObject
1760+
> groupResults = mongoTemplate.aggregate(aggregation, "newyork", DBObject.class);
17601761

17611762
assertThat(groupResults.getMappedResults().size(), is(4));
17621763
}

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/AbstractGeoSpatialTests.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import org.springframework.test.context.ContextConfiguration;
4646
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
4747

48+
import com.mongodb.Mongo;
4849
import com.mongodb.MongoClient;
4950
import com.mongodb.WriteConcern;
5051

0 commit comments

Comments
 (0)