diff --git a/pom.xml b/pom.xml index 018dd48e27..fbdf717e6f 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 4.2.0-SNAPSHOT + 4.2.x-4394-SNAPSHOT pom Spring Data MongoDB diff --git a/spring-data-mongodb-benchmarks/pom.xml b/spring-data-mongodb-benchmarks/pom.xml index 2de4b6b635..87f9556f91 100644 --- a/spring-data-mongodb-benchmarks/pom.xml +++ b/spring-data-mongodb-benchmarks/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-mongodb-parent - 4.2.0-SNAPSHOT + 4.2.x-4394-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index 3bc1ab9df2..f7c87286a5 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 4.2.0-SNAPSHOT + 4.2.x-4394-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index 921254ca44..f5dfab9c17 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 4.2.0-SNAPSHOT + 4.2.x-4394-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index 062cfbb707..472951f079 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -214,6 +214,10 @@ public AsBuilder filter() { return Filter.filter(fieldReference); } + if(usesExpression()) { + return Filter.filter(expression); + } + Assert.state(values != null, "Values must not be null"); return Filter.filter(new ArrayList<>(values)); } @@ -649,6 +653,19 @@ public static AsBuilder filter(Field field) { return new FilterExpressionBuilder().filter(field); } + /** + * Set the {@link AggregationExpression} resolving to an arry to apply the {@code $filter} to. + * + * @param expression must not be {@literal null}. + * @return never {@literal null}. + * @since 4.2 + */ + public static AsBuilder filter(AggregationExpression expression) { + + Assert.notNull(expression, "Field must not be null"); + return new FilterExpressionBuilder().filter(expression); + } + /** * Set the {@literal values} to apply the {@code $filter} to. * @@ -681,7 +698,13 @@ private Document toFilter(ExposedFields exposedFields, AggregationOperationConte } private Object getMappedInput(AggregationOperationContext context) { - return input instanceof Field field ? context.getReference(field).toString() : input; + if(input instanceof Field field) { + return context.getReference(field).toString(); + } + if(input instanceof AggregationExpression expression) { + return expression.toDocument(context); + } + return input; } private Object getMappedCondition(AggregationOperationContext context) { @@ -715,6 +738,15 @@ public interface InputBuilder { * @return */ AsBuilder filter(Field field); + + /** + * Set the {@link AggregationExpression} resolving to an array to apply the {@code $filter} to. + * + * @param expression must not be {@literal null}. + * @return + * @since 4.2 + */ + AsBuilder filter(AggregationExpression expression); } /** @@ -797,6 +829,14 @@ public AsBuilder filter(Field field) { return this; } + @Override + public AsBuilder filter(AggregationExpression expression) { + + Assert.notNull(expression, "Expression must not be null"); + filter.input = expression; + return this; + } + @Override public ConditionBuilder as(String variableName) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/FilterExpressionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/FilterExpressionUnitTests.java index cf4bb0a140..313bd487e1 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/FilterExpressionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/FilterExpressionUnitTests.java @@ -117,6 +117,23 @@ void shouldConstructFilterExpressionCorrectlyWhenConditionContainsFieldReference assertThat($filter).isEqualTo(new Document(expected)); } + @Test // GH-4394 + void filterShouldAcceptExpression() { + + Document $filter = ArrayOperators.arrayOf(ObjectOperators.valueOf("data.metadata").toArray()).filter().as("item") + .by(ComparisonOperators.valueOf("item.price").greaterThan("field-1")).toDocument(Aggregation.DEFAULT_CONTEXT); + + Document expected = Document.parse(""" + { $filter : { + input: { $objectToArray: "$data.metadata" }, + as: "item", + cond: { $gt: [ "$$item.price", "$field-1" ] } + }} + """); + + assertThat($filter).isEqualTo(expected); + } + private Document extractFilterOperatorFromDocument(Document source) { List pipeline = DocumentTestUtils.getAsDBList(source, "pipeline");