1
1
import logging
2
2
import uuid
3
- from typing import Iterator
4
- from typing import Any , List , Optional , Union , cast
5
3
from functools import partial
4
+ from typing import Any , Iterator , List , Optional , Union , cast
6
5
7
6
import pytest
8
- import pytest_asyncio
9
- from elasticsearch import Elasticsearch
10
7
11
- from elasticsearch import NotFoundError
8
+ from elasticsearch import Elasticsearch , NotFoundError
12
9
from elasticsearch .helpers import BulkIndexError
13
-
14
10
from elasticsearch .vectorstore ._sync import VectorStore
15
11
from elasticsearch .vectorstore ._sync ._utils import model_is_deployed
16
12
from elasticsearch .vectorstore ._sync .strategies import (
22
18
)
23
19
24
20
from ._test_utils import (
25
- create_requests_saving_client ,
26
- es_client_fixture ,
27
21
ConsistentFakeEmbeddings ,
28
22
FakeEmbeddings ,
29
23
RequestSavingTransport ,
24
+ create_requests_saving_client ,
25
+ es_client_fixture ,
30
26
)
31
27
32
28
logging .basicConfig (level = logging .DEBUG )
53
49
54
50
55
51
class TestElasticsearch :
56
- @pytest_asyncio .fixture ( autouse = True )
52
+ @pytest .fixture
57
53
def es_client (self ) -> Iterator [Elasticsearch ]:
58
54
for x in es_client_fixture ():
59
55
yield x
60
56
61
- @pytest_asyncio .fixture ( autouse = True )
57
+ @pytest .fixture
62
58
def requests_saving_client (self ) -> Iterator [Elasticsearch ]:
63
59
client = create_requests_saving_client ()
64
60
try :
@@ -71,7 +67,6 @@ def index_name(self) -> str:
71
67
"""Return the index name."""
72
68
return f"test_{ uuid .uuid4 ().hex } "
73
69
74
- @pytest .mark .asyncio
75
70
def test_search_without_metadata (
76
71
self , es_client : Elasticsearch , index_name : str
77
72
) -> None :
@@ -102,7 +97,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
102
97
output = store .search ("foo" , k = 1 , custom_query = assert_query )
103
98
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
104
99
105
- @pytest .mark .asyncio
106
100
def test_search_without_metadata_async (
107
101
self , es_client : Elasticsearch , index_name : str
108
102
) -> None :
@@ -120,10 +114,7 @@ def test_search_without_metadata_async(
120
114
output = store .search ("foo" , k = 1 )
121
115
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
122
116
123
- @pytest .mark .asyncio
124
- def test_add_vectors (
125
- self , es_client : Elasticsearch , index_name : str
126
- ) -> None :
117
+ def test_add_vectors (self , es_client : Elasticsearch , index_name : str ) -> None :
127
118
"""
128
119
Test adding pre-built embeddings instead of using inference for the texts.
129
120
This allows you to separate the embeddings text and the page_content
@@ -145,14 +136,11 @@ def test_add_vectors(
145
136
es_client = es_client ,
146
137
)
147
138
148
- store .add_texts (
149
- texts = texts , vectors = embedding_vectors , metadatas = metadatas
150
- )
139
+ store .add_texts (texts = texts , vectors = embedding_vectors , metadatas = metadatas )
151
140
output = store .search ("foo1" , k = 1 )
152
141
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo1" ]
153
142
assert [doc ["_source" ]["metadata" ]["page" ] for doc in output ] == [0 ]
154
143
155
- @pytest .mark .asyncio
156
144
def test_search_with_metadata (
157
145
self , es_client : Elasticsearch , index_name : str
158
146
) -> None :
@@ -178,7 +166,6 @@ def test_search_with_metadata(
178
166
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["bar" ]
179
167
assert [doc ["_source" ]["metadata" ]["page" ] for doc in output ] == [1 ]
180
168
181
- @pytest .mark .asyncio
182
169
def test_search_with_filter (
183
170
self , es_client : Elasticsearch , index_name : str
184
171
) -> None :
@@ -215,7 +202,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
215
202
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
216
203
assert [doc ["_source" ]["metadata" ]["page" ] for doc in output ] == [1 ]
217
204
218
- @pytest .mark .asyncio
219
205
def test_search_script_score (
220
206
self , es_client : Elasticsearch , index_name : str
221
207
) -> None :
@@ -264,7 +250,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
264
250
output = store .search ("foo" , k = 1 , custom_query = assert_query )
265
251
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
266
252
267
- @pytest .mark .asyncio
268
253
def test_search_script_score_with_filter (
269
254
self , es_client : Elasticsearch , index_name : str
270
255
) -> None :
@@ -319,7 +304,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
319
304
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
320
305
assert [doc ["_source" ]["metadata" ]["page" ] for doc in output ] == [0 ]
321
306
322
- @pytest .mark .asyncio
323
307
def test_search_script_score_distance_dot_product (
324
308
self , es_client : Elasticsearch , index_name : str
325
309
) -> None :
@@ -370,7 +354,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
370
354
output = store .search ("foo" , k = 1 , custom_query = assert_query )
371
355
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
372
356
373
- @pytest .mark .asyncio
374
357
def test_search_knn_with_hybrid_search (
375
358
self , es_client : Elasticsearch , index_name : str
376
359
) -> None :
@@ -410,7 +393,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
410
393
output = store .search ("foo" , k = 1 , custom_query = assert_query )
411
394
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
412
395
413
- @pytest .mark .asyncio
414
396
def test_search_knn_with_hybrid_search_rrf (
415
397
self , es_client : Elasticsearch , index_name : str
416
398
) -> None :
@@ -528,7 +510,6 @@ def assert_query(
528
510
custom_query = partial (assert_query , expected_rrf = {}),
529
511
)
530
512
531
- @pytest .mark .asyncio
532
513
def test_search_knn_with_custom_query_fn (
533
514
self , es_client : Elasticsearch , index_name : str
534
515
) -> None :
@@ -561,7 +542,6 @@ def my_custom_query(query_body: dict, query: Optional[str]) -> dict:
561
542
output = store .search ("foo" , k = 1 , custom_query = my_custom_query )
562
543
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["bar" ]
563
544
564
- @pytest .mark .asyncio
565
545
def test_search_with_knn_infer_instack (
566
546
self , es_client : Elasticsearch , index_name : str
567
547
) -> None :
@@ -655,7 +635,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
655
635
output = store .search ("bar" , k = 1 )
656
636
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["bar" ]
657
637
658
- @pytest .mark .asyncio
659
638
def test_search_with_sparse_infer_instack (
660
639
self , es_client : Elasticsearch , index_name : str
661
640
) -> None :
@@ -679,7 +658,6 @@ def test_search_with_sparse_infer_instack(
679
658
output = store .search ("foo" , k = 1 )
680
659
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
681
660
682
- @pytest .mark .asyncio
683
661
def test_deployed_model_check_fails_semantic (
684
662
self , es_client : Elasticsearch , index_name : str
685
663
) -> None :
@@ -693,10 +671,7 @@ def test_deployed_model_check_fails_semantic(
693
671
)
694
672
store .add_texts (["foo" , "bar" , "baz" ])
695
673
696
- @pytest .mark .asyncio
697
- def test_search_bm25 (
698
- self , es_client : Elasticsearch , index_name : str
699
- ) -> None :
674
+ def test_search_bm25 (self , es_client : Elasticsearch , index_name : str ) -> None :
700
675
"""Test end to end using the BM25 retrieval strategy."""
701
676
store = VectorStore (
702
677
user_agent = "test" ,
@@ -722,7 +697,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
722
697
output = store .search ("foo" , k = 1 , custom_query = assert_query )
723
698
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
724
699
725
- @pytest .mark .asyncio
726
700
def test_search_bm25_with_filter (
727
701
self , es_client : Elasticsearch , index_name : str
728
702
) -> None :
@@ -758,7 +732,6 @@ def assert_query(query_body: dict, query: Optional[str]) -> dict:
758
732
assert [doc ["_source" ]["text_field" ] for doc in output ] == ["foo" ]
759
733
assert [doc ["_source" ]["metadata" ]["page" ] for doc in output ] == [1 ]
760
734
761
- @pytest .mark .asyncio
762
735
def test_delete (self , es_client : Elasticsearch , index_name : str ) -> None :
763
736
"""Test delete methods from vector store."""
764
737
store = VectorStore (
@@ -791,7 +764,6 @@ def test_delete(self, es_client: Elasticsearch, index_name: str) -> None:
791
764
output = store .search ("gni" , k = 10 )
792
765
assert len (output ) == 0
793
766
794
- @pytest .mark .asyncio
795
767
def test_indexing_exception_error (
796
768
self ,
797
769
es_client : Elasticsearch ,
@@ -822,7 +794,6 @@ def test_indexing_exception_error(
822
794
823
795
assert log_message in caplog .text
824
796
825
- @pytest .mark .asyncio
826
797
def test_user_agent (
827
798
self , requests_saving_client : Elasticsearch , index_name : str
828
799
) -> None :
@@ -845,10 +816,7 @@ def test_user_agent(
845
816
for request in transport .requests :
846
817
assert request ["headers" ]["User-Agent" ] == user_agent
847
818
848
- @pytest .mark .asyncio
849
- def test_bulk_args (
850
- self , requests_saving_client : Any , index_name : str
851
- ) -> None :
819
+ def test_bulk_args (self , requests_saving_client : Any , index_name : str ) -> None :
852
820
"""Test to make sure the bulk arguments work as expected."""
853
821
store = VectorStore (
854
822
user_agent = "test" ,
@@ -863,7 +831,6 @@ def test_bulk_args(
863
831
# 1 for index exist, 1 for index create, 3 to index docs
864
832
assert len (store .es_client .transport .requests ) == 5 # type: ignore
865
833
866
- @pytest .mark .asyncio
867
834
def test_max_marginal_relevance_search (
868
835
self , es_client : Elasticsearch , index_name : str
869
836
) -> None :
0 commit comments