Skip to content

Commit b18d63d

Browse files
committed
batch embedding call; infer num_dimensions
1 parent 9387b74 commit b18d63d

File tree

3 files changed

+73
-100
lines changed

3 files changed

+73
-100
lines changed

elasticsearch/vectorstore/_async/strategies.py

Lines changed: 21 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
from elasticsearch import AsyncElasticsearch
2323
from elasticsearch.vectorstore._async._utils import model_must_be_deployed
24-
from elasticsearch.vectorstore._async.embedding_service import AsyncEmbeddingService
2524

2625

2726
class DistanceMetric(str, Enum):
@@ -63,7 +62,8 @@ async def create_index(
6362
self,
6463
client: AsyncElasticsearch,
6564
index_name: str,
66-
metadata_mapping: Optional[Dict[str, str]],
65+
num_dimensions: Optional[int] = None,
66+
metadata_mapping: Optional[Dict[str, str]] = None,
6767
) -> None:
6868
"""
6969
Create the required index and do necessary preliminary work, like
@@ -76,21 +76,11 @@ async def create_index(
7676
describe the schema of the metadata.
7777
"""
7878

79-
async def embed_for_indexing(self, text: str) -> Dict[str, Any]:
79+
def needs_inference(self) -> bool:
8080
"""
81-
If this strategy creates vector embeddings in Python (not in Elasticsearch),
82-
this method is used to apply the inference.
83-
The output is a dictionary with the vector field and the vector embedding.
84-
It is merged in the ElasticserachStore with the rest of the document (text data,
85-
metadata) before indexing.
86-
87-
Args:
88-
text: Text input that can be used as input for inference.
89-
90-
Returns:
91-
Dict: field and value pairs that extend the document to be indexed.
81+
TODO
9282
"""
93-
return {}
83+
return False
9484

9585

9686
# TODO test when repsective image is released
@@ -134,6 +124,7 @@ async def create_index(
134124
self,
135125
client: AsyncElasticsearch,
136126
index_name: str,
127+
num_dimensions: int,
137128
metadata_mapping: Optional[Dict[str, str]],
138129
) -> None:
139130
if self.model_id:
@@ -206,6 +197,7 @@ async def create_index(
206197
self,
207198
client: AsyncElasticsearch,
208199
index_name: str,
200+
num_dimensions: int,
209201
metadata_mapping: Optional[Dict[str, str]],
210202
) -> None:
211203
pipeline_name = f"{self.model_id}_sparse_embedding"
@@ -257,19 +249,11 @@ def __init__(
257249
knn_type: Literal["hnsw", "int8_hnsw", "flat", "int8_flat"] = "hnsw",
258250
vector_field: str = "vector_field",
259251
distance: DistanceMetric = DistanceMetric.COSINE,
260-
embedding_service: Optional[AsyncEmbeddingService] = None,
261252
model_id: Optional[str] = None,
262-
num_dimensions: Optional[int] = None,
263253
hybrid: bool = False,
264254
rrf: Union[bool, Dict[str, Any]] = True,
265255
text_field: Optional[str] = "text_field",
266256
):
267-
if embedding_service and model_id:
268-
raise ValueError("either specify embedding_service or model_id, not both")
269-
if model_id and not num_dimensions:
270-
raise ValueError(
271-
"if model_id is specified, num_dimensions must also be specified"
272-
)
273257
if hybrid and not text_field:
274258
raise ValueError(
275259
"to enable hybrid you have to specify a text_field (for BM25 matching)"
@@ -278,9 +262,7 @@ def __init__(
278262
self.knn_type = knn_type
279263
self.vector_field = vector_field
280264
self.distance = distance
281-
self.embedding_service = embedding_service
282265
self.model_id = model_id
283-
self.num_dimensions = num_dimensions
284266
self.hybrid = hybrid
285267
self.rrf = rrf
286268
self.text_field = text_field
@@ -302,10 +284,6 @@ async def es_query(
302284

303285
if query_vector:
304286
knn["query_vector"] = query_vector
305-
elif self.embedding_service:
306-
knn["query_vector"] = await self.embedding_service.embed_query(
307-
cast(str, query)
308-
)
309287
else:
310288
# Inference in Elasticsearch. When initializing we make sure to always have
311289
# a model_id if don't have an embedding_service.
@@ -325,13 +303,9 @@ async def create_index(
325303
self,
326304
client: AsyncElasticsearch,
327305
index_name: str,
306+
num_dimensions: int,
328307
metadata_mapping: Optional[Dict[str, str]],
329308
) -> None:
330-
if self.embedding_service and not self.num_dimensions:
331-
self.num_dimensions = len(
332-
await self.embedding_service.embed_query("get number of dimensions")
333-
)
334-
335309
if self.model_id:
336310
await model_must_be_deployed(client, self.model_id)
337311

@@ -350,7 +324,7 @@ async def create_index(
350324
"properties": {
351325
self.vector_field: {
352326
"type": "dense_vector",
353-
"dims": self.num_dimensions,
327+
"dims": num_dimensions,
354328
"index": True,
355329
"similarity": similarityAlgo,
356330
},
@@ -362,12 +336,6 @@ async def create_index(
362336
r = await client.indices.create(index=index_name, mappings=mappings)
363337
print(r)
364338

365-
async def embed_for_indexing(self, text: str) -> Dict[str, Any]:
366-
if self.embedding_service:
367-
vector = await self.embedding_service.embed_query(text)
368-
return {self.vector_field: vector}
369-
return {}
370-
371339
def _hybrid(
372340
self, query: str, knn: Dict[str, Any], filter: List[Dict[str, Any]]
373341
) -> Dict[str, Any]:
@@ -393,28 +361,27 @@ def _hybrid(
393361
},
394362
}
395363

396-
if isinstance(self.rrf, Dict[str, Any]):
364+
if isinstance(self.rrf, Dict):
397365
query_body["rank"] = {"rrf": self.rrf}
398366
elif isinstance(self.rrf, bool) and self.rrf is True:
399367
query_body["rank"] = {"rrf": {}}
400368

401369
return query_body
402370

371+
def needs_inference(self) -> bool:
372+
return not self.model_id
373+
403374

404375
class DenseVectorScriptScore(RetrievalStrategy):
405376
"""Exact nearest neighbors retrieval using the `script_score` query."""
406377

407378
def __init__(
408379
self,
409-
embedding_service: AsyncEmbeddingService,
410380
vector_field: str = "vector_field",
411381
distance: DistanceMetric = DistanceMetric.COSINE,
412-
num_dimensions: Optional[int] = None,
413382
) -> None:
414383
self.vector_field = vector_field
415384
self.distance = distance
416-
self.embedding_service = embedding_service
417-
self.num_dimensions = num_dimensions
418385

419386
async def es_query(
420387
self,
@@ -424,6 +391,9 @@ async def es_query(
424391
filter: List[Dict[str, Any]] = [],
425392
query_vector: Optional[List[float]] = None,
426393
) -> Dict[str, Any]:
394+
if not query_vector:
395+
raise ValueError("specify a query_vector")
396+
427397
if self.distance is DistanceMetric.COSINE:
428398
similarityAlgo = (
429399
f"cosineSimilarity(params.query_vector, '{self.vector_field}') + 1.0"
@@ -452,16 +422,6 @@ async def es_query(
452422
if filter:
453423
queryBool = {"bool": {"filter": filter}}
454424

455-
if not query_vector:
456-
if not self.embedding_service:
457-
raise ValueError(
458-
"if not embedding_service is given, you need to "
459-
"procive a query_vector"
460-
)
461-
if not query:
462-
raise ValueError("either specify a query string or a query_vector")
463-
query_vector = await self.embedding_service.embed_query(query)
464-
465425
return {
466426
"query": {
467427
"script_score": {
@@ -478,18 +438,14 @@ async def create_index(
478438
self,
479439
client: AsyncElasticsearch,
480440
index_name: str,
441+
num_dimensions: int,
481442
metadata_mapping: Optional[Dict[str, str]],
482443
) -> None:
483-
if not self.num_dimensions:
484-
self.num_dimensions = len(
485-
await self.embedding_service.embed_query("get number of dimensions")
486-
)
487-
488444
mappings = {
489445
"properties": {
490446
self.vector_field: {
491447
"type": "dense_vector",
492-
"dims": self.num_dimensions,
448+
"dims": num_dimensions,
493449
"index": False,
494450
}
495451
}
@@ -499,10 +455,8 @@ async def create_index(
499455

500456
await client.indices.create(index=index_name, mappings=mappings)
501457

502-
return None
503-
504-
async def embed_for_indexing(self, text: str) -> Dict[str, Any]:
505-
return {self.vector_field: await self.embedding_service.embed_query(text)}
458+
def needs_inference(self) -> bool:
459+
return True
506460

507461

508462
class BM25(RetrievalStrategy):
@@ -545,6 +499,7 @@ async def create_index(
545499
self,
546500
client: AsyncElasticsearch,
547501
index_name: str,
502+
num_dimensions: int,
548503
metadata_mapping: Optional[Dict[str, str]],
549504
) -> None:
550505
similarity_name = "custom_bm25"

elasticsearch/vectorstore/_async/vectorestore.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def __init__(
4444
user_agent: str,
4545
index_name: str,
4646
retrieval_strategy: RetrievalStrategy,
47+
embedding_service: Optional[AsyncEmbeddingService] = None,
48+
num_dimensions: Optional[int] = None,
4749
text_field: str = "text_field",
4850
vector_field: str = "vector_field",
4951
metadata_mapping: Optional[Dict[str, str]] = None,
@@ -61,7 +63,6 @@ def __init__(
6163
es_client: Elasticsearch client connection. Alternatively specify the
6264
Elasticsearch connection with the other es_* parameters.
6365
"""
64-
6566
# Add integration-specific usage header for tracking usage in Elastic Cloud.
6667
# client.options preserces existing (non-user-agent) headers.
6768
es_client = es_client.options(headers={"User-Agent": user_agent})
@@ -74,6 +75,8 @@ def __init__(
7475
self.es_client = es_client
7576
self.index_name = index_name
7677
self.retrieval_strategy = retrieval_strategy
78+
self.embedding_service = embedding_service
79+
self.num_dimensions = num_dimensions
7780
self.text_field = text_field
7881
self.vector_field = vector_field
7982
self.metadata_mapping = metadata_mapping
@@ -118,6 +121,9 @@ async def add_texts(
118121
if create_index_if_not_exists:
119122
await self._create_index_if_not_exists()
120123

124+
if self.embedding_service and not vectors:
125+
vectors = await self.embedding_service.embed_documents(texts)
126+
121127
for i, text in enumerate(texts):
122128
metadata = metadatas[i] if metadatas else {}
123129

@@ -132,7 +138,6 @@ async def add_texts(
132138
if vectors:
133139
request[self.vector_field] = vectors[i]
134140

135-
request.update(await self.retrieval_strategy.embed_for_indexing(text))
136141
requests.append(request)
137142

138143
if len(requests) > 0:
@@ -240,6 +245,11 @@ async def search(
240245
if self.text_field not in fields:
241246
fields.append(self.text_field)
242247

248+
if self.embedding_service and not query_vector:
249+
if not query:
250+
raise ValueError("specify a query or a query_vector to search")
251+
query_vector = await self.embedding_service.embed_query(query)
252+
243253
query_body = await self.retrieval_strategy.es_query(
244254
query=query,
245255
k=k,
@@ -267,9 +277,22 @@ async def _create_index_if_not_exists(self) -> None:
267277
if exists.meta.status == 200:
268278
logger.debug(f"Index {self.index_name} already exists. Skipping creation.")
269279
else:
280+
if self.retrieval_strategy.needs_inference():
281+
if not self.num_dimensions and not self.embedding_service:
282+
raise ValueError(
283+
"retrieval strategy requires embeddings; either embedding_service "
284+
"or num_dimensions need to be specified"
285+
)
286+
if not self.num_dimensions and self.embedding_service:
287+
vector = await self.embedding_service.embed_query(
288+
"get num dimensions"
289+
)
290+
self.num_dimensions = len(vector)
291+
270292
await self.retrieval_strategy.create_index(
271293
client=self.es_client,
272294
index_name=self.index_name,
295+
num_dimensions=self.num_dimensions,
273296
metadata_mapping=self.metadata_mapping,
274297
)
275298

0 commit comments

Comments
 (0)