Skip to content

Commit 4700b4d

Browse files
christophstroblmp911de
authored andcommitted
Use relaxed type mapping for aggregations by default.
This commit switches from a strict to a relaxed type mapping for aggregation executions. This allows users to add fields to the aggregation that might be part of the stored document but not necessarily of its java model representation. Instead of throwing an exception in those cases the relaxed type check will go on with the user provided field names. To restore the original behaviour use the strictMapping() option on AggregationOptions. Closes #3542 Original pull request: #3545.
1 parent 91f1dc1 commit 4700b4d

File tree

10 files changed

+415
-61
lines changed

10 files changed

+415
-61
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/AggregationUtil.java

+23-18
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.springframework.data.mongodb.core.aggregation.AggregationOperation;
2828
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
2929
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
30+
import org.springframework.data.mongodb.core.aggregation.AggregationOptions.DomainTypeMapping;
3031
import org.springframework.data.mongodb.core.aggregation.CountOperation;
3132
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
3233
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
@@ -36,6 +37,7 @@
3637
import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty;
3738
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
3839
import org.springframework.data.mongodb.core.query.Query;
40+
import org.springframework.data.util.Lazy;
3941
import org.springframework.lang.Nullable;
4042
import org.springframework.util.Assert;
4143
import org.springframework.util.ObjectUtils;
@@ -52,41 +54,44 @@ class AggregationUtil {
5254

5355
QueryMapper queryMapper;
5456
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext;
57+
Lazy<AggregationOperationContext> untypedMappingContext;
5558

5659
AggregationUtil(QueryMapper queryMapper,
5760
MappingContext<? extends MongoPersistentEntity<?>, MongoPersistentProperty> mappingContext) {
5861

5962
this.queryMapper = queryMapper;
6063
this.mappingContext = mappingContext;
64+
this.untypedMappingContext = Lazy
65+
.of(() -> new RelaxedTypeBasedAggregationOperationContext(Object.class, mappingContext, queryMapper));
6166
}
6267

63-
/**
64-
* Prepare the {@link AggregationOperationContext} for a given aggregation by either returning the context itself it
65-
* is not {@literal null}, create a {@link TypeBasedAggregationOperationContext} if the aggregation contains type
66-
* information (is a {@link TypedAggregation}) or use the {@link Aggregation#DEFAULT_CONTEXT}.
67-
*
68-
* @param aggregation must not be {@literal null}.
69-
* @param context can be {@literal null}.
70-
* @return the root {@link AggregationOperationContext} to use.
71-
*/
72-
AggregationOperationContext prepareAggregationContext(Aggregation aggregation,
73-
@Nullable AggregationOperationContext context) {
68+
AggregationOperationContext createAggregationContext(Aggregation aggregation, @Nullable Class<?> inputType) {
7469

75-
if (context != null) {
76-
return context;
70+
if (aggregation.getOptions().getDomainTypeMapping() == DomainTypeMapping.NONE) {
71+
return Aggregation.DEFAULT_CONTEXT;
7772
}
7873

7974
if (!(aggregation instanceof TypedAggregation)) {
80-
return new RelaxedTypeBasedAggregationOperationContext(Object.class, mappingContext, queryMapper);
81-
}
8275

83-
Class<?> inputType = ((TypedAggregation) aggregation).getInputType();
76+
if(inputType == null) {
77+
return untypedMappingContext.get();
78+
}
79+
80+
if (aggregation.getOptions().getDomainTypeMapping() == DomainTypeMapping.STRICT
81+
&& !aggregation.getPipeline().containsUnionWith()) {
82+
return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
83+
}
8484

85-
if (aggregation.getPipeline().containsUnionWith()) {
8685
return new RelaxedTypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
8786
}
8887

89-
return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
88+
inputType = ((TypedAggregation) aggregation).getInputType();
89+
if (aggregation.getOptions().getDomainTypeMapping() == DomainTypeMapping.STRICT
90+
&& !aggregation.getPipeline().containsUnionWith()) {
91+
return new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
92+
}
93+
94+
return new RelaxedTypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper);
9095
}
9196

9297
/**

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java

+10-7
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import org.springframework.data.mongodb.core.BulkOperations.BulkMode;
5656
import org.springframework.data.mongodb.core.DefaultBulkOperations.BulkOperationContext;
5757
import org.springframework.data.mongodb.core.EntityOperations.AdaptibleEntity;
58+
import org.springframework.data.mongodb.core.QueryOperations.AggregateContext;
5859
import org.springframework.data.mongodb.core.QueryOperations.CountContext;
5960
import org.springframework.data.mongodb.core.QueryOperations.DeleteContext;
6061
import org.springframework.data.mongodb.core.QueryOperations.DistinctQueryContext;
@@ -1988,7 +1989,7 @@ public <O> AggregationResults<O> aggregate(TypedAggregation<?> aggregation, Stri
19881989
public <O> AggregationResults<O> aggregate(Aggregation aggregation, Class<?> inputType, Class<O> outputType) {
19891990

19901991
return aggregate(aggregation, getCollectionName(inputType), outputType,
1991-
new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper));
1992+
queryOperations.createAggregationContext(aggregation, inputType).getAggregationOperationContext());
19921993
}
19931994

19941995
/* (non-Javadoc)
@@ -2095,9 +2096,12 @@ protected <O> AggregationResults<O> aggregate(Aggregation aggregation, String co
20952096
Assert.notNull(aggregation, "Aggregation pipeline must not be null!");
20962097
Assert.notNull(outputType, "Output type must not be null!");
20972098

2098-
AggregationOperationContext contextToUse = new AggregationUtil(queryMapper, mappingContext)
2099-
.prepareAggregationContext(aggregation, context);
2100-
return doAggregate(aggregation, collectionName, outputType, contextToUse);
2099+
return doAggregate(aggregation, collectionName, outputType, queryOperations.createAggregationContext(aggregation, context));
2100+
}
2101+
2102+
private <O> AggregationResults<O> doAggregate(Aggregation aggregation, String collectionName, Class<O> outputType,
2103+
AggregateContext context) {
2104+
return doAggregate(aggregation, collectionName, outputType, context.getAggregationOperationContext());
21012105
}
21022106

21032107
@SuppressWarnings("ConstantConditions")
@@ -2185,11 +2189,10 @@ protected <O> CloseableIterator<O> aggregateStream(Aggregation aggregation, Stri
21852189
Assert.notNull(outputType, "Output type must not be null!");
21862190
Assert.isTrue(!aggregation.getOptions().isExplain(), "Can't use explain option with streaming!");
21872191

2188-
AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext);
2189-
AggregationOperationContext rootContext = aggregationUtil.prepareAggregationContext(aggregation, context);
2192+
AggregateContext aggregateContext = queryOperations.createAggregationContext(aggregation, context);
21902193

21912194
AggregationOptions options = aggregation.getOptions();
2192-
List<Document> pipeline = aggregationUtil.createPipeline(aggregation, rootContext);
2195+
List<Document> pipeline = aggregateContext.getAggregationPipeline();
21932196

21942197
if (LOGGER.isDebugEnabled()) {
21952198
LOGGER.debug("Streaming aggregation: {} in collection {}", serializeToJsonSafely(pipeline), collectionName);

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/QueryOperations.java

+135-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,12 @@
3434
import org.springframework.data.mongodb.core.MappedDocument.MappedUpdate;
3535
import org.springframework.data.mongodb.core.aggregation.Aggregation;
3636
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
37+
import org.springframework.data.mongodb.core.aggregation.AggregationOptions;
38+
import org.springframework.data.mongodb.core.aggregation.AggregationPipeline;
3739
import org.springframework.data.mongodb.core.aggregation.AggregationUpdate;
3840
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
41+
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
42+
import org.springframework.data.mongodb.core.aggregation.TypedAggregation;
3943
import org.springframework.data.mongodb.core.convert.QueryMapper;
4044
import org.springframework.data.mongodb.core.convert.UpdateMapper;
4145
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
@@ -48,6 +52,7 @@
4852
import org.springframework.data.mongodb.core.query.UpdateDefinition.ArrayFilter;
4953
import org.springframework.data.mongodb.util.BsonUtils;
5054
import org.springframework.data.projection.ProjectionFactory;
55+
import org.springframework.data.util.Lazy;
5156
import org.springframework.lang.Nullable;
5257
import org.springframework.util.ClassUtils;
5358
import org.springframework.util.ObjectUtils;
@@ -194,6 +199,31 @@ DeleteContext deleteSingleContext(Query query) {
194199
return new DeleteContext(query, false);
195200
}
196201

202+
/**
203+
* Create a new {@link AggregateContext} for the given {@link Aggregation}.
204+
*
205+
* @param aggregation must not be {@literal null}.
206+
* @param inputType fallback mapping type in case of untyped aggregation. Can be {@literal null}.
207+
* @return new instance of {@link AggregateContext}.
208+
* @since 3.2
209+
*/
210+
AggregateContext createAggregationContext(Aggregation aggregation, @Nullable Class<?> inputType) {
211+
return new AggregateContext(aggregation, inputType);
212+
}
213+
214+
/**
215+
* Create a new {@link AggregateContext} for the given {@link Aggregation}.
216+
*
217+
* @param aggregation must not be {@literal null}.
218+
* @param aggregationOperationContext the {@link AggregationOperationContext} to use. Can be {@literal null}.
219+
* @return new instance of {@link AggregateContext}.
220+
* @since 3.2
221+
*/
222+
AggregateContext createAggregationContext(Aggregation aggregation,
223+
@Nullable AggregationOperationContext aggregationOperationContext) {
224+
return new AggregateContext(aggregation, aggregationOperationContext);
225+
}
226+
197227
/**
198228
* {@link QueryContext} encapsulates common tasks required to convert a {@link Query} into its MongoDB document
199229
* representation, mapping fieldnames, as well as determinging and applying {@link Collation collations}.
@@ -341,7 +371,8 @@ private DistinctQueryContext(@Nullable Object query, String fieldName) {
341371
}
342372

343373
@Override
344-
Document getMappedFields(@Nullable MongoPersistentEntity<?> entity, Class<?> targetType, ProjectionFactory projectionFactory) {
374+
Document getMappedFields(@Nullable MongoPersistentEntity<?> entity, Class<?> targetType,
375+
ProjectionFactory projectionFactory) {
345376
return getMappedFields(entity);
346377
}
347378

@@ -709,7 +740,8 @@ List<Document> getUpdatePipeline(@Nullable Class<?> domainType) {
709740

710741
Class<?> type = domainType != null ? domainType : Object.class;
711742

712-
AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext(type, mappingContext, queryMapper);
743+
AggregationOperationContext context = new RelaxedTypeBasedAggregationOperationContext(type, mappingContext,
744+
queryMapper);
713745
return aggregationUtil.createPipeline((AggregationUpdate) update, context);
714746
}
715747

@@ -759,4 +791,105 @@ boolean isMulti() {
759791
return multi;
760792
}
761793
}
794+
795+
/**
796+
* A context class that encapsulates common tasks required when running {@literal aggregations}.
797+
*
798+
* @since 3.2
799+
*/
800+
class AggregateContext {
801+
802+
private Aggregation aggregation;
803+
private Lazy<AggregationOperationContext> aggregationOperationContext;
804+
private Lazy<List<Document>> pipeline;
805+
private @Nullable Class<?> inputType;
806+
807+
/**
808+
* Creates new instance of {@link AggregateContext} extracting the input type from either the
809+
* {@link org.springframework.data.mongodb.core.aggregation.Aggregation} in case of a {@link TypedAggregation} or
810+
* the given {@literal aggregationOperationContext} if present. <br />
811+
* Creates a new {@link AggregationOperationContext} if none given, based on the {@link Aggregation} input type and
812+
* the desired {@link AggregationOptions#getDomainTypeMapping() domain type mapping}. <br />
813+
* Pipelines are mapped on first access of {@link #getAggregationPipeline()} and cached for reuse.
814+
*
815+
* @param aggregation the source aggregation.
816+
* @param aggregationOperationContext can be {@literal null}.
817+
*/
818+
AggregateContext(Aggregation aggregation, @Nullable AggregationOperationContext aggregationOperationContext) {
819+
820+
this.aggregation = aggregation;
821+
if (aggregation instanceof TypedAggregation) {
822+
this.inputType = ((TypedAggregation) aggregation).getInputType();
823+
} else if (aggregationOperationContext instanceof TypeBasedAggregationOperationContext) {
824+
this.inputType = ((TypeBasedAggregationOperationContext) aggregationOperationContext).getType();
825+
}
826+
this.aggregationOperationContext = Lazy.of(() -> aggregationOperationContext != null ? aggregationOperationContext
827+
: aggregationUtil.createAggregationContext(aggregation, getInputType()));
828+
this.pipeline = Lazy.of(() -> aggregationUtil.createPipeline(this.aggregation, getAggregationOperationContext()));
829+
}
830+
831+
/**
832+
* Creates new instance of {@link AggregateContext} extracting the input type from either the
833+
* {@link org.springframework.data.mongodb.core.aggregation.Aggregation} in case of a {@link TypedAggregation} or
834+
* the given {@literal aggregationOperationContext} if present. <br />
835+
* Creates a new {@link AggregationOperationContext} based on the {@link Aggregation} input type and the desired
836+
* {@link AggregationOptions#getDomainTypeMapping() domain type mapping}. <br />
837+
* Pipelines are mapped on first access of {@link #getAggregationPipeline()} and cached for reuse.
838+
*
839+
* @param aggregation the source aggregation.
840+
* @param inputType can be {@literal null}.
841+
*/
842+
AggregateContext(Aggregation aggregation, @Nullable Class<?> inputType) {
843+
844+
this.aggregation = aggregation;
845+
846+
if (aggregation instanceof TypedAggregation) {
847+
this.inputType = ((TypedAggregation) aggregation).getInputType();
848+
} else {
849+
this.inputType = inputType;
850+
}
851+
852+
this.aggregationOperationContext = Lazy
853+
.of(() -> aggregationUtil.createAggregationContext(aggregation, getInputType()));
854+
this.pipeline = Lazy.of(() -> aggregationUtil.createPipeline(this.aggregation, getAggregationOperationContext()));
855+
}
856+
857+
/**
858+
* Obtain the already mapped pipeline.
859+
*
860+
* @return never {@literal null}.
861+
*/
862+
List<Document> getAggregationPipeline() {
863+
return pipeline.get();
864+
}
865+
866+
/**
867+
* @return {@literal true} if the last aggregation stage is either {@literal $out} or {@literal $merge}.
868+
* @see AggregationPipeline#isOutOrMerge()
869+
*/
870+
boolean isOutOrMerge() {
871+
return aggregation.getPipeline().isOutOrMerge();
872+
}
873+
874+
/**
875+
* Obtain the {@link AggregationOperationContext} used for mapping the pipeline.
876+
*
877+
* @return never {@literal null}.
878+
*/
879+
AggregationOperationContext getAggregationOperationContext() {
880+
return aggregationOperationContext.get();
881+
}
882+
883+
/**
884+
* @return the input type to map the pipeline against. Can be {@literal null}.
885+
*/
886+
@Nullable
887+
Class<?> getInputType() {
888+
return inputType;
889+
}
890+
891+
Document getAggregationCommand(String collectionName) {
892+
return aggregationUtil.createCommand(collectionName, aggregation, getAggregationOperationContext());
893+
}
894+
}
762895
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java

+10-24
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import static org.springframework.data.mongodb.core.query.SerializationUtils.*;
1919

20+
import org.springframework.data.mongodb.core.QueryOperations.AggregateContext;
2021
import org.springframework.data.mongodb.core.aggregation.RelaxedTypeBasedAggregationOperationContext;
2122
import reactor.core.publisher.Flux;
2223
import reactor.core.publisher.Mono;
@@ -946,9 +947,7 @@ public <O> Flux<O> aggregate(TypedAggregation<?> aggregation, String inputCollec
946947

947948
Assert.notNull(aggregation, "Aggregation pipeline must not be null!");
948949

949-
AggregationOperationContext context = new TypeBasedAggregationOperationContext(aggregation.getInputType(),
950-
mappingContext, queryMapper);
951-
return aggregate(aggregation, inputCollectionName, outputType, context);
950+
return doAggregate(aggregation, inputCollectionName, aggregation.getInputType(), outputType);
952951
}
953952

954953
/*
@@ -966,9 +965,7 @@ public <O> Flux<O> aggregate(TypedAggregation<?> aggregation, Class<O> outputTyp
966965
*/
967966
@Override
968967
public <O> Flux<O> aggregate(Aggregation aggregation, Class<?> inputType, Class<O> outputType) {
969-
970-
return aggregate(aggregation, getCollectionName(inputType), outputType,
971-
new TypeBasedAggregationOperationContext(inputType, mappingContext, queryMapper));
968+
return doAggregate(aggregation, getCollectionName(inputType), inputType, outputType);
972969
}
973970

974971
/*
@@ -977,40 +974,29 @@ public <O> Flux<O> aggregate(Aggregation aggregation, Class<?> inputType, Class<
977974
*/
978975
@Override
979976
public <O> Flux<O> aggregate(Aggregation aggregation, String collectionName, Class<O> outputType) {
980-
return aggregate(aggregation, collectionName, outputType, null);
977+
return doAggregate(aggregation, collectionName, null, outputType);
981978
}
982979

983-
/**
984-
* @param aggregation must not be {@literal null}.
985-
* @param collectionName must not be {@literal null}.
986-
* @param outputType must not be {@literal null}.
987-
* @param context can be {@literal null} and will be defaulted to {@link Aggregation#DEFAULT_CONTEXT}.
988-
* @return never {@literal null}.
989-
*/
990-
protected <O> Flux<O> aggregate(Aggregation aggregation, String collectionName, Class<O> outputType,
991-
@Nullable AggregationOperationContext context) {
980+
protected <O> Flux<O> doAggregate(Aggregation aggregation, String collectionName, @Nullable Class<?> inputType, Class<O> outputType) {
992981

993982
Assert.notNull(aggregation, "Aggregation pipeline must not be null!");
994983
Assert.hasText(collectionName, "Collection name must not be null or empty!");
995984
Assert.notNull(outputType, "Output type must not be null!");
996985

997-
AggregationUtil aggregationUtil = new AggregationUtil(queryMapper, mappingContext);
998-
AggregationOperationContext rootContext = aggregationUtil.prepareAggregationContext(aggregation, context);
999-
1000986
AggregationOptions options = aggregation.getOptions();
1001-
List<Document> pipeline = aggregationUtil.createPipeline(aggregation, rootContext);
1002-
1003987
Assert.isTrue(!options.isExplain(), "Cannot use explain option with streaming!");
1004988

989+
AggregateContext ctx = queryOperations.createAggregationContext(aggregation, inputType);
990+
1005991
if (LOGGER.isDebugEnabled()) {
1006-
LOGGER.debug("Streaming aggregation: {} in collection {}", serializeToJsonSafely(pipeline), collectionName);
992+
LOGGER.debug("Streaming aggregation: {} in collection {}", serializeToJsonSafely(ctx.getAggregationPipeline()), collectionName);
1007993
}
1008994

1009995
ReadDocumentCallback<O> readCallback = new ReadDocumentCallback<>(mongoConverter, outputType, collectionName);
1010996
return execute(collectionName,
1011-
collection -> aggregateAndMap(collection, pipeline, aggregation.getPipeline().isOutOrMerge(), options,
997+
collection -> aggregateAndMap(collection, ctx.getAggregationPipeline(), ctx.isOutOrMerge(), options,
1012998
readCallback,
1013-
aggregation instanceof TypedAggregation ? ((TypedAggregation<?>) aggregation).getInputType() : null));
999+
ctx.getInputType()));
10141000
}
10151001

10161002
private <O> Flux<O> aggregateAndMap(MongoCollection<Document> collection, List<Document> pipeline,

0 commit comments

Comments
 (0)