21
21
22
22
from elasticsearch import AsyncElasticsearch
23
23
from elasticsearch .vectorstore ._async ._utils import model_must_be_deployed
24
- from elasticsearch .vectorstore ._async .embedding_service import AsyncEmbeddingService
25
24
26
25
27
26
class DistanceMetric (str , Enum ):
@@ -63,7 +62,8 @@ async def create_index(
63
62
self ,
64
63
client : AsyncElasticsearch ,
65
64
index_name : str ,
66
- metadata_mapping : Optional [Dict [str , str ]],
65
+ num_dimensions : Optional [int ] = None ,
66
+ metadata_mapping : Optional [Dict [str , str ]] = None ,
67
67
) -> None :
68
68
"""
69
69
Create the required index and do necessary preliminary work, like
@@ -76,21 +76,11 @@ async def create_index(
76
76
describe the schema of the metadata.
77
77
"""
78
78
79
- async def embed_for_indexing (self , text : str ) -> Dict [ str , Any ] :
79
+ def needs_inference (self ) -> bool :
80
80
"""
81
- If this strategy creates vector embeddings in Python (not in Elasticsearch),
82
- this method is used to apply the inference.
83
- The output is a dictionary with the vector field and the vector embedding.
84
- It is merged in the ElasticserachStore with the rest of the document (text data,
85
- metadata) before indexing.
86
-
87
- Args:
88
- text: Text input that can be used as input for inference.
89
-
90
- Returns:
91
- Dict: field and value pairs that extend the document to be indexed.
81
+ TODO
92
82
"""
93
- return {}
83
+ return False
94
84
95
85
96
86
# TODO test when repsective image is released
@@ -134,6 +124,7 @@ async def create_index(
134
124
self ,
135
125
client : AsyncElasticsearch ,
136
126
index_name : str ,
127
+ num_dimensions : int ,
137
128
metadata_mapping : Optional [Dict [str , str ]],
138
129
) -> None :
139
130
if self .model_id :
@@ -206,6 +197,7 @@ async def create_index(
206
197
self ,
207
198
client : AsyncElasticsearch ,
208
199
index_name : str ,
200
+ num_dimensions : int ,
209
201
metadata_mapping : Optional [Dict [str , str ]],
210
202
) -> None :
211
203
pipeline_name = f"{ self .model_id } _sparse_embedding"
@@ -257,19 +249,11 @@ def __init__(
257
249
knn_type : Literal ["hnsw" , "int8_hnsw" , "flat" , "int8_flat" ] = "hnsw" ,
258
250
vector_field : str = "vector_field" ,
259
251
distance : DistanceMetric = DistanceMetric .COSINE ,
260
- embedding_service : Optional [AsyncEmbeddingService ] = None ,
261
252
model_id : Optional [str ] = None ,
262
- num_dimensions : Optional [int ] = None ,
263
253
hybrid : bool = False ,
264
254
rrf : Union [bool , Dict [str , Any ]] = True ,
265
255
text_field : Optional [str ] = "text_field" ,
266
256
):
267
- if embedding_service and model_id :
268
- raise ValueError ("either specify embedding_service or model_id, not both" )
269
- if model_id and not num_dimensions :
270
- raise ValueError (
271
- "if model_id is specified, num_dimensions must also be specified"
272
- )
273
257
if hybrid and not text_field :
274
258
raise ValueError (
275
259
"to enable hybrid you have to specify a text_field (for BM25 matching)"
@@ -278,9 +262,7 @@ def __init__(
278
262
self .knn_type = knn_type
279
263
self .vector_field = vector_field
280
264
self .distance = distance
281
- self .embedding_service = embedding_service
282
265
self .model_id = model_id
283
- self .num_dimensions = num_dimensions
284
266
self .hybrid = hybrid
285
267
self .rrf = rrf
286
268
self .text_field = text_field
@@ -302,10 +284,6 @@ async def es_query(
302
284
303
285
if query_vector :
304
286
knn ["query_vector" ] = query_vector
305
- elif self .embedding_service :
306
- knn ["query_vector" ] = await self .embedding_service .embed_query (
307
- cast (str , query )
308
- )
309
287
else :
310
288
# Inference in Elasticsearch. When initializing we make sure to always have
311
289
# a model_id if don't have an embedding_service.
@@ -325,13 +303,9 @@ async def create_index(
325
303
self ,
326
304
client : AsyncElasticsearch ,
327
305
index_name : str ,
306
+ num_dimensions : int ,
328
307
metadata_mapping : Optional [Dict [str , str ]],
329
308
) -> None :
330
- if self .embedding_service and not self .num_dimensions :
331
- self .num_dimensions = len (
332
- await self .embedding_service .embed_query ("get number of dimensions" )
333
- )
334
-
335
309
if self .model_id :
336
310
await model_must_be_deployed (client , self .model_id )
337
311
@@ -350,7 +324,7 @@ async def create_index(
350
324
"properties" : {
351
325
self .vector_field : {
352
326
"type" : "dense_vector" ,
353
- "dims" : self . num_dimensions ,
327
+ "dims" : num_dimensions ,
354
328
"index" : True ,
355
329
"similarity" : similarityAlgo ,
356
330
},
@@ -362,12 +336,6 @@ async def create_index(
362
336
r = await client .indices .create (index = index_name , mappings = mappings )
363
337
print (r )
364
338
365
- async def embed_for_indexing (self , text : str ) -> Dict [str , Any ]:
366
- if self .embedding_service :
367
- vector = await self .embedding_service .embed_query (text )
368
- return {self .vector_field : vector }
369
- return {}
370
-
371
339
def _hybrid (
372
340
self , query : str , knn : Dict [str , Any ], filter : List [Dict [str , Any ]]
373
341
) -> Dict [str , Any ]:
@@ -393,28 +361,27 @@ def _hybrid(
393
361
},
394
362
}
395
363
396
- if isinstance (self .rrf , Dict [ str , Any ] ):
364
+ if isinstance (self .rrf , Dict ):
397
365
query_body ["rank" ] = {"rrf" : self .rrf }
398
366
elif isinstance (self .rrf , bool ) and self .rrf is True :
399
367
query_body ["rank" ] = {"rrf" : {}}
400
368
401
369
return query_body
402
370
371
+ def needs_inference (self ) -> bool :
372
+ return not self .model_id
373
+
403
374
404
375
class DenseVectorScriptScore (RetrievalStrategy ):
405
376
"""Exact nearest neighbors retrieval using the `script_score` query."""
406
377
407
378
def __init__ (
408
379
self ,
409
- embedding_service : AsyncEmbeddingService ,
410
380
vector_field : str = "vector_field" ,
411
381
distance : DistanceMetric = DistanceMetric .COSINE ,
412
- num_dimensions : Optional [int ] = None ,
413
382
) -> None :
414
383
self .vector_field = vector_field
415
384
self .distance = distance
416
- self .embedding_service = embedding_service
417
- self .num_dimensions = num_dimensions
418
385
419
386
async def es_query (
420
387
self ,
@@ -424,6 +391,9 @@ async def es_query(
424
391
filter : List [Dict [str , Any ]] = [],
425
392
query_vector : Optional [List [float ]] = None ,
426
393
) -> Dict [str , Any ]:
394
+ if not query_vector :
395
+ raise ValueError ("specify a query_vector" )
396
+
427
397
if self .distance is DistanceMetric .COSINE :
428
398
similarityAlgo = (
429
399
f"cosineSimilarity(params.query_vector, '{ self .vector_field } ') + 1.0"
@@ -452,16 +422,6 @@ async def es_query(
452
422
if filter :
453
423
queryBool = {"bool" : {"filter" : filter }}
454
424
455
- if not query_vector :
456
- if not self .embedding_service :
457
- raise ValueError (
458
- "if not embedding_service is given, you need to "
459
- "procive a query_vector"
460
- )
461
- if not query :
462
- raise ValueError ("either specify a query string or a query_vector" )
463
- query_vector = await self .embedding_service .embed_query (query )
464
-
465
425
return {
466
426
"query" : {
467
427
"script_score" : {
@@ -478,18 +438,14 @@ async def create_index(
478
438
self ,
479
439
client : AsyncElasticsearch ,
480
440
index_name : str ,
441
+ num_dimensions : int ,
481
442
metadata_mapping : Optional [Dict [str , str ]],
482
443
) -> None :
483
- if not self .num_dimensions :
484
- self .num_dimensions = len (
485
- await self .embedding_service .embed_query ("get number of dimensions" )
486
- )
487
-
488
444
mappings = {
489
445
"properties" : {
490
446
self .vector_field : {
491
447
"type" : "dense_vector" ,
492
- "dims" : self . num_dimensions ,
448
+ "dims" : num_dimensions ,
493
449
"index" : False ,
494
450
}
495
451
}
@@ -499,10 +455,8 @@ async def create_index(
499
455
500
456
await client .indices .create (index = index_name , mappings = mappings )
501
457
502
- return None
503
-
504
- async def embed_for_indexing (self , text : str ) -> Dict [str , Any ]:
505
- return {self .vector_field : await self .embedding_service .embed_query (text )}
458
+ def needs_inference (self ) -> bool :
459
+ return True
506
460
507
461
508
462
class BM25 (RetrievalStrategy ):
@@ -545,6 +499,7 @@ async def create_index(
545
499
self ,
546
500
client : AsyncElasticsearch ,
547
501
index_name : str ,
502
+ num_dimensions : int ,
548
503
metadata_mapping : Optional [Dict [str , str ]],
549
504
) -> None :
550
505
similarity_name = "custom_bm25"
0 commit comments