Skip to content

Commit c574e5c

Browse files
christophstroblmp911de
authored andcommitted
Add support for $covariancePop and $covarianceSamp aggregation expressions.
Closes: #3712 Original pull request: #3740.
1 parent f9f4c46 commit c574e5c

File tree

6 files changed

+400
-0
lines changed

6 files changed

+400
-0
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java

+177
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,63 @@ public StdDevSamp stdDevSamp() {
142142
return usesFieldRef() ? StdDevSamp.stdDevSampOf(fieldReference) : StdDevSamp.stdDevSampOf(expression);
143143
}
144144

145+
/**
146+
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
147+
* field to calculate the population covariance of the two.
148+
*
149+
* @param fieldReference must not be {@literal null}.
150+
* @return new instance of {@link CovariancePop}.
151+
* @since 3.3
152+
*/
153+
public CovariancePop covariancePop(String fieldReference) {
154+
return covariancePop().and(fieldReference);
155+
}
156+
157+
/**
158+
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
159+
* {@link AggregationExpression expression} to calculate the population covariance of the two.
160+
*
161+
* @param expression must not be {@literal null}.
162+
* @return new instance of {@link CovariancePop}.
163+
* @since 3.3
164+
*/
165+
public CovariancePop covariancePop(AggregationExpression expression) {
166+
return covariancePop().and(expression);
167+
}
168+
169+
private CovariancePop covariancePop() {
170+
return usesFieldRef() ? CovariancePop.covariancePopOf(fieldReference) : CovariancePop.covariancePopOf(expression);
171+
}
172+
173+
/**
174+
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
175+
* field to calculate the sample covariance of the two.
176+
*
177+
* @param fieldReference must not be {@literal null}.
178+
* @return new instance of {@link CovariancePop}.
179+
* @since 3.3
180+
*/
181+
public CovarianceSamp covarianceSamp(String fieldReference) {
182+
return covarianceSamp().and(fieldReference);
183+
}
184+
185+
/**
186+
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
187+
* {@link AggregationExpression expression} to calculate the sample covariance of the two.
188+
*
189+
* @param expression must not be {@literal null}.
190+
* @return new instance of {@link CovariancePop}.
191+
* @since 3.3
192+
*/
193+
public CovarianceSamp covarianceSamp(AggregationExpression expression) {
194+
return covarianceSamp().and(expression);
195+
}
196+
197+
private CovarianceSamp covarianceSamp() {
198+
return usesFieldRef() ? CovarianceSamp.covarianceSampOf(fieldReference)
199+
: CovarianceSamp.covarianceSampOf(expression);
200+
}
201+
145202
private boolean usesFieldRef() {
146203
return fieldReference != null;
147204
}
@@ -658,4 +715,124 @@ public Document toDocument(Object value, AggregationOperationContext context) {
658715
return super.toDocument(value, context);
659716
}
660717
}
718+
719+
/**
720+
* {@link AggregationExpression} for {@code $covariancePop}.
721+
*
722+
* @author Christoph Strobl
723+
* @since 3.3
724+
*/
725+
public static class CovariancePop extends AbstractAggregationExpression {
726+
727+
private CovariancePop(Object value) {
728+
super(value);
729+
}
730+
731+
/**
732+
* Creates new {@link CovariancePop}.
733+
*
734+
* @param fieldReference must not be {@literal null}.
735+
* @return new instance of {@link CovariancePop}.
736+
*/
737+
public static CovariancePop covariancePopOf(String fieldReference) {
738+
739+
Assert.notNull(fieldReference, "FieldReference must not be null!");
740+
return new CovariancePop(asFields(fieldReference));
741+
}
742+
743+
/**
744+
* Creates new {@link CovariancePop}.
745+
*
746+
* @param expression must not be {@literal null}.
747+
* @return new instance of {@link CovariancePop}.
748+
*/
749+
public static CovariancePop covariancePopOf(AggregationExpression expression) {
750+
return new CovariancePop(Collections.singletonList(expression));
751+
}
752+
753+
/**
754+
* Creates new {@link CovariancePop} with all previously added arguments appending the given one.
755+
*
756+
* @param fieldReference must not be {@literal null}.
757+
* @return new instance of {@link CovariancePop}.
758+
*/
759+
public CovariancePop and(String fieldReference) {
760+
return new CovariancePop(append(asFields(fieldReference)));
761+
}
762+
763+
/**
764+
* Creates new {@link CovariancePop} with all previously added arguments appending the given one.
765+
*
766+
* @param expression must not be {@literal null}.
767+
* @return new instance of {@link CovariancePop}.
768+
*/
769+
public CovariancePop and(AggregationExpression expression) {
770+
return new CovariancePop(append(expression));
771+
}
772+
773+
@Override
774+
protected String getMongoMethod() {
775+
return "$covariancePop";
776+
}
777+
}
778+
779+
/**
780+
* {@link AggregationExpression} for {@code $covarianceSamp}.
781+
*
782+
* @author Christoph Strobl
783+
* @since 3.3
784+
*/
785+
public static class CovarianceSamp extends AbstractAggregationExpression {
786+
787+
private CovarianceSamp(Object value) {
788+
super(value);
789+
}
790+
791+
/**
792+
* Creates new {@link CovarianceSamp}.
793+
*
794+
* @param fieldReference must not be {@literal null}.
795+
* @return new instance of {@link CovarianceSamp}.
796+
*/
797+
public static CovarianceSamp covarianceSampOf(String fieldReference) {
798+
799+
Assert.notNull(fieldReference, "FieldReference must not be null!");
800+
return new CovarianceSamp(asFields(fieldReference));
801+
}
802+
803+
/**
804+
* Creates new {@link CovarianceSamp}.
805+
*
806+
* @param expression must not be {@literal null}.
807+
* @return new instance of {@link CovarianceSamp}.
808+
*/
809+
public static CovarianceSamp covarianceSampOf(AggregationExpression expression) {
810+
return new CovarianceSamp(Collections.singletonList(expression));
811+
}
812+
813+
/**
814+
* Creates new {@link CovarianceSamp} with all previously added arguments appending the given one.
815+
*
816+
* @param fieldReference must not be {@literal null}.
817+
* @return new instance of {@link CovarianceSamp}.
818+
*/
819+
public CovarianceSamp and(String fieldReference) {
820+
return new CovarianceSamp(append(asFields(fieldReference)));
821+
}
822+
823+
/**
824+
* Creates new {@link CovarianceSamp} with all previously added arguments appending the given one.
825+
*
826+
* @param expression must not be {@literal null}.
827+
* @return new instance of {@link CovarianceSamp}.
828+
*/
829+
public CovarianceSamp and(AggregationExpression expression) {
830+
return new CovarianceSamp(append(expression));
831+
}
832+
833+
@Override
834+
protected String getMongoMethod() {
835+
return "$covarianceSamp";
836+
}
837+
}
661838
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java

+59
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import java.util.List;
2020

2121
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Avg;
22+
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovariancePop;
23+
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovarianceSamp;
2224
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Max;
2325
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Min;
2426
import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevPop;
@@ -511,6 +513,63 @@ public StdDevSamp stdDevSamp() {
511513
: AccumulatorOperators.StdDevSamp.stdDevSampOf(expression);
512514
}
513515

516+
/**
517+
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
518+
* field to calculate the population covariance of the two.
519+
*
520+
* @param fieldReference must not be {@literal null}.
521+
* @return new instance of {@link CovariancePop}.
522+
* @since 3.3
523+
*/
524+
public CovariancePop covariancePop(String fieldReference) {
525+
return covariancePop().and(fieldReference);
526+
}
527+
528+
/**
529+
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
530+
* {@link AggregationExpression expression} to calculate the population covariance of the two.
531+
*
532+
* @param expression must not be {@literal null}.
533+
* @return new instance of {@link CovariancePop}.
534+
* @since 3.3
535+
*/
536+
public CovariancePop covariancePop(AggregationExpression expression) {
537+
return covariancePop().and(expression);
538+
}
539+
540+
private CovariancePop covariancePop() {
541+
return usesFieldRef() ? CovariancePop.covariancePopOf(fieldReference) : CovariancePop.covariancePopOf(expression);
542+
}
543+
544+
/**
545+
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given
546+
* field to calculate the sample covariance of the two.
547+
*
548+
* @param fieldReference must not be {@literal null}.
549+
* @return new instance of {@link CovariancePop}.
550+
* @since 3.3
551+
*/
552+
public CovarianceSamp covarianceSamp(String fieldReference) {
553+
return covarianceSamp().and(fieldReference);
554+
}
555+
556+
/**
557+
* Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given
558+
* {@link AggregationExpression expression} to calculate the sample covariance of the two.
559+
*
560+
* @param expression must not be {@literal null}.
561+
* @return new instance of {@link CovariancePop}.
562+
* @since 3.3
563+
*/
564+
public CovarianceSamp covarianceSamp(AggregationExpression expression) {
565+
return covarianceSamp().and(expression);
566+
}
567+
568+
private CovarianceSamp covarianceSamp() {
569+
return usesFieldRef() ? CovarianceSamp.covarianceSampOf(fieldReference)
570+
: CovarianceSamp.covarianceSampOf(expression);
571+
}
572+
514573
/**
515574
* Creates new {@link AggregationExpression} that rounds a number to a whole integer or to a specified decimal
516575
* place.

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java

+2
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ public class MethodReferenceNode extends ExpressionNode {
170170
map.put("addToSet", singleArgRef().forOperator("$addToSet"));
171171
map.put("stdDevPop", arrayArgRef().forOperator("$stdDevPop"));
172172
map.put("stdDevSamp", arrayArgRef().forOperator("$stdDevSamp"));
173+
map.put("covariancePop", arrayArgRef().forOperator("$covariancePop"));
174+
map.put("covarianceSamp", arrayArgRef().forOperator("$covarianceSamp"));
173175

174176
// TYPE OPERATORS
175177
map.put("type", singleArgRef().forOperator("$type"));
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/*
2+
* Copyright 2021 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.mongodb.util.aggregation;
17+
18+
import org.bson.Document;
19+
import org.springframework.data.mongodb.core.aggregation.Aggregation;
20+
import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext;
21+
import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference;
22+
import org.springframework.data.mongodb.core.aggregation.Field;
23+
import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext;
24+
import org.springframework.data.mongodb.core.convert.MappingMongoConverter;
25+
import org.springframework.data.mongodb.core.convert.MongoConverter;
26+
import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver;
27+
import org.springframework.data.mongodb.core.convert.QueryMapper;
28+
import org.springframework.data.mongodb.core.mapping.MongoMappingContext;
29+
import org.springframework.lang.Nullable;
30+
31+
/**
32+
* @author Christoph Strobl
33+
*/
34+
public class TestAggregationContext implements AggregationOperationContext {
35+
36+
private final AggregationOperationContext delegate;
37+
38+
private TestAggregationContext(AggregationOperationContext delegate) {
39+
this.delegate = delegate;
40+
}
41+
42+
public static AggregationOperationContext contextFor(@Nullable Class<?> type) {
43+
44+
MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE,
45+
new MongoMappingContext());
46+
mongoConverter.afterPropertiesSet();
47+
48+
return contextFor(type, mongoConverter);
49+
}
50+
51+
public static AggregationOperationContext contextFor(@Nullable Class<?> type, MongoConverter mongoConverter) {
52+
53+
if (type == null) {
54+
return Aggregation.DEFAULT_CONTEXT;
55+
}
56+
57+
return new TestAggregationContext(new TypeBasedAggregationOperationContext(type, mongoConverter.getMappingContext(),
58+
new QueryMapper(mongoConverter)).continueOnMissingFieldReference());
59+
}
60+
61+
@Override
62+
public Document getMappedObject(Document document, @Nullable Class<?> type) {
63+
return delegate.getMappedObject(document, type);
64+
}
65+
66+
@Override
67+
public FieldReference getReference(Field field) {
68+
return delegate.getReference(field);
69+
}
70+
71+
@Override
72+
public FieldReference getReference(String name) {
73+
return delegate.getReference(name);
74+
}
75+
}

0 commit comments

Comments
 (0)