diff --git a/pom.xml b/pom.xml index b688f3ee50..c6edd9469b 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 3.3.0-SNAPSHOT + 3.3.0-GH-3712-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-benchmarks/pom.xml b/spring-data-mongodb-benchmarks/pom.xml index 0033bd11d5..fb2abfc9de 100644 --- a/spring-data-mongodb-benchmarks/pom.xml +++ b/spring-data-mongodb-benchmarks/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-mongodb-parent - 3.3.0-SNAPSHOT + 3.3.0-GH-3712-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index f62c8dc7f4..b5f22fd55d 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -14,7 +14,7 @@ org.springframework.data spring-data-mongodb-parent - 3.3.0-SNAPSHOT + 3.3.0-GH-3712-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 1f157e75bc..23cacf26cc 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -11,7 +11,7 @@ org.springframework.data spring-data-mongodb-parent - 3.3.0-SNAPSHOT + 3.3.0-GH-3712-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java index 6698b932f8..1ea1af9731 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperators.java @@ -142,6 +142,63 @@ public StdDevSamp stdDevSamp() { return usesFieldRef() ? StdDevSamp.stdDevSampOf(fieldReference) : StdDevSamp.stdDevSampOf(expression); } + /** + * Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given + * field to calculate the population covariance of the two. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + * @since 3.3 + */ + public CovariancePop covariancePop(String fieldReference) { + return covariancePop().and(fieldReference); + } + + /** + * Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given + * {@link AggregationExpression expression} to calculate the population covariance of the two. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + * @since 3.3 + */ + public CovariancePop covariancePop(AggregationExpression expression) { + return covariancePop().and(expression); + } + + private CovariancePop covariancePop() { + return usesFieldRef() ? CovariancePop.covariancePopOf(fieldReference) : CovariancePop.covariancePopOf(expression); + } + + /** + * Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given + * field to calculate the sample covariance of the two. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + * @since 3.3 + */ + public CovarianceSamp covarianceSamp(String fieldReference) { + return covarianceSamp().and(fieldReference); + } + + /** + * Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given + * {@link AggregationExpression expression} to calculate the sample covariance of the two. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + * @since 3.3 + */ + public CovarianceSamp covarianceSamp(AggregationExpression expression) { + return covarianceSamp().and(expression); + } + + private CovarianceSamp covarianceSamp() { + return usesFieldRef() ? CovarianceSamp.covarianceSampOf(fieldReference) + : CovarianceSamp.covarianceSampOf(expression); + } + private boolean usesFieldRef() { return fieldReference != null; } @@ -658,4 +715,124 @@ public Document toDocument(Object value, AggregationOperationContext context) { return super.toDocument(value, context); } } + + /** + * {@link AggregationExpression} for {@code $covariancePop}. + * + * @author Christoph Strobl + * @since 3.3 + */ + public static class CovariancePop extends AbstractAggregationExpression { + + private CovariancePop(Object value) { + super(value); + } + + /** + * Creates new {@link CovariancePop}. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + */ + public static CovariancePop covariancePopOf(String fieldReference) { + + Assert.notNull(fieldReference, "FieldReference must not be null!"); + return new CovariancePop(asFields(fieldReference)); + } + + /** + * Creates new {@link CovariancePop}. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + */ + public static CovariancePop covariancePopOf(AggregationExpression expression) { + return new CovariancePop(Collections.singletonList(expression)); + } + + /** + * Creates new {@link CovariancePop} with all previously added arguments appending the given one. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + */ + public CovariancePop and(String fieldReference) { + return new CovariancePop(append(asFields(fieldReference))); + } + + /** + * Creates new {@link CovariancePop} with all previously added arguments appending the given one. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + */ + public CovariancePop and(AggregationExpression expression) { + return new CovariancePop(append(expression)); + } + + @Override + protected String getMongoMethod() { + return "$covariancePop"; + } + } + + /** + * {@link AggregationExpression} for {@code $covarianceSamp}. + * + * @author Christoph Strobl + * @since 3.3 + */ + public static class CovarianceSamp extends AbstractAggregationExpression { + + private CovarianceSamp(Object value) { + super(value); + } + + /** + * Creates new {@link CovarianceSamp}. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link CovarianceSamp}. + */ + public static CovarianceSamp covarianceSampOf(String fieldReference) { + + Assert.notNull(fieldReference, "FieldReference must not be null!"); + return new CovarianceSamp(asFields(fieldReference)); + } + + /** + * Creates new {@link CovarianceSamp}. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link CovarianceSamp}. + */ + public static CovarianceSamp covarianceSampOf(AggregationExpression expression) { + return new CovarianceSamp(Collections.singletonList(expression)); + } + + /** + * Creates new {@link CovarianceSamp} with all previously added arguments appending the given one. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link CovarianceSamp}. + */ + public CovarianceSamp and(String fieldReference) { + return new CovarianceSamp(append(asFields(fieldReference))); + } + + /** + * Creates new {@link CovarianceSamp} with all previously added arguments appending the given one. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link CovarianceSamp}. + */ + public CovarianceSamp and(AggregationExpression expression) { + return new CovarianceSamp(append(expression)); + } + + @Override + protected String getMongoMethod() { + return "$covarianceSamp"; + } + } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java index 6053f3ae1b..b27e54d298 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArithmeticOperators.java @@ -19,6 +19,8 @@ import java.util.List; import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Avg; +import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovariancePop; +import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.CovarianceSamp; import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Max; import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.Min; import org.springframework.data.mongodb.core.aggregation.AccumulatorOperators.StdDevPop; @@ -511,6 +513,63 @@ public StdDevSamp stdDevSamp() { : AccumulatorOperators.StdDevSamp.stdDevSampOf(expression); } + /** + * Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given + * field to calculate the population covariance of the two. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + * @since 3.3 + */ + public CovariancePop covariancePop(String fieldReference) { + return covariancePop().and(fieldReference); + } + + /** + * Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given + * {@link AggregationExpression expression} to calculate the population covariance of the two. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + * @since 3.3 + */ + public CovariancePop covariancePop(AggregationExpression expression) { + return covariancePop().and(expression); + } + + private CovariancePop covariancePop() { + return usesFieldRef() ? CovariancePop.covariancePopOf(fieldReference) : CovariancePop.covariancePopOf(expression); + } + + /** + * Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the value of the given + * field to calculate the sample covariance of the two. + * + * @param fieldReference must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + * @since 3.3 + */ + public CovarianceSamp covarianceSamp(String fieldReference) { + return covarianceSamp().and(fieldReference); + } + + /** + * Creates new {@link AggregationExpression} that uses the previous input (field/expression) and the result of the given + * {@link AggregationExpression expression} to calculate the sample covariance of the two. + * + * @param expression must not be {@literal null}. + * @return new instance of {@link CovariancePop}. + * @since 3.3 + */ + public CovarianceSamp covarianceSamp(AggregationExpression expression) { + return covarianceSamp().and(expression); + } + + private CovarianceSamp covarianceSamp() { + return usesFieldRef() ? CovarianceSamp.covarianceSampOf(fieldReference) + : CovarianceSamp.covarianceSampOf(expression); + } + /** * Creates new {@link AggregationExpression} that rounds a number to a whole integer or to a specified decimal * place. diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java index 5a2c48bc20..c858926446 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/spel/MethodReferenceNode.java @@ -170,6 +170,8 @@ public class MethodReferenceNode extends ExpressionNode { map.put("addToSet", singleArgRef().forOperator("$addToSet")); map.put("stdDevPop", arrayArgRef().forOperator("$stdDevPop")); map.put("stdDevSamp", arrayArgRef().forOperator("$stdDevSamp")); + map.put("covariancePop", arrayArgRef().forOperator("$covariancePop")); + map.put("covarianceSamp", arrayArgRef().forOperator("$covarianceSamp")); // TYPE OPERATORS map.put("type", singleArgRef().forOperator("$type")); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/aggregation/TestAggregationContext.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/aggregation/TestAggregationContext.java new file mode 100644 index 0000000000..4f16072e43 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/util/aggregation/TestAggregationContext.java @@ -0,0 +1,75 @@ +/* + * Copyright 2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.util.aggregation; + +import org.bson.Document; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperationContext; +import org.springframework.data.mongodb.core.aggregation.ExposedFields.FieldReference; +import org.springframework.data.mongodb.core.aggregation.Field; +import org.springframework.data.mongodb.core.aggregation.TypeBasedAggregationOperationContext; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.MongoConverter; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.convert.QueryMapper; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.lang.Nullable; + +/** + * @author Christoph Strobl + */ +public class TestAggregationContext implements AggregationOperationContext { + + private final AggregationOperationContext delegate; + + private TestAggregationContext(AggregationOperationContext delegate) { + this.delegate = delegate; + } + + public static AggregationOperationContext contextFor(@Nullable Class type) { + + MappingMongoConverter mongoConverter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, + new MongoMappingContext()); + mongoConverter.afterPropertiesSet(); + + return contextFor(type, mongoConverter); + } + + public static AggregationOperationContext contextFor(@Nullable Class type, MongoConverter mongoConverter) { + + if (type == null) { + return Aggregation.DEFAULT_CONTEXT; + } + + return new TestAggregationContext(new TypeBasedAggregationOperationContext(type, mongoConverter.getMappingContext(), + new QueryMapper(mongoConverter)).continueOnMissingFieldReference()); + } + + @Override + public Document getMappedObject(Document document, @Nullable Class type) { + return delegate.getMappedObject(document, type); + } + + @Override + public FieldReference getReference(Field field) { + return delegate.getReference(field); + } + + @Override + public FieldReference getReference(String name) { + return delegate.getReference(name); + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java new file mode 100644 index 0000000000..977183c448 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/AccumulatorOperatorsUnitTests.java @@ -0,0 +1,77 @@ +/* + * Copyright 2021 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.core.aggregation; + +import static org.assertj.core.api.Assertions.*; + +import java.util.Arrays; +import java.util.Date; + +import org.bson.Document; +import org.junit.jupiter.api.Test; +import org.springframework.data.mongodb.core.aggregation.DateOperators.Year; +import org.springframework.data.mongodb.core.mapping.Field; +import org.springframework.data.mongodb.util.aggregation.TestAggregationContext; + +/** + * @author Christoph Strobl + */ +class AccumulatorOperatorsUnitTests { + + @Test // GH-3712 + void rendersCovariancePopWithFieldReference() { + + assertThat(AccumulatorOperators.valueOf("balance").covariancePop("midichlorianCount") + .toDocument(TestAggregationContext.contextFor(Jedi.class))) + .isEqualTo(new Document("$covariancePop", Arrays.asList("$balance", "$force"))); + } + + @Test // GH-3712 + void rendersCovariancePopWithExpression() { + + assertThat(AccumulatorOperators.valueOf(Year.yearOf("birthdate")).covariancePop("midichlorianCount") + .toDocument(TestAggregationContext.contextFor(Jedi.class))) + .isEqualTo(new Document("$covariancePop", Arrays.asList(new Document("$year", "$birthdate"), "$force"))); + } + + @Test // GH-3712 + void rendersCovarianceSampWithFieldReference() { + + assertThat(AccumulatorOperators.valueOf("balance").covarianceSamp("midichlorianCount") + .toDocument(TestAggregationContext.contextFor(Jedi.class))) + .isEqualTo(new Document("$covarianceSamp", Arrays.asList("$balance", "$force"))); + } + + @Test // GH-3712 + void rendersCovarianceSampWithExpression() { + + assertThat(AccumulatorOperators.valueOf(Year.yearOf("birthdate")).covarianceSamp("midichlorianCount") + .toDocument(TestAggregationContext.contextFor(Jedi.class))) + .isEqualTo(new Document("$covarianceSamp", Arrays.asList(new Document("$year", "$birthdate"), "$force"))); + } + + static class Jedi { + + String name; + + Date birthdate; + + @Field("force") + Integer midichlorianCount; + + Integer balance; + } +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java index b67beed126..c4b945ab94 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/SpelExpressionTransformerUnitTests.java @@ -946,6 +946,16 @@ public void shouldRenderRoundWithPlace() { assertThat(transform("round(field, 2)")).isEqualTo(Document.parse("{ \"$round\" : [\"$field\", 2]}")); } + @Test // GH-3712 + void shouldRenderCovariancePop() { + assertThat(transform("covariancePop(field1, field2)")).isEqualTo(Document.parse("{ \"$covariancePop\" : [\"$field1\", \"$field2\"]}")); + } + + @Test // GH-3712 + void shouldRenderCovarianceSamp() { + assertThat(transform("covarianceSamp(field1, field2)")).isEqualTo(Document.parse("{ \"$covarianceSamp\" : [\"$field1\", \"$field2\"]}")); + } + private Object transform(String expression, Object... params) { Object result = transformer.transform(expression, Aggregation.DEFAULT_CONTEXT, params); return result == null ? null : (!(result instanceof org.bson.Document) ? result.toString() : result);