From 16ebc51f614a765ede449ed304e51cdede70a15b Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 23 Apr 2025 08:54:59 +0200 Subject: [PATCH 01/10] Prepare issue branch. --- pom.xml | 4 ++-- spring-data-mongodb-distribution/pom.xml | 2 +- spring-data-mongodb/pom.xml | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pom.xml b/pom.xml index ffbe2b8e19..9eb477077f 100644 --- a/pom.xml +++ b/pom.xml @@ -5,7 +5,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.0-SEARCH-SNAPSHOT pom Spring Data MongoDB @@ -26,7 +26,7 @@ multi spring-data-mongodb - 4.0.0-SNAPSHOT + 4.0.0-SEARCH-RESULT-SNAPSHOT 5.4.0 1.19 diff --git a/spring-data-mongodb-distribution/pom.xml b/spring-data-mongodb-distribution/pom.xml index fc88571622..6e0e6b99f4 100644 --- a/spring-data-mongodb-distribution/pom.xml +++ b/spring-data-mongodb-distribution/pom.xml @@ -15,7 +15,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.0-SEARCH-SNAPSHOT ../pom.xml diff --git a/spring-data-mongodb/pom.xml b/spring-data-mongodb/pom.xml index ad3c1338ec..3e7ccd09e8 100644 --- a/spring-data-mongodb/pom.xml +++ b/spring-data-mongodb/pom.xml @@ -13,7 +13,7 @@ org.springframework.data spring-data-mongodb-parent - 5.0.0-SNAPSHOT + 5.0.0-SEARCH-SNAPSHOT ../pom.xml From c80f735c77b31e1326f95914d61cc3c95638053a Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Thu, 10 Apr 2025 11:49:05 +0200 Subject: [PATCH 02/10] Explore returning Search Results. --- .../data/mongodb/core/MongoTemplate.java | 8 +- .../mongodb/core/ReactiveMongoTemplate.java | 2 +- .../core/aggregation/AggregationResults.java | 1 + .../mongodb/core/convert/GeoConverters.java | 4 +- .../data/mongodb/core/geo/Sphere.java | 2 +- .../data/mongodb/core/query/NearQuery.java | 11 +- .../data/mongodb/repository/VectorSearch.java | 118 +++++++ .../repository/aot/AotQueryCreator.java | 17 + .../repository/query/AbstractMongoQuery.java | 4 +- .../query/ConvertingParameterAccessor.java | 20 +- .../query/MongoParameterAccessor.java | 1 + .../repository/query/MongoParameters.java | 48 ++- .../MongoParametersParameterAccessor.java | 22 +- .../repository/query/MongoQueryCreator.java | 29 +- .../repository/query/MongoQueryExecution.java | 145 +++++++++ .../repository/query/MongoQueryMethod.java | 19 ++ .../repository/query/PartTreeMongoQuery.java | 4 +- .../query/ReactiveMongoQueryExecution.java | 1 + .../query/ReactivePartTreeMongoQuery.java | 2 +- .../query/VectorSearchAggregation.java | 308 ++++++++++++++++++ .../support/MongoRepositoryFactory.java | 3 + .../GeoNearOperationUnitTests.java | 4 +- .../core/convert/GeoConvertersUnitTests.java | 8 +- .../MappingMongoConverterUnitTests.java | 5 +- .../core/geo/GeoSpatial2DSphereTests.java | 4 +- .../core/query/MetricConversionUnitTests.java | 11 +- .../core/query/NearQueryUnitTests.java | 26 +- ...tractPersonRepositoryIntegrationTests.java | 26 +- .../mongodb/repository/PersonRepository.java | 1 + .../ReactiveMongoRepositoryTests.java | 8 +- .../mongodb/repository/VectorSearchTests.java | 268 +++++++++++++++ ...oParametersParameterAccessorUnitTests.java | 51 ++- .../query/MongoParametersUnitTests.java | 21 ++ .../query/MongoQueryCreatorUnitTests.java | 13 +- .../query/MongoQueryExecutionUnitTests.java | 3 +- .../query/MongoQueryMethodUnitTests.java | 1 + .../ReactiveMongoQueryExecutionUnitTests.java | 6 +- .../ReactiveMongoQueryMethodUnitTests.java | 2 +- .../query/StubParameterAccessor.java | 18 + .../VectorSearchAggregationUnitTests.java | 102 ++++++ .../ReactiveFindOperationExtensionsTests.kt | 6 +- 41 files changed, 1249 insertions(+), 104 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java index 5c7df76cc5..5ed7f9b8a3 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/MongoTemplate.java @@ -1098,7 +1098,7 @@ public GeoResults geoNear(NearQuery near, Class domainType, String col result.add(geoResult); } - Distance avgDistance = new Distance( + Distance avgDistance = Distance.of( result.size() == 0 ? 0 : aggregate.divide(new BigDecimal(result.size()), RoundingMode.HALF_UP).doubleValue(), near.getMetric()); @@ -2654,7 +2654,9 @@ protected List doFind(String collectionName, if (LOGGER.isDebugEnabled()) { - Document mappedSort = preparer instanceof SortingQueryCursorPreparer sqcp ? getMappedSortObject(sqcp.getSortObject(), entity) : null; + Document mappedSort = preparer instanceof SortingQueryCursorPreparer sqcp + ? getMappedSortObject(sqcp.getSortObject(), entity) + : null; LOGGER.debug(String.format("find using query: %s fields: %s sort: %s for class: %s in collection: %s", serializeToJsonSafely(mappedQuery), mappedFields, serializeToJsonSafely(mappedSort), entityClass, collectionName)); @@ -3553,7 +3555,7 @@ public GeoResult doWith(Document object) { T doWith = delegate.doWith(object); - return new GeoResult<>(doWith, new Distance(distance, metric)); + return new GeoResult<>(doWith, Distance.of(distance, metric)); } } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java index 325a96dc85..e263735187 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/ReactiveMongoTemplate.java @@ -3227,7 +3227,7 @@ public Mono> doWith(Document object) { double distance = getDistance(object); - return delegate.doWith(object).map(doWith -> new GeoResult<>(doWith, new Distance(distance, metric))); + return delegate.doWith(object).map(doWith -> new GeoResult<>(doWith, Distance.of(distance, metric))); } double getDistance(Document object) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java index f5a861cddd..7b27739229 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationResults.java @@ -105,4 +105,5 @@ public Document getRawResults() { Object object = rawResults.get("serverUsed"); return object instanceof String stringValue ? stringValue : null; } + } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/GeoConverters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/GeoConverters.java index ae73ab68bd..b595ab688f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/GeoConverters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/convert/GeoConverters.java @@ -270,7 +270,7 @@ enum DocumentToCircleConverter implements Converter { Assert.notNull(center, "Center must not be null"); Assert.notNull(radius, "Radius must not be null"); - Distance distance = new Distance(toPrimitiveDoubleValue(radius)); + Distance distance = Distance.of(toPrimitiveDoubleValue(radius)); if (source.containsKey("metric")) { @@ -335,7 +335,7 @@ enum DocumentToSphereConverter implements Converter { Assert.notNull(center, "Center must not be null"); Assert.notNull(radius, "Radius must not be null"); - Distance distance = new Distance(toPrimitiveDoubleValue(radius)); + Distance distance = Distance.of(toPrimitiveDoubleValue(radius)); if (source.containsKey("metric")) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/geo/Sphere.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/geo/Sphere.java index 47be645869..d3ca840d6b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/geo/Sphere.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/geo/Sphere.java @@ -63,7 +63,7 @@ public Sphere(Point center, Distance radius) { * @param radius */ public Sphere(Point center, double radius) { - this(center, new Distance(radius)); + this(center, Distance.of(radius)); } /** diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java index 6dad07b8cb..88d7dc5c1d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/NearQuery.java @@ -19,6 +19,7 @@ import org.bson.Document; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Pageable; import org.springframework.data.geo.CustomMetric; import org.springframework.data.geo.Distance; @@ -329,7 +330,7 @@ public NearQuery with(Pageable pageable) { */ @Contract("_ -> this") public NearQuery maxDistance(double maxDistance) { - return maxDistance(new Distance(maxDistance, getMetric())); + return maxDistance(Distance.of(maxDistance, getMetric())); } /** @@ -345,7 +346,7 @@ public NearQuery maxDistance(double maxDistance, Metric metric) { Assert.notNull(metric, "Metric must not be null"); - return maxDistance(new Distance(maxDistance, metric)); + return maxDistance(Distance.of(maxDistance, metric)); } /** @@ -388,7 +389,7 @@ public NearQuery maxDistance(Distance distance) { */ @Contract("_ -> this") public NearQuery minDistance(double minDistance) { - return minDistance(new Distance(minDistance, getMetric())); + return minDistance(Distance.of(minDistance, getMetric())); } /** @@ -405,7 +406,7 @@ public NearQuery minDistance(double minDistance, Metric metric) { Assert.notNull(metric, "Metric must not be null"); - return minDistance(new Distance(minDistance, metric)); + return minDistance(Distance.of(minDistance, metric)); } /** @@ -611,7 +612,7 @@ public NearQuery withReadPreference(ReadPreference readPreference) { * Get the {@link ReadConcern} to use. Will return the underlying {@link #query(Query) queries} * {@link Query#getReadConcern() ReadConcern} if present or the one defined on the {@link NearQuery#readConcern} * itself. - * + * * @return can be {@literal null} if none set. * @since 4.1 * @see ReadConcernAware diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java new file mode 100644 index 0000000000..7c6b4f4906 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java @@ -0,0 +1,118 @@ +/* + * Copyright 2016-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import org.springframework.core.annotation.AliasFor; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; + +/** + * Annotation to declare Vector Search queries directly on repository methods. Vector Search queries are used to search + * for similar documents based on vector embeddings typically returning + * {@link org.springframework.data.domain.SearchResults} and limited by either a + * {@link org.springframework.data.domain.Score} (within) or a {@link org.springframework.data.domain.Range} of scores + * (between). + *

+ * Vector search must define an index name using the {@link #indexName()} attribute. The index must be created in the + * MongoDB Atlas cluster before executing the query. Any misspelling of the index name will result in returning no + * results. + *

+ * When using pre-filters, you can either define {@link #filter()} or use query derivation to define the pre-filter. + * {@link org.springframework.data.domain.Vector} and distance parameters are considered once these are present. Vector + * search supports sorting and will consider {@link org.springframework.data.domain.Sort} parameters. + * + * @author Mark Paluch + * @since 5.0 + * @see org.springframework.data.geo.Distance + * @see org.springframework.data.domain.Vector + * @see org.springframework.data.domain.SearchResults + */ +@Retention(RetentionPolicy.RUNTIME) +@Target({ ElementType.METHOD, ElementType.ANNOTATION_TYPE }) +@Documented +@Query +@Hint +public @interface VectorSearch { + + /** + * Configuration whether to use ANN or ENN for the search. ANN is the default. + * + * @return the search type to use. + */ + VectorSearchOperation.SearchType searchType() default VectorSearchOperation.SearchType.ENN; + + /** + * Name of the Atlas Vector Search index to use. Atlas Vector Search doesn't return results if you misspell the index + * name or if the specified index doesn't already exist on the cluster. + * + * @return name of the Atlas Vector Search index to use. + */ + @AliasFor(annotation = Hint.class, value = "indexName") + String indexName(); + + /** + * Indexed vector type field to search. This is defaulted from the domain model using the first Vector property found. + * + * @return an empty String by default. + */ + String path() default ""; + + /** + * Takes a MongoDB JSON (MQL) string defining the pre-filter against indexed fields. Alias for + * {@link VectorSearch#filter}. + * + * @return an empty String by default. + */ + @AliasFor(annotation = Query.class) + String value() default ""; + + /** + * Takes a MongoDB JSON (MQL) string defining the pre-filter against indexed fields. Alias for + * {@link VectorSearch#value}. + * + * @return an empty String by default. + */ + @AliasFor(annotation = Query.class, value = "value") + String filter() default ""; + + /** + * Number of documents to return in the results. This value can't exceed the value of {@link #numCandidates} if you + * specify {@link #numCandidates}. Limit accepts Value Expressions. A Vector Search method cannot define both, + * {@code limit()} and a {@link org.springframework.data.domain.Limit} parameter. + * + * @return number of documents to return in the results + */ + String limit() default ""; + + /** + * Number of nearest neighbors to use during the search. Value must be less than or equal to ({@code <=}) + * {@code 10000}. You can't specify a number less than the {@link #limit() number of documents to return}. We + * recommend that you specify a number at least {@code 20} times higher than the {@link #limit() number of documents + * to return} to increase accuracy. This over-request pattern is the recommended way to trade off latency and recall + * in your ANN searches, and we recommend tuning this parameter based on your specific dataset size and query + * requirements. Required if the query uses + * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ANN}. + * + * @return number of documents to return in the results + */ + String numCandidates() default ""; + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java index 831d21bb44..17c19ad951 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/AotQueryCreator.java @@ -25,8 +25,10 @@ import org.jspecify.annotations.Nullable; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.convert.MongoCustomConversions; @@ -129,6 +131,21 @@ public Range getDistanceRange() { return null; } + @Override + public @Nullable Vector getVector() { + return null; + } + + @Override + public @Nullable Score getScore() { + return null; + } + + @Override + public @Nullable Range getScoreRange() { + return null; + } + @Override public @Nullable Point getGeoNearLocation() { return null; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java index f56c2c7a22..596b895ebd 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/AbstractMongoQuery.java @@ -164,7 +164,7 @@ private Query applyAnnotatedReadPreferenceIfPresent(Query query) { } @SuppressWarnings("NullAway") - private MongoQueryExecution getExecution(ConvertingParameterAccessor accessor, FindWithQuery operation) { + MongoQueryExecution getExecution(ConvertingParameterAccessor accessor, FindWithQuery operation) { if (isDeleteQuery()) { return new DeleteExecution<>(executableRemove, method); @@ -345,7 +345,7 @@ private Document bindParameters(String source, ConvertingParameterAccessor acces * @return never {@literal null}. * @since 3.4 */ - protected ParameterBindingContext prepareBindingContext(String source, ConvertingParameterAccessor accessor) { + protected ParameterBindingContext prepareBindingContext(String source, MongoParameterAccessor accessor) { ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor); return new ParameterBindingContext(accessor::getBindableValue, evaluator); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java index d075b67efe..e51d4435a8 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java @@ -22,11 +22,14 @@ import java.util.List; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Limit; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.convert.MongoWriter; @@ -73,6 +76,11 @@ public PotentiallyConvertingIterator iterator() { return new ConvertingIterator(delegate.iterator()); } + @Override + public Vector getVector() { + return delegate.getVector(); + } + @Override public @Nullable ScrollPosition getScrollPosition() { return delegate.getScrollPosition(); @@ -95,6 +103,16 @@ public Sort getSort() { return getConvertedValue(delegate.getBindableValue(index), null); } + @Override + public @org.jspecify.annotations.Nullable Score getScore() { + return delegate.getScore(); + } + + @Override + public @org.jspecify.annotations.Nullable Range getScoreRange() { + return delegate.getScoreRange(); + } + @Override public @Nullable Range getDistanceRange() { return delegate.getDistanceRange(); @@ -208,7 +226,7 @@ private static Collection asCollection(@Nullable Object source) { if (source instanceof Iterable iterable) { - if(source instanceof Collection collection) { + if (source instanceof Collection collection) { return new ArrayList<>(collection); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameterAccessor.java index 00d748f8a9..1b52233eac 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameterAccessor.java @@ -16,6 +16,7 @@ package org.springframework.data.mongodb.repository.query; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Range; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java index cb91ccd8e6..98438d1652 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java @@ -22,11 +22,13 @@ import org.jspecify.annotations.Nullable; import org.springframework.core.MethodParameter; +import org.springframework.core.ResolvableType; import org.springframework.data.domain.Range; -import org.springframework.data.geo.Distance; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.GeoResults; +import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.query.Collation; import org.springframework.data.mongodb.core.query.TextCriteria; @@ -76,7 +78,7 @@ public MongoParameters(ParametersSource parametersSource) { * @param isGeoNearMethod indicate if this is a geo-spatial query method */ public MongoParameters(ParametersSource parametersSource, boolean isGeoNearMethod) { - this(parametersSource, new NearIndex(parametersSource, isGeoNearMethod)); + this(parametersSource, new NearIndex(parametersSource, isGeoNearMethod), new DistanceRangeIndex(parametersSource)); } /** @@ -85,10 +87,11 @@ public MongoParameters(ParametersSource parametersSource, boolean isGeoNearMetho * @param parametersSource must not be {@literal null}. * @param nearIndex the near parameter index. */ - private MongoParameters(ParametersSource parametersSource, NearIndex nearIndex) { + private MongoParameters(ParametersSource parametersSource, NearIndex nearIndex, + DistanceRangeIndex distanceRangeIndex) { super(parametersSource, methodParameter -> new MongoParameter(methodParameter, - parametersSource.getDomainTypeInformation(), nearIndex.nearIndex)); + parametersSource.getDomainTypeInformation(), nearIndex.nearIndex, distanceRangeIndex.distanceRangeIndex)); Method method = parametersSource.getMethod(); List> parameterTypes = Arrays.asList(method.getParameterTypes()); @@ -153,6 +156,15 @@ public NearIndex(ParametersSource parametersSource, boolean isGeoNearMethod) { } } + static class DistanceRangeIndex { + + private final int distanceRangeIndex; + + public DistanceRangeIndex(ParametersSource parametersSource) { + this.distanceRangeIndex = findDistanceRangeIndexInParameters(parametersSource.getMethod()); + } + } + private static int getNearIndex(List> parameterTypes) { for (Class reference : Arrays.asList(Point.class, double[].class)) { @@ -195,8 +207,19 @@ static int findNearIndexInParameters(Method method) { return index; } - public int getDistanceRangeIndex() { - return -1; + static int findDistanceRangeIndexInParameters(Method method) { + + int index = -1; + for (java.lang.reflect.Parameter p : method.getParameters()) { + + MethodParameter methodParameter = MethodParameter.forParameter(p); + + if (Range.class.isAssignableFrom(methodParameter.getParameterType()) + && ResolvableType.forMethodParameter(methodParameter).getGeneric(0).isAssignableFrom(Distance.class)) { + index = methodParameter.getParameterIndex(); + } + } + return index; } /** @@ -298,17 +321,21 @@ static class MongoParameter extends Parameter { private final MethodParameter parameter; private final @Nullable Integer nearIndex; + private final @Nullable Integer distanceRangeIndex; /** * Creates a new {@link MongoParameter}. * * @param parameter must not be {@literal null}. * @param domainType must not be {@literal null}. + * @param distanceRangeIndex */ - MongoParameter(MethodParameter parameter, TypeInformation domainType, @Nullable Integer nearIndex) { + MongoParameter(MethodParameter parameter, TypeInformation domainType, @Nullable Integer nearIndex, + @Nullable Integer distanceRangeIndex) { super(parameter, domainType); this.parameter = parameter; this.nearIndex = nearIndex; + this.distanceRangeIndex = distanceRangeIndex; if (!isPoint() && hasNearAnnotation()) { throw new IllegalArgumentException("Near annotation is only allowed at Point parameter"); @@ -317,7 +344,8 @@ static class MongoParameter extends Parameter { @Override public boolean isSpecialParameter() { - return super.isSpecialParameter() || Distance.class.isAssignableFrom(getType()) || isNearParameter() + return super.isSpecialParameter() || Distance.class.isAssignableFrom(getType()) + || Vector.class.isAssignableFrom(getType()) || isNearParameter() || isDistanceRangeParameter() || TextCriteria.class.isAssignableFrom(getType()) || Collation.class.isAssignableFrom(getType()); } @@ -325,6 +353,10 @@ private boolean isNearParameter() { return nearIndex != null && nearIndex.equals(getIndex()); } + private boolean isDistanceRangeParameter() { + return distanceRangeIndex != null && distanceRangeIndex.equals(getIndex()); + } + private boolean isManuallyAnnotatedNearParameter() { return isPoint() && hasNearAnnotation(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java index 66529dfce9..41cf084d45 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java @@ -16,8 +16,10 @@ package org.springframework.data.mongodb.repository.query; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; +import org.springframework.data.domain.Score; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.query.Collation; @@ -55,7 +57,25 @@ public MongoParametersParameterAccessor(MongoQueryMethod method, Object[] values } @SuppressWarnings("NullAway") - public @Nullable Range getDistanceRange() { + @Override + public Range getScoreRange() { + + MongoParameters mongoParameters = method.getParameters(); + int rangeIndex = mongoParameters.getScoreRangeIndex(); + + if (rangeIndex != -1) { + return getValue(rangeIndex); + } + + int scoreIndex = mongoParameters.getScoreIndex(); + Bound maxDistance = scoreIndex == -1 ? Bound.unbounded() : Bound.inclusive((Score) getScore()); + + return Range.of(Bound.unbounded(), maxDistance); + } + + @SuppressWarnings("NullAway") + @Override + public Range getDistanceRange() { MongoParameters mongoParameters = method.getParameters(); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java index b8a8c34f48..1f742ec32f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java @@ -27,6 +27,7 @@ import org.apache.commons.logging.LogFactory; import org.bson.BsonRegularExpression; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; import org.springframework.data.domain.Sort; @@ -72,6 +73,7 @@ public class MongoQueryCreator extends AbstractQueryCreator { private final MongoParameterAccessor accessor; private final MappingContext context; private final boolean isGeoNearQuery; + private final boolean isSearchQuery; /** * Creates a new {@link MongoQueryCreator} from the given {@link PartTree}, {@link ConvertingParameterAccessor} and @@ -81,9 +83,9 @@ public class MongoQueryCreator extends AbstractQueryCreator { * @param accessor * @param context */ - public MongoQueryCreator(PartTree tree, ConvertingParameterAccessor accessor, + public MongoQueryCreator(PartTree tree, MongoParameterAccessor accessor, MappingContext context) { - this(tree, accessor, context, false); + this(tree, accessor, context, false, false); } /** @@ -94,9 +96,10 @@ public MongoQueryCreator(PartTree tree, ConvertingParameterAccessor accessor, * @param accessor * @param context * @param isGeoNearQuery + * @param isSearchQuery */ - public MongoQueryCreator(PartTree tree, ConvertingParameterAccessor accessor, - MappingContext context, boolean isGeoNearQuery) { + public MongoQueryCreator(PartTree tree, MongoParameterAccessor accessor, + MappingContext context, boolean isGeoNearQuery, boolean isSearchQuery) { super(tree, accessor); @@ -104,6 +107,7 @@ public MongoQueryCreator(PartTree tree, ConvertingParameterAccessor accessor, this.accessor = accessor; this.isGeoNearQuery = isGeoNearQuery; + this.isSearchQuery = isSearchQuery; this.context = context; } @@ -114,6 +118,10 @@ protected Criteria create(Part part, Iterator iterator) { return new Criteria(); } + if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { + return null; + } + PersistentPropertyPath path = context.getPersistentPropertyPath(part.getProperty()); MongoPersistentProperty property = path.getLeafProperty(); @@ -127,6 +135,10 @@ protected Criteria and(Part part, Criteria base, Iterator iterator) { return create(part, iterator); } + if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { + return base; + } + PersistentPropertyPath path = context.getPersistentPropertyPath(part.getProperty()); MongoPersistentProperty property = path.getLeafProperty(); @@ -164,6 +176,15 @@ protected Query complete(@Nullable Criteria criteria, Sort sort) { @SuppressWarnings("NullAway") private Criteria from(Part part, MongoPersistentProperty property, Criteria criteria, Iterator parameters) { + if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { + + int numberOfArguments = part.getType().getNumberOfArguments(); + for (int i = 0; i < numberOfArguments; i++) { + parameters.next(); + } + return null; + } + Type type = part.getType(); switch (type) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java index 01d4e0c63d..7f632f58e4 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java @@ -15,16 +15,24 @@ */ package org.springframework.data.mongodb.repository.query; +import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.function.Supplier; +import org.bson.Document; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Slice; import org.springframework.data.domain.SliceImpl; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResult; @@ -37,6 +45,12 @@ import org.springframework.data.mongodb.core.ExecutableRemoveOperation.ExecutableRemove; import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationResults; +import org.springframework.data.mongodb.core.aggregation.TypedAggregation; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; @@ -210,6 +224,137 @@ private static boolean isListOfGeoResult(TypeInformation returnType) { } } + /** + * {@link MongoQueryExecution} to execute vector search + * + * @author Mark Paluch + * @since 5.0 + */ + class VectorSearchExecution implements MongoQueryExecution { + + private final MongoOperations operations; + private final MongoQueryMethod method; + private final String collectionName; + private final @Nullable Integer numCandidates; + private final VectorSearchOperation.SearchType searchType; + private final MongoParameterAccessor accessor; + private final Class outputType; + private final String path; + + public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, + String path, @Nullable Integer numCandidates, VectorSearchOperation.SearchType searchType, + MongoParameterAccessor accessor, Class outputType) { + + this.operations = operations; + this.collectionName = collectionName; + this.path = path; + this.numCandidates = numCandidates; + this.method = method; + this.searchType = searchType; + this.accessor = accessor; + this.outputType = outputType; + } + + @Override + public Object execute(Query query) { + + SearchResults results = doExecuteQuery(query); + return isListOfSearchResult(method.getReturnType()) ? results.getContent() : results; + } + + @SuppressWarnings("unchecked") + SearchResults doExecuteQuery(Query query) { + + Vector vector = accessor.getVector(); + Score score = accessor.getScore(); + Range distance = accessor.getScoreRange(); + int limit; + + if (query.isLimited()) { + limit = query.getLimit(); + } else { + limit = Math.max(1, numCandidates != null ? numCandidates / 20 : 1); + } + + List stages = new ArrayList<>(); + VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(method.getAnnotatedHint()).path(path) + .vector(vector).limit(limit); + + if (numCandidates != null) { + $vectorSearch = $vectorSearch.numCandidates(numCandidates); + } + + $vectorSearch = $vectorSearch.filter(query.getQueryObject()); + $vectorSearch = $vectorSearch.searchType(searchType); + $vectorSearch = $vectorSearch.withSearchScore("__score__"); + + if (score != null) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + c.gt(score.getValue()); + }); + } else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + Range.Bound lower = distance.getLowerBound(); + if (lower.isBounded()) { + double value = lower.getValue().get().getValue(); + if (lower.isInclusive()) { + c.gte(value); + } else { + c.gt(value); + } + } + + Range.Bound upper = distance.getUpperBound(); + if (upper.isBounded()) { + + double value = upper.getValue().get().getValue(); + if (upper.isInclusive()) { + c.lte(value); + } else { + c.lt(value); + } + } + }); + } + + stages.add($vectorSearch); + + if (query.isSorted()) { + // TODO stages.add(Aggregation.sort(query.with())); + } else { + stages.add(Aggregation.sort(Sort.Direction.DESC, "__score__")); + } + + AggregationResults aggregated = operations + .aggregate(TypedAggregation. newAggregation(outputType, stages), collectionName, outputType); + + List mappedResults = aggregated.getMappedResults(); + List rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class); + + List> result = new ArrayList<>(mappedResults.size()); + + for (int i = 0; i < mappedResults.size(); i++) { + Document document = rawResults.get(i); + SearchResult searchResult = new SearchResult<>(mappedResults.get(i), + Score.of(document.getDouble("__score__"))); + + result.add(searchResult); + } + + return new SearchResults<>(result); + } + + private static boolean isListOfSearchResult(TypeInformation returnType) { + + if (!returnType.getType().equals(List.class)) { + return false; + } + + TypeInformation componentType = returnType.getComponentType(); + return componentType != null && SearchResult.class.equals(componentType.getType()); + } + } + /** * {@link MongoQueryExecution} to execute geo-near queries with paging. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryMethod.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryMethod.java index 52c5e32555..060d03e223 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryMethod.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryMethod.java @@ -35,6 +35,7 @@ import org.springframework.data.mongodb.repository.ReadPreference; import org.springframework.data.mongodb.repository.Tailable; import org.springframework.data.mongodb.repository.Update; +import org.springframework.data.mongodb.repository.VectorSearch; import org.springframework.data.mongodb.util.BsonUtils; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.repository.core.RepositoryMetadata; @@ -414,10 +415,28 @@ private Optional findAnnotatedAggregation() { .filter(it -> !ObjectUtils.isEmpty(it)); } + /** + * Returns whether the method has an annotated vector search. + * + * @return true if {@link VectorSearch} is present. + * @since 5.0 + */ + public boolean hasAnnotatedVectorSearch() { + return findAnnotatedVectorSearch().isPresent(); + } + + Optional findAnnotatedVectorSearch() { + return lookupVectorSearchAnnotation(); + } + Optional lookupAggregationAnnotation() { return doFindAnnotation(Aggregation.class); } + Optional lookupVectorSearchAnnotation() { + return doFindAnnotation(VectorSearch.class); + } + Optional lookupUpdateAnnotation() { return doFindAnnotation(Update.class); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java index 6116cc5534..9682e4971f 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/PartTreeMongoQuery.java @@ -81,7 +81,7 @@ public PartTree getTree() { @SuppressWarnings("NullAway") protected Query createQuery(ConvertingParameterAccessor accessor) { - MongoQueryCreator creator = new MongoQueryCreator(tree, accessor, context, isGeoNearQuery); + MongoQueryCreator creator = new MongoQueryCreator(tree, accessor, context, isGeoNearQuery, false); Query query = creator.createQuery(); if (tree.isLimiting()) { @@ -126,7 +126,7 @@ protected Query createQuery(ConvertingParameterAccessor accessor) { @Override protected Query createCountQuery(ConvertingParameterAccessor accessor) { - return new MongoQueryCreator(tree, accessor, context, false).createQuery(); + return new MongoQueryCreator(tree, accessor, context, false, false).createQuery(); } @Override diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java index 06f946d745..f9b47c9a84 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java @@ -20,6 +20,7 @@ import org.jspecify.annotations.Nullable; import org.reactivestreams.Publisher; + import org.springframework.core.convert.converter.Converter; import org.springframework.data.convert.DtoInstantiatingConverter; import org.springframework.data.domain.Pageable; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactivePartTreeMongoQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactivePartTreeMongoQuery.java index 4aa773091b..9a17b2b5fc 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactivePartTreeMongoQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactivePartTreeMongoQuery.java @@ -90,7 +90,7 @@ protected Mono createCountQuery(ConvertingParameterAccessor accessor) { @SuppressWarnings("NullAway") private Query createQueryInternal(ConvertingParameterAccessor accessor, boolean isCountQuery) { - MongoQueryCreator creator = new MongoQueryCreator(tree, accessor, context, !isCountQuery && isGeoNearQuery); + MongoQueryCreator creator = new MongoQueryCreator(tree, accessor, context, !isCountQuery && isGeoNearQuery, false); Query query = creator.createQuery(); if (isCountQuery) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java new file mode 100644 index 0000000000..2f0d0258d1 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java @@ -0,0 +1,308 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import org.bson.Document; + +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; +import org.springframework.data.expression.ValueExpression; +import org.springframework.data.mapping.PersistentPropertyPath; +import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.mapping.model.ValueExpressionEvaluator; +import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.convert.MongoConverter; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; +import org.springframework.data.mongodb.core.query.BasicQuery; +import org.springframework.data.mongodb.core.query.Query; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.repository.query.ValueExpressionDelegate; +import org.springframework.data.repository.query.parser.Part; +import org.springframework.data.repository.query.parser.PartTree; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * {@link AbstractMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either derived + * from the method name or provided through {@link VectorSearch#filter()}. + * + * @author Mark Paluch + * @since 5.0 + */ +public class VectorSearchAggregation extends AbstractMongoQuery { + + private final MongoOperations mongoOperations; + private final MongoConverter mongoConverter; + private final MongoPersistentEntity collectionEntity; + private final VectorSearchQueryFactory queryFactory; + private final VectorSearchOperation.SearchType searchType; + private final @Nullable Integer numCandidates; + private final @Nullable String numCandidatesExpression; + + private final Limit limit; + private final @Nullable String limitExpression; + + /** + * Creates a new {@link VectorSearchAggregation} from the given {@link MongoQueryMethod} and {@link MongoOperations}. + * + * @param method must not be {@literal null}. + * @param mongoOperations must not be {@literal null}. + * @param delegate must not be {@literal null}. + */ + public VectorSearchAggregation(MongoQueryMethod method, MongoOperations mongoOperations, + ValueExpressionDelegate delegate) { + + super(method, mongoOperations, delegate); + + if (!method.isSearchQuery() && !method.isCollectionQuery()) { + throw new InvalidMongoDbApiUsageException(String.format( + "Repository Vector Search method '%s' must return either return SearchResults or List but was %s", + method.getName(), method.getReturnType().getType().getSimpleName())); + } + + this.mongoOperations = mongoOperations; + this.mongoConverter = mongoOperations.getConverter(); + this.collectionEntity = method.getEntityInformation().getCollectionEntity(); + + VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow(); + + this.searchType = vectorSearch.searchType(); + + if (StringUtils.hasText(vectorSearch.numCandidates())) { + + ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates()); + + if (expression.isLiteral()) { + numCandidates = Integer.parseInt(vectorSearch.numCandidates()); + numCandidatesExpression = null; + } else { + numCandidates = null; + numCandidatesExpression = vectorSearch.numCandidates(); + } + + } else { + numCandidates = null; + numCandidatesExpression = null; + } + + if (StringUtils.hasText(vectorSearch.limit())) { + + ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit()); + + if (expression.isLiteral()) { + limit = Limit.of(Integer.parseInt(vectorSearch.limit())); + limitExpression = null; + } else { + limit = Limit.unlimited(); + limitExpression = vectorSearch.limit(); + } + + } else { + limit = Limit.unlimited(); + limitExpression = null; + } + + if (StringUtils.hasText(vectorSearch.filter())) { + queryFactory = StringUtils.hasText(vectorSearch.path()) + ? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path()) + : new AnnotatedQueryFactory(vectorSearch.filter(), collectionEntity); + } else { + queryFactory = new PartTreeQueryFactory( + new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()), + mongoConverter.getMappingContext()); + } + } + + @SuppressWarnings("unchecked") + @Override + protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor, + @Nullable Class typeToRead) { + + ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor); + Integer numCandidates = null; + Limit limit; + Class outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType(); + + if (this.numCandidatesExpression != null) { + numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); + } else if (this.numCandidates != null) { + numCandidates = this.numCandidates; + } + + if (this.limitExpression != null) { + + Object value = evaluator.evaluate(this.limitExpression); + limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); + } else if (this.limit.isLimited()) { + limit = this.limit; + } else { + limit = accessor.getLimit(); + } + + VectorSearchQuery query = createVectorSearchQuery(accessor); + + if (limit.isLimited()) { + query.query().limit(limit); + } + + MongoQueryExecution.VectorSearchExecution execution = new MongoQueryExecution.VectorSearchExecution(mongoOperations, + method, collectionEntity.getCollection(), query.path(), numCandidates, searchType, accessor, + (Class) outputType); + + return execution.execute(query.query()); + } + + VectorSearchQuery createVectorSearchQuery(MongoParameterAccessor accessor) { + return queryFactory.createQuery(accessor); + } + + @Override + protected Query createQuery(ConvertingParameterAccessor accessor) { + throw new UnsupportedOperationException(); + } + + @Override + protected boolean isCountQuery() { + return false; + } + + @Override + protected boolean isExistsQuery() { + return false; + } + + @Override + protected boolean isDeleteQuery() { + return false; + } + + @Override + protected boolean isLimiting() { + return false; + } + + interface VectorSearchQueryFactory { + + VectorSearchQuery createQuery(MongoParameterAccessor parameterAccessor); + } + + class AnnotatedQueryFactory implements VectorSearchQueryFactory { + + private final String query; + private final String path; + + AnnotatedQueryFactory(String query, String path) { + + this.query = query; + this.path = path; + } + + AnnotatedQueryFactory(String query, MongoPersistentEntity entity) { + + this.query = query; + String path = null; + for (MongoPersistentProperty property : entity) { + if (Vector.class.isAssignableFrom(property.getType())) { + path = property.getFieldName(); + break; + } + } + + if (path == null) { + throw new InvalidMongoDbApiUsageException( + "Cannot find Vector Search property in entity [%s]".formatted(entity.getName())); + } + + this.path = path; + } + + public VectorSearchQuery createQuery(MongoParameterAccessor parameterAccessor) { + + Document queryObject = decode(this.query, prepareBindingContext(this.query, parameterAccessor)); + Query query = new BasicQuery(queryObject); + + Sort sort = parameterAccessor.getSort(); + if (sort.isSorted()) { + query = query.with(sort); + } + + return new VectorSearchQuery(path, query); + } + + } + + class PartTreeQueryFactory implements VectorSearchQueryFactory { + + private final String path; + private final Part.Type type; + private final MappingContext context; + private final PartTree partTree; + + @SuppressWarnings("NullableProblems") + PartTreeQueryFactory(PartTree partTree, MappingContext context) { + + String path = null; + Part.Type type = null; + for (PartTree.OrPart part : partTree) { + for (Part p : part) { + if (p.getType() == Part.Type.SIMPLE_PROPERTY || p.getType() == Part.Type.NEAR + || p.getType() == Part.Type.WITHIN || p.getType() == Part.Type.BETWEEN) { + PersistentPropertyPath ppp = context.getPersistentPropertyPath(p.getProperty()); + MongoPersistentProperty property = ppp.getLeafProperty(); + + if (Vector.class.isAssignableFrom(property.getType())) { + path = p.getProperty().toDotPath(); + type = p.getType(); + break; + } + } + } + } + + if (path == null) { + throw new InvalidMongoDbApiUsageException( + "No Simple Property/Near/Within/Between part found for a Vector property"); + } + + this.path = path; + this.type = type; + + this.partTree = partTree; + this.context = context; + } + + public VectorSearchQuery createQuery(MongoParameterAccessor parameterAccessor) { + + MongoQueryCreator creator = new MongoQueryCreator(partTree, parameterAccessor, mongoConverter.getMappingContext(), + false, true); + + Query query = creator.createQuery(parameterAccessor.getSort()); + + return new VectorSearchQuery(path, query); + } + + } + + record VectorSearchQuery(String path, Query query) { + + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java index e1abcdc2ab..d6047aa058 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/MongoRepositoryFactory.java @@ -34,6 +34,7 @@ import org.springframework.data.mongodb.repository.query.PartTreeMongoQuery; import org.springframework.data.mongodb.repository.query.StringBasedAggregation; import org.springframework.data.mongodb.repository.query.StringBasedMongoQuery; +import org.springframework.data.mongodb.repository.query.VectorSearchAggregation; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.querydsl.QuerydslPredicateExecutor; import org.springframework.data.repository.core.NamedQueries; @@ -182,6 +183,8 @@ public RepositoryQuery resolveQuery(Method method, RepositoryMetadata metadata, if (namedQueries.hasQuery(namedQueryName)) { String namedQuery = namedQueries.getQuery(namedQueryName); return new StringBasedMongoQuery(namedQuery, queryMethod, operations, expressionSupport); + } else if (queryMethod.hasAnnotatedVectorSearch()) { + return new VectorSearchAggregation(queryMethod, operations, expressionSupport); } else if (queryMethod.hasAnnotatedAggregation()) { return new StringBasedAggregation(queryMethod, operations, expressionSupport); } else if (queryMethod.hasAnnotatedQuery()) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperationUnitTests.java index 1b9aba1ba0..5f66e61bdc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/GeoNearOperationUnitTests.java @@ -70,7 +70,7 @@ public void rendersNearQueryWithKeyCorrectly() { @Test // DATAMONGO-2264 public void rendersMaxDistanceCorrectly() { - NearQuery query = NearQuery.near(10.0, 20.0).maxDistance(new Distance(30.0)); + NearQuery query = NearQuery.near(10.0, 20.0).maxDistance(Distance.of(30.0)); assertThat(new GeoNearOperation(query, "distance").toPipelineStages(Aggregation.DEFAULT_CONTEXT)) .containsExactly($geoNear().near(10.0, 20.0).maxDistance(30.0).doc()); @@ -79,7 +79,7 @@ public void rendersMaxDistanceCorrectly() { @Test // DATAMONGO-2264 public void rendersMinDistanceCorrectly() { - NearQuery query = NearQuery.near(10.0, 20.0).minDistance(new Distance(30.0)); + NearQuery query = NearQuery.near(10.0, 20.0).minDistance(Distance.of(30.0)); assertThat(new GeoNearOperation(query, "distance").toPipelineStages(Aggregation.DEFAULT_CONTEXT)) .containsExactly($geoNear().near(10.0, 20.0).minDistance(30.0).doc()); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/GeoConvertersUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/GeoConvertersUnitTests.java index 7fb664b00c..84a494f9d8 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/GeoConvertersUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/GeoConvertersUnitTests.java @@ -69,7 +69,7 @@ public void convertsCircleToDocumentAndBackCorrectlyNeutralDistance() { @Test // DATAMONGO-858 public void convertsCircleToDocumentAndBackCorrectlyMilesDistance() { - Distance radius = new Distance(3, Metrics.MILES); + Distance radius = Distance.of(3, Metrics.MILES); Circle circle = new Circle(new Point(1, 2), radius); Document document = CircleToDocumentConverter.INSTANCE.convert(circle); @@ -106,7 +106,7 @@ public void convertsSphereToDocumentAndBackCorrectlyWithNeutralDistance() { @Test // DATAMONGO-858 public void convertsSphereToDocumentAndBackCorrectlyWithKilometerDistance() { - Distance radius = new Distance(3, Metrics.KILOMETERS); + Distance radius = Distance.of(3, Metrics.KILOMETERS); Sphere sphere = new Sphere(new Point(1, 2), radius); Document document = SphereToDocumentConverter.INSTANCE.convert(sphere); @@ -160,7 +160,7 @@ public void convertsCircleCorrectlyWhenUsingNonDoubleForCoordinates() { circle.put("radius", 3L); assertThat(DocumentToCircleConverter.INSTANCE.convert(circle)) - .isEqualTo(new Circle(new Point(1, 2), new Distance(3))); + .isEqualTo(new Circle(new Point(1, 2), Distance.of(3))); } @Test // DATAMONGO-1607 @@ -171,7 +171,7 @@ public void convertsSphereCorrectlyWhenUsingNonDoubleForCoordinates() { sphere.put("radius", 3L); assertThat(DocumentToSphereConverter.INSTANCE.convert(sphere)) - .isEqualTo(new Sphere(new Point(1, 2), new Distance(3))); + .isEqualTo(new Sphere(new Point(1, 2), Distance.of(3))); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java index 5bd7e06b97..6f1c7439c0 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/convert/MappingMongoConverterUnitTests.java @@ -1626,7 +1626,7 @@ void shouldWriteEntityWithGeoSphereCorrectly() { void shouldWriteEntityWithGeoSphereWithMetricDistanceCorrectly() { ClassWithGeoSphere object = new ClassWithGeoSphere(); - Sphere sphere = new Sphere(new Point(1, 2), new Distance(3, Metrics.KILOMETERS)); + Sphere sphere = new Sphere(new Point(1, 2), Distance.of(3, Metrics.KILOMETERS)); Distance radius = sphere.getRadius(); object.sphere = sphere; @@ -4082,8 +4082,7 @@ static class WithExplicitTargetTypes { @Field(targetType = FieldType.DECIMAL128) // BigDecimal bigDecimal; - @Field(targetType = FieldType.DECIMAL128) - BigInteger bigInteger; + @Field(targetType = FieldType.DECIMAL128) BigInteger bigInteger; @Field(targetType = FieldType.INT64) // Date dateAsLong; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/GeoSpatial2DSphereTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/GeoSpatial2DSphereTests.java index 3a9140d34c..1774c36493 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/GeoSpatial2DSphereTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/geo/GeoSpatial2DSphereTests.java @@ -23,9 +23,9 @@ import java.util.List; import org.junit.Test; + import org.springframework.data.domain.Sort.Direction; import org.springframework.data.geo.GeoResults; -import org.springframework.data.geo.Metric; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.Venue; @@ -67,7 +67,7 @@ public void geoNearWithMinDistance() { GeoResults result = template.geoNear(geoNear, Venue.class); assertThat(result.getContent().size()).isNotEqualTo(0); - assertThat(result.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(result.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-1110 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/MetricConversionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/MetricConversionUnitTests.java index bbdad047f2..fdfa840d58 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/MetricConversionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/MetricConversionUnitTests.java @@ -17,6 +17,7 @@ package org.springframework.data.mongodb.core.query; import static org.assertj.core.api.Assertions.*; +import static org.assertj.core.data.Offset.*; import static org.assertj.core.data.Offset.offset; import org.junit.jupiter.api.Test; @@ -34,7 +35,7 @@ public class MetricConversionUnitTests { @Test // DATAMONGO-1348 public void shouldConvertMilesToMeters() { - Distance distance = new Distance(1, Metrics.MILES); + Distance distance = Distance.of(1, Metrics.MILES); double distanceInMeters = MetricConversion.getDistanceInMeters(distance); assertThat(distanceInMeters).isCloseTo(1609.3438343d, offset(0.000000001)); @@ -43,7 +44,7 @@ public void shouldConvertMilesToMeters() { @Test // DATAMONGO-1348 public void shouldConvertKilometersToMeters() { - Distance distance = new Distance(1, Metrics.KILOMETERS); + Distance distance = Distance.of(1, Metrics.KILOMETERS); double distanceInMeters = MetricConversion.getDistanceInMeters(distance); assertThat(distanceInMeters).isCloseTo(1000, offset(0.000000001)); @@ -72,11 +73,13 @@ public void shouldCalculateMetersToMilesMultiplier() { @Test // GH-4004 void shouldConvertKilometersToRadians/* on an earth like sphere with r=6378.137km */() { - assertThat(MetricConversion.toRadians(new Distance(1, Metrics.KILOMETERS))).isCloseTo(0.000156785594d, offset(0.000000001)); + assertThat(MetricConversion.toRadians(Distance.of(1, Metrics.KILOMETERS))).isCloseTo(0.000156785594d, + offset(0.000000001)); } @Test // GH-4004 void shouldConvertMilesToRadians/* on an earth like sphere with r=6378.137km */() { - assertThat(MetricConversion.toRadians(new Distance(1, Metrics.MILES))).isCloseTo(0.000252321328d, offset(0.000000001)); + assertThat(MetricConversion.toRadians(Distance.of(1, Metrics.MILES))).isCloseTo(0.000252321328d, + offset(0.000000001)); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/NearQueryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/NearQueryUnitTests.java index f4e3d26eb1..2b600988db 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/NearQueryUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/NearQueryUnitTests.java @@ -21,10 +21,10 @@ import java.math.RoundingMode; import org.junit.jupiter.api.Test; + import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.geo.Distance; -import org.springframework.data.geo.Metric; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.DocumentTestUtils; @@ -44,7 +44,7 @@ */ public class NearQueryUnitTests { - private static final Distance ONE_FIFTY_KILOMETERS = new Distance(150, Metrics.KILOMETERS); + private static final Distance ONE_FIFTY_KILOMETERS = Distance.of(150, Metrics.KILOMETERS); @Test public void rejectsNullPoint() { @@ -57,7 +57,7 @@ public void settingUpNearWithMetricRecalculatesDistance() { NearQuery query = NearQuery.near(2.5, 2.5, Metrics.KILOMETERS).maxDistance(150); assertThat(query.getMaxDistance()).isEqualTo(ONE_FIFTY_KILOMETERS); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(query.getMetric()).isEqualTo(Metrics.KILOMETERS); assertThat(query.isSpherical()).isTrue(); } @@ -68,27 +68,27 @@ public void settingMetricRecalculatesMaxDistance() { query.inMiles(); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.MILES); + assertThat(query.getMetric()).isEqualTo(Metrics.MILES); } @Test public void configuresResultMetricCorrectly() { NearQuery query = NearQuery.near(2.5, 2.1); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.NEUTRAL); + assertThat(query.getMetric()).isEqualTo(Metrics.NEUTRAL); query = query.maxDistance(ONE_FIFTY_KILOMETERS); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(query.getMetric()).isEqualTo(Metrics.KILOMETERS); assertThat(query.getMaxDistance()).isEqualTo(ONE_FIFTY_KILOMETERS); assertThat(query.isSpherical()).isTrue(); query = query.in(Metrics.MILES); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.MILES); + assertThat(query.getMetric()).isEqualTo(Metrics.MILES); assertThat(query.getMaxDistance()).isEqualTo(ONE_FIFTY_KILOMETERS); assertThat(query.isSpherical()).isTrue(); - query = query.maxDistance(new Distance(200, Metrics.KILOMETERS)); - assertThat(query.getMetric()).isEqualTo((Metric) Metrics.MILES); + query = query.maxDistance(Distance.of(200, Metrics.KILOMETERS)); + assertThat(query.getMetric()).isEqualTo(Metrics.MILES); } @Test // DATAMONGO-445, DATAMONGO-2264 @@ -200,7 +200,7 @@ public void shouldUseMetersForGeoJsonData() { public void shouldUseMetersForGeoJsonDataWhenDistanceInKilometers() { NearQuery query = NearQuery.near(new GeoJsonPoint(27.987901, 86.9165379)); - query.maxDistance(new Distance(1, Metrics.KILOMETERS)); + query.maxDistance(Distance.of(1, Metrics.KILOMETERS)); assertThat(query.toDocument()).containsEntry("maxDistance", 1000D).containsEntry("distanceMultiplier", 0.001D); } @@ -209,7 +209,7 @@ public void shouldUseMetersForGeoJsonDataWhenDistanceInKilometers() { public void shouldUseMetersForGeoJsonDataWhenDistanceInMiles() { NearQuery query = NearQuery.near(new GeoJsonPoint(27.987901, 86.9165379)); - query.maxDistance(new Distance(1, Metrics.MILES)); + query.maxDistance(Distance.of(1, Metrics.MILES)); assertThat(query.toDocument()).containsEntry("maxDistance", 1609.3438343D).containsEntry("distanceMultiplier", 0.00062137D); @@ -219,7 +219,7 @@ public void shouldUseMetersForGeoJsonDataWhenDistanceInMiles() { public void shouldUseKilometersForDistanceWhenMaxDistanceInMiles() { NearQuery query = NearQuery.near(new GeoJsonPoint(27.987901, 86.9165379)); - query.maxDistance(new Distance(1, Metrics.MILES)).in(Metrics.KILOMETERS); + query.maxDistance(Distance.of(1, Metrics.MILES)).in(Metrics.KILOMETERS); assertThat(query.toDocument()).containsEntry("maxDistance", 1609.3438343D).containsEntry("distanceMultiplier", 0.001D); @@ -229,7 +229,7 @@ public void shouldUseKilometersForDistanceWhenMaxDistanceInMiles() { public void shouldUseMilesForDistanceWhenMaxDistanceInKilometers() { NearQuery query = NearQuery.near(new GeoJsonPoint(27.987901, 86.9165379)); - query.maxDistance(new Distance(1, Metrics.KILOMETERS)).in(Metrics.MILES); + query.maxDistance(Distance.of(1, Metrics.KILOMETERS)).in(Metrics.MILES); assertThat(query.toDocument()).containsEntry("maxDistance", 1000D).containsEntry("distanceMultiplier", 0.00062137D); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java index 3f2e60f4c4..c2cb6cacf8 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/AbstractPersonRepositoryIntegrationTests.java @@ -38,6 +38,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.dao.DuplicateKeyException; import org.springframework.dao.IncorrectResultSizeDataAccessException; @@ -49,7 +50,6 @@ import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResults; -import org.springframework.data.geo.Metric; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; import org.springframework.data.geo.Polygon; @@ -458,7 +458,7 @@ void executesGeoNearQueryForResultsCorrectly() { repository.save(dave); GeoResults results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS)); + Distance.of(2000, Metrics.KILOMETERS)); assertThat(results.getContent()).isNotEmpty(); } @@ -470,11 +470,11 @@ void executesGeoPageQueryForResultsCorrectly() { repository.save(dave); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(0, 20)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 20)); assertThat(results.getContent()).isNotEmpty(); // DATAMONGO-607 - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-323 @@ -634,13 +634,13 @@ void executesGeoPageQueryForWithPageRequestForPageInBetween() { repository.saveAll(Arrays.asList(dave, oliver, carter, boyd, leroi)); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); assertThat(results.getContent()).isNotEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(2); assertThat(results.isFirst()).isFalse(); assertThat(results.isLast()).isFalse(); - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); assertThat(results.getAverageDistance().getNormalizedValue()).isEqualTo(0.0); } @@ -656,12 +656,12 @@ void executesGeoPageQueryForWithPageRequestForPageAtTheEnd() { repository.saveAll(Arrays.asList(dave, oliver, carter)); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); assertThat(results.getContent()).isNotEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(1); assertThat(results.isFirst()).isFalse(); assertThat(results.isLast()).isTrue(); - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-445 @@ -672,13 +672,13 @@ void executesGeoPageQueryForWithPageRequestForJustOneElement() { repository.save(dave); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(0, 2)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(0, 2)); assertThat(results.getContent()).isNotEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(1); assertThat(results.isFirst()).isTrue(); assertThat(results.isLast()).isTrue(); - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-445 @@ -688,13 +688,13 @@ void executesGeoPageQueryForWithPageRequestForJustOneElementEmptyPage() { repository.save(dave); GeoPage results = repository.findByLocationNear(new Point(-73.99, 40.73), - new Distance(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); + Distance.of(2000, Metrics.KILOMETERS), PageRequest.of(1, 2)); assertThat(results.getContent()).isEmpty(); assertThat(results.getNumberOfElements()).isEqualTo(0); assertThat(results.isFirst()).isFalse(); assertThat(results.isLast()).isTrue(); - assertThat(results.getAverageDistance().getMetric()).isEqualTo((Metric) Metrics.KILOMETERS); + assertThat(results.getAverageDistance().getMetric()).isEqualTo(Metrics.KILOMETERS); } @Test // DATAMONGO-1608 @@ -1117,7 +1117,7 @@ void executesGeoNearQueryForResultsCorrectlyWhenGivenMinAndMaxDistance() { dave.setLocation(point); repository.save(dave); - Range range = Distance.between(new Distance(0.01, KILOMETERS), new Distance(2000, KILOMETERS)); + Range range = Distance.between(Distance.of(0.01, KILOMETERS), Distance.of(2000, KILOMETERS)); GeoResults results = repository.findPersonByLocationNear(new Point(-73.99, 40.73), range); assertThat(results.getContent()).isNotEmpty(); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java index 93a293ecff..1f4f682ebc 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/PersonRepository.java @@ -24,6 +24,7 @@ import java.util.stream.Stream; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Limit; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveMongoRepositoryTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveMongoRepositoryTests.java index 1a481b49ed..2a76c0ba6c 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveMongoRepositoryTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveMongoRepositoryTests.java @@ -20,6 +20,7 @@ import static org.springframework.data.domain.Sort.Direction.*; import static org.springframework.data.mongodb.core.query.Criteria.*; import static org.springframework.data.mongodb.core.query.Query.*; +import static org.springframework.data.mongodb.test.util.Assertions.*; import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import reactor.core.Disposable; @@ -40,6 +41,7 @@ import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.extension.ExtendWith; import org.reactivestreams.Publisher; + import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; @@ -353,7 +355,7 @@ void findsPeopleGeoresultByLocationWithinBox() { repository.save(dave).as(StepVerifier::create).expectNextCount(1).verifyComplete(); repository.findByLocationNear(new Point(-73.99, 40.73), // - new Distance(2000, Metrics.KILOMETERS)).as(StepVerifier::create).consumeNextWith(actual -> { + Distance.of(2000, Metrics.KILOMETERS)).as(StepVerifier::create).consumeNextWith(actual -> { assertThat(actual.getDistance().getValue()).isCloseTo(1, offset(1d)); assertThat(actual.getContent()).isEqualTo(dave); @@ -372,7 +374,7 @@ void findsPeoplePageableGeoresultByLocationWithinBox() throws InterruptedExcepti Thread.sleep(500); repository.findByLocationNear(new Point(-73.99, 40.73), // - new Distance(2000, Metrics.KILOMETERS), // + Distance.of(2000, Metrics.KILOMETERS), // PageRequest.of(0, 10)).as(StepVerifier::create) // .consumeNextWith(actual -> { @@ -393,7 +395,7 @@ void findsPeopleByLocationWithinBox() throws InterruptedException { Thread.sleep(500); repository.findPersonByLocationNear(new Point(-73.99, 40.73), // - new Distance(2000, Metrics.KILOMETERS)).as(StepVerifier::create) // + Distance.of(2000, Metrics.KILOMETERS)).as(StepVerifier::create) // .expectNext(dave) // .verifyComplete(); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java new file mode 100644 index 0000000000..4e3b12b32a --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java @@ -0,0 +1,268 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository; + +import static org.assertj.core.api.Assertions.*; + +import java.util.List; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.FilterType; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.TestMongoConfiguration; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.index.VectorIndex; +import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; +import org.springframework.data.mongodb.repository.config.EnableMongoRepositories; +import org.springframework.data.mongodb.test.util.AtlasContainer; +import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.data.repository.CrudRepository; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; + +/** + * Integration tests using Vector Search and Vector Indexes through local MongoDB Atlas. + * + * @author Christoph Strobl + * @author Mark Paluch + */ +@Testcontainers(disabledWithoutDocker = true) +@SpringJUnitConfig(classes = { VectorSearchTests.Config.class }) +public class VectorSearchTests { + + Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f); + + private static final MongoDBAtlasLocalContainer atlasLocal = AtlasContainer.bestMatch().withReuse(true); + private static final String COLLECTION_NAME = "collection-1"; + + static MongoClient client; + static MongoTestTemplate template; + + @Autowired VectorSearchRepository repository; + + @EnableMongoRepositories( + includeFilters = { + @ComponentScan.Filter(value = VectorSearchRepository.class, type = FilterType.ASSIGNABLE_TYPE) }, + considerNestedRepositories = true) + static class Config extends TestMongoConfiguration { + + @Override + public String getDatabaseName() { + return "vector-search-tests"; + } + + @Override + public MongoClient mongoClient() { + atlasLocal.start(); + return MongoClients.create(atlasLocal.getConnectionString()); + } + } + + @BeforeAll + static void beforeAll() throws InterruptedException { + atlasLocal.start(); + + System.out.println(atlasLocal.getConnectionString()); + client = MongoClients.create(atlasLocal.getConnectionString()); + template = new MongoTestTemplate(client, "vector-search-tests"); + + template.remove(WithVectorFields.class).all(); + initDocuments(); + initIndexes(); + + Thread.sleep(500); // just wait a little or the index will be broken + } + + @Test + void shouldSearchEnnWithAnnotatedFilter() { + + SearchResults results = repository.searchAnnotated("de", VECTOR, + Score.of(0.4), Limit.of(10)); + + assertThat(results).hasSize(3); + } + + @Test + void shouldSearchEnnWithDerivedFilter() { + + SearchResults results = repository.searchByCountryAndEmbeddingNear("de", VECTOR, Score.of(0.98), + Limit.of(10)); + + assertThat(results).hasSize(2).extracting(SearchResult::getContent).extracting(WithVectorFields::getCountry) + .containsOnly("de", "de"); + + assertThat(results).extracting(SearchResult::getContent).extracting(WithVectorFields::getDescription) + .containsExactlyInAnyOrder("two", "one"); + } + + @Test + void shouldSearchEnnWithDerivedFilterWithoutScore() { + + SearchResults de = repository.searchByCountryAndEmbeddingNear("de", VECTOR, Similarity.of(0.4), + Limit.of(10)); + assertThat(de).hasSizeGreaterThanOrEqualTo(2); + + assertThat(repository.searchByCountryAndEmbeddingNear("de", VECTOR, Similarity.of(0.999), Limit.of(10))).hasSize(1); + } + + @Test + void shouldSearchEuclideanWithDerivedFilter() { + + SearchResults results = repository.searchEuclideanByCountryAndEmbeddingNear("de", VECTOR, + Limit.of(2)); + + assertThat(results).hasSize(2).extracting(SearchResult::getContent).extracting(WithVectorFields::getCountry) + .containsOnly("de", "de"); + + assertThat(results).extracting(SearchResult::getContent).extracting(WithVectorFields::getDescription) + .containsExactlyInAnyOrder("two", "one"); + } + + @Test + void shouldSearchEnnWithDerivedFilterWithin() { + + SearchResults results = repository.searchByCountryAndEmbeddingWithin("de", VECTOR, + Score.between(0.93, 0.98)); + + assertThat(results).hasSize(1); + for (SearchResult result : results) { + assertThat(result.getScore().getValue()).isBetween(0.93, 0.98); + } + } + + @Test + void shouldSearchEnnWithDerivedAndLimitedFilterWithin() { + + SearchResults results = repository.searchTop1ByCountryAndEmbeddingWithin("de", VECTOR, + Score.between(0.8, 1)); + + assertThat(results).hasSize(1); + + for (SearchResult result : results) { + assertThat(result.getScore().getValue()).isBetween(0.8, 1.0); + } + } + + static void initDocuments() { + + WithVectorFields w1 = new WithVectorFields("de", "one", Vector.of(0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f)); + WithVectorFields w2 = new WithVectorFields("de", "two", Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f)); + WithVectorFields w3 = new WithVectorFields("en", "three", + Vector.of(0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f)); + WithVectorFields w4 = new WithVectorFields("de", "four", + Vector.of(0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f)); + + template.insertAll(List.of(w1, w2, w3, w4)); + } + + static void initIndexes() { + + VectorIndex cosIndex = new VectorIndex("cos-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)).addFilter("country"); + + template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); + + VectorIndex euclideanIndex = new VectorIndex("euc-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.EUCLIDEAN).dimensions(5)).addFilter("country"); + + template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex); + template.awaitIndexCreation(WithVectorFields.class, cosIndex.getName()); + template.awaitIndexCreation(WithVectorFields.class, euclideanIndex.getName()); + } + + interface VectorSearchRepository extends CrudRepository { + + @VectorSearch(indexName = "cos-index", filter = "{country: ?0}", numCandidates = "#{10+10}", + searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchAnnotated(String country, Vector vector, + Score distance, Limit limit); + + @VectorSearch(indexName = "cos-index") + SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Score similarity, + Limit limit); + + @VectorSearch(indexName = "cos-index") + SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Limit limit); + + @VectorSearch(indexName = "euc-index") + SearchResults searchEuclideanByCountryAndEmbeddingNear(String country, Vector vector, + Limit limit); + + @VectorSearch(indexName = "cos-index", limit = "10") + SearchResults searchByCountryAndEmbeddingWithin(String country, Vector vector, + Range distance); + + @VectorSearch(indexName = "cos-index") + SearchResults searchTop1ByCountryAndEmbeddingWithin(String country, Vector vector, + Range distance); + + } + + @org.springframework.data.mongodb.core.mapping.Document(COLLECTION_NAME) + static class WithVectorFields { + + String id; + String country; + String description; + + Vector embedding; + + public WithVectorFields(String country, String description, Vector embedding) { + this.country = country; + this.description = description; + this.embedding = embedding; + } + + public String getId() { + return id; + } + + public String getCountry() { + return country; + } + + public String getDescription() { + return description; + } + + public Vector getEmbedding() { + return embedding; + } + + @Override + public String toString() { + return "WithVectorFields{" + "id='" + id + '\'' + ", country='" + country + '\'' + ", description='" + description + + '\'' + '}'; + } + } + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessorUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessorUnitTests.java index 1c856394d8..f0ffebde20 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessorUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessorUnitTests.java @@ -22,8 +22,10 @@ import org.bson.Document; import org.junit.jupiter.api.Test; + import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; +import org.springframework.data.domain.Score; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Metrics; import org.springframework.data.geo.Point; @@ -45,15 +47,15 @@ * @author Oliver Gierke * @author Christoph Strobl */ -public class MongoParametersParameterAccessorUnitTests { +class MongoParametersParameterAccessorUnitTests { - Distance DISTANCE = new Distance(2.5, Metrics.KILOMETERS); - RepositoryMetadata metadata = new DefaultRepositoryMetadata(PersonRepository.class); - MongoMappingContext context = new MongoMappingContext(); - ProjectionFactory factory = new SpelAwareProxyProjectionFactory(); + private Distance DISTANCE = Distance.of(2.5, Metrics.KILOMETERS); + private RepositoryMetadata metadata = new DefaultRepositoryMetadata(PersonRepository.class); + private MongoMappingContext context = new MongoMappingContext(); + private ProjectionFactory factory = new SpelAwareProxyProjectionFactory(); @Test - public void returnsUnboundedForDistanceIfNoneAvailable() throws NoSuchMethodException, SecurityException { + void returnsUnboundedForDistanceIfNoneAvailable() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -64,7 +66,7 @@ public void returnsUnboundedForDistanceIfNoneAvailable() throws NoSuchMethodExce } @Test - public void returnsDistanceIfAvailable() throws NoSuchMethodException, SecurityException { + void returnsDistanceIfAvailable() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class, Distance.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -75,7 +77,7 @@ public void returnsDistanceIfAvailable() throws NoSuchMethodException, SecurityE } @Test // DATAMONGO-973 - public void shouldReturnAsFullTextStringWhenNoneDefinedForMethod() throws NoSuchMethodException, SecurityException { + void shouldReturnAsFullTextStringWhenNoneDefinedForMethod() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class, Distance.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -86,7 +88,7 @@ public void shouldReturnAsFullTextStringWhenNoneDefinedForMethod() throws NoSuch } @Test // DATAMONGO-973 - public void shouldProperlyConvertTextCriteria() throws NoSuchMethodException, SecurityException { + void shouldProperlyConvertTextCriteria() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByFirstname", String.class, TextCriteria.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -98,13 +100,13 @@ public void shouldProperlyConvertTextCriteria() throws NoSuchMethodException, Se } @Test // DATAMONGO-1110 - public void shouldDetectMinAndMaxDistance() throws NoSuchMethodException, SecurityException { + void shouldDetectMinAndMaxDistance() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class, Range.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); - Distance min = new Distance(10, Metrics.KILOMETERS); - Distance max = new Distance(20, Metrics.KILOMETERS); + Distance min = Distance.of(10, Metrics.KILOMETERS); + Distance max = Distance.of(20, Metrics.KILOMETERS); MongoParameterAccessor accessor = new MongoParametersParameterAccessor(queryMethod, new Object[] { new Point(10, 20), Distance.between(min, max) }); @@ -116,7 +118,7 @@ public void shouldDetectMinAndMaxDistance() throws NoSuchMethodException, Securi } @Test // DATAMONGO-1854 - public void shouldDetectCollation() throws NoSuchMethodException, SecurityException { + void shouldDetectCollation() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByFirstname", String.class, Collation.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -129,7 +131,7 @@ public void shouldDetectCollation() throws NoSuchMethodException, SecurityExcept } @Test // GH-2107 - public void shouldReturnUpdateIfPresent() throws NoSuchMethodException, SecurityException { + void shouldReturnUpdateIfPresent() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findAndModifyByFirstname", String.class, UpdateDefinition.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -142,7 +144,7 @@ public void shouldReturnUpdateIfPresent() throws NoSuchMethodException, Security } @Test // GH-2107 - public void shouldReturnNullIfNoUpdatePresent() throws NoSuchMethodException, SecurityException { + void shouldReturnNullIfNoUpdatePresent() throws NoSuchMethodException, SecurityException { Method method = PersonRepository.class.getMethod("findByLocationNear", Point.class); MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); @@ -153,6 +155,23 @@ public void shouldReturnNullIfNoUpdatePresent() throws NoSuchMethodException, Se assertThat(accessor.getUpdate()).isNull(); } + @Test // GH- + void shouldReturnRangeFromScore() throws NoSuchMethodException, SecurityException { + + Method method = PersonRepository.class.getMethod("findByFirstname", String.class, Score.class); + MongoQueryMethod queryMethod = new MongoQueryMethod(method, metadata, factory, context); + + MongoParameterAccessor accessor = new MongoParametersParameterAccessor(queryMethod, + new Object[] { "foo", Score.of(1) }); + + Range scoreRange = accessor.getScoreRange(); + + assertThat(scoreRange).isNotNull(); + assertThat(scoreRange.getLowerBound().isBounded()).isFalse(); + assertThat(scoreRange.getUpperBound().isBounded()).isTrue(); + assertThat(scoreRange.getUpperBound().getValue()).contains(Score.of(1)); + } + interface PersonRepository extends Repository { List findByLocationNear(Point point); @@ -165,6 +184,8 @@ interface PersonRepository extends Repository { List findByFirstname(String firstname, Collation collation); + List findByFirstname(String firstname, Score score); + List findAndModifyByFirstname(String firstname, UpdateDefinition update); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersUnitTests.java index 93674e23fc..fc1ffb971e 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoParametersUnitTests.java @@ -27,6 +27,8 @@ import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoResults; import org.springframework.data.geo.Point; @@ -43,6 +45,7 @@ * * @author Oliver Gierke * @author Christoph Strobl + * @author Mark Paluch */ @ExtendWith(MockitoExtension.class) class MongoParametersUnitTests { @@ -184,6 +187,21 @@ void shouldReturnInvalidIndexIfUpdateDoesNotExist() throws NoSuchMethodException assertThat(parameters.getUpdateIndex()).isEqualTo(-1); } + @Test // GH-2107 + void shouldOmitVector() throws NoSuchMethodException, SecurityException { + + Method method = PersonRepository.class.getMethod("shouldOmitVector", Vector.class, Score.class, + Range.class, String.class); + MongoParameters parameters = new MongoParameters(ParametersSource.of(method), false); + + assertThat(parameters.getVectorIndex()).isEqualTo(0); + assertThat(parameters.getScoreIndex()).isEqualTo(1); + assertThat(parameters.getScoreRangeIndex()).isEqualTo(2); + + MongoParameters bindableParameters = parameters.getBindableParameters(); + assertThat(bindableParameters).hasSize(3); + } + interface PersonRepository { List findByLocationNear(Point point, Distance distance); @@ -205,5 +223,8 @@ interface PersonRepository { List findByText(String text, Collation collation); List findAndModifyByFirstname(String firstname, UpdateDefinition update, Pageable page); + + List shouldOmitVector(Vector vector, Score distance, Range range, + String country); } } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryCreatorUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryCreatorUnitTests.java index 609e0a0018..55e3df6b43 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryCreatorUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryCreatorUnitTests.java @@ -29,6 +29,7 @@ import org.bson.types.ObjectId; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; import org.springframework.data.geo.Distance; @@ -120,7 +121,7 @@ void createsIsNullQueryCorrectly() { void bindsMetricDistanceParameterToNearSphereCorrectly() throws Exception { Point point = new Point(10, 20); - Distance distance = new Distance(2.5, Metrics.KILOMETERS); + Distance distance = Distance.of(2.5, Metrics.KILOMETERS); Query query = query( where("location").nearSphere(point).maxDistance(distance.getNormalizedValue()).and("firstname").is("Dave")); @@ -131,7 +132,7 @@ void bindsMetricDistanceParameterToNearSphereCorrectly() throws Exception { void bindsDistanceParameterToNearCorrectly() throws Exception { Point point = new Point(10, 20); - Distance distance = new Distance(2.5); + Distance distance = Distance.of(2.5); Query query = query( where("location").near(point).maxDistance(distance.getNormalizedValue()).and("firstname").is("Dave")); @@ -405,7 +406,7 @@ void shouldCreateRegexWhenUsingNotContainsOnStringProperty() { void createsNonSphericalNearForDistanceWithDefaultMetric() { Point point = new Point(1.0, 1.0); - Distance distance = new Distance(1.0); + Distance distance = Distance.of(1.0); PartTree tree = new PartTree("findByLocationNear", Venue.class); MongoQueryCreator creator = new MongoQueryCreator(tree, getAccessor(converter, point, distance), context); @@ -445,7 +446,7 @@ void shouldCreateNearSphereQueryForSphericalProperty() { void shouldCreateNearSphereQueryForSphericalPropertyHavingDistanceWithDefaultMetric() { Point point = new Point(1.0, 1.0); - Distance distance = new Distance(1.0); + Distance distance = Distance.of(1.0); PartTree tree = new PartTree("findByAddress2dSphere_GeoNear", User.class); MongoQueryCreator creator = new MongoQueryCreator(tree, getAccessor(converter, point, distance), context); @@ -458,7 +459,7 @@ void shouldCreateNearSphereQueryForSphericalPropertyHavingDistanceWithDefaultMet void shouldCreateNearQueryForMinMaxDistance() { Point point = new Point(10, 20); - Range range = Distance.between(new Distance(10), new Distance(20)); + Range range = Distance.between(Distance.of(10), Distance.of(20)); PartTree tree = new PartTree("findByAddress_GeoNear", User.class); MongoQueryCreator creator = new MongoQueryCreator(tree, getAccessor(converter, point, range), context); @@ -664,7 +665,7 @@ void nearShouldUseMetricDistanceForGeoJsonTypes() { GeoJsonPoint point = new GeoJsonPoint(27.987901, 86.9165379); PartTree tree = new PartTree("findByLocationNear", User.class); MongoQueryCreator creator = new MongoQueryCreator(tree, - getAccessor(converter, point, new Distance(1, Metrics.KILOMETERS)), context); + getAccessor(converter, point, Distance.of(1, Metrics.KILOMETERS)), context); assertThat(creator.createQuery()).isEqualTo(query(where("location").nearSphere(point).maxDistance(1000.0D))); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java index dbd17aa805..2c0c996bc3 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryExecutionUnitTests.java @@ -32,6 +32,7 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; + import org.springframework.data.domain.PageRequest; import org.springframework.data.domain.Pageable; import org.springframework.data.geo.Distance; @@ -86,7 +87,7 @@ class MongoQueryExecutionUnitTests { @Mock DbRefResolver dbRefResolver; private Point POINT = new Point(10, 20); - private Distance DISTANCE = new Distance(2.5, Metrics.KILOMETERS); + private Distance DISTANCE = Distance.of(2.5, Metrics.KILOMETERS); private RepositoryMetadata metadata = new DefaultRepositoryMetadata(PersonRepository.class); private MongoMappingContext context = new MongoMappingContext(); private ProjectionFactory factory = new SpelAwareProxyProjectionFactory(); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryMethodUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryMethodUnitTests.java index 8f9824e14d..386d0fa4b5 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryMethodUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/MongoQueryMethodUnitTests.java @@ -23,6 +23,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.data.domain.Pageable; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoPage; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecutionUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecutionUnitTests.java index d7a3430048..1fbd60414a 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecutionUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecutionUnitTests.java @@ -71,7 +71,7 @@ public void geoNearExecutionShouldApplyQuerySettings() throws Exception { Query query = new Query(); when(parameterAccessor.getGeoNearLocation()).thenReturn(new Point(1, 2)); when(parameterAccessor.getDistanceRange()) - .thenReturn(Range.from(Bound.inclusive(new Distance(10))).to(Bound.inclusive(new Distance(15)))); + .thenReturn(Range.from(Bound.inclusive(Distance.of(10))).to(Bound.inclusive(Distance.of(15)))); when(parameterAccessor.getPageable()).thenReturn(PageRequest.of(1, 10)); new GeoNearExecution(operations, parameterAccessor, TypeInformation.fromReturnTypeOf(geoNear)).execute(query, @@ -83,8 +83,8 @@ public void geoNearExecutionShouldApplyQuerySettings() throws Exception { NearQuery nearQuery = queryArgumentCaptor.getValue(); assertThat(nearQuery.toDocument().get("near")).isEqualTo(Arrays.asList(1d, 2d)); assertThat(nearQuery.getSkip()).isEqualTo(10L); - assertThat(nearQuery.getMinDistance()).isEqualTo(new Distance(10)); - assertThat(nearQuery.getMaxDistance()).isEqualTo(new Distance(15)); + assertThat(nearQuery.getMinDistance()).isEqualTo(Distance.of(10)); + assertThat(nearQuery.getMaxDistance()).isEqualTo(Distance.of(15)); } @Test // DATAMONGO-1444 diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryMethodUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryMethodUnitTests.java index 82cd0a157c..14cbbc0394 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryMethodUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryMethodUnitTests.java @@ -17,7 +17,6 @@ import static org.assertj.core.api.Assertions.*; -import org.springframework.data.mongodb.repository.query.MongoQueryMethodUnitTests.PersonRepository; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -27,6 +26,7 @@ import org.assertj.core.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StubParameterAccessor.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StubParameterAccessor.java index 3ed7ace0f9..91f23bb049 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StubParameterAccessor.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/StubParameterAccessor.java @@ -19,11 +19,14 @@ import java.util.Iterator; import org.jspecify.annotations.Nullable; + import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; +import org.springframework.data.domain.Score; import org.springframework.data.domain.ScrollPosition; import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.convert.MongoWriter; @@ -73,6 +76,21 @@ public StubParameterAccessor(Object... values) { } } + @Override + public Vector getVector() { + return null; + } + + @Override + public @org.jspecify.annotations.Nullable Score getScore() { + return null; + } + + @Override + public @org.jspecify.annotations.Nullable Range getScoreRange() { + return null; + } + @Override public ScrollPosition getScrollPosition() { return null; diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java new file mode 100644 index 0000000000..c1aa7cfff9 --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java @@ -0,0 +1,102 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import java.lang.reflect.Method; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.projection.ProjectionFactory; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.CrudRepository; +import org.springframework.data.repository.core.support.DefaultRepositoryMetadata; +import org.springframework.data.repository.query.ValueExpressionDelegate; + +/** + * Unit tests for {@link VectorSearchAggregation}. + * + * @author Mark Paluch + */ +class VectorSearchAggregationUnitTests { + + MongoOperations operationsMock; + MongoMappingContext context; + MappingMongoConverter converter; + + @BeforeEach + public void setUp() { + context = new MongoMappingContext(); + converter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, context); + operationsMock = Mockito.mock(MongoOperations.class); + when(operationsMock.getConverter()).thenReturn(converter); + } + + @Test + void derivesPrefilter() throws Exception { + + VectorSearchAggregation aggregation = aggregation(SampleRepository.class, "searchByCountryAndEmbeddingNear", + String.class, Vector.class, Score.class, Limit.class); + + VectorSearchAggregation.VectorSearchQuery query = aggregation + .createVectorSearchQuery(new MongoParametersParameterAccessor(aggregation.getQueryMethod(), + new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() })); + + assertThat(query.query().getQueryObject()).containsEntry("country", "de"); + } + + private VectorSearchAggregation aggregation(Class repository, String name, Class... parameters) + throws Exception { + + Method method = repository.getMethod(name, parameters); + ProjectionFactory factory = new SpelAwareProxyProjectionFactory(); + MongoQueryMethod queryMethod = new MongoQueryMethod(method, new DefaultRepositoryMetadata(repository), factory, + context); + return new VectorSearchAggregation(queryMethod, operationsMock, ValueExpressionDelegate.create()); + } + + interface SampleRepository extends CrudRepository { + + @VectorSearch(indexName = "cos-index") + SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Score similarity, + Limit limit); + + } + + static class WithVectorFields { + + String id; + String country; + String description; + + Vector embedding; + + } + +} diff --git a/spring-data-mongodb/src/test/kotlin/org/springframework/data/mongodb/core/ReactiveFindOperationExtensionsTests.kt b/spring-data-mongodb/src/test/kotlin/org/springframework/data/mongodb/core/ReactiveFindOperationExtensionsTests.kt index cbb7ae46f3..99d57002e4 100644 --- a/spring-data-mongodb/src/test/kotlin/org/springframework/data/mongodb/core/ReactiveFindOperationExtensionsTests.kt +++ b/spring-data-mongodb/src/test/kotlin/org/springframework/data/mongodb/core/ReactiveFindOperationExtensionsTests.kt @@ -270,9 +270,9 @@ class ReactiveFindOperationExtensionsTests { fun terminatingFindNearAllAsFlow() { val spec = mockk>() - val foo = GeoResult("foo", Distance(0.0)) - val bar = GeoResult("bar", Distance(0.0)) - val baz = GeoResult("baz", Distance(0.0)) + val foo = GeoResult("foo", Distance.of(0.0)) + val bar = GeoResult("bar", Distance.of(0.0)) + val baz = GeoResult("baz", Distance.of(0.0)) every { spec.all() } returns Flux.just(foo, bar, baz) runBlocking { From 642396c7e58df4b6048484b26f53c2b6bd97bd83 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 29 Apr 2025 08:55:06 +0200 Subject: [PATCH 03/10] Add reactive search support. --- .../aot/MongoRepositoryContributor.java | 17 +- .../repository/query/MongoQueryExecution.java | 127 ++---- .../query/ReactiveMongoQueryExecution.java | 58 +++ .../ReactiveVectorSearchAggregation.java | 123 ++++++ .../query/VectorSearchAggregation.java | 220 +--------- .../query/VectorSearchDelegate.java | 389 ++++++++++++++++++ .../ReactiveMongoRepositoryFactory.java | 4 + .../repository/ReactiveVectorSearchTests.java | 224 ++++++++++ .../mongodb/repository/VectorSearchTests.java | 40 +- .../VectorSearchAggregationUnitTests.java | 12 +- 10 files changed, 888 insertions(+), 326 deletions(-) create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java create mode 100644 spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java create mode 100644 spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index def03c7973..a9368615b0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -15,27 +15,20 @@ */ package org.springframework.data.mongodb.repository.aot; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationExecutionBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.deleteExecutionBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryExecutionBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateBlockBuilder; -import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateExecutionBlockBuilder; - import java.lang.reflect.Method; import java.util.regex.Pattern; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.jspecify.annotations.Nullable; + import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.aggregation.AggregationUpdate; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.Update; -import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.QueryCodeBlockBuilder; +import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder; import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata; @@ -178,10 +171,11 @@ private QueryInteraction createStringQuery(RepositoryInformation repositoryInfor private static boolean backoff(MongoQueryMethod method) { - boolean skip = method.isGeoNearQuery() || method.isScrollQuery() || method.isStreamQuery(); + boolean skip = method.isGeoNearQuery() || method.isScrollQuery() || method.isStreamQuery() + || method.isSearchQuery(); if (skip && logger.isDebugEnabled()) { - logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming or scrolling query" + logger.debug("Skipping AOT generation for [%s]. Method is either geo-near, streaming, search or scrolling query" .formatted(method.getName())); } return skip; @@ -193,7 +187,6 @@ private static MethodContributor aggregationMethodContributor( return MethodContributor.forQueryMethod(queryMethod).withMetadata(aggregation).contribute(context -> { CodeBlock.Builder builder = CodeBlock.builder(); - builder.add(context.codeBlocks().logDebug("invoking [%s]".formatted(context.getMethod().getName()))); builder.add(aggregationBlockBuilder(context, queryMethod).stages(aggregation) .usingAggregationVariableName("aggregation").build()); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java index 7f632f58e4..d9a91434ce 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java @@ -16,6 +16,7 @@ package org.springframework.data.mongodb.repository.query; import java.util.ArrayList; +import java.util.Collection; import java.util.Iterator; import java.util.List; import java.util.function.Supplier; @@ -26,13 +27,11 @@ import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; -import org.springframework.data.domain.Score; import org.springframework.data.domain.SearchResult; import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Similarity; import org.springframework.data.domain.Slice; import org.springframework.data.domain.SliceImpl; -import org.springframework.data.domain.Sort; -import org.springframework.data.domain.Vector; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResult; @@ -46,11 +45,9 @@ import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; import org.springframework.data.mongodb.core.MongoOperations; -import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; -import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; @@ -225,7 +222,7 @@ private static boolean isListOfGeoResult(TypeInformation returnType) { } /** - * {@link MongoQueryExecution} to execute vector search + * {@link MongoQueryExecution} to execute vector search. * * @author Mark Paluch * @since 5.0 @@ -235,118 +232,64 @@ class VectorSearchExecution implements MongoQueryExecution { private final MongoOperations operations; private final MongoQueryMethod method; private final String collectionName; - private final @Nullable Integer numCandidates; - private final VectorSearchOperation.SearchType searchType; - private final MongoParameterAccessor accessor; - private final Class outputType; - private final String path; + private final VectorSearchDelegate.QueryMetadata queryMetadata; + private final List pipeline; public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, - String path, @Nullable Integer numCandidates, VectorSearchOperation.SearchType searchType, - MongoParameterAccessor accessor, Class outputType) { + VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) { this.operations = operations; this.collectionName = collectionName; - this.path = path; - this.numCandidates = numCandidates; + this.queryMetadata = queryMetadata; this.method = method; - this.searchType = searchType; - this.accessor = accessor; - this.outputType = outputType; + this.pipeline = queryMetadata.getAggregationPipeline(method, accessor); } @Override public Object execute(Query query) { - SearchResults results = doExecuteQuery(query); - return isListOfSearchResult(method.getReturnType()) ? results.getContent() : results; - } + AggregationResults aggregated = operations.aggregate( + TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collectionName, + queryMetadata.outputType()); - @SuppressWarnings("unchecked") - SearchResults doExecuteQuery(Query query) { + List mappedResults = aggregated.getMappedResults(); - Vector vector = accessor.getVector(); - Score score = accessor.getScore(); - Range distance = accessor.getScoreRange(); - int limit; + if (isSearchResult(method.getReturnType())) { - if (query.isLimited()) { - limit = query.getLimit(); - } else { - limit = Math.max(1, numCandidates != null ? numCandidates / 20 : 1); - } + List rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class); + List> result = new ArrayList<>(mappedResults.size()); - List stages = new ArrayList<>(); - VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(method.getAnnotatedHint()).path(path) - .vector(vector).limit(limit); + for (int i = 0; i < mappedResults.size(); i++) { + Document document = rawResults.get(i); + SearchResult searchResult = new SearchResult<>(mappedResults.get(i), + Similarity.raw(document.getDouble("__score__"), queryMetadata.scoringFunction())); - if (numCandidates != null) { - $vectorSearch = $vectorSearch.numCandidates(numCandidates); - } + result.add(searchResult); + } - $vectorSearch = $vectorSearch.filter(query.getQueryObject()); - $vectorSearch = $vectorSearch.searchType(searchType); - $vectorSearch = $vectorSearch.withSearchScore("__score__"); - - if (score != null) { - $vectorSearch = $vectorSearch.withFilterBySore(c -> { - c.gt(score.getValue()); - }); - } else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) { - $vectorSearch = $vectorSearch.withFilterBySore(c -> { - Range.Bound lower = distance.getLowerBound(); - if (lower.isBounded()) { - double value = lower.getValue().get().getValue(); - if (lower.isInclusive()) { - c.gte(value); - } else { - c.gt(value); - } - } - - Range.Bound upper = distance.getUpperBound(); - if (upper.isBounded()) { - - double value = upper.getValue().get().getValue(); - if (upper.isInclusive()) { - c.lte(value); - } else { - c.lt(value); - } - } - }); + return isListOfSearchResult(method.getReturnType()) ? result : new SearchResults<>(result); } - stages.add($vectorSearch); - - if (query.isSorted()) { - // TODO stages.add(Aggregation.sort(query.with())); - } else { - stages.add(Aggregation.sort(Sort.Direction.DESC, "__score__")); - } - - AggregationResults aggregated = operations - .aggregate(TypedAggregation. newAggregation(outputType, stages), collectionName, outputType); - - List mappedResults = aggregated.getMappedResults(); - List rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class); - - List> result = new ArrayList<>(mappedResults.size()); + return mappedResults; + } - for (int i = 0; i < mappedResults.size(); i++) { - Document document = rawResults.get(i); - SearchResult searchResult = new SearchResult<>(mappedResults.get(i), - Score.of(document.getDouble("__score__"))); + private static boolean isListOfSearchResult(TypeInformation returnType) { - result.add(searchResult); + if (!Collection.class.isAssignableFrom(returnType.getType())) { + return false; } - return new SearchResults<>(result); + TypeInformation componentType = returnType.getComponentType(); + return componentType != null && SearchResult.class.equals(componentType.getType()); } - private static boolean isListOfSearchResult(TypeInformation returnType) { + private static boolean isSearchResult(TypeInformation returnType) { - if (!returnType.getType().equals(List.class)) { + if (SearchResults.class.isAssignableFrom(returnType.getType())) { + return true; + } + + if (!Iterable.class.isAssignableFrom(returnType.getType())) { return false; } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java index f9b47c9a84..389f4e871d 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java @@ -18,6 +18,9 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import java.util.List; + +import org.bson.Document; import org.jspecify.annotations.Nullable; import org.reactivestreams.Publisher; @@ -25,12 +28,16 @@ import org.springframework.data.convert.DtoInstantiatingConverter; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.Similarity; import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.Point; import org.springframework.data.mapping.model.EntityInstantiators; import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate; +import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; @@ -118,6 +125,57 @@ private boolean isStreamOfGeoResult() { } } + /** + * {@link ReactiveMongoQueryExecution} to execute vector search. + * + * @author Mark Paluch + * @since 5.0 + */ + class VectorSearchExecution implements ReactiveMongoQueryExecution { + + private final ReactiveMongoOperations operations; + private final VectorSearchDelegate.QueryMetadata queryMetadata; + private final List pipeline; + private final boolean returnSearchResult; + + public VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, + VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) { + + this.operations = operations; + this.queryMetadata = queryMetadata; + this.pipeline = queryMetadata.getAggregationPipeline(method, accessor); + this.returnSearchResult = isSearchResult(method.getReturnType()); + } + + @Override + public Publisher execute(Query query, Class type, String collection) { + + Flux aggregate = operations + .aggregate(TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collection, Document.class); + + return aggregate.map(document -> { + + Object mappedResult = operations.getConverter().read(queryMetadata.outputType(), document); + + return returnSearchResult + ? new SearchResult<>(mappedResult, + Similarity.raw(document.getDouble(queryMetadata.scoreField()), queryMetadata.scoringFunction())) + : mappedResult; + }); + } + + private static boolean isSearchResult(TypeInformation returnType) { + + if (!Publisher.class.isAssignableFrom(returnType.getType())) { + return false; + } + + TypeInformation componentType = returnType.getComponentType(); + return componentType != null && SearchResult.class.equals(componentType.getType()); + } + + } + /** * {@link ReactiveMongoQueryExecution} removing documents matching the query. * diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java new file mode 100644 index 0000000000..1ecbb0235f --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java @@ -0,0 +1,123 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import reactor.core.publisher.Mono; + +import org.bson.Document; +import org.reactivestreams.Publisher; + +import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; +import org.springframework.data.mongodb.core.MongoOperations; +import org.springframework.data.mongodb.core.ReactiveMongoOperations; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.query.Query; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.repository.query.ValueExpressionDelegate; +import org.springframework.data.spel.ExpressionDependencies; + +/** + * {@link AbstractReactiveMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either + * derived from the method name or provided through {@link VectorSearch#filter()}. + * + * @author Mark Paluch + * @since 5.0 + */ +public class ReactiveVectorSearchAggregation extends AbstractReactiveMongoQuery { + + private final ReactiveMongoOperations mongoOperations; + private final MongoPersistentEntity collectionEntity; + private final ValueExpressionDelegate valueExpressionDelegate; + private final VectorSearchDelegate delegate; + + /** + * Creates a new {@link ReactiveVectorSearchAggregation} from the given {@link MongoQueryMethod} and + * {@link MongoOperations}. + * + * @param method must not be {@literal null}. + * @param mongoOperations must not be {@literal null}. + * @param delegate must not be {@literal null}. + */ + public ReactiveVectorSearchAggregation(ReactiveMongoQueryMethod method, ReactiveMongoOperations mongoOperations, + ValueExpressionDelegate delegate) { + + super(method, mongoOperations, delegate); + + this.valueExpressionDelegate = delegate; + if (!method.isSearchQuery() && !method.isCollectionQuery()) { + throw new InvalidMongoDbApiUsageException(String.format( + "Repository Vector Search method '%s' must return either return SearchResults or List but was %s", + method.getName(), method.getReturnType().getType().getSimpleName())); + } + + this.mongoOperations = mongoOperations; + this.collectionEntity = method.getEntityInformation().getCollectionEntity(); + this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate); + } + + @Override + protected Publisher doExecute(ReactiveMongoQueryMethod method, ResultProcessor processor, + ConvertingParameterAccessor accessor, @org.jspecify.annotations.Nullable Class typeToRead) { + + return getParameterBindingCodec().flatMapMany(codec -> { + + String json = delegate.getQueryString(); + ExpressionDependencies dependencies = codec.captureExpressionDependencies(json, accessor::getBindableValue, + valueExpressionDelegate); + + return getValueExpressionEvaluatorLater(dependencies, accessor).flatMapMany(expressionEvaluator -> { + + ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue, + expressionEvaluator); + VectorSearchDelegate.QueryMetadata query = delegate.createQuery(expressionEvaluator, processor, accessor, + typeToRead, codec, bindingContext); + + ReactiveMongoQueryExecution.VectorSearchExecution execution = new ReactiveMongoQueryExecution.VectorSearchExecution( + mongoOperations, method, query, accessor); + + return execution.execute(query.query(), Document.class, collectionEntity.getCollection()); + }); + }); + } + + @Override + protected Mono createQuery(ConvertingParameterAccessor accessor) { + throw new UnsupportedOperationException(); + } + + @Override + protected boolean isCountQuery() { + return false; + } + + @Override + protected boolean isExistsQuery() { + return false; + } + + @Override + protected boolean isDeleteQuery() { + return false; + } + + @Override + protected boolean isLimiting() { + return false; + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java index 2f0d0258d1..9740c0696c 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java @@ -15,30 +15,16 @@ */ package org.springframework.data.mongodb.repository.query; -import org.bson.Document; - -import org.springframework.data.domain.Limit; -import org.springframework.data.domain.Sort; -import org.springframework.data.domain.Vector; -import org.springframework.data.expression.ValueExpression; -import org.springframework.data.mapping.PersistentPropertyPath; -import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mapping.model.ValueExpressionEvaluator; import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.core.MongoOperations; -import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; -import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; -import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; -import org.springframework.data.mongodb.core.query.BasicQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ValueExpressionDelegate; -import org.springframework.data.repository.query.parser.Part; -import org.springframework.data.repository.query.parser.PartTree; import org.springframework.lang.Nullable; -import org.springframework.util.StringUtils; /** * {@link AbstractMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either derived @@ -50,15 +36,8 @@ public class VectorSearchAggregation extends AbstractMongoQuery { private final MongoOperations mongoOperations; - private final MongoConverter mongoConverter; private final MongoPersistentEntity collectionEntity; - private final VectorSearchQueryFactory queryFactory; - private final VectorSearchOperation.SearchType searchType; - private final @Nullable Integer numCandidates; - private final @Nullable String numCandidatesExpression; - - private final Limit limit; - private final @Nullable String limitExpression; + private final VectorSearchDelegate delegate; /** * Creates a new {@link VectorSearchAggregation} from the given {@link MongoQueryMethod} and {@link MongoOperations}. @@ -79,56 +58,8 @@ public VectorSearchAggregation(MongoQueryMethod method, MongoOperations mongoOpe } this.mongoOperations = mongoOperations; - this.mongoConverter = mongoOperations.getConverter(); this.collectionEntity = method.getEntityInformation().getCollectionEntity(); - - VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow(); - - this.searchType = vectorSearch.searchType(); - - if (StringUtils.hasText(vectorSearch.numCandidates())) { - - ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates()); - - if (expression.isLiteral()) { - numCandidates = Integer.parseInt(vectorSearch.numCandidates()); - numCandidatesExpression = null; - } else { - numCandidates = null; - numCandidatesExpression = vectorSearch.numCandidates(); - } - - } else { - numCandidates = null; - numCandidatesExpression = null; - } - - if (StringUtils.hasText(vectorSearch.limit())) { - - ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit()); - - if (expression.isLiteral()) { - limit = Limit.of(Integer.parseInt(vectorSearch.limit())); - limitExpression = null; - } else { - limit = Limit.unlimited(); - limitExpression = vectorSearch.limit(); - } - - } else { - limit = Limit.unlimited(); - limitExpression = null; - } - - if (StringUtils.hasText(vectorSearch.filter())) { - queryFactory = StringUtils.hasText(vectorSearch.path()) - ? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path()) - : new AnnotatedQueryFactory(vectorSearch.filter(), collectionEntity); - } else { - queryFactory = new PartTreeQueryFactory( - new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()), - mongoConverter.getMappingContext()); - } + this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate); } @SuppressWarnings("unchecked") @@ -136,42 +67,21 @@ public VectorSearchAggregation(MongoQueryMethod method, MongoOperations mongoOpe protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor, @Nullable Class typeToRead) { - ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor); - Integer numCandidates = null; - Limit limit; - Class outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType(); - - if (this.numCandidatesExpression != null) { - numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); - } else if (this.numCandidates != null) { - numCandidates = this.numCandidates; - } - - if (this.limitExpression != null) { - - Object value = evaluator.evaluate(this.limitExpression); - limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); - } else if (this.limit.isLimited()) { - limit = this.limit; - } else { - limit = accessor.getLimit(); - } - - VectorSearchQuery query = createVectorSearchQuery(accessor); - - if (limit.isLimited()) { - query.query().limit(limit); - } + VectorSearchDelegate.QueryMetadata query = createVectorSearchQuery(processor, accessor, typeToRead); MongoQueryExecution.VectorSearchExecution execution = new MongoQueryExecution.VectorSearchExecution(mongoOperations, - method, collectionEntity.getCollection(), query.path(), numCandidates, searchType, accessor, - (Class) outputType); + method, collectionEntity.getCollection(), query, accessor); return execution.execute(query.query()); } - VectorSearchQuery createVectorSearchQuery(MongoParameterAccessor accessor) { - return queryFactory.createQuery(accessor); + VectorSearchDelegate.QueryMetadata createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor, + @Nullable Class typeToRead) { + + ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor); + ParameterBindingContext bindingContext = prepareBindingContext(delegate.getQueryString(), accessor); + + return delegate.createQuery(evaluator, processor, accessor, typeToRead, getParameterBindingCodec(), bindingContext); } @Override @@ -199,110 +109,4 @@ protected boolean isLimiting() { return false; } - interface VectorSearchQueryFactory { - - VectorSearchQuery createQuery(MongoParameterAccessor parameterAccessor); - } - - class AnnotatedQueryFactory implements VectorSearchQueryFactory { - - private final String query; - private final String path; - - AnnotatedQueryFactory(String query, String path) { - - this.query = query; - this.path = path; - } - - AnnotatedQueryFactory(String query, MongoPersistentEntity entity) { - - this.query = query; - String path = null; - for (MongoPersistentProperty property : entity) { - if (Vector.class.isAssignableFrom(property.getType())) { - path = property.getFieldName(); - break; - } - } - - if (path == null) { - throw new InvalidMongoDbApiUsageException( - "Cannot find Vector Search property in entity [%s]".formatted(entity.getName())); - } - - this.path = path; - } - - public VectorSearchQuery createQuery(MongoParameterAccessor parameterAccessor) { - - Document queryObject = decode(this.query, prepareBindingContext(this.query, parameterAccessor)); - Query query = new BasicQuery(queryObject); - - Sort sort = parameterAccessor.getSort(); - if (sort.isSorted()) { - query = query.with(sort); - } - - return new VectorSearchQuery(path, query); - } - - } - - class PartTreeQueryFactory implements VectorSearchQueryFactory { - - private final String path; - private final Part.Type type; - private final MappingContext context; - private final PartTree partTree; - - @SuppressWarnings("NullableProblems") - PartTreeQueryFactory(PartTree partTree, MappingContext context) { - - String path = null; - Part.Type type = null; - for (PartTree.OrPart part : partTree) { - for (Part p : part) { - if (p.getType() == Part.Type.SIMPLE_PROPERTY || p.getType() == Part.Type.NEAR - || p.getType() == Part.Type.WITHIN || p.getType() == Part.Type.BETWEEN) { - PersistentPropertyPath ppp = context.getPersistentPropertyPath(p.getProperty()); - MongoPersistentProperty property = ppp.getLeafProperty(); - - if (Vector.class.isAssignableFrom(property.getType())) { - path = p.getProperty().toDotPath(); - type = p.getType(); - break; - } - } - } - } - - if (path == null) { - throw new InvalidMongoDbApiUsageException( - "No Simple Property/Near/Within/Between part found for a Vector property"); - } - - this.path = path; - this.type = type; - - this.partTree = partTree; - this.context = context; - } - - public VectorSearchQuery createQuery(MongoParameterAccessor parameterAccessor) { - - MongoQueryCreator creator = new MongoQueryCreator(partTree, parameterAccessor, mongoConverter.getMappingContext(), - false, true); - - Query query = creator.createQuery(parameterAccessor.getSort()); - - return new VectorSearchQuery(path, query); - } - - } - - record VectorSearchQuery(String path, Query query) { - - } - } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java new file mode 100644 index 0000000000..2a60da0911 --- /dev/null +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java @@ -0,0 +1,389 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import java.util.ArrayList; +import java.util.List; + +import org.bson.Document; + +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Range; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.ScoringFunction; +import org.springframework.data.domain.Sort; +import org.springframework.data.domain.Vector; +import org.springframework.data.expression.ValueExpression; +import org.springframework.data.mapping.PersistentPropertyPath; +import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.mapping.model.ValueExpressionEvaluator; +import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.convert.MongoConverter; +import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; +import org.springframework.data.mongodb.core.mapping.MongoPersistentProperty; +import org.springframework.data.mongodb.core.query.BasicQuery; +import org.springframework.data.mongodb.core.query.Query; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; +import org.springframework.data.repository.query.ResultProcessor; +import org.springframework.data.repository.query.ValueExpressionDelegate; +import org.springframework.data.repository.query.parser.Part; +import org.springframework.data.repository.query.parser.PartTree; +import org.springframework.lang.Nullable; +import org.springframework.util.StringUtils; + +/** + * Delegate to assemble information about Vector Search queries necessary to run a MongoDB {@code $vectorSearch}. + * + * @author Mark Paluch + */ +class VectorSearchDelegate { + + private final VectorSearchQueryFactory queryFactory; + private final VectorSearchOperation.SearchType searchType; + private final @Nullable Integer numCandidates; + private final @Nullable String numCandidatesExpression; + private final Limit limit; + private final @Nullable String limitExpression; + private final MongoConverter converter; + + public VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) { + + VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow(); + this.searchType = vectorSearch.searchType(); + + if (StringUtils.hasText(vectorSearch.numCandidates())) { + + ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates()); + + if (expression.isLiteral()) { + numCandidates = Integer.parseInt(vectorSearch.numCandidates()); + numCandidatesExpression = null; + } else { + numCandidates = null; + numCandidatesExpression = vectorSearch.numCandidates(); + } + + } else { + numCandidates = null; + numCandidatesExpression = null; + } + + if (StringUtils.hasText(vectorSearch.limit())) { + + ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit()); + + if (expression.isLiteral()) { + limit = Limit.of(Integer.parseInt(vectorSearch.limit())); + limitExpression = null; + } else { + limit = Limit.unlimited(); + limitExpression = vectorSearch.limit(); + } + + } else { + limit = Limit.unlimited(); + limitExpression = null; + } + + this.converter = converter; + + if (StringUtils.hasText(vectorSearch.filter())) { + queryFactory = StringUtils.hasText(vectorSearch.path()) + ? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path()) + : new AnnotatedQueryFactory(vectorSearch.filter(), method.getEntityInformation().getCollectionEntity()); + } else { + queryFactory = new PartTreeQueryFactory( + new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()), + converter.getMappingContext()); + } + } + + /** + * Create Query Metadata for {@code $vectorSearch}. + */ + public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor, + MongoParameterAccessor accessor, @Nullable Class typeToRead, ParameterBindingDocumentCodec codec, + ParameterBindingContext context) { + + Integer numCandidates = null; + Limit limit; + Class outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType(); + + if (this.numCandidatesExpression != null) { + numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); + } else if (this.numCandidates != null) { + numCandidates = this.numCandidates; + } + + if (this.limitExpression != null) { + + Object value = evaluator.evaluate(this.limitExpression); + limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); + } else if (this.limit.isLimited()) { + limit = this.limit; + } else { + limit = accessor.getLimit(); + } + + VectorSearchInput query = queryFactory.createQuery(accessor, codec, context); + + if (limit.isLimited()) { + query.query().limit(limit); + } + + return new QueryMetadata(query.path, "__score__", query.query, searchType, outputType, numCandidates, + getSimilarityFunction(accessor)); + } + + public String getQueryString() { + return queryFactory.getQueryString(); + } + + ScoringFunction getSimilarityFunction(MongoParameterAccessor accessor) { + + Score score = accessor.getScore(); + + if (score != null) { + return score.getFunction(); + } + + Range scoreRange = accessor.getScoreRange(); + + if (scoreRange != null) { + if (scoreRange.getUpperBound().isBounded()) { + return scoreRange.getUpperBound().getValue().get().getFunction(); + } + + if (scoreRange.getLowerBound().isBounded()) { + return scoreRange.getLowerBound().getValue().get().getFunction(); + } + } + + return ScoringFunction.unspecified(); + } + + /** + * Metadata for a Vector Search Aggregation. + * + * @param path + * @param query + * @param searchType + * @param outputType + * @param numCandidates + * @param scoringFunction + */ + public record QueryMetadata(String path, String scoreField, Query query, VectorSearchOperation.SearchType searchType, + Class outputType, @org.jspecify.annotations.Nullable Integer numCandidates, ScoringFunction scoringFunction) { + + /** + * Create the Aggregation Pipeline. + * + * @param queryMethod + * @param accessor + * @return + */ + public List getAggregationPipeline(MongoQueryMethod queryMethod, + MongoParameterAccessor accessor) { + + Vector vector = accessor.getVector(); + Score score = accessor.getScore(); + Range distance = accessor.getScoreRange(); + int limit; + + if (query.isLimited()) { + limit = query.getLimit(); + } else { + limit = Math.max(1, numCandidates() != null ? numCandidates() / 20 : 1); + } + + List stages = new ArrayList<>(); + VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(queryMethod.getAnnotatedHint()).path(path()) + .vector(vector).limit(limit); + + if (numCandidates() != null) { + $vectorSearch = $vectorSearch.numCandidates(numCandidates()); + } + + $vectorSearch = $vectorSearch.filter(query.getQueryObject()); + $vectorSearch = $vectorSearch.searchType(searchType()); + $vectorSearch = $vectorSearch.withSearchScore(scoreField()); + + if (score != null) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + c.gt(score.getValue()); + }); + } else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + Range.Bound lower = distance.getLowerBound(); + if (lower.isBounded()) { + double value = lower.getValue().get().getValue(); + if (lower.isInclusive()) { + c.gte(value); + } else { + c.gt(value); + } + } + + Range.Bound upper = distance.getUpperBound(); + if (upper.isBounded()) { + + double value = upper.getValue().get().getValue(); + if (upper.isInclusive()) { + c.lte(value); + } else { + c.lt(value); + } + } + }); + } + + stages.add($vectorSearch); + + if (query.isSorted()) { + // TODO stages.add(Aggregation.sort(query.with())); + } else { + stages.add(Aggregation.sort(Sort.Direction.DESC, "__score__")); + } + + return stages; + } + + } + + /** + * Strategy interface to implement a query factory for the Vector Search pre-filter query. + */ + private interface VectorSearchQueryFactory { + + VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, + ParameterBindingContext context); + + /** + * @return the underlying query string to determine {@link ParameterBindingContext}. + */ + String getQueryString(); + } + + private static class AnnotatedQueryFactory implements VectorSearchQueryFactory { + + private final String query; + private final String path; + + AnnotatedQueryFactory(String query, String path) { + + this.query = query; + this.path = path; + } + + AnnotatedQueryFactory(String query, MongoPersistentEntity entity) { + + this.query = query; + String path = null; + for (MongoPersistentProperty property : entity) { + if (Vector.class.isAssignableFrom(property.getType())) { + path = property.getFieldName(); + break; + } + } + + if (path == null) { + throw new InvalidMongoDbApiUsageException( + "Cannot find Vector Search property in entity [%s]".formatted(entity.getName())); + } + + this.path = path; + } + + public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, + ParameterBindingContext context) { + + Document queryObject = codec.decode(this.query, context); + Query query = new BasicQuery(queryObject); + + Sort sort = parameterAccessor.getSort(); + if (sort.isSorted()) { + query = query.with(sort); + } + + return new VectorSearchInput(path, query); + } + + @Override + public String getQueryString() { + return this.query; + } + } + + private class PartTreeQueryFactory implements VectorSearchQueryFactory { + + private final String path; + private final PartTree partTree; + + @SuppressWarnings("NullableProblems") + PartTreeQueryFactory(PartTree partTree, MappingContext context) { + + String path = null; + for (PartTree.OrPart part : partTree) { + for (Part p : part) { + if (p.getType() == Part.Type.SIMPLE_PROPERTY || p.getType() == Part.Type.NEAR + || p.getType() == Part.Type.WITHIN || p.getType() == Part.Type.BETWEEN) { + PersistentPropertyPath ppp = context.getPersistentPropertyPath(p.getProperty()); + MongoPersistentProperty property = ppp.getLeafProperty(); + + if (Vector.class.isAssignableFrom(property.getType())) { + path = p.getProperty().toDotPath(); + break; + } + } + } + } + + if (path == null) { + throw new InvalidMongoDbApiUsageException( + "No Simple Property/Near/Within/Between part found for a Vector property"); + } + + this.path = path; + this.partTree = partTree; + } + + public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, + ParameterBindingContext context) { + + MongoQueryCreator creator = new MongoQueryCreator(partTree, parameterAccessor, converter.getMappingContext(), + false, true); + + Query query = creator.createQuery(parameterAccessor.getSort()); + + return new VectorSearchInput(path, query); + } + + @Override + public String getQueryString() { + return ""; + } + } + + private record VectorSearchInput(String path, Query query) { + + } + +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java index ae8561bc17..11c5b09460 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/support/ReactiveMongoRepositoryFactory.java @@ -22,6 +22,7 @@ import java.util.Optional; import org.jspecify.annotations.Nullable; + import org.springframework.beans.factory.BeanFactory; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.mongodb.core.ReactiveMongoOperations; @@ -33,6 +34,7 @@ import org.springframework.data.mongodb.repository.query.ReactivePartTreeMongoQuery; import org.springframework.data.mongodb.repository.query.ReactiveStringBasedAggregation; import org.springframework.data.mongodb.repository.query.ReactiveStringBasedMongoQuery; +import org.springframework.data.mongodb.repository.query.ReactiveVectorSearchAggregation; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.querydsl.ReactiveQuerydslPredicateExecutor; import org.springframework.data.repository.core.NamedQueries; @@ -174,6 +176,8 @@ public RepositoryQuery resolveQuery(Method method, RepositoryMetadata metadata, if (namedQueries.hasQuery(namedQueryName)) { String namedQuery = namedQueries.getQuery(namedQueryName); return new ReactiveStringBasedMongoQuery(namedQuery, queryMethod, operations, delegate); + } else if (queryMethod.hasAnnotatedVectorSearch()) { + return new ReactiveVectorSearchAggregation(queryMethod, operations, delegate); } else if (queryMethod.hasAnnotatedAggregation()) { return new ReactiveStringBasedAggregation(queryMethod, operations, delegate); } else if (queryMethod.hasAnnotatedQuery()) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java new file mode 100644 index 0000000000..14a4749c8a --- /dev/null +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/ReactiveVectorSearchTests.java @@ -0,0 +1,224 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository; + +import static org.assertj.core.api.Assertions.*; + +import reactor.core.publisher.Flux; +import reactor.test.StepVerifier; + +import java.util.List; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.ComponentScan; +import org.springframework.context.annotation.FilterType; +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResult; +import org.springframework.data.domain.Similarity; +import org.springframework.data.domain.Vector; +import org.springframework.data.mongodb.core.ReactiveMongoTemplate; +import org.springframework.data.mongodb.core.SimpleReactiveMongoDatabaseFactory; +import org.springframework.data.mongodb.core.TestMongoConfiguration; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.index.VectorIndex; +import org.springframework.data.mongodb.core.index.VectorIndex.SimilarityFunction; +import org.springframework.data.mongodb.repository.config.EnableReactiveMongoRepositories; +import org.springframework.data.mongodb.test.util.AtlasContainer; +import org.springframework.data.mongodb.test.util.MongoTestTemplate; +import org.springframework.data.repository.CrudRepository; +import org.springframework.test.context.junit.jupiter.SpringJUnitConfig; + +import org.testcontainers.junit.jupiter.Testcontainers; +import org.testcontainers.mongodb.MongoDBAtlasLocalContainer; + +import com.mongodb.client.MongoClient; +import com.mongodb.client.MongoClients; + +/** + * Integration tests using reactive Vector Search and Vector Indexes through local MongoDB Atlas. + * + * @author Mark Paluch + */ +@Testcontainers(disabledWithoutDocker = true) +@SpringJUnitConfig(classes = { ReactiveVectorSearchTests.Config.class }) +public class ReactiveVectorSearchTests { + + Vector VECTOR = Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f); + + private static final MongoDBAtlasLocalContainer atlasLocal = AtlasContainer.bestMatch().withReuse(true); + private static final String COLLECTION_NAME = "collection-1"; + + static MongoClient client; + static MongoTestTemplate template; + + @Autowired ReactiveVectorSearchRepository repository; + + @EnableReactiveMongoRepositories( + includeFilters = { + @ComponentScan.Filter(value = ReactiveVectorSearchRepository.class, type = FilterType.ASSIGNABLE_TYPE) }, + considerNestedRepositories = true) + static class Config extends TestMongoConfiguration { + + @Override + public String getDatabaseName() { + return "vector-search-tests"; + } + + @Override + public MongoClient mongoClient() { + atlasLocal.start(); + return MongoClients.create(atlasLocal.getConnectionString()); + } + + @Bean + public com.mongodb.reactivestreams.client.MongoClient reactiveMongoClient() { + atlasLocal.start(); + return com.mongodb.reactivestreams.client.MongoClients.create(atlasLocal.getConnectionString()); + } + + @Bean + ReactiveMongoTemplate reactiveMongoTemplate(MappingMongoConverter mongoConverter) { + return new ReactiveMongoTemplate(new SimpleReactiveMongoDatabaseFactory(reactiveMongoClient(), getDatabaseName()), + mongoConverter); + } + } + + @BeforeAll + static void beforeAll() throws InterruptedException { + atlasLocal.start(); + + System.out.println(atlasLocal.getConnectionString()); + client = MongoClients.create(atlasLocal.getConnectionString()); + template = new MongoTestTemplate(client, "vector-search-tests"); + + template.remove(WithVectorFields.class).all(); + initDocuments(); + initIndexes(); + + Thread.sleep(500); // just wait a little or the index will be broken + } + + @Test + void shouldSearchEnnWithAnnotatedFilter() { + + Flux> results = repository.searchAnnotated("de", VECTOR, Score.of(0.4), + Limit.of(10)); + + results.as(StepVerifier::create).consumeNextWith(actual -> { + assertThat(actual.getScore().getValue()).isGreaterThan(0.4); + assertThat(actual.getScore()).isInstanceOf(Similarity.class); + + }).expectNextCount(2).verifyComplete(); + } + + @Test + void shouldSearchEnnWithDerivedFilter() { + + Flux results = repository.searchByCountryAndEmbeddingNear("de", VECTOR, Limit.of(10)); + + results.as(StepVerifier::create).consumeNextWith(actual -> assertThat(actual).isInstanceOf(WithVectorFields.class)) + .expectNextCount(2).verifyComplete(); + } + + static void initDocuments() { + + WithVectorFields w1 = new WithVectorFields("de", "one", Vector.of(0.1001f, 0.22345f, 0.33456f, 0.44567f, 0.55678f)); + WithVectorFields w2 = new WithVectorFields("de", "two", Vector.of(0.2001f, 0.32345f, 0.43456f, 0.54567f, 0.65678f)); + WithVectorFields w3 = new WithVectorFields("en", "three", + Vector.of(0.9001f, 0.82345f, 0.73456f, 0.64567f, 0.55678f)); + WithVectorFields w4 = new WithVectorFields("de", "four", + Vector.of(0.9001f, 0.92345f, 0.93456f, 0.94567f, 0.95678f)); + + template.insertAll(List.of(w1, w2, w3, w4)); + } + + static void initIndexes() { + + VectorIndex cosIndex = new VectorIndex("cos-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.COSINE).dimensions(5)).addFilter("country"); + + template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); + + VectorIndex euclideanIndex = new VectorIndex("euc-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.EUCLIDEAN).dimensions(5)).addFilter("country"); + + VectorIndex inner = new VectorIndex("ip-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.DOT_PRODUCT).dimensions(5)).addFilter("country"); + + template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(inner); + template.awaitIndexCreation(WithVectorFields.class, cosIndex.getName()); + template.awaitIndexCreation(WithVectorFields.class, euclideanIndex.getName()); + template.awaitIndexCreation(WithVectorFields.class, inner.getName()); + } + + interface ReactiveVectorSearchRepository extends CrudRepository { + + @VectorSearch(indexName = "cos-index", filter = "{country: ?0}", numCandidates = "#{10+10}", + searchType = VectorSearchOperation.SearchType.ANN) + Flux> searchAnnotated(String country, Vector vector, Score distance, Limit limit); + + @VectorSearch(indexName = "cos-index") + Flux searchByCountryAndEmbeddingNear(String country, Vector vector, Limit limit); + + } + + @org.springframework.data.mongodb.core.mapping.Document(COLLECTION_NAME) + static class WithVectorFields { + + String id; + String country; + String description; + + Vector embedding; + + public WithVectorFields(String country, String description, Vector embedding) { + this.country = country; + this.description = description; + this.embedding = embedding; + } + + public String getId() { + return id; + } + + public String getCountry() { + return country; + } + + public String getDescription() { + return description; + } + + public Vector getEmbedding() { + return embedding; + } + + @Override + public String toString() { + return "WithVectorFields{" + "id='" + id + '\'' + ", country='" + country + '\'' + ", description='" + description + + '\'' + '}'; + } + } + +} diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java index 4e3b12b32a..028a6926fb 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java @@ -107,15 +107,18 @@ void shouldSearchEnnWithAnnotatedFilter() { SearchResults results = repository.searchAnnotated("de", VECTOR, Score.of(0.4), Limit.of(10)); + assertThat(results).extracting(SearchResult::getScore).hasOnlyElementsOfType(Similarity.class); assertThat(results).hasSize(3); } @Test void shouldSearchEnnWithDerivedFilter() { - SearchResults results = repository.searchByCountryAndEmbeddingNear("de", VECTOR, Score.of(0.98), + SearchResults results = repository.searchCosineByCountryAndEmbeddingNear("de", VECTOR, + Similarity.of(0.98), Limit.of(10)); + assertThat(results).extracting(SearchResult::getScore).hasOnlyElementsOfType(Similarity.class); assertThat(results).hasSize(2).extracting(SearchResult::getContent).extracting(WithVectorFields::getCountry) .containsOnly("de", "de"); @@ -126,11 +129,21 @@ void shouldSearchEnnWithDerivedFilter() { @Test void shouldSearchEnnWithDerivedFilterWithoutScore() { - SearchResults de = repository.searchByCountryAndEmbeddingNear("de", VECTOR, Similarity.of(0.4), - Limit.of(10)); + SearchResults de = repository.searchCosineByCountryAndEmbeddingNear("de", VECTOR, + Similarity.of(0.4), Limit.of(10)); + assertThat(de).hasSizeGreaterThanOrEqualTo(2); - assertThat(repository.searchByCountryAndEmbeddingNear("de", VECTOR, Similarity.of(0.999), Limit.of(10))).hasSize(1); + assertThat(repository.searchCosineByCountryAndEmbeddingNear("de", VECTOR, Similarity.of(0.999), Limit.of(10))) + .hasSize(1); + } + + @Test + void shouldSearchAsListEnnWithDerivedFilterWithoutScore() { + + List de = repository.searchAsListByCountryAndEmbeddingNear("de", VECTOR, Limit.of(10)); + + assertThat(de).hasOnlyElementsOfType(WithVectorFields.class); } @Test @@ -150,7 +163,7 @@ void shouldSearchEuclideanWithDerivedFilter() { void shouldSearchEnnWithDerivedFilterWithin() { SearchResults results = repository.searchByCountryAndEmbeddingWithin("de", VECTOR, - Score.between(0.93, 0.98)); + Similarity.between(0.93, 0.98)); assertThat(results).hasSize(1); for (SearchResult result : results) { @@ -162,7 +175,7 @@ void shouldSearchEnnWithDerivedFilterWithin() { void shouldSearchEnnWithDerivedAndLimitedFilterWithin() { SearchResults results = repository.searchTop1ByCountryAndEmbeddingWithin("de", VECTOR, - Score.between(0.8, 1)); + Similarity.between(0.8, 1)); assertThat(results).hasSize(1); @@ -193,10 +206,15 @@ static void initIndexes() { VectorIndex euclideanIndex = new VectorIndex("euc-index") .addVector("embedding", it -> it.similarity(SimilarityFunction.EUCLIDEAN).dimensions(5)).addFilter("country"); + VectorIndex inner = new VectorIndex("ip-index") + .addVector("embedding", it -> it.similarity(SimilarityFunction.DOT_PRODUCT).dimensions(5)).addFilter("country"); + template.searchIndexOps(WithVectorFields.class).createIndex(cosIndex); template.searchIndexOps(WithVectorFields.class).createIndex(euclideanIndex); + template.searchIndexOps(WithVectorFields.class).createIndex(inner); template.awaitIndexCreation(WithVectorFields.class, cosIndex.getName()); template.awaitIndexCreation(WithVectorFields.class, euclideanIndex.getName()); + template.awaitIndexCreation(WithVectorFields.class, inner.getName()); } interface VectorSearchRepository extends CrudRepository { @@ -207,11 +225,11 @@ SearchResults searchAnnotated(String country, Vector vector, Score distance, Limit limit); @VectorSearch(indexName = "cos-index") - SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Score similarity, - Limit limit); + SearchResults searchCosineByCountryAndEmbeddingNear(String country, Vector vector, + Score similarity, Limit limit); @VectorSearch(indexName = "cos-index") - SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Limit limit); + List searchAsListByCountryAndEmbeddingNear(String country, Vector vector, Limit limit); @VectorSearch(indexName = "euc-index") SearchResults searchEuclideanByCountryAndEmbeddingNear(String country, Vector vector, @@ -219,11 +237,11 @@ SearchResults searchEuclideanByCountryAndEmbeddingNear(String @VectorSearch(indexName = "cos-index", limit = "10") SearchResults searchByCountryAndEmbeddingWithin(String country, Vector vector, - Range distance); + Range distance); @VectorSearch(indexName = "cos-index") SearchResults searchTop1ByCountryAndEmbeddingWithin(String country, Vector vector, - Range distance); + Range distance); } diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java index c1aa7cfff9..c347936dfe 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java @@ -20,6 +20,7 @@ import java.lang.reflect.Method; +import org.bson.conversions.Bson; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.mockito.Mockito; @@ -52,10 +53,13 @@ class VectorSearchAggregationUnitTests { @BeforeEach public void setUp() { + context = new MongoMappingContext(); converter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, context); operationsMock = Mockito.mock(MongoOperations.class); + when(operationsMock.getConverter()).thenReturn(converter); + when(operationsMock.execute(any())).thenReturn(Bson.DEFAULT_CODEC_REGISTRY); } @Test @@ -64,9 +68,11 @@ void derivesPrefilter() throws Exception { VectorSearchAggregation aggregation = aggregation(SampleRepository.class, "searchByCountryAndEmbeddingNear", String.class, Vector.class, Score.class, Limit.class); - VectorSearchAggregation.VectorSearchQuery query = aggregation - .createVectorSearchQuery(new MongoParametersParameterAccessor(aggregation.getQueryMethod(), - new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() })); + VectorSearchDelegate.QueryMetadata query = aggregation.createVectorSearchQuery( + aggregation.getQueryMethod().getResultProcessor(), + new MongoParametersParameterAccessor(aggregation.getQueryMethod(), + new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() }), + Object.class); assertThat(query.query().getQueryObject()).containsEntry("country", "de"); } From a2833e27c99c5f0c39a40a49e33c14f66c0d88c9 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 29 Apr 2025 17:08:45 +0200 Subject: [PATCH 04/10] Documentation. --- .../data/mongodb/repository/VectorSearch.java | 13 ++++---- src/main/antora/modules/ROOT/nav.adoc | 1 + .../mongodb/repositories/vector-search.adoc | 8 +++++ .../partials/vector-search-intro-include.adoc | 1 + ...ector-search-method-annotated-include.adoc | 23 +++++++++++++ .../vector-search-method-derived-include.adoc | 21 ++++++++++++ .../partials/vector-search-model-include.adoc | 15 +++++++++ .../vector-search-repository-include.adoc | 19 +++++++++++ .../vector-search-scoring-include.adoc | 32 +++++++++++++++++++ 9 files changed, 127 insertions(+), 6 deletions(-) create mode 100644 src/main/antora/modules/ROOT/pages/mongodb/repositories/vector-search.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc create mode 100644 src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java index 7c6b4f4906..b378eea76e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java @@ -76,8 +76,8 @@ String path() default ""; /** - * Takes a MongoDB JSON (MQL) string defining the pre-filter against indexed fields. Alias for - * {@link VectorSearch#filter}. + * Takes a MongoDB JSON (MQL) string defining the pre-filter against indexed fields. Supports Value Expressions. Alias + * for {@link VectorSearch#filter}. * * @return an empty String by default. */ @@ -85,8 +85,8 @@ String value() default ""; /** - * Takes a MongoDB JSON (MQL) string defining the pre-filter against indexed fields. Alias for - * {@link VectorSearch#value}. + * Takes a MongoDB JSON (MQL) string defining the pre-filter against indexed fields. Supports Value Expressions. Alias + * for {@link VectorSearch#value}. * * @return an empty String by default. */ @@ -96,7 +96,7 @@ /** * Number of documents to return in the results. This value can't exceed the value of {@link #numCandidates} if you * specify {@link #numCandidates}. Limit accepts Value Expressions. A Vector Search method cannot define both, - * {@code limit()} and a {@link org.springframework.data.domain.Limit} parameter. + * {@code limit()} and a {@link org.springframework.data.domain.Limit} parameter. Supports Value Expressions. * * @return number of documents to return in the results */ @@ -109,7 +109,8 @@ * to return} to increase accuracy. This over-request pattern is the recommended way to trade off latency and recall * in your ANN searches, and we recommend tuning this parameter based on your specific dataset size and query * requirements. Required if the query uses - * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ANN}. + * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ANN}. Supports Value + * Expressions. * * @return number of documents to return in the results */ diff --git a/src/main/antora/modules/ROOT/nav.adoc b/src/main/antora/modules/ROOT/nav.adoc index a7401fb11f..6f2d1e2847 100644 --- a/src/main/antora/modules/ROOT/nav.adoc +++ b/src/main/antora/modules/ROOT/nav.adoc @@ -45,6 +45,7 @@ ** xref:repositories/create-instances.adoc[] ** xref:repositories/query-methods-details.adoc[] ** xref:mongodb/repositories/query-methods.adoc[] +** xref:mongodb/repositories/vector-search.adoc[] ** xref:mongodb/repositories/modifying-methods.adoc[] ** xref:repositories/projections.adoc[] ** xref:repositories/custom-implementations.adoc[] diff --git a/src/main/antora/modules/ROOT/pages/mongodb/repositories/vector-search.adoc b/src/main/antora/modules/ROOT/pages/mongodb/repositories/vector-search.adoc new file mode 100644 index 0000000000..2e590107ec --- /dev/null +++ b/src/main/antora/modules/ROOT/pages/mongodb/repositories/vector-search.adoc @@ -0,0 +1,8 @@ +:vector-search-intro-include: data-mongodb::partial$vector-search-intro-include.adoc +:vector-search-model-include: data-mongodb::partial$vector-search-model-include.adoc +:vector-search-repository-include: data-mongodb::partial$vector-search-repository-include.adoc +:vector-search-scoring-include: data-mongodb::partial$vector-search-scoring-include.adoc +:vector-search-method-derived-include: data-mongodb::partial$vector-search-method-derived-include.adoc +:vector-search-method-annotated-include: data-mongodb::partial$vector-search-method-annotated-include.adoc + +include::{commons}@data-commons::page$repositories/vector-search.adoc[] diff --git a/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc new file mode 100644 index 0000000000..355bccf4e3 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-intro-include.adoc @@ -0,0 +1 @@ +To use Vector Search with MongoDB, you need a MongoDB Atlas instance that is either running in the cloud or by using https://www.mongodb.com/docs/atlas/cli/current/atlas-cli-deploy-docker/[Docker]. diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc new file mode 100644 index 0000000000..752ffad622 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc @@ -0,0 +1,23 @@ +Annotated search methods use the `@VectorSearch` annotation to define parameters for the https://www.mongodb.com/docs/upcoming/reference/operator/aggregation/vectorSearch/[`$vectorSearch`] aggregation stage. + +.Using `@VectorSearch` Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + @VectorSearch(indexName = "cos-index", filter = "{country: ?0}") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + Score distance); + + @VectorSearch(indexName = "my-index", filter = "{country: ?0}", numCandidates = "#{#limit * 20}", + searchType = VectorSearchOperation.SearchType.ANN) + List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit); +} +---- +==== + +Annotated Search Methods can define `filter` for pre-filter usage. + +`filter`, `limit`, and `numCandidates` support xref:page$mongodb/value-expressions.adoc[Value Expressions] allowing references to search method arguments. + diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc new file mode 100644 index 0000000000..dd06ee699a --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc @@ -0,0 +1,21 @@ +MongoDB Search methods must use the `@VectorSearch` annotation to define the index name for the https://www.mongodb.com/docs/upcoming/reference/operator/aggregation/vectorSearch/[`$vectorSearch`] aggregation stage. + +.Using `Near` and `Within` Keywords in Repository Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + @VectorSearch(indexName = "my-index") + SearchResults searchByEmbeddingNear(Vector vector, Score score); + + @VectorSearch(indexName = "my-index") + SearchResults searchByEmbeddingWithin(Vector vector, Range range); + + @VectorSearch(indexName = "my-index") + SearchResults searchByCountryAndEmbeddingWithin(String country, Vector vector, Range range); +} +---- +==== + +Derived Search Methods can define domain model attributes to create the pre-filter for indexed fields. diff --git a/src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc new file mode 100644 index 0000000000..e657f3aa63 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-model-include.adoc @@ -0,0 +1,15 @@ +==== +[source,java] +---- +class Comment { + + @Id String id; + String country; + String comment; + + Vector embedding; + + // getters, setters, … +} +---- +==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc new file mode 100644 index 0000000000..c7ad91c9db --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc @@ -0,0 +1,19 @@ +.Using `SearchResult` in a Repository Search Method +==== +[source,java] +---- +interface CommentRepository extends Repository { + + @VectorSearch(indexName = "my-index") + SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Score score, + Limit limit); + + @VectorSearch(indexName = "my-index") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + Score score); + +} + +SearchResults results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10)); +---- +==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc new file mode 100644 index 0000000000..b97475b467 --- /dev/null +++ b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc @@ -0,0 +1,32 @@ +MongoDB reports the score directly as similarity value. +The scoring function must be specified in the index and therefore, Vector search methods do not consider the `Score.scoringFunction`. +The scoring function defaults to `ScoringFunction.unspecified()` as there is no information inside of search results how the score has been computed. + +.Using `Score` and `Similarity` in a Repository Search Methods +==== +[source,java] +---- +interface CommentRepository extends Repository { + + @VectorSearch(…) + SearchResults searchByEmbeddingNear(Vector vector, Score similarity); + + @VectorSearch(…) + SearchResults searchByEmbeddingNear(Vector vector, Similarity similarity); + + @VectorSearch(…) + SearchResults searchByEmbeddingNear(Vector vector, Range range); +} + +repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9)); <1> + +repository.searchByEmbeddingNear(Vector.of(…), Similarity.of(0.9)); <2> + +repository.searchByEmbeddingNear(Vector.of(…), Similarity.between(0.5, 1)); <3> +---- + +<1> Run a search and return results with a similarity of `0.9` or greater. +<2> Return results with a similarity of `0.9` or greater. +<3> Return results with a similarity of between `0.5` and `1.0` or greater. +==== + From b6d9efae0a074589be273fdd059642ac529d1848 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Tue, 29 Apr 2025 17:08:45 +0200 Subject: [PATCH 05/10] Documentation. --- .../repository/aot/MongoRepositoryContributor.java | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java index a9368615b0..40d72a69f7 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/aot/MongoRepositoryContributor.java @@ -15,6 +15,14 @@ */ package org.springframework.data.mongodb.repository.aot; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.aggregationExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.deleteExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.queryExecutionBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateBlockBuilder; +import static org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.updateExecutionBlockBuilder; + import java.lang.reflect.Method; import java.util.regex.Pattern; @@ -28,7 +36,7 @@ import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.Query; import org.springframework.data.mongodb.repository.Update; -import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.*; +import org.springframework.data.mongodb.repository.aot.MongoCodeBlocks.QueryCodeBlockBuilder; import org.springframework.data.mongodb.repository.query.MongoQueryMethod; import org.springframework.data.repository.aot.generate.AotRepositoryConstructorBuilder; import org.springframework.data.repository.aot.generate.AotRepositoryFragmentMetadata; From 66575085f5ee81312de155e199b1b4889ba4043e Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Wed, 30 Apr 2025 11:57:40 +0200 Subject: [PATCH 06/10] Polishing. --- .../query/VectorSearchDelegateUnitTests.java | 132 ++++++++++++++++++ .../core/mapping/MongoSimpleTypes.java | 1 + .../data/mongodb/repository/VectorSearch.java | 4 +- .../repository/query/MongoParameters.java | 47 +------ .../query/VectorSearchDelegate.java | 36 +++-- 5 files changed, 164 insertions(+), 56 deletions(-) create mode 100644 spring-data-mongodb/src/jmh/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java diff --git a/spring-data-mongodb/src/jmh/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java b/spring-data-mongodb/src/jmh/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java new file mode 100644 index 0000000000..b5a94b8060 --- /dev/null +++ b/spring-data-mongodb/src/jmh/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java @@ -0,0 +1,132 @@ +/* + * Copyright 2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.mongodb.repository.query; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.Mockito.*; + +import java.lang.reflect.Method; + +import org.jetbrains.annotations.NotNull; +import org.junit.jupiter.api.Test; + +import org.springframework.data.domain.Limit; +import org.springframework.data.domain.Score; +import org.springframework.data.domain.SearchResults; +import org.springframework.data.domain.Vector; +import org.springframework.data.mapping.model.ValueExpressionEvaluator; +import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; +import org.springframework.data.mongodb.core.convert.MappingMongoConverter; +import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.mapping.MongoMappingContext; +import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.util.json.ParameterBindingContext; +import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; +import org.springframework.data.projection.SpelAwareProxyProjectionFactory; +import org.springframework.data.repository.Repository; +import org.springframework.data.repository.core.RepositoryMetadata; +import org.springframework.data.repository.core.support.AnnotationRepositoryMetadata; +import org.springframework.data.repository.query.ValueExpressionDelegate; + +/** + * Unit tests for {@link VectorSearchDelegate}. + * + * @author Mark Paluch + */ +class VectorSearchDelegateUnitTests { + + MappingMongoConverter converter = new MappingMongoConverter(NoOpDbRefResolver.INSTANCE, new MongoMappingContext()); + + @Test + void shouldConsiderDerivedLimit() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); + + VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); + + assertThat(query.query().getLimit()).isEqualTo(10); + assertThat(query.numCandidates()).isEqualTo(10 * 20); + } + + @Test + void shouldNotSetNumCandidates() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10EnnByEmbeddingNear", Vector.class, Score.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); + + VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); + + assertThat(query.query().getLimit()).isEqualTo(10); + assertThat(query.numCandidates()).isNull(); + } + + @Test + void shouldConsiderProvidedLimit() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class, + Limit.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11)); + + VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); + + assertThat(query.query().getLimit()).isEqualTo(11); + assertThat(query.numCandidates()).isEqualTo(11 * 20); + } + + private VectorSearchDelegate.QueryMetadata createQueryMetadata(MongoQueryMethod queryMethod, + MongoParametersParameterAccessor accessor) { + + VectorSearchDelegate delegate = new VectorSearchDelegate(queryMethod, converter, ValueExpressionDelegate.create()); + + return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, + Object.class, new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class)); + } + + private MongoQueryMethod getMongoQueryMethod(Method method) { + RepositoryMetadata metadata = AnnotationRepositoryMetadata.getMetadata(method.getDeclaringClass()); + return new MongoQueryMethod(method, metadata, new SpelAwareProxyProjectionFactory(), converter.getMappingContext()); + } + + @NotNull + private static MongoParametersParameterAccessor getAccessor(MongoQueryMethod queryMethod, Object... values) { + return new MongoParametersParameterAccessor(queryMethod, values); + } + + interface VectorSearchRepository extends Repository { + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity); + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ENN) + SearchResults searchTop10EnnByEmbeddingNear(Vector vector, Score similarity); + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity, Limit limit); + + } + + static class WithVector { + + Vector embedding; + } +} diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java index 3b3a520bc3..6b4d9b9e9b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/mapping/MongoSimpleTypes.java @@ -29,6 +29,7 @@ import org.bson.types.Decimal128; import org.bson.types.ObjectId; import org.bson.types.Symbol; + import org.springframework.data.mapping.model.SimpleTypeHolder; import com.mongodb.DBRef; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java index b378eea76e..a62f31143a 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java @@ -1,5 +1,5 @@ /* - * Copyright 2016-2025 the original author or authors. + * Copyright 2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ * * @author Mark Paluch * @since 5.0 - * @see org.springframework.data.geo.Distance + * @see org.springframework.data.domain.Score * @see org.springframework.data.domain.Vector * @see org.springframework.data.domain.SearchResults */ diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java index 98438d1652..94acef17ce 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParameters.java @@ -21,14 +21,14 @@ import java.util.List; import org.jspecify.annotations.Nullable; + import org.springframework.core.MethodParameter; -import org.springframework.core.ResolvableType; import org.springframework.data.domain.Range; import org.springframework.data.domain.Vector; +import org.springframework.data.geo.Distance; import org.springframework.data.geo.GeoPage; import org.springframework.data.geo.GeoResult; import org.springframework.data.geo.GeoResults; -import org.springframework.data.geo.Distance; import org.springframework.data.geo.Point; import org.springframework.data.mongodb.core.query.Collation; import org.springframework.data.mongodb.core.query.TextCriteria; @@ -78,7 +78,7 @@ public MongoParameters(ParametersSource parametersSource) { * @param isGeoNearMethod indicate if this is a geo-spatial query method */ public MongoParameters(ParametersSource parametersSource, boolean isGeoNearMethod) { - this(parametersSource, new NearIndex(parametersSource, isGeoNearMethod), new DistanceRangeIndex(parametersSource)); + this(parametersSource, new NearIndex(parametersSource, isGeoNearMethod)); } /** @@ -87,11 +87,10 @@ public MongoParameters(ParametersSource parametersSource, boolean isGeoNearMetho * @param parametersSource must not be {@literal null}. * @param nearIndex the near parameter index. */ - private MongoParameters(ParametersSource parametersSource, NearIndex nearIndex, - DistanceRangeIndex distanceRangeIndex) { + private MongoParameters(ParametersSource parametersSource, NearIndex nearIndex) { super(parametersSource, methodParameter -> new MongoParameter(methodParameter, - parametersSource.getDomainTypeInformation(), nearIndex.nearIndex, distanceRangeIndex.distanceRangeIndex)); + parametersSource.getDomainTypeInformation(), nearIndex.nearIndex)); Method method = parametersSource.getMethod(); List> parameterTypes = Arrays.asList(method.getParameterTypes()); @@ -156,15 +155,6 @@ public NearIndex(ParametersSource parametersSource, boolean isGeoNearMethod) { } } - static class DistanceRangeIndex { - - private final int distanceRangeIndex; - - public DistanceRangeIndex(ParametersSource parametersSource) { - this.distanceRangeIndex = findDistanceRangeIndexInParameters(parametersSource.getMethod()); - } - } - private static int getNearIndex(List> parameterTypes) { for (Class reference : Arrays.asList(Point.class, double[].class)) { @@ -207,21 +197,6 @@ static int findNearIndexInParameters(Method method) { return index; } - static int findDistanceRangeIndexInParameters(Method method) { - - int index = -1; - for (java.lang.reflect.Parameter p : method.getParameters()) { - - MethodParameter methodParameter = MethodParameter.forParameter(p); - - if (Range.class.isAssignableFrom(methodParameter.getParameterType()) - && ResolvableType.forMethodParameter(methodParameter).getGeneric(0).isAssignableFrom(Distance.class)) { - index = methodParameter.getParameterIndex(); - } - } - return index; - } - /** * Returns the index of the {@link Distance} parameter to be used for max distance in geo queries. * @@ -321,21 +296,17 @@ static class MongoParameter extends Parameter { private final MethodParameter parameter; private final @Nullable Integer nearIndex; - private final @Nullable Integer distanceRangeIndex; /** * Creates a new {@link MongoParameter}. * * @param parameter must not be {@literal null}. * @param domainType must not be {@literal null}. - * @param distanceRangeIndex */ - MongoParameter(MethodParameter parameter, TypeInformation domainType, @Nullable Integer nearIndex, - @Nullable Integer distanceRangeIndex) { + MongoParameter(MethodParameter parameter, TypeInformation domainType, @Nullable Integer nearIndex) { super(parameter, domainType); this.parameter = parameter; this.nearIndex = nearIndex; - this.distanceRangeIndex = distanceRangeIndex; if (!isPoint() && hasNearAnnotation()) { throw new IllegalArgumentException("Near annotation is only allowed at Point parameter"); @@ -345,7 +316,7 @@ static class MongoParameter extends Parameter { @Override public boolean isSpecialParameter() { return super.isSpecialParameter() || Distance.class.isAssignableFrom(getType()) - || Vector.class.isAssignableFrom(getType()) || isNearParameter() || isDistanceRangeParameter() + || Vector.class.isAssignableFrom(getType()) || isNearParameter() || TextCriteria.class.isAssignableFrom(getType()) || Collation.class.isAssignableFrom(getType()); } @@ -353,10 +324,6 @@ private boolean isNearParameter() { return nearIndex != null && nearIndex.equals(getIndex()); } - private boolean isDistanceRangeParameter() { - return distanceRangeIndex != null && distanceRangeIndex.equals(getIndex()); - } - private boolean isManuallyAnnotatedNearParameter() { return isPoint() && hasNearAnnotation(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java index 2a60da0911..aa47647e68 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java @@ -126,15 +126,9 @@ public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProce Integer numCandidates = null; Limit limit; Class outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType(); - - if (this.numCandidatesExpression != null) { - numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); - } else if (this.numCandidates != null) { - numCandidates = this.numCandidates; - } + VectorSearchInput query = queryFactory.createQuery(accessor, codec, context); if (this.limitExpression != null) { - Object value = evaluator.evaluate(this.limitExpression); limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); } else if (this.limit.isLimited()) { @@ -143,12 +137,22 @@ public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProce limit = accessor.getLimit(); } - VectorSearchInput query = queryFactory.createQuery(accessor, codec, context); - if (limit.isLimited()) { query.query().limit(limit); } + if (this.numCandidatesExpression != null) { + numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); + } else if (this.numCandidates != null) { + numCandidates = this.numCandidates; + } else if (query.query().isLimited() && searchType == VectorSearchOperation.SearchType.ANN) { + + /* + MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy. + */ + numCandidates = query.query().getLimit() * 20; + } + return new QueryMetadata(query.path, "__score__", query.query, searchType, outputType, numCandidates, getSimilarityFunction(accessor)); } @@ -335,13 +339,13 @@ public String getQueryString() { private class PartTreeQueryFactory implements VectorSearchQueryFactory { private final String path; - private final PartTree partTree; + private final PartTree tree; @SuppressWarnings("NullableProblems") - PartTreeQueryFactory(PartTree partTree, MappingContext context) { + PartTreeQueryFactory(PartTree tree, MappingContext context) { String path = null; - for (PartTree.OrPart part : partTree) { + for (PartTree.OrPart part : tree) { for (Part p : part) { if (p.getType() == Part.Type.SIMPLE_PROPERTY || p.getType() == Part.Type.NEAR || p.getType() == Part.Type.WITHIN || p.getType() == Part.Type.BETWEEN) { @@ -362,17 +366,21 @@ private class PartTreeQueryFactory implements VectorSearchQueryFactory { } this.path = path; - this.partTree = partTree; + this.tree = tree; } public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, ParameterBindingContext context) { - MongoQueryCreator creator = new MongoQueryCreator(partTree, parameterAccessor, converter.getMappingContext(), + MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), false, true); Query query = creator.createQuery(parameterAccessor.getSort()); + if (tree.isLimiting()) { + query.limit(tree.getMaxResults()); + } + return new VectorSearchInput(path, query); } From cc7449d93afaf1e65c6a2175f802008449158fc2 Mon Sep 17 00:00:00 2001 From: Mark Paluch Date: Mon, 5 May 2025 12:06:40 +0200 Subject: [PATCH 07/10] Review findings. --- .../data/mongodb/repository/VectorSearch.java | 22 +++++++++++-------- .../query/VectorSearchDelegate.java | 13 +++++------ .../query/VectorSearchDelegateUnitTests.java | 2 -- 3 files changed, 19 insertions(+), 18 deletions(-) rename spring-data-mongodb/src/{jmh => test}/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java (99%) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java index a62f31143a..336889f719 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/VectorSearch.java @@ -53,11 +53,13 @@ public @interface VectorSearch { /** - * Configuration whether to use ANN or ENN for the search. ANN is the default. + * Configuration whether to use + * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ANN} or + * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ENN} for the search. * * @return the search type to use. */ - VectorSearchOperation.SearchType searchType() default VectorSearchOperation.SearchType.ENN; + VectorSearchOperation.SearchType searchType() default VectorSearchOperation.SearchType.DEFAULT; /** * Name of the Atlas Vector Search index to use. Atlas Vector Search doesn't return results if you misspell the index @@ -98,7 +100,7 @@ * specify {@link #numCandidates}. Limit accepts Value Expressions. A Vector Search method cannot define both, * {@code limit()} and a {@link org.springframework.data.domain.Limit} parameter. Supports Value Expressions. * - * @return number of documents to return in the results + * @return number of documents to return in the results. */ String limit() default ""; @@ -106,13 +108,15 @@ * Number of nearest neighbors to use during the search. Value must be less than or equal to ({@code <=}) * {@code 10000}. You can't specify a number less than the {@link #limit() number of documents to return}. We * recommend that you specify a number at least {@code 20} times higher than the {@link #limit() number of documents - * to return} to increase accuracy. This over-request pattern is the recommended way to trade off latency and recall - * in your ANN searches, and we recommend tuning this parameter based on your specific dataset size and query - * requirements. Required if the query uses - * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ANN}. Supports Value - * Expressions. + * to return} to increase accuracy. + *

+ * This over-request pattern is the recommended way to trade off latency and recall in your ANN searches, and we + * recommend tuning this parameter based on your specific dataset size and query requirements. Required if the query + * uses + * {@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#ANN}/{@link org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType#DEFAULT}. + * Supports Value Expressions. * - * @return number of documents to return in the results + * @return number of nearest neighbors to use during the search. */ String numCandidates() default ""; diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java index aa47647e68..8932b85b1b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java @@ -19,6 +19,7 @@ import java.util.List; import org.bson.Document; +import org.jspecify.annotations.Nullable; import org.springframework.data.domain.Limit; import org.springframework.data.domain.Range; @@ -46,7 +47,6 @@ import org.springframework.data.repository.query.ValueExpressionDelegate; import org.springframework.data.repository.query.parser.Part; import org.springframework.data.repository.query.parser.PartTree; -import org.springframework.lang.Nullable; import org.springframework.util.StringUtils; /** @@ -145,7 +145,8 @@ public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProce numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); } else if (this.numCandidates != null) { numCandidates = this.numCandidates; - } else if (query.query().isLimited() && searchType == VectorSearchOperation.SearchType.ANN) { + } else if (query.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN + || searchType == VectorSearchOperation.SearchType.DEFAULT)) { /* MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy. @@ -195,7 +196,7 @@ ScoringFunction getSimilarityFunction(MongoParameterAccessor accessor) { * @param scoringFunction */ public record QueryMetadata(String path, String scoreField, Query query, VectorSearchOperation.SearchType searchType, - Class outputType, @org.jspecify.annotations.Nullable Integer numCandidates, ScoringFunction scoringFunction) { + Class outputType, @Nullable Integer numCandidates, ScoringFunction scoringFunction) { /** * Create the Aggregation Pipeline. @@ -210,12 +211,10 @@ public List getAggregationPipeline(MongoQueryMethod queryM Vector vector = accessor.getVector(); Score score = accessor.getScore(); Range distance = accessor.getScoreRange(); - int limit; + Limit limit = Limit.unlimited(); if (query.isLimited()) { - limit = query.getLimit(); - } else { - limit = Math.max(1, numCandidates() != null ? numCandidates() / 20 : 1); + limit = Limit.of(query.getLimit()); } List stages = new ArrayList<>(); diff --git a/spring-data-mongodb/src/jmh/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java similarity index 99% rename from spring-data-mongodb/src/jmh/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java rename to spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java index b5a94b8060..06a80e78fc 100644 --- a/spring-data-mongodb/src/jmh/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java @@ -20,7 +20,6 @@ import java.lang.reflect.Method; -import org.jetbrains.annotations.NotNull; import org.junit.jupiter.api.Test; import org.springframework.data.domain.Limit; @@ -107,7 +106,6 @@ private MongoQueryMethod getMongoQueryMethod(Method method) { return new MongoQueryMethod(method, metadata, new SpelAwareProxyProjectionFactory(), converter.getMappingContext()); } - @NotNull private static MongoParametersParameterAccessor getAccessor(MongoQueryMethod queryMethod, Object... values) { return new MongoParametersParameterAccessor(queryMethod, values); } From 7fb9b2a4ef30f047ef4e48bc4f99ffb029834e07 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Tue, 6 May 2025 10:56:42 +0200 Subject: [PATCH 08/10] hacking --- .../core/aggregation/AggregationPipeline.java | 14 +- .../query/ConvertingParameterAccessor.java | 4 +- .../MongoParametersParameterAccessor.java | 9 +- .../repository/query/MongoQueryCreator.java | 49 ++-- .../repository/query/MongoQueryExecution.java | 59 +++-- .../query/ReactiveMongoQueryExecution.java | 20 +- .../ReactiveVectorSearchAggregation.java | 8 +- .../query/VectorSearchAggregation.java | 9 +- .../query/VectorSearchDelegate.java | 234 +++++++++--------- .../mongodb/repository/VectorSearchTests.java | 3 +- .../VectorSearchAggregationUnitTests.java | 3 +- .../query/VectorSearchDelegateUnitTests.java | 156 ++++++++++-- 12 files changed, 362 insertions(+), 206 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java index 40966bcf3d..f06803997b 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/AggregationPipeline.java @@ -22,8 +22,10 @@ import java.util.function.Predicate; import org.bson.Document; +import org.jspecify.annotations.Nullable; import org.springframework.lang.Contract; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; /** * The {@link AggregationPipeline} holds the collection of {@link AggregationOperation aggregation stages}. @@ -82,6 +84,14 @@ public List getOperations() { return Collections.unmodifiableList(pipeline); } + public @Nullable AggregationOperation firstOperation() { + return CollectionUtils.firstElement(pipeline); + } + + public @Nullable AggregationOperation lastOperation() { + return CollectionUtils.lastElement(pipeline); + } + List toDocuments(AggregationOperationContext context) { verify(); @@ -97,8 +107,8 @@ public boolean isOutOrMerge() { return false; } - AggregationOperation operation = pipeline.get(pipeline.size() - 1); - return isOut(operation) || isMerge(operation); + AggregationOperation operation = lastOperation(); + return operation != null && (isOut(operation) || isMerge(operation)); } void verify() { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java index e51d4435a8..0eac1aa3e0 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java @@ -104,12 +104,12 @@ public Sort getSort() { } @Override - public @org.jspecify.annotations.Nullable Score getScore() { + public @Nullable Score getScore() { return delegate.getScore(); } @Override - public @org.jspecify.annotations.Nullable Range getScoreRange() { + public @Nullable Range getScoreRange() { return delegate.getScoreRange(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java index 41cf084d45..0f56223492 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoParametersParameterAccessor.java @@ -61,14 +61,13 @@ public MongoParametersParameterAccessor(MongoQueryMethod method, Object[] values public Range getScoreRange() { MongoParameters mongoParameters = method.getParameters(); - int rangeIndex = mongoParameters.getScoreRangeIndex(); - if (rangeIndex != -1) { - return getValue(rangeIndex); + if (mongoParameters.hasScoreRangeParameter()) { + return getValue(mongoParameters.getScoreRangeIndex()); } - int scoreIndex = mongoParameters.getScoreIndex(); - Bound maxDistance = scoreIndex == -1 ? Bound.unbounded() : Bound.inclusive((Score) getScore()); + Score score = getScore(); + Bound maxDistance = score != null ? Bound.inclusive(score) : Bound.unbounded(); return Range.of(Bound.unbounded(), maxDistance); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java index 1f742ec32f..ba7394ec17 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryCreator.java @@ -15,7 +15,8 @@ */ package org.springframework.data.mongodb.repository.query; -import static org.springframework.data.mongodb.core.query.Criteria.*; +import static org.springframework.data.mongodb.core.query.Criteria.Placeholder; +import static org.springframework.data.mongodb.core.query.Criteria.where; import java.util.Arrays; import java.util.Collection; @@ -27,7 +28,6 @@ import org.apache.commons.logging.LogFactory; import org.bson.BsonRegularExpression; import org.jspecify.annotations.Nullable; - import org.springframework.data.domain.Range; import org.springframework.data.domain.Range.Bound; import org.springframework.data.domain.Sort; @@ -118,8 +118,9 @@ protected Criteria create(Part part, Iterator iterator) { return new Criteria(); } - if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { - return null; + if (isPartOfSearchQuery(part)) { + skip(part, iterator); + return new Criteria(); } PersistentPropertyPath path = context.getPersistentPropertyPath(part.getProperty()); @@ -135,7 +136,8 @@ protected Criteria and(Part part, Criteria base, Iterator iterator) { return create(part, iterator); } - if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { + if (isPartOfSearchQuery(part)) { + skip(part, iterator); return base; } @@ -176,15 +178,6 @@ protected Query complete(@Nullable Criteria criteria, Sort sort) { @SuppressWarnings("NullAway") private Criteria from(Part part, MongoPersistentProperty property, Criteria criteria, Iterator parameters) { - if (isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN))) { - - int numberOfArguments = part.getType().getNumberOfArguments(); - for (int i = 0; i < numberOfArguments; i++) { - parameters.next(); - } - return null; - } - Type type = part.getType(); switch (type) { @@ -206,13 +199,13 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit return criteria.is(null); case NOT_IN: Object ninValue = parameters.next(); - if(ninValue instanceof Placeholder) { + if (ninValue instanceof Placeholder) { return criteria.raw("$nin", ninValue); } return criteria.nin(valueAsList(ninValue, part)); case IN: Object inValue = parameters.next(); - if(inValue instanceof Placeholder) { + if (inValue instanceof Placeholder) { return criteria.raw("$in", inValue); } return criteria.in(valueAsList(inValue, part)); @@ -231,7 +224,7 @@ private Criteria from(Part part, MongoPersistentProperty property, Criteria crit return param instanceof Pattern pattern ? criteria.regex(pattern) : criteria.regex(param.toString()); case EXISTS: Object next = parameters.next(); - if(next instanceof Placeholder placeholder) { + if (next instanceof Placeholder placeholder) { return criteria.raw("$exists", placeholder); } else { return criteria.exists((Boolean) next); @@ -355,7 +348,7 @@ private Criteria createContainingCriteria(Part part, MongoPersistentProperty pro if (property.isCollectionLike()) { Object next = parameters.next(); - if(next instanceof Placeholder) { + if (next instanceof Placeholder) { return criteria.raw("$in", next); } return criteria.in(valueAsList(next, part)); @@ -433,8 +426,7 @@ private java.util.List valueAsList(Object value, Part part) { streamable = streamable.map(it -> { if (it instanceof String sv) { - return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), - regexOptions); + return new BsonRegularExpression(MongoRegexCreator.INSTANCE.toRegularExpression(sv, matchMode), regexOptions); } return it; }); @@ -468,10 +460,23 @@ private boolean isSpherical(MongoPersistentProperty property) { return false; } + private boolean isPartOfSearchQuery(Part part) { + return isSearchQuery && (part.getType().equals(Type.NEAR) || part.getType().equals(Type.WITHIN)); + } + + private static void skip(Part part, Iterator parameters) { + + int total = part.getNumberOfArguments(); + int i = 0; + while (parameters.hasNext() && i < total) { + parameters.next(); + i++; + } + } + /** * Compute a {@link Type#BETWEEN} typed {@link Part} using {@link Criteria#gt(Object) $gt}, - * {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}. - *
+ * {@link Criteria#gte(Object) $gte}, {@link Criteria#lt(Object) $lt} and {@link Criteria#lte(Object) $lte}.
* In case the first {@literal value} is actually a {@link Range} the lower and upper bounds of the {@link Range} are * used according to their {@link Bound#isInclusive() inclusion} definition. Otherwise the {@literal value} is used * for {@literal $gt} and {@link Iterator#next() parameters.next()} as {@literal $lt}. diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java index d9a91434ce..f606a59859 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/MongoQueryExecution.java @@ -23,10 +23,10 @@ import org.bson.Document; import org.jspecify.annotations.Nullable; - import org.springframework.data.domain.Page; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Range; +import org.springframework.data.domain.ScoringFunction; import org.springframework.data.domain.SearchResult; import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Similarity; @@ -45,12 +45,13 @@ import org.springframework.data.mongodb.core.ExecutableRemoveOperation.TerminatingRemove; import org.springframework.data.mongodb.core.ExecutableUpdateOperation.ExecutableUpdate; import org.springframework.data.mongodb.core.MongoOperations; -import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.AggregationResults; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.mongodb.repository.util.SliceUtils; import org.springframework.data.repository.query.QueryMethod; import org.springframework.data.support.PageableExecutionUtils; @@ -186,7 +187,7 @@ public Object execute(Query query) { return isListOfGeoResult(method.getReturnType()) ? results.getContent() : results; } - @SuppressWarnings({"unchecked","NullAway"}) + @SuppressWarnings({ "unchecked", "NullAway" }) GeoResults doExecuteQuery(Query query) { Point nearLocation = accessor.getGeoNearLocation(); @@ -225,52 +226,60 @@ private static boolean isListOfGeoResult(TypeInformation returnType) { * {@link MongoQueryExecution} to execute vector search. * * @author Mark Paluch + * @author Chistoph Strobl * @since 5.0 */ class VectorSearchExecution implements MongoQueryExecution { private final MongoOperations operations; - private final MongoQueryMethod method; + private final TypeInformation returnType; private final String collectionName; - private final VectorSearchDelegate.QueryMetadata queryMetadata; - private final List pipeline; + private final Class targetType; + private final ScoringFunction scoringFunction; + private final AggregationPipeline pipeline; + + VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, + QueryContainer queryContainer) { + this(operations, queryContainer.outputType(), collectionName, method.getReturnType(), queryContainer.pipeline(), + queryContainer.scoringFunction()); + } - public VectorSearchExecution(MongoOperations operations, MongoQueryMethod method, String collectionName, - VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) { + public VectorSearchExecution(MongoOperations operations, Class targetType, String collectionName, + TypeInformation returnType, AggregationPipeline pipeline, ScoringFunction scoringFunction) { this.operations = operations; + this.returnType = returnType; this.collectionName = collectionName; - this.queryMetadata = queryMetadata; - this.method = method; - this.pipeline = queryMetadata.getAggregationPipeline(method, accessor); + this.targetType = targetType; + this.scoringFunction = scoringFunction; + this.pipeline = pipeline; } @Override public Object execute(Query query) { - AggregationResults aggregated = operations.aggregate( - TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collectionName, - queryMetadata.outputType()); + AggregationResults aggregated = operations + .aggregate(TypedAggregation.newAggregation(targetType, pipeline.getOperations()), collectionName, targetType); List mappedResults = aggregated.getMappedResults(); - if (isSearchResult(method.getReturnType())) { + if (!isSearchResult(returnType)) { + return mappedResults; + } - List rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class); - List> result = new ArrayList<>(mappedResults.size()); + List rawResults = aggregated.getRawResults().getList("results", org.bson.Document.class); + List> result = new ArrayList<>(mappedResults.size()); - for (int i = 0; i < mappedResults.size(); i++) { - Document document = rawResults.get(i); - SearchResult searchResult = new SearchResult<>(mappedResults.get(i), - Similarity.raw(document.getDouble("__score__"), queryMetadata.scoringFunction())); + for (int i = 0; i < mappedResults.size(); i++) { - result.add(searchResult); - } + Document document = rawResults.get(i); + SearchResult searchResult = new SearchResult<>(mappedResults.get(i), + Similarity.raw(document.getDouble("__score__"), scoringFunction)); - return isListOfSearchResult(method.getReturnType()) ? result : new SearchResults<>(result); + result.add(searchResult); } - return mappedResults; + return isListOfSearchResult(returnType) ? result : new SearchResults<>(result); } private static boolean isListOfSearchResult(TypeInformation returnType) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java index 389f4e871d..29e2127e18 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveMongoQueryExecution.java @@ -18,12 +18,9 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.List; - import org.bson.Document; import org.jspecify.annotations.Nullable; import org.reactivestreams.Publisher; - import org.springframework.core.convert.converter.Converter; import org.springframework.data.convert.DtoInstantiatingConverter; import org.springframework.data.domain.Pageable; @@ -36,11 +33,12 @@ import org.springframework.data.mapping.model.EntityInstantiators; import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.ReactiveUpdateOperation.ReactiveUpdate; -import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.TypedAggregation; import org.springframework.data.mongodb.core.query.NearQuery; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.core.query.UpdateDefinition; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ReturnedType; import org.springframework.data.util.ReactiveWrappers; @@ -134,24 +132,24 @@ private boolean isStreamOfGeoResult() { class VectorSearchExecution implements ReactiveMongoQueryExecution { private final ReactiveMongoOperations operations; - private final VectorSearchDelegate.QueryMetadata queryMetadata; - private final List pipeline; + private final QueryContainer queryMetadata; + private final AggregationPipeline pipeline; private final boolean returnSearchResult; - public VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, - VectorSearchDelegate.QueryMetadata queryMetadata, MongoParameterAccessor accessor) { + VectorSearchExecution(ReactiveMongoOperations operations, MongoQueryMethod method, QueryContainer queryMetadata) { this.operations = operations; this.queryMetadata = queryMetadata; - this.pipeline = queryMetadata.getAggregationPipeline(method, accessor); + this.pipeline = queryMetadata.pipeline(); this.returnSearchResult = isSearchResult(method.getReturnType()); } @Override public Publisher execute(Query query, Class type, String collection) { - Flux aggregate = operations - .aggregate(TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline), collection, Document.class); + Flux aggregate = operations.aggregate( + TypedAggregation.newAggregation(queryMetadata.outputType(), pipeline.getOperations()), collection, + Document.class); return aggregate.map(document -> { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java index 1ecbb0235f..cf75c7db94 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ReactiveVectorSearchAggregation.java @@ -19,13 +19,13 @@ import org.bson.Document; import org.reactivestreams.Publisher; - import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.ReactiveMongoOperations; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ValueExpressionDelegate; @@ -84,11 +84,11 @@ protected Publisher doExecute(ReactiveMongoQueryMethod method, ResultPro ParameterBindingContext bindingContext = new ParameterBindingContext(accessor::getBindableValue, expressionEvaluator); - VectorSearchDelegate.QueryMetadata query = delegate.createQuery(expressionEvaluator, processor, accessor, - typeToRead, codec, bindingContext); + QueryContainer query = delegate.createQuery(expressionEvaluator, processor, accessor, typeToRead, codec, + bindingContext); ReactiveMongoQueryExecution.VectorSearchExecution execution = new ReactiveMongoQueryExecution.VectorSearchExecution( - mongoOperations, method, query, accessor); + mongoOperations, method, query); return execution.execute(query.query(), Document.class, collectionEntity.getCollection()); }); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java index 9740c0696c..d9f81d09da 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java @@ -15,16 +15,17 @@ */ package org.springframework.data.mongodb.repository.query; +import org.jspecify.annotations.Nullable; import org.springframework.data.mapping.model.ValueExpressionEvaluator; import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.core.MongoOperations; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; import org.springframework.data.mongodb.core.query.Query; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.repository.query.ResultProcessor; import org.springframework.data.repository.query.ValueExpressionDelegate; -import org.springframework.lang.Nullable; /** * {@link AbstractMongoQuery} implementation to run a {@link VectorSearchAggregation}. The pre-filter is either derived @@ -67,15 +68,15 @@ public VectorSearchAggregation(MongoQueryMethod method, MongoOperations mongoOpe protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor, @Nullable Class typeToRead) { - VectorSearchDelegate.QueryMetadata query = createVectorSearchQuery(processor, accessor, typeToRead); + QueryContainer query = createVectorSearchQuery(processor, accessor, typeToRead); MongoQueryExecution.VectorSearchExecution execution = new MongoQueryExecution.VectorSearchExecution(mongoOperations, - method, collectionEntity.getCollection(), query, accessor); + method, collectionEntity.getCollection(), query); return execution.execute(query.query()); } - VectorSearchDelegate.QueryMetadata createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor, + QueryContainer createVectorSearchQuery(ResultProcessor processor, MongoParameterAccessor accessor, @Nullable Class typeToRead) { ValueExpressionEvaluator evaluator = getExpressionEvaluatorFor(accessor); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java index 8932b85b1b..b82e9e4b64 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java @@ -20,7 +20,6 @@ import org.bson.Document; import org.jspecify.annotations.Nullable; - import org.springframework.data.domain.Limit; import org.springframework.data.domain.Range; import org.springframework.data.domain.Score; @@ -34,6 +33,7 @@ import org.springframework.data.mongodb.InvalidMongoDbApiUsageException; import org.springframework.data.mongodb.core.aggregation.Aggregation; import org.springframework.data.mongodb.core.aggregation.AggregationOperation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.convert.MongoConverter; import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity; @@ -58,32 +58,35 @@ class VectorSearchDelegate { private final VectorSearchQueryFactory queryFactory; private final VectorSearchOperation.SearchType searchType; + private final String indexName; private final @Nullable Integer numCandidates; private final @Nullable String numCandidatesExpression; private final Limit limit; private final @Nullable String limitExpression; private final MongoConverter converter; - public VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) { + VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, ValueExpressionDelegate delegate) { VectorSearch vectorSearch = method.findAnnotatedVectorSearch().orElseThrow(); + this.searchType = vectorSearch.searchType(); + this.indexName = method.getAnnotatedHint(); if (StringUtils.hasText(vectorSearch.numCandidates())) { ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.numCandidates()); if (expression.isLiteral()) { - numCandidates = Integer.parseInt(vectorSearch.numCandidates()); - numCandidatesExpression = null; + this.numCandidates = Integer.parseInt(vectorSearch.numCandidates()); + this.numCandidatesExpression = null; } else { - numCandidates = null; - numCandidatesExpression = vectorSearch.numCandidates(); + this.numCandidates = null; + this.numCandidatesExpression = vectorSearch.numCandidates(); } } else { - numCandidates = null; - numCandidatesExpression = null; + this.numCandidates = null; + this.numCandidatesExpression = null; } if (StringUtils.hasText(vectorSearch.limit())) { @@ -91,26 +94,26 @@ public VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, V ValueExpression expression = delegate.getValueExpressionParser().parse(vectorSearch.limit()); if (expression.isLiteral()) { - limit = Limit.of(Integer.parseInt(vectorSearch.limit())); - limitExpression = null; + this.limit = Limit.of(Integer.parseInt(vectorSearch.limit())); + this.limitExpression = null; } else { - limit = Limit.unlimited(); - limitExpression = vectorSearch.limit(); + this.limit = Limit.unlimited(); + this.limitExpression = vectorSearch.limit(); } } else { - limit = Limit.unlimited(); - limitExpression = null; + this.limit = Limit.unlimited(); + this.limitExpression = null; } this.converter = converter; if (StringUtils.hasText(vectorSearch.filter())) { - queryFactory = StringUtils.hasText(vectorSearch.path()) + this.queryFactory = StringUtils.hasText(vectorSearch.path()) ? new AnnotatedQueryFactory(vectorSearch.filter(), vectorSearch.path()) : new AnnotatedQueryFactory(vectorSearch.filter(), method.getEntityInformation().getCollectionEntity()); } else { - queryFactory = new PartTreeQueryFactory( + this.queryFactory = new PartTreeQueryFactory( new PartTree(method.getName(), method.getResultProcessor().getReturnedType().getDomainType()), converter.getMappingContext()); } @@ -119,43 +122,122 @@ public VectorSearchDelegate(MongoQueryMethod method, MongoConverter converter, V /** * Create Query Metadata for {@code $vectorSearch}. */ - public QueryMetadata createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor, + QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor processor, MongoParameterAccessor accessor, @Nullable Class typeToRead, ParameterBindingDocumentCodec codec, ParameterBindingContext context) { - Integer numCandidates = null; - Limit limit; + String scoreField = "__score__"; Class outputType = typeToRead != null ? typeToRead : processor.getReturnedType().getReturnedType(); - VectorSearchInput query = queryFactory.createQuery(accessor, codec, context); + VectorSearchInput vectorSearchInput = createSearchInput(evaluator, accessor, codec, context); + AggregationPipeline pipeline = createVectorSearchPipeline(vectorSearchInput, scoreField, outputType, accessor, + evaluator); - if (this.limitExpression != null) { - Object value = evaluator.evaluate(this.limitExpression); - limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); - } else if (this.limit.isLimited()) { - limit = this.limit; - } else { - limit = accessor.getLimit(); - } + return new QueryContainer(vectorSearchInput.path, scoreField, vectorSearchInput.query, pipeline, searchType, + outputType, getSimilarityFunction(accessor), indexName); + } - if (limit.isLimited()) { - query.query().limit(limit); + public AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class outputType, + MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) { + + Vector vector = accessor.getVector(); + Score score = accessor.getScore(); + Range distance = accessor.getScoreRange(); + Limit limit = Limit.unlimited(); + + if (input.query().isLimited()) { + limit = Limit.of(input.query().getLimit()); } + List stages = new ArrayList<>(); + VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector) + .limit(limit); + + Integer candidates = null; if (this.numCandidatesExpression != null) { - numCandidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); + candidates = ((Number) evaluator.evaluate(this.numCandidatesExpression)).intValue(); } else if (this.numCandidates != null) { - numCandidates = this.numCandidates; - } else if (query.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN + candidates = this.numCandidates; + } else if (input.query().isLimited() && (searchType == VectorSearchOperation.SearchType.ANN || searchType == VectorSearchOperation.SearchType.DEFAULT)) { /* MongoDB: We recommend that you specify a number at least 20 times higher than the number of documents to return (limit) to increase accuracy. */ - numCandidates = query.query().getLimit() * 20; + candidates = input.query().getLimit() * 20; + } + + if (candidates != null) { + $vectorSearch = $vectorSearch.numCandidates(candidates); + } + // + $vectorSearch = $vectorSearch.filter(input.query.getQueryObject()); + $vectorSearch = $vectorSearch.searchType(this.searchType); + $vectorSearch = $vectorSearch.withSearchScore(scoreField); + + if (score != null) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + c.gt(score.getValue()); + }); + } else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) { + $vectorSearch = $vectorSearch.withFilterBySore(c -> { + Range.Bound lower = distance.getLowerBound(); + if (lower.isBounded()) { + double value = lower.getValue().get().getValue(); + if (lower.isInclusive()) { + c.gte(value); + } else { + c.gt(value); + } + } + + Range.Bound upper = distance.getUpperBound(); + if (upper.isBounded()) { + + double value = upper.getValue().get().getValue(); + if (upper.isInclusive()) { + c.lte(value); + } else { + c.lt(value); + } + } + }); + } + + stages.add($vectorSearch); + + if (input.query().isSorted()) { + + stages.add(ctx -> { + + Document mappedSort = ctx.getMappedObject(input.query().getSortObject(), outputType); + mappedSort.append(scoreField, -1); + return ctx.getMappedObject(new Document("$sort", mappedSort)); + }); + } else { + stages.add(Aggregation.sort(Sort.Direction.DESC, scoreField)); + } + + return new AggregationPipeline(stages); + } + + private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor, + ParameterBindingDocumentCodec codec, ParameterBindingContext context) { + + VectorSearchInput query = queryFactory.createQuery(accessor, codec, context); + Limit limit; + if (this.limitExpression != null) { + Object value = evaluator.evaluate(this.limitExpression); + limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); + } else if (this.limit.isLimited()) { + limit = this.limit; + } else { + limit = accessor.getLimit(); } - return new QueryMetadata(query.path, "__score__", query.query, searchType, outputType, numCandidates, - getSimilarityFunction(accessor)); + if (limit.isLimited()) { + query.query().limit(limit); + } + return query; } public String getQueryString() { @@ -192,82 +274,10 @@ ScoringFunction getSimilarityFunction(MongoParameterAccessor accessor) { * @param query * @param searchType * @param outputType - * @param numCandidates * @param scoringFunction */ - public record QueryMetadata(String path, String scoreField, Query query, VectorSearchOperation.SearchType searchType, - Class outputType, @Nullable Integer numCandidates, ScoringFunction scoringFunction) { - - /** - * Create the Aggregation Pipeline. - * - * @param queryMethod - * @param accessor - * @return - */ - public List getAggregationPipeline(MongoQueryMethod queryMethod, - MongoParameterAccessor accessor) { - - Vector vector = accessor.getVector(); - Score score = accessor.getScore(); - Range distance = accessor.getScoreRange(); - Limit limit = Limit.unlimited(); - - if (query.isLimited()) { - limit = Limit.of(query.getLimit()); - } - - List stages = new ArrayList<>(); - VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(queryMethod.getAnnotatedHint()).path(path()) - .vector(vector).limit(limit); - - if (numCandidates() != null) { - $vectorSearch = $vectorSearch.numCandidates(numCandidates()); - } - - $vectorSearch = $vectorSearch.filter(query.getQueryObject()); - $vectorSearch = $vectorSearch.searchType(searchType()); - $vectorSearch = $vectorSearch.withSearchScore(scoreField()); - - if (score != null) { - $vectorSearch = $vectorSearch.withFilterBySore(c -> { - c.gt(score.getValue()); - }); - } else if (distance.getLowerBound().isBounded() || distance.getUpperBound().isBounded()) { - $vectorSearch = $vectorSearch.withFilterBySore(c -> { - Range.Bound lower = distance.getLowerBound(); - if (lower.isBounded()) { - double value = lower.getValue().get().getValue(); - if (lower.isInclusive()) { - c.gte(value); - } else { - c.gt(value); - } - } - - Range.Bound upper = distance.getUpperBound(); - if (upper.isBounded()) { - - double value = upper.getValue().get().getValue(); - if (upper.isInclusive()) { - c.lte(value); - } else { - c.lt(value); - } - } - }); - } - - stages.add($vectorSearch); - - if (query.isSorted()) { - // TODO stages.add(Aggregation.sort(query.with())); - } else { - stages.add(Aggregation.sort(Sort.Direction.DESC, "__score__")); - } - - return stages; - } + record QueryContainer(String path, String scoreField, Query query, AggregationPipeline pipeline, + VectorSearchOperation.SearchType searchType, Class outputType, ScoringFunction scoringFunction, String index) { } @@ -371,8 +381,8 @@ private class PartTreeQueryFactory implements VectorSearchQueryFactory { public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, ParameterBindingContext context) { - MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), - false, true); + MongoQueryCreator creator = new MongoQueryCreator(tree, parameterAccessor, converter.getMappingContext(), false, + true); Query query = creator.createQuery(parameterAccessor.getSort()); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java index 028a6926fb..a224481da1 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VectorSearchTests.java @@ -81,16 +81,15 @@ public String getDatabaseName() { @Override public MongoClient mongoClient() { - atlasLocal.start(); return MongoClients.create(atlasLocal.getConnectionString()); } } @BeforeAll static void beforeAll() throws InterruptedException { + atlasLocal.start(); - System.out.println(atlasLocal.getConnectionString()); client = MongoClients.create(atlasLocal.getConnectionString()); template = new MongoTestTemplate(client, "vector-search-tests"); diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java index c347936dfe..819bba5a48 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregationUnitTests.java @@ -34,6 +34,7 @@ import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.repository.CrudRepository; @@ -68,7 +69,7 @@ void derivesPrefilter() throws Exception { VectorSearchAggregation aggregation = aggregation(SampleRepository.class, "searchByCountryAndEmbeddingNear", String.class, Vector.class, Score.class, Limit.class); - VectorSearchDelegate.QueryMetadata query = aggregation.createVectorSearchQuery( + QueryContainer query = aggregation.createVectorSearchQuery( aggregation.getQueryMethod().getResultProcessor(), new MongoParametersParameterAccessor(aggregation.getQueryMethod(), new Object[] { "de", Vector.of(1f), Score.of(1), Limit.unlimited() }), diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java index 06a80e78fc..078c01eece 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegateUnitTests.java @@ -15,23 +15,30 @@ */ package org.springframework.data.mongodb.repository.query; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import java.lang.reflect.Method; +import java.util.List; +import org.bson.Document; +import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; - import org.springframework.data.domain.Limit; import org.springframework.data.domain.Score; import org.springframework.data.domain.SearchResults; import org.springframework.data.domain.Vector; import org.springframework.data.mapping.model.ValueExpressionEvaluator; +import org.springframework.data.mongodb.core.aggregation.Aggregation; +import org.springframework.data.mongodb.core.aggregation.AggregationPipeline; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation; import org.springframework.data.mongodb.core.convert.MappingMongoConverter; import org.springframework.data.mongodb.core.convert.NoOpDbRefResolver; +import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.mapping.MongoMappingContext; import org.springframework.data.mongodb.repository.VectorSearch; +import org.springframework.data.mongodb.repository.query.VectorSearchDelegate.QueryContainer; +import org.springframework.data.mongodb.util.aggregation.TestAggregationContext; import org.springframework.data.mongodb.util.json.ParameterBindingContext; import org.springframework.data.mongodb.util.json.ParameterBindingDocumentCodec; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; @@ -44,6 +51,7 @@ * Unit tests for {@link VectorSearchDelegate}. * * @author Mark Paluch + * @author Christoph Strobl */ class VectorSearchDelegateUnitTests { @@ -57,10 +65,10 @@ void shouldConsiderDerivedLimit() throws ReflectiveOperationException { MongoQueryMethod queryMethod = getMongoQueryMethod(method); MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); - VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); + QueryContainer container = createQueryContainer(queryMethod, accessor); - assertThat(query.query().getLimit()).isEqualTo(10); - assertThat(query.numCandidates()).isEqualTo(10 * 20); + assertThat(container.query().getLimit()).isEqualTo(10); + assertThat(numCandidates(container.pipeline())).isEqualTo(10 * 20); } @Test @@ -71,10 +79,10 @@ void shouldNotSetNumCandidates() throws ReflectiveOperationException { MongoQueryMethod queryMethod = getMongoQueryMethod(method); MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1)); - VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); + QueryContainer container = createQueryContainer(queryMethod, accessor); - assertThat(query.query().getLimit()).isEqualTo(10); - assertThat(query.numCandidates()).isNull(); + assertThat(container.query().getLimit()).isEqualTo(10); + assertThat(numCandidates(container.pipeline())).isNull(); } @Test @@ -86,19 +94,87 @@ void shouldConsiderProvidedLimit() throws ReflectiveOperationException { MongoQueryMethod queryMethod = getMongoQueryMethod(method); MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11)); - VectorSearchDelegate.QueryMetadata query = createQueryMetadata(queryMethod, accessor); + QueryContainer container = createQueryContainer(queryMethod, accessor); - assertThat(query.query().getLimit()).isEqualTo(11); - assertThat(query.numCandidates()).isEqualTo(11 * 20); + assertThat(container.query().getLimit()).isEqualTo(11); + assertThat(numCandidates(container.pipeline())).isEqualTo(11 * 20); } - private VectorSearchDelegate.QueryMetadata createQueryMetadata(MongoQueryMethod queryMethod, - MongoParametersParameterAccessor accessor) { + @Test + void considersDerivedQueryPart() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByFirstNameAndEmbeddingNear", String.class, + Vector.class, Score.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, "spring", Vector.of(1, 2), Score.of(1)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter", + new Document("first_name", "spring")); + } + + @Test + void considersDerivedQueryPartInDifferentOrder() throws ReflectiveOperationException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNearAndFirstName", Vector.class, + Score.class, String.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), "spring"); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + assertThat(vectorSearchStageOf(container.pipeline())).containsEntry("$vectorSearch.filter", + new Document("first_name", "spring")); + } + + @Test + void defaultSortsByScore() throws NoSuchMethodException { + + Method method = VectorSearchRepository.class.getMethod("searchTop10ByEmbeddingNear", Vector.class, Score.class, + Limit.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(10)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + + List stages = container.pipeline().lastOperation() + .toPipelineStages(TestAggregationContext.contextFor(WithVector.class)); + + assertThat(stages).containsExactly(new Document("$sort", new Document("__score__", -1))); + } + + @Test + void usesDerivedSort() throws NoSuchMethodException { + + Method method = VectorSearchRepository.class.getMethod("searchByEmbeddingNearOrderByFirstName", Vector.class, + Score.class, Limit.class); + + MongoQueryMethod queryMethod = getMongoQueryMethod(method); + MongoParametersParameterAccessor accessor = getAccessor(queryMethod, Vector.of(1, 2), Score.of(1), Limit.of(11)); + + QueryContainer container = createQueryContainer(queryMethod, accessor); + AggregationPipeline aggregationPipeline = container.pipeline(); + + List stages = aggregationPipeline.lastOperation() + .toPipelineStages(TestAggregationContext.contextFor(WithVector.class)); + + assertThat(stages).containsExactly(new Document("$sort", new Document("first_name", 1).append("__score__", -1))); + } + + Document vectorSearchStageOf(AggregationPipeline pipeline) { + return pipeline.firstOperation().toPipelineStages(TestAggregationContext.contextFor(WithVector.class)).get(0); + } + + private QueryContainer createQueryContainer(MongoQueryMethod queryMethod, MongoParametersParameterAccessor accessor) { VectorSearchDelegate delegate = new VectorSearchDelegate(queryMethod, converter, ValueExpressionDelegate.create()); - return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, - Object.class, new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class)); + return delegate.createQuery(mock(ValueExpressionEvaluator.class), queryMethod.getResultProcessor(), accessor, null, + new ParameterBindingDocumentCodec(), mock(ParameterBindingContext.class)); } private MongoQueryMethod getMongoQueryMethod(Method method) { @@ -110,21 +186,69 @@ private static MongoParametersParameterAccessor getAccessor(MongoQueryMethod que return new MongoParametersParameterAccessor(queryMethod, values); } + @Nullable + private static Integer numCandidates(AggregationPipeline pipeline) { + + Document $vectorSearch = pipeline.firstOperation().toPipelineStages(Aggregation.DEFAULT_CONTEXT).get(0); + if ($vectorSearch.containsKey("$vectorSearch")) { + Object value = $vectorSearch.get("$vectorSearch", Document.class).get("numCandidates"); + return value instanceof Number i ? i.intValue() : null; + } + return null; + } + interface VectorSearchRepository extends Repository { @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity); + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByFirstNameAndEmbeddingNear(String firstName, Vector vector, Score similarity); + + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchTop10ByEmbeddingNearAndFirstName(Vector vector, Score similarity, String firstname); + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ENN) SearchResults searchTop10EnnByEmbeddingNear(Vector vector, Score similarity); @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity, Limit limit); + @VectorSearch(indexName = "cos-index", searchType = VectorSearchOperation.SearchType.ANN) + SearchResults searchByEmbeddingNearOrderByFirstName(Vector vector, Score similarity, Limit limit); + } static class WithVector { Vector embedding; + + String lastName; + + @Field("first_name") String firstName; + + public Vector getEmbedding() { + return embedding; + } + + public void setEmbedding(Vector embedding) { + this.embedding = embedding; + } + + public String getLastName() { + return lastName; + } + + public void setLastName(String lastName) { + this.lastName = lastName; + } + + public String getFirstName() { + return firstName; + } + + public void setFirstName(String firstName) { + this.firstName = firstName; + } } } From 1ac2c9f599b9ff6a886a42ed373fd4729bfd0b78 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Wed, 7 May 2025 09:31:35 +0200 Subject: [PATCH 09/10] are we done? --- .../core/aggregation/ArrayOperators.java | 1 + .../aggregation/VectorSearchOperation.java | 5 +- .../query/ConvertingParameterAccessor.java | 2 +- .../query/VectorSearchAggregation.java | 1 - .../query/VectorSearchDelegate.java | 48 ++++++++++++------- .../VectorSearchOperationUnitTests.java | 14 +++++- 6 files changed, 50 insertions(+), 21 deletions(-) diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java index 02b805d5ed..85952d8f39 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/ArrayOperators.java @@ -356,6 +356,7 @@ public SortArray sort(Sort sort) { * @return new instance of {@link SortArray}. * @since 4.5 */ + @SuppressWarnings("NullAway") public SortArray sort(Direction direction) { if (usesFieldRef()) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java index 2c74900bc5..95f1c5b4d2 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperation.java @@ -237,7 +237,10 @@ public Document toDocument(AggregationOperationContext context) { } $vectorSearch.append("index", indexName); - $vectorSearch.append("limit", limit.max()); + + if(limit.isLimited()) { // TODO: exception or pass it on? + $vectorSearch.append("limit", limit.max()); + } if (numCandidates != null) { $vectorSearch.append("numCandidates", numCandidates); diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java index 0eac1aa3e0..f203b67e67 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/ConvertingParameterAccessor.java @@ -77,7 +77,7 @@ public PotentiallyConvertingIterator iterator() { } @Override - public Vector getVector() { + public @Nullable Vector getVector() { return delegate.getVector(); } diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java index d9f81d09da..eb8dc2e52e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchAggregation.java @@ -63,7 +63,6 @@ public VectorSearchAggregation(MongoQueryMethod method, MongoOperations mongoOpe this.delegate = new VectorSearchDelegate(method, mongoOperations.getConverter(), delegate); } - @SuppressWarnings("unchecked") @Override protected Object doExecute(MongoQueryMethod method, ResultProcessor processor, ConvertingParameterAccessor accessor, @Nullable Class typeToRead) { diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java index b82e9e4b64..0dbff2e932 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/repository/query/VectorSearchDelegate.java @@ -47,6 +47,7 @@ import org.springframework.data.repository.query.ValueExpressionDelegate; import org.springframework.data.repository.query.parser.Part; import org.springframework.data.repository.query.parser.PartTree; +import org.springframework.util.NumberUtils; import org.springframework.util.StringUtils; /** @@ -136,17 +137,14 @@ QueryContainer createQuery(ValueExpressionEvaluator evaluator, ResultProcessor p outputType, getSimilarityFunction(accessor), indexName); } - public AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class outputType, + @SuppressWarnings("NullAway") + AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, String scoreField, Class outputType, MongoParameterAccessor accessor, ValueExpressionEvaluator evaluator) { Vector vector = accessor.getVector(); Score score = accessor.getScore(); Range distance = accessor.getScoreRange(); - Limit limit = Limit.unlimited(); - - if (input.query().isLimited()) { - limit = Limit.of(input.query().getLimit()); - } + Limit limit = Limit.of(input.query().getLimit()); List stages = new ArrayList<>(); VectorSearchOperation $vectorSearch = Aggregation.vectorSearch(indexName).path(input.path()).vector(vector) @@ -223,21 +221,38 @@ public AggregationPipeline createVectorSearchPipeline(VectorSearchInput input, S private VectorSearchInput createSearchInput(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor, ParameterBindingDocumentCodec codec, ParameterBindingContext context) { - VectorSearchInput query = queryFactory.createQuery(accessor, codec, context); - Limit limit; + VectorSearchInput input = queryFactory.createQuery(accessor, codec, context); + Limit limit = getLimit(evaluator, accessor); + if(!input.query.isLimited() || (input.query.isLimited() && !limit.isUnlimited())) { + input.query().limit(limit); + } + return input; + } + + private Limit getLimit(ValueExpressionEvaluator evaluator, MongoParameterAccessor accessor) { + if (this.limitExpression != null) { + Object value = evaluator.evaluate(this.limitExpression); - limit = value instanceof Limit l ? l : Limit.of(((Number) value).intValue()); - } else if (this.limit.isLimited()) { - limit = this.limit; - } else { - limit = accessor.getLimit(); + if (value != null) { + if (value instanceof Limit l) { + return l; + } + if (value instanceof Number n) { + return Limit.of(n.intValue()); + } + if (value instanceof String s) { + return Limit.of(NumberUtils.parseNumber(s, Integer.class)); + } + throw new IllegalArgumentException("Invalid type for Limit. Found [%s], expected Limit or Number"); + } } - if (limit.isLimited()) { - query.query().limit(limit); + if (this.limit.isLimited()) { + return this.limit; } - return query; + + return accessor.getLimit(); } public String getQueryString() { @@ -378,6 +393,7 @@ private class PartTreeQueryFactory implements VectorSearchQueryFactory { this.tree = tree; } + @SuppressWarnings("NullAway") public VectorSearchInput createQuery(MongoParameterAccessor parameterAccessor, ParameterBindingDocumentCodec codec, ParameterBindingContext context) { diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java index 4ce045fe6f..936460f466 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/aggregation/VectorSearchOperationUnitTests.java @@ -15,14 +15,14 @@ */ package org.springframework.data.mongodb.core.aggregation; -import static org.assertj.core.api.Assertions.*; +import static org.springframework.data.mongodb.test.util.Assertions.assertThat; import java.util.List; import org.bson.Document; import org.junit.jupiter.api.Test; - import org.springframework.data.annotation.Id; +import org.springframework.data.domain.Limit; import org.springframework.data.mongodb.core.aggregation.VectorSearchOperation.SearchType; import org.springframework.data.mongodb.core.mapping.Field; import org.springframework.data.mongodb.core.query.Criteria; @@ -103,6 +103,16 @@ void mapsCriteriaToDomainType() { .containsExactly(new Document("$vectorSearch", new Document($VECTOR_SEARCH).append("filter", filter))); } + @Test + void withInvalidLimit() { + + VectorSearchOperation $search = VectorSearchOperation.search("vector_index").path("plot_embedding") + .vector(-0.0016261312, -0.028070757, -0.011342932).limit(Limit.unlimited()); + + List stages = $search.toPipelineStages(TestAggregationContext.contextFor(Movie.class)); + assertThat(stages.get(0)).doesNotContainKey("$vectorSearch.limit"); + } + static class Movie { @Id String id; From 72806aabe1209f89543daed69977dd067e4e0f22 Mon Sep 17 00:00:00 2001 From: Christoph Strobl Date: Wed, 7 May 2025 12:07:25 +0200 Subject: [PATCH 10/10] Update documentation fix invalid domain type reference and make sure to include required arguemtns --- .../ROOT/pages/mongodb/mongo-search-indexes.adoc | 2 +- .../vector-search-method-annotated-include.adoc | 8 ++++---- .../vector-search-method-derived-include.adoc | 12 ++++++------ .../partials/vector-search-repository-include.adoc | 12 +++++++++--- .../ROOT/partials/vector-search-scoring-include.adoc | 6 +++--- 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc b/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc index 345b5dbb6c..7fc51de007 100644 --- a/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc +++ b/src/main/antora/modules/ROOT/pages/mongodb/mongo-search-indexes.adoc @@ -25,7 +25,7 @@ Java:: [source,java,indent=0,subs="verbatim,quotes",role="primary"] ---- VectorIndex index = new VectorIndex("vector_index") - .addVector("plotEmbedding"), vector -> vector.dimensions(1536).similarity(COSINE)) <1> + .addVector("plotEmbedding", vector -> vector.dimensions(1536).similarity(COSINE)) <1> .addFilter("year"); <2> mongoTemplate.searchIndexOps(Movie.class) <3> diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc index 752ffad622..252437f0b7 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-annotated-include.adoc @@ -6,13 +6,13 @@ Annotated search methods use the `@VectorSearch` annotation to define parameters ---- interface CommentRepository extends Repository { - @VectorSearch(indexName = "cos-index", filter = "{country: ?0}") - SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + @VectorSearch(indexName = "cos-index", filter = "{country: ?0}", limit="100", numCandidates="2000") + SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance); - @VectorSearch(indexName = "my-index", filter = "{country: ?0}", numCandidates = "#{#limit * 20}", + @VectorSearch(indexName = "my-index", filter = "{country: ?0}", limit="?3", numCandidates = "#{#limit * 20}", searchType = VectorSearchOperation.SearchType.ANN) - List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit); + List findAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, Score distance, int limit); } ---- ==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc index dd06ee699a..f2b006b8e4 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-method-derived-include.adoc @@ -6,14 +6,14 @@ MongoDB Search methods must use the `@VectorSearch` annotation to define the ind ---- interface CommentRepository extends Repository { - @VectorSearch(indexName = "my-index") - SearchResults searchByEmbeddingNear(Vector vector, Score score); + @VectorSearch(indexName = "my-index", numCandidates="200") + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score score); - @VectorSearch(indexName = "my-index") - SearchResults searchByEmbeddingWithin(Vector vector, Range range); + @VectorSearch(indexName = "my-index", numCandidates="200") + SearchResults searchTop10ByEmbeddingWithin(Vector vector, Range range); - @VectorSearch(indexName = "my-index") - SearchResults searchByCountryAndEmbeddingWithin(String country, Vector vector, Range range); + @VectorSearch(indexName = "my-index", numCandidates="200") + SearchResults searchTop10ByCountryAndEmbeddingWithin(String country, Vector vector, Range range); } ---- ==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc index c7ad91c9db..0e987fc1c5 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-repository-include.adoc @@ -4,12 +4,12 @@ ---- interface CommentRepository extends Repository { - @VectorSearch(indexName = "my-index") + @VectorSearch(indexName = "my-index", numCandidates="#{#limit.max() * 20}") SearchResults searchByCountryAndEmbeddingNear(String country, Vector vector, Score score, Limit limit); - @VectorSearch(indexName = "my-index") - SearchResults searchAnnotatedByCountryAndEmbeddingWithin(String country, Vector embedding, + @VectorSearch(indexName = "my-index", limit="10", numCandidates="200") + SearchResults searchByCountryAndEmbeddingWithin(String country, Vector embedding, Score score); } @@ -17,3 +17,9 @@ interface CommentRepository extends Repository { SearchResults results = repository.searchByCountryAndEmbeddingNear("en", Vector.of(…), Score.of(0.9), Limit.of(10)); ---- ==== + +[TIP] +==== +The MongoDB https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/[vector search aggregation] stage defines a set of required arguments and restrictions. +Please make sure to follow the guidelines and make sure to provide required arguments like `limit`. +==== diff --git a/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc index b97475b467..313d8bf394 100644 --- a/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc +++ b/src/main/antora/modules/ROOT/partials/vector-search-scoring-include.adoc @@ -9,13 +9,13 @@ The scoring function defaults to `ScoringFunction.unspecified()` as there is no interface CommentRepository extends Repository { @VectorSearch(…) - SearchResults searchByEmbeddingNear(Vector vector, Score similarity); + SearchResults searchTop10ByEmbeddingNear(Vector vector, Score similarity); @VectorSearch(…) - SearchResults searchByEmbeddingNear(Vector vector, Similarity similarity); + SearchResults searchTop10ByEmbeddingNear(Vector vector, Similarity similarity); @VectorSearch(…) - SearchResults searchByEmbeddingNear(Vector vector, Range range); + SearchResults searchTop10ByEmbeddingNear(Vector vector, Range range); } repository.searchByEmbeddingNear(Vector.of(…), Score.of(0.9)); <1>