Skip to content

Commit 2b93bf3

Browse files
christophstroblmp911de
authored andcommitted
Add support for $vectorSearch aggregation stage.
Closes #4706 Original pull request: #4882
1 parent dd4579c commit 2b93bf3

16 files changed

+1643
-9
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/DefaultIndexOperations.java

+12-6
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
import java.util.List;
2121

2222
import org.bson.Document;
23-
2423
import org.springframework.dao.DataAccessException;
2524
import org.springframework.data.mongodb.MongoDatabaseFactory;
2625
import org.springframework.data.mongodb.UncategorizedMongoDbException;
2726
import org.springframework.data.mongodb.core.convert.QueryMapper;
27+
import org.springframework.data.mongodb.core.index.DefaultVectorIndexOperations;
2828
import org.springframework.data.mongodb.core.index.IndexDefinition;
2929
import org.springframework.data.mongodb.core.index.IndexInfo;
3030
import org.springframework.data.mongodb.core.index.IndexOperations;
31+
import org.springframework.data.mongodb.core.index.VectorIndexOperations;
3132
import org.springframework.data.mongodb.core.mapping.MongoPersistentEntity;
3233
import org.springframework.lang.Nullable;
3334
import org.springframework.util.Assert;
@@ -51,11 +52,11 @@ public class DefaultIndexOperations implements IndexOperations {
5152

5253
private static final String PARTIAL_FILTER_EXPRESSION_KEY = "partialFilterExpression";
5354

54-
private final String collectionName;
55-
private final QueryMapper mapper;
56-
private final @Nullable Class<?> type;
55+
protected final String collectionName;
56+
protected final QueryMapper mapper;
57+
protected final @Nullable Class<?> type;
5758

58-
private final MongoOperations mongoOperations;
59+
protected final MongoOperations mongoOperations;
5960

6061
/**
6162
* Creates a new {@link DefaultIndexOperations}.
@@ -133,7 +134,7 @@ public String ensureIndex(IndexDefinition indexDefinition) {
133134
}
134135

135136
@Nullable
136-
private MongoPersistentEntity<?> lookupPersistentEntity(@Nullable Class<?> entityType, String collection) {
137+
protected MongoPersistentEntity<?> lookupPersistentEntity(@Nullable Class<?> entityType, String collection) {
137138

138139
if (entityType != null) {
139140
return mapper.getMappingContext().getRequiredPersistentEntity(entityType);
@@ -209,6 +210,11 @@ private List<IndexInfo> getIndexData(MongoCursor<Document> cursor) {
209210
});
210211
}
211212

213+
@Override
214+
public VectorIndexOperations vectorIndex() {
215+
return new DefaultVectorIndexOperations(mongoOperations, collectionName, type);
216+
}
217+
212218
@Nullable
213219
public <T> T execute(CollectionCallback<T> callback) {
214220

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
/*
2+
* Copyright 2024 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package org.springframework.data.mongodb.core.aggregation;
17+
18+
import java.util.Arrays;
19+
import java.util.LinkedHashSet;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.Set;
23+
import java.util.function.Consumer;
24+
import java.util.stream.Collectors;
25+
26+
import org.bson.Document;
27+
import org.springframework.data.domain.Limit;
28+
import org.springframework.data.mongodb.core.query.Criteria;
29+
import org.springframework.data.mongodb.core.query.CriteriaDefinition;
30+
import org.springframework.lang.Nullable;
31+
import org.springframework.util.StringUtils;
32+
33+
/**
34+
* @author Christoph Strobl
35+
*/
36+
public class VectorSearchOperation implements AggregationOperation {
37+
38+
public enum SearchType {
39+
40+
/** MongoDB Server default (value will be omitted) */
41+
DEFAULT,
42+
/** Approximate Nearest Neighbour */
43+
ANN,
44+
/** Exact Nearest Neighbour */
45+
ENN
46+
}
47+
48+
// A query path cannot only contain the name of the filed but may also hold additional information about the
49+
// analyzer to use;
50+
// "path": [ "names", "notes", { "value": "comments", "multi": "mySecondaryAnalyzer" } ]
51+
// see: https://www.mongodb.com/docs/atlas/atlas-search/path-construction/#std-label-ref-path
52+
public static class QueryPaths {
53+
54+
Set<QueryPath<?>> paths;
55+
56+
public static QueryPaths of(QueryPath<String> path) {
57+
58+
QueryPaths queryPaths = new QueryPaths();
59+
queryPaths.paths = new LinkedHashSet<>(2);
60+
queryPaths.paths.add(path);
61+
return queryPaths;
62+
}
63+
64+
Object getPathObject() {
65+
66+
if (paths.size() == 1) {
67+
return paths.iterator().next().value();
68+
}
69+
return paths.stream().map(QueryPath::value).collect(Collectors.toList());
70+
}
71+
}
72+
73+
public interface QueryPath<T> {
74+
75+
T value();
76+
77+
static QueryPath<String> path(String field) {
78+
return new SimplePath(field);
79+
}
80+
81+
static QueryPath<Map<String, Object>> wildcard(String field) {
82+
return new WildcardPath(field);
83+
}
84+
85+
static QueryPath<Map<String, Object>> multi(String field, String analyzer) {
86+
return new MultiPath(field, analyzer);
87+
}
88+
}
89+
90+
public static class SimplePath implements QueryPath<String> {
91+
92+
String name;
93+
94+
public SimplePath(String name) {
95+
this.name = name;
96+
}
97+
98+
@Override
99+
public String value() {
100+
return name;
101+
}
102+
}
103+
104+
public static class WildcardPath implements QueryPath<Map<String, Object>> {
105+
106+
String name;
107+
108+
public WildcardPath(String name) {
109+
this.name = name;
110+
}
111+
112+
@Override
113+
public Map<String, Object> value() {
114+
return Map.of("wildcard", name);
115+
}
116+
}
117+
118+
public static class MultiPath implements QueryPath<Map<String, Object>> {
119+
120+
String field;
121+
String analyzer;
122+
123+
public MultiPath(String field, String analyzer) {
124+
this.field = field;
125+
this.analyzer = analyzer;
126+
}
127+
128+
@Override
129+
public Map<String, Object> value() {
130+
return Map.of("value", field, "multi", analyzer);
131+
}
132+
}
133+
134+
private SearchType searchType;
135+
private CriteriaDefinition filter;
136+
private String indexName;
137+
private Limit limit;
138+
private Integer numCandidates;
139+
private QueryPaths path;
140+
private List<Double> vector;
141+
142+
private String score;
143+
private Consumer<Criteria> scoreCriteria;
144+
145+
private VectorSearchOperation(SearchType searchType, CriteriaDefinition filter, String indexName, Limit limit,
146+
Integer numCandidates, QueryPaths path, List<Double> vector, String searchScore,
147+
Consumer<Criteria> scoreCriteria) {
148+
149+
this.searchType = searchType;
150+
this.filter = filter;
151+
this.indexName = indexName;
152+
this.limit = limit;
153+
this.numCandidates = numCandidates;
154+
this.path = path;
155+
this.vector = vector;
156+
this.score = searchScore;
157+
this.scoreCriteria = scoreCriteria;
158+
}
159+
160+
public VectorSearchOperation(String indexName, QueryPaths path, Limit limit, List<Double> vector) {
161+
this(SearchType.DEFAULT, null, indexName, limit, null, path, vector, null, null);
162+
}
163+
164+
static PathContributor search(String index) {
165+
return new VectorSearchBuilder().index(index);
166+
}
167+
168+
public VectorSearchOperation(String indexName, String path, Limit limit, List<Double> vector) {
169+
this(indexName, QueryPaths.of(QueryPath.path(path)), limit, vector);
170+
}
171+
172+
public VectorSearchOperation searchType(SearchType searchType) {
173+
return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score,
174+
scoreCriteria);
175+
}
176+
177+
public VectorSearchOperation filter(Document filter) {
178+
179+
return filter(new CriteriaDefinition() {
180+
@Override
181+
public Document getCriteriaObject() {
182+
return filter;
183+
}
184+
185+
@Nullable
186+
@Override
187+
public String getKey() {
188+
return null;
189+
}
190+
});
191+
}
192+
193+
public VectorSearchOperation filter(CriteriaDefinition filter) {
194+
return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score,
195+
scoreCriteria);
196+
}
197+
198+
public VectorSearchOperation numCandidates(int numCandidates) {
199+
return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, score,
200+
scoreCriteria);
201+
}
202+
203+
public VectorSearchOperation searchScore() {
204+
return searchScore("score");
205+
}
206+
207+
public VectorSearchOperation searchScore(String scoreFieldName) {
208+
return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector, scoreFieldName,
209+
scoreCriteria);
210+
}
211+
212+
public VectorSearchOperation filterBySore(Consumer<Criteria> score) {
213+
return new VectorSearchOperation(searchType, filter, indexName, limit, numCandidates, path, vector,
214+
StringUtils.hasText(this.score) ? this.score : "score", score);
215+
}
216+
217+
@Override
218+
public Document toDocument(AggregationOperationContext context) {
219+
220+
Document $vectorSearch = new Document();
221+
222+
$vectorSearch.append("index", indexName);
223+
$vectorSearch.append("path", path.getPathObject());
224+
$vectorSearch.append("queryVector", vector);
225+
$vectorSearch.append("limit", limit.max());
226+
227+
if (searchType != null && !searchType.equals(SearchType.DEFAULT)) {
228+
$vectorSearch.append("exact", searchType.equals(SearchType.ENN));
229+
}
230+
231+
if (filter != null) {
232+
$vectorSearch.append("filter", context.getMappedObject(filter.getCriteriaObject()));
233+
}
234+
235+
if (numCandidates != null) {
236+
$vectorSearch.append("numCandidates", numCandidates);
237+
}
238+
239+
return new Document(getOperator(), $vectorSearch);
240+
}
241+
242+
@Override
243+
public List<Document> toPipelineStages(AggregationOperationContext context) {
244+
245+
if (!StringUtils.hasText(score)) {
246+
return List.of(toDocument(context));
247+
}
248+
249+
AddFieldsOperation $vectorSearchScore = Aggregation.addFields().addField(score)
250+
.withValueOfExpression("{\"$meta\":\"vectorSearchScore\"}").build();
251+
252+
if (scoreCriteria == null) {
253+
return List.of(toDocument(context), $vectorSearchScore.toDocument(context));
254+
}
255+
256+
Criteria criteria = Criteria.where(score);
257+
scoreCriteria.accept(criteria);
258+
MatchOperation $filterByScore = Aggregation.match(criteria);
259+
260+
return List.of(toDocument(context), $vectorSearchScore.toDocument(context), $filterByScore.toDocument(context));
261+
}
262+
263+
@Override
264+
public String getOperator() {
265+
return "$vectorSearch";
266+
}
267+
268+
public static class VectorSearchBuilder implements PathContributor, VectorContributor, LimitContributor {
269+
270+
String index;
271+
QueryPaths paths;
272+
private List<Double> vector;
273+
274+
PathContributor index(String index) {
275+
this.index = index;
276+
return this;
277+
}
278+
279+
@Override
280+
public VectorContributor path(QueryPaths paths) {
281+
this.paths = paths;
282+
return this;
283+
}
284+
285+
@Override
286+
public VectorSearchOperation limit(Limit limit) {
287+
return new VectorSearchOperation(index, paths, limit, vector);
288+
}
289+
290+
@Override
291+
public LimitContributor vectors(List<Double> vectors) {
292+
this.vector = vectors;
293+
return this;
294+
}
295+
}
296+
297+
public interface PathContributor {
298+
default VectorContributor path(String path) {
299+
return path(QueryPaths.of(QueryPath.path(path)));
300+
}
301+
302+
VectorContributor path(QueryPaths paths);
303+
}
304+
305+
public interface VectorContributor {
306+
default LimitContributor vectors(Double... vectors) {
307+
return vectors(Arrays.asList(vectors));
308+
}
309+
310+
LimitContributor vectors(List<Double> vectors);
311+
}
312+
313+
public interface LimitContributor {
314+
default VectorSearchOperation limit(int limit) {
315+
return limit(Limit.of(limit));
316+
}
317+
318+
VectorSearchOperation limit(Limit limit);
319+
}
320+
321+
}

0 commit comments

Comments
 (0)