Skip to content

Commit 7ee3846

Browse files
committed
add cleanup step for _sync generation
1 parent 9be44fd commit 7ee3846

File tree

9 files changed

+57
-86
lines changed

9 files changed

+57
-86
lines changed

elasticsearch/vectorstore/_sync/_utils.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
1-
from elasticsearch import (
2-
Elasticsearch,
3-
BadRequestError,
4-
ConflictError,
5-
NotFoundError,
6-
)
1+
from elasticsearch import BadRequestError, ConflictError, Elasticsearch, NotFoundError
72

83

94
def model_must_be_deployed(client: Elasticsearch, model_id: str) -> None:

elasticsearch/vectorstore/_sync/strategies.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Any, Dict, List, Literal, Optional, Union, cast
44

55
from elasticsearch import Elasticsearch
6-
76
from elasticsearch.vectorstore._sync._utils import model_must_be_deployed
87
from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService
98

@@ -226,9 +225,7 @@ def create_index(
226225
mappings["properties"]["metadata"] = {"properties": metadata_mapping}
227226
settings = {"default_pipeline": pipeline_name}
228227

229-
client.indices.create(
230-
index=index_name, mappings=mappings, settings=settings
231-
)
228+
client.indices.create(index=index_name, mappings=mappings, settings=settings)
232229

233230
return None
234231

@@ -287,9 +284,7 @@ def es_query(
287284
if query_vector:
288285
knn["query_vector"] = query_vector
289286
elif self.embedding_service:
290-
knn["query_vector"] = self.embedding_service.embed_query(
291-
cast(str, query)
292-
)
287+
knn["query_vector"] = self.embedding_service.embed_query(cast(str, query))
293288
else:
294289
# Inference in Elasticsearch. When initializing we make sure to always have
295290
# a model_id if don't have an embedding_service.
@@ -555,6 +550,4 @@ def create_index(
555550
}
556551
}
557552

558-
client.indices.create(
559-
index=index_name, mappings=mappings, settings=settings
560-
)
553+
client.indices.create(index=index_name, mappings=mappings, settings=settings)

elasticsearch/vectorstore/_sync/vectorestore.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,9 @@
44

55
from elasticsearch import Elasticsearch
66
from elasticsearch.helpers import BulkIndexError, bulk
7-
8-
from elasticsearch.vectorstore._utils import (
9-
maximal_marginal_relevance,
10-
)
117
from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService
128
from elasticsearch.vectorstore._sync.strategies import RetrievalStrategy
9+
from elasticsearch.vectorstore._utils import maximal_marginal_relevance
1310

1411
logger = logging.getLogger(__name__)
1512

test_elasticsearch/test_server/test_vectorstore/_async/test_embedding_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384"))
2424

2525

26-
@pytest_asyncio.fixture(autouse=True)
26+
@pytest_asyncio.fixture
2727
async def es_client() -> AsyncIterator[AsyncElasticsearch]:
2828
async for x in es_client_fixture():
2929
yield x

test_elasticsearch/test_server/test_vectorstore/_async/test_vectorestore.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@
5353

5454

5555
class TestElasticsearch:
56-
@pytest_asyncio.fixture(autouse=True)
56+
@pytest_asyncio.fixture
5757
async def es_client(self) -> AsyncIterator[AsyncElasticsearch]:
5858
async for x in es_client_fixture():
5959
yield x
6060

61-
@pytest_asyncio.fixture(autouse=True)
61+
@pytest_asyncio.fixture
6262
async def requests_saving_client(self) -> AsyncIterator[AsyncElasticsearch]:
6363
client = create_requests_saving_client()
6464
try:

test_elasticsearch/test_server/test_vectorstore/_sync/_test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import os
2-
from typing import Any, Dict, List, Optional, Iterator
2+
from typing import Any, Dict, Iterator, List, Optional
33

44
from elastic_transport import Transport
5-
from elasticsearch import Elasticsearch
65

6+
from elasticsearch import Elasticsearch
77
from elasticsearch.vectorstore._sync.embedding_service import EmbeddingService
88

99

test_elasticsearch/test_server/test_vectorstore/_sync/test_embedding_service.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,26 @@
11
import os
2+
from typing import Iterator
23

34
import pytest
45

5-
import pytest_asyncio
66
from elasticsearch import Elasticsearch
7-
8-
from typing import Iterator
9-
107
from elasticsearch.vectorstore._sync._utils import model_is_deployed
8+
from elasticsearch.vectorstore._sync.embedding_service import ElasticsearchEmbeddings
119

12-
from ._test_utils import (
13-
es_client_fixture,
14-
)
15-
16-
from elasticsearch.vectorstore._sync.embedding_service import (
17-
ElasticsearchEmbeddings,
18-
)
10+
from ._test_utils import es_client_fixture
1911

2012
# deployed with
2113
# https://www.elastic.co/guide/en/machine-learning/current/ml-nlp-text-emb-vector-search-example.html
2214
MODEL_ID = os.getenv("MODEL_ID", "sentence-transformers__msmarco-minilm-l-12-v3")
2315
NUM_DIMENSIONS = int(os.getenv("NUM_DIMENTIONS", "384"))
2416

2517

26-
@pytest_asyncio.fixture(autouse=True)
18+
@pytest.fixture
2719
def es_client() -> Iterator[Elasticsearch]:
2820
for x in es_client_fixture():
2921
yield x
3022

3123

32-
@pytest.mark.asyncio
3324
def test_elasticsearch_embedding_documents(es_client: Elasticsearch) -> None:
3425
"""Test Elasticsearch embedding documents."""
3526

@@ -47,7 +38,6 @@ def test_elasticsearch_embedding_documents(es_client: Elasticsearch) -> None:
4738
assert len(output[2]) == NUM_DIMENSIONS
4839

4940

50-
@pytest.mark.asyncio
5141
def test_elasticsearch_embedding_query(es_client: Elasticsearch) -> None:
5242
"""Test Elasticsearch embedding query."""
5343

test_elasticsearch/test_server/test_vectorstore/_sync/test_vectorestore.py

Lines changed: 10 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import logging
22
import uuid
3-
from typing import Iterator
4-
from typing import Any, List, Optional, Union, cast
53
from functools import partial
4+
from typing import Any, Iterator, List, Optional, Union, cast
65

76
import pytest
8-
import pytest_asyncio
9-
from elasticsearch import Elasticsearch
107

11-
from elasticsearch import NotFoundError
8+
from elasticsearch import Elasticsearch, NotFoundError
129
from elasticsearch.helpers import BulkIndexError
13-
1410
from elasticsearch.vectorstore._sync import VectorStore
1511
from elasticsearch.vectorstore._sync._utils import model_is_deployed
1612
from elasticsearch.vectorstore._sync.strategies import (
@@ -22,11 +18,11 @@
2218
)
2319

2420
from ._test_utils import (
25-
create_requests_saving_client,
26-
es_client_fixture,
2721
ConsistentFakeEmbeddings,
2822
FakeEmbeddings,
2923
RequestSavingTransport,
24+
create_requests_saving_client,
25+
es_client_fixture,
3026
)
3127

3228
logging.basicConfig(level=logging.DEBUG)
@@ -53,12 +49,12 @@
5349

5450

5551
class TestElasticsearch:
56-
@pytest_asyncio.fixture(autouse=True)
52+
@pytest.fixture
5753
def es_client(self) -> Iterator[Elasticsearch]:
5854
for x in es_client_fixture():
5955
yield x
6056

61-
@pytest_asyncio.fixture(autouse=True)
57+
@pytest.fixture
6258
def requests_saving_client(self) -> Iterator[Elasticsearch]:
6359
client = create_requests_saving_client()
6460
try:
@@ -71,7 +67,6 @@ def index_name(self) -> str:
7167
"""Return the index name."""
7268
return f"test_{uuid.uuid4().hex}"
7369

74-
@pytest.mark.asyncio
7570
def test_search_without_metadata(
7671
self, es_client: Elasticsearch, index_name: str
7772
) -> None:
@@ -102,7 +97,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
10297
output = store.search("foo", k=1, custom_query=assert_query)
10398
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
10499

105-
@pytest.mark.asyncio
106100
def test_search_without_metadata_async(
107101
self, es_client: Elasticsearch, index_name: str
108102
) -> None:
@@ -120,10 +114,7 @@ def test_search_without_metadata_async(
120114
output = store.search("foo", k=1)
121115
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
122116

123-
@pytest.mark.asyncio
124-
def test_add_vectors(
125-
self, es_client: Elasticsearch, index_name: str
126-
) -> None:
117+
def test_add_vectors(self, es_client: Elasticsearch, index_name: str) -> None:
127118
"""
128119
Test adding pre-built embeddings instead of using inference for the texts.
129120
This allows you to separate the embeddings text and the page_content
@@ -145,14 +136,11 @@ def test_add_vectors(
145136
es_client=es_client,
146137
)
147138

148-
store.add_texts(
149-
texts=texts, vectors=embedding_vectors, metadatas=metadatas
150-
)
139+
store.add_texts(texts=texts, vectors=embedding_vectors, metadatas=metadatas)
151140
output = store.search("foo1", k=1)
152141
assert [doc["_source"]["text_field"] for doc in output] == ["foo1"]
153142
assert [doc["_source"]["metadata"]["page"] for doc in output] == [0]
154143

155-
@pytest.mark.asyncio
156144
def test_search_with_metadata(
157145
self, es_client: Elasticsearch, index_name: str
158146
) -> None:
@@ -178,7 +166,6 @@ def test_search_with_metadata(
178166
assert [doc["_source"]["text_field"] for doc in output] == ["bar"]
179167
assert [doc["_source"]["metadata"]["page"] for doc in output] == [1]
180168

181-
@pytest.mark.asyncio
182169
def test_search_with_filter(
183170
self, es_client: Elasticsearch, index_name: str
184171
) -> None:
@@ -215,7 +202,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
215202
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
216203
assert [doc["_source"]["metadata"]["page"] for doc in output] == [1]
217204

218-
@pytest.mark.asyncio
219205
def test_search_script_score(
220206
self, es_client: Elasticsearch, index_name: str
221207
) -> None:
@@ -264,7 +250,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
264250
output = store.search("foo", k=1, custom_query=assert_query)
265251
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
266252

267-
@pytest.mark.asyncio
268253
def test_search_script_score_with_filter(
269254
self, es_client: Elasticsearch, index_name: str
270255
) -> None:
@@ -319,7 +304,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
319304
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
320305
assert [doc["_source"]["metadata"]["page"] for doc in output] == [0]
321306

322-
@pytest.mark.asyncio
323307
def test_search_script_score_distance_dot_product(
324308
self, es_client: Elasticsearch, index_name: str
325309
) -> None:
@@ -370,7 +354,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
370354
output = store.search("foo", k=1, custom_query=assert_query)
371355
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
372356

373-
@pytest.mark.asyncio
374357
def test_search_knn_with_hybrid_search(
375358
self, es_client: Elasticsearch, index_name: str
376359
) -> None:
@@ -410,7 +393,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
410393
output = store.search("foo", k=1, custom_query=assert_query)
411394
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
412395

413-
@pytest.mark.asyncio
414396
def test_search_knn_with_hybrid_search_rrf(
415397
self, es_client: Elasticsearch, index_name: str
416398
) -> None:
@@ -528,7 +510,6 @@ def assert_query(
528510
custom_query=partial(assert_query, expected_rrf={}),
529511
)
530512

531-
@pytest.mark.asyncio
532513
def test_search_knn_with_custom_query_fn(
533514
self, es_client: Elasticsearch, index_name: str
534515
) -> None:
@@ -561,7 +542,6 @@ def my_custom_query(query_body: dict, query: Optional[str]) -> dict:
561542
output = store.search("foo", k=1, custom_query=my_custom_query)
562543
assert [doc["_source"]["text_field"] for doc in output] == ["bar"]
563544

564-
@pytest.mark.asyncio
565545
def test_search_with_knn_infer_instack(
566546
self, es_client: Elasticsearch, index_name: str
567547
) -> None:
@@ -655,7 +635,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
655635
output = store.search("bar", k=1)
656636
assert [doc["_source"]["text_field"] for doc in output] == ["bar"]
657637

658-
@pytest.mark.asyncio
659638
def test_search_with_sparse_infer_instack(
660639
self, es_client: Elasticsearch, index_name: str
661640
) -> None:
@@ -679,7 +658,6 @@ def test_search_with_sparse_infer_instack(
679658
output = store.search("foo", k=1)
680659
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
681660

682-
@pytest.mark.asyncio
683661
def test_deployed_model_check_fails_semantic(
684662
self, es_client: Elasticsearch, index_name: str
685663
) -> None:
@@ -693,10 +671,7 @@ def test_deployed_model_check_fails_semantic(
693671
)
694672
store.add_texts(["foo", "bar", "baz"])
695673

696-
@pytest.mark.asyncio
697-
def test_search_bm25(
698-
self, es_client: Elasticsearch, index_name: str
699-
) -> None:
674+
def test_search_bm25(self, es_client: Elasticsearch, index_name: str) -> None:
700675
"""Test end to end using the BM25 retrieval strategy."""
701676
store = VectorStore(
702677
user_agent="test",
@@ -722,7 +697,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
722697
output = store.search("foo", k=1, custom_query=assert_query)
723698
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
724699

725-
@pytest.mark.asyncio
726700
def test_search_bm25_with_filter(
727701
self, es_client: Elasticsearch, index_name: str
728702
) -> None:
@@ -758,7 +732,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
758732
assert [doc["_source"]["text_field"] for doc in output] == ["foo"]
759733
assert [doc["_source"]["metadata"]["page"] for doc in output] == [1]
760734

761-
@pytest.mark.asyncio
762735
def test_delete(self, es_client: Elasticsearch, index_name: str) -> None:
763736
"""Test delete methods from vector store."""
764737
store = VectorStore(
@@ -791,7 +764,6 @@ def test_delete(self, es_client: Elasticsearch, index_name: str) -> None:
791764
output = store.search("gni", k=10)
792765
assert len(output) == 0
793766

794-
@pytest.mark.asyncio
795767
def test_indexing_exception_error(
796768
self,
797769
es_client: Elasticsearch,
@@ -822,7 +794,6 @@ def test_indexing_exception_error(
822794

823795
assert log_message in caplog.text
824796

825-
@pytest.mark.asyncio
826797
def test_user_agent(
827798
self, requests_saving_client: Elasticsearch, index_name: str
828799
) -> None:
@@ -845,10 +816,7 @@ def test_user_agent(
845816
for request in transport.requests:
846817
assert request["headers"]["User-Agent"] == user_agent
847818

848-
@pytest.mark.asyncio
849-
def test_bulk_args(
850-
self, requests_saving_client: Any, index_name: str
851-
) -> None:
819+
def test_bulk_args(self, requests_saving_client: Any, index_name: str) -> None:
852820
"""Test to make sure the bulk arguments work as expected."""
853821
store = VectorStore(
854822
user_agent="test",
@@ -863,7 +831,6 @@ def test_bulk_args(
863831
# 1 for index exist, 1 for index create, 3 to index docs
864832
assert len(store.es_client.transport.requests) == 5 # type: ignore
865833

866-
@pytest.mark.asyncio
867834
def test_max_marginal_relevance_search(
868835
self, es_client: Elasticsearch, index_name: str
869836
) -> None:

0 commit comments

Comments
 (0)