Skip to content

Commit 881d56c

Browse files
committed
address PR feedback
1 parent 6f81af9 commit 881d56c

File tree

16 files changed

+121
-151
lines changed

16 files changed

+121
-151
lines changed

elasticsearch/helpers/vectorstore/__init__.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,21 +42,21 @@
4242
from elasticsearch.helpers.vectorstore._utils import DistanceMetric
4343

4444
__all__ = [
45-
"BM25Strategy",
46-
"DenseVectorStrategy",
47-
"DenseVectorScriptScoreStrategy",
48-
"ElasticsearchEmbeddings",
49-
"EmbeddingService",
50-
"RetrievalStrategy",
51-
"SparseVectorStrategy",
52-
"VectorStore",
5345
"AsyncBM25Strategy",
54-
"AsyncDenseVectorStrategy",
5546
"AsyncDenseVectorScriptScoreStrategy",
47+
"AsyncDenseVectorStrategy",
5648
"AsyncElasticsearchEmbeddings",
5749
"AsyncEmbeddingService",
5850
"AsyncRetrievalStrategy",
5951
"AsyncSparseVectorStrategy",
6052
"AsyncVectorStore",
53+
"BM25Strategy",
54+
"DenseVectorScriptScoreStrategy",
55+
"DenseVectorStrategy",
6156
"DistanceMetric",
57+
"ElasticsearchEmbeddings",
58+
"EmbeddingService",
59+
"RetrievalStrategy",
60+
"SparseVectorStrategy",
61+
"VectorStore",
6262
]

elasticsearch/helpers/vectorstore/_async/embedding_service.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ def __init__(
7575
self.input_field = input_field
7676

7777
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
78-
result = await self._embedding_func(texts)
79-
return result
78+
return await self._embedding_func(texts)
8079

8180
async def embed_query(self, text: str) -> List[float]:
8281
result = await self._embedding_func([text])

elasticsearch/helpers/vectorstore/_async/strategies.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,13 @@ def es_mappings_settings(
241241
num_dimensions: Optional[int],
242242
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
243243
if self.distance is DistanceMetric.COSINE:
244-
similarityAlgo = "cosine"
244+
similarity = "cosine"
245245
elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE:
246-
similarityAlgo = "l2_norm"
246+
similarity = "l2_norm"
247247
elif self.distance is DistanceMetric.DOT_PRODUCT:
248-
similarityAlgo = "dot_product"
248+
similarity = "dot_product"
249249
elif self.distance is DistanceMetric.MAX_INNER_PRODUCT:
250-
similarityAlgo = "max_inner_product"
250+
similarity = "max_inner_product"
251251
else:
252252
raise ValueError(f"Similarity {self.distance} not supported.")
253253

@@ -257,7 +257,7 @@ def es_mappings_settings(
257257
"type": "dense_vector",
258258
"dims": num_dimensions,
259259
"index": True,
260-
"similarity": similarityAlgo,
260+
"similarity": similarity,
261261
},
262262
}
263263
}
@@ -326,18 +326,18 @@ def es_query(
326326
raise ValueError("specify a query_vector")
327327

328328
if self.distance is DistanceMetric.COSINE:
329-
similarityAlgo = (
329+
similarity_algo = (
330330
f"cosineSimilarity(params.query_vector, '{vector_field}') + 1.0"
331331
)
332332
elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE:
333-
similarityAlgo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))"
333+
similarity_algo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))"
334334
elif self.distance is DistanceMetric.DOT_PRODUCT:
335-
similarityAlgo = f"""
335+
similarity_algo = f"""
336336
double value = dotProduct(params.query_vector, '{vector_field}');
337337
return sigmoid(1, Math.E, -value);
338338
"""
339339
elif self.distance is DistanceMetric.MAX_INNER_PRODUCT:
340-
similarityAlgo = f"""
340+
similarity_algo = f"""
341341
double value = dotProduct(params.query_vector, '{vector_field}');
342342
if (dotProduct < 0) {{
343343
return 1 / (1 + -1 * dotProduct);
@@ -347,16 +347,16 @@ def es_query(
347347
else:
348348
raise ValueError(f"Similarity {self.distance} not supported.")
349349

350-
queryBool: Dict[str, Any] = {"match_all": {}}
350+
query_bool: Dict[str, Any] = {"match_all": {}}
351351
if filter:
352-
queryBool = {"bool": {"filter": filter}}
352+
query_bool = {"bool": {"filter": filter}}
353353

354354
return {
355355
"query": {
356356
"script_score": {
357-
"query": queryBool,
357+
"query": query_bool,
358358
"script": {
359-
"source": similarityAlgo,
359+
"source": similarity_algo,
360360
"params": {"query_vector": query_vector},
361361
},
362362
},

elasticsearch/helpers/vectorstore/_async/vectorstore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
from elasticsearch import AsyncElasticsearch
2323
from elasticsearch._version import __versionstr__ as lib_version
2424
from elasticsearch.helpers import BulkIndexError, async_bulk
25-
from elasticsearch.helpers.vectorstore._async.embedding_service import (
25+
from elasticsearch.helpers.vectorstore import (
2626
AsyncEmbeddingService,
27+
AsyncRetrievalStrategy,
2728
)
28-
from elasticsearch.helpers.vectorstore._async.strategies import AsyncRetrievalStrategy
2929
from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance
3030

3131
logger = logging.getLogger(__name__)

elasticsearch/helpers/vectorstore/_sync/embedding_service.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,7 @@ def __init__(
7575
self.input_field = input_field
7676

7777
def embed_documents(self, texts: List[str]) -> List[List[float]]:
78-
result = self._embedding_func(texts)
79-
return result
78+
return self._embedding_func(texts)
8079

8180
def embed_query(self, text: str) -> List[float]:
8281
result = self._embedding_func([text])

elasticsearch/helpers/vectorstore/_sync/strategies.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,13 @@ def es_mappings_settings(
241241
num_dimensions: Optional[int],
242242
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
243243
if self.distance is DistanceMetric.COSINE:
244-
similarityAlgo = "cosine"
244+
similarity = "cosine"
245245
elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE:
246-
similarityAlgo = "l2_norm"
246+
similarity = "l2_norm"
247247
elif self.distance is DistanceMetric.DOT_PRODUCT:
248-
similarityAlgo = "dot_product"
248+
similarity = "dot_product"
249249
elif self.distance is DistanceMetric.MAX_INNER_PRODUCT:
250-
similarityAlgo = "max_inner_product"
250+
similarity = "max_inner_product"
251251
else:
252252
raise ValueError(f"Similarity {self.distance} not supported.")
253253

@@ -257,7 +257,7 @@ def es_mappings_settings(
257257
"type": "dense_vector",
258258
"dims": num_dimensions,
259259
"index": True,
260-
"similarity": similarityAlgo,
260+
"similarity": similarity,
261261
},
262262
}
263263
}
@@ -326,18 +326,18 @@ def es_query(
326326
raise ValueError("specify a query_vector")
327327

328328
if self.distance is DistanceMetric.COSINE:
329-
similarityAlgo = (
329+
similarity_algo = (
330330
f"cosineSimilarity(params.query_vector, '{vector_field}') + 1.0"
331331
)
332332
elif self.distance is DistanceMetric.EUCLIDEAN_DISTANCE:
333-
similarityAlgo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))"
333+
similarity_algo = f"1 / (1 + l2norm(params.query_vector, '{vector_field}'))"
334334
elif self.distance is DistanceMetric.DOT_PRODUCT:
335-
similarityAlgo = f"""
335+
similarity_algo = f"""
336336
double value = dotProduct(params.query_vector, '{vector_field}');
337337
return sigmoid(1, Math.E, -value);
338338
"""
339339
elif self.distance is DistanceMetric.MAX_INNER_PRODUCT:
340-
similarityAlgo = f"""
340+
similarity_algo = f"""
341341
double value = dotProduct(params.query_vector, '{vector_field}');
342342
if (dotProduct < 0) {{
343343
return 1 / (1 + -1 * dotProduct);
@@ -347,16 +347,16 @@ def es_query(
347347
else:
348348
raise ValueError(f"Similarity {self.distance} not supported.")
349349

350-
queryBool: Dict[str, Any] = {"match_all": {}}
350+
query_bool: Dict[str, Any] = {"match_all": {}}
351351
if filter:
352-
queryBool = {"bool": {"filter": filter}}
352+
query_bool = {"bool": {"filter": filter}}
353353

354354
return {
355355
"query": {
356356
"script_score": {
357-
"query": queryBool,
357+
"query": query_bool,
358358
"script": {
359-
"source": similarityAlgo,
359+
"source": similarity_algo,
360360
"params": {"query_vector": query_vector},
361361
},
362362
},

elasticsearch/helpers/vectorstore/_sync/vectorstore.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@
2222
from elasticsearch import Elasticsearch
2323
from elasticsearch._version import __versionstr__ as lib_version
2424
from elasticsearch.helpers import BulkIndexError, bulk
25-
from elasticsearch.helpers.vectorstore._sync.embedding_service import EmbeddingService
26-
from elasticsearch.helpers.vectorstore._sync.strategies import RetrievalStrategy
25+
from elasticsearch.helpers.vectorstore import EmbeddingService, RetrievalStrategy
2726
from elasticsearch.helpers.vectorstore._utils import maximal_marginal_relevance
2827

2928
logger = logging.getLogger(__name__)

elasticsearch/helpers/vectorstore/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,5 +112,5 @@ def _raise_missing_mmr_deps_error(parent_error: ModuleNotFoundError) -> None:
112112
raise ModuleNotFoundError(
113113
f"Failed to compute maximal marginal relevance because the required "
114114
f"module '{parent_error.name}' is missing. You can install it by running: "
115-
f"'{sys.executable} -m pip install elasticsearch[mmr]'"
115+
f"'{sys.executable} -m pip install elasticsearch[vectorstore_mmr]'"
116116
) from parent_error

noxfile.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def pytest_argv():
4848

4949
@nox.session(python=["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"])
5050
def test(session):
51-
session.install(".[async,requests,orjson,mmr]", env=INSTALL_ENV, silent=False)
51+
session.install(
52+
".[async,requests,orjson,vectorstore_mmr]", env=INSTALL_ENV, silent=False
53+
)
5254
session.install("-r", "dev-requirements.txt", silent=False)
5355

5456
session.run(*pytest_argv())
@@ -95,7 +97,7 @@ def lint(session):
9597
session.run("flake8", *SOURCE_FILES)
9698
session.run("python", "utils/license-headers.py", "check", *SOURCE_FILES)
9799

98-
session.install(".[async,requests,orjson,mmr]", env=INSTALL_ENV)
100+
session.install(".[async,requests,orjson,vectorstore_mmr]", env=INSTALL_ENV)
99101

100102
# Run mypy on the package and then the type examples separately for
101103
# the two different mypy use-cases, ourselves and our users.

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
"requests": ["requests>=2.4.0, <3.0.0"],
9393
"async": ["aiohttp>=3,<4"],
9494
"orjson": ["orjson>=3"],
95-
"mmr": ["numpy>=1", "simsimd>=3"],
95+
# Maximal Marginal Relevance (MMR) for search results
96+
"vectorstore_mmr": ["numpy>=1", "simsimd>=3"],
9697
},
9798
)

test_elasticsearch/test_server/test_helpers_vectorstore/__init__.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,68 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
18+
from typing import List
19+
20+
from elastic_transport import Transport
21+
22+
from elasticsearch.helpers.vectorstore import EmbeddingService
23+
24+
25+
class RequestSavingTransport(Transport):
26+
def __init__(self, *args, **kwargs) -> None:
27+
super().__init__(*args, **kwargs)
28+
self.requests: list = []
29+
30+
def perform_request(self, *args, **kwargs):
31+
self.requests.append(kwargs)
32+
return super().perform_request(*args, **kwargs)
33+
34+
35+
class FakeEmbeddings(EmbeddingService):
36+
"""Fake embeddings functionality for testing."""
37+
38+
def __init__(self, dimensionality: int = 10) -> None:
39+
self.dimensionality = dimensionality
40+
41+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
42+
"""Return simple embeddings. Embeddings encode each text as its index."""
43+
return [
44+
[float(1.0)] * (self.dimensionality - 1) + [float(i)]
45+
for i in range(len(texts))
46+
]
47+
48+
def embed_query(self, text: str) -> List[float]:
49+
"""Return constant query embeddings.
50+
Embeddings are identical to embed_documents(texts)[0].
51+
Distance to each text will be that text's index,
52+
as it was passed to embed_documents.
53+
"""
54+
return [float(1.0)] * (self.dimensionality - 1) + [float(0.0)]
55+
56+
57+
class ConsistentFakeEmbeddings(FakeEmbeddings):
58+
"""Fake embeddings which remember all the texts seen so far to return consistent
59+
vectors for the same texts."""
60+
61+
def __init__(self, dimensionality: int = 10) -> None:
62+
self.known_texts: List[str] = []
63+
self.dimensionality = dimensionality
64+
65+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
66+
"""Return consistent embeddings for each text seen so far."""
67+
out_vectors = []
68+
for text in texts:
69+
if text not in self.known_texts:
70+
self.known_texts.append(text)
71+
vector = [float(1.0)] * (self.dimensionality - 1) + [
72+
float(self.known_texts.index(text))
73+
]
74+
out_vectors.append(vector)
75+
return out_vectors
76+
77+
def embed_query(self, text: str) -> List[float]:
78+
"""Return consistent embeddings for the text, if seen before, or a constant
79+
one if the text is unknown."""
80+
result = self.embed_documents([text])
81+
return result[0]

0 commit comments

Comments
 (0)