1
+ # Licensed to Elasticsearch B.V. under one or more contributor
2
+ # license agreements. See the NOTICE file distributed with
3
+ # this work for additional information regarding copyright
4
+ # ownership. Elasticsearch B.V. licenses this file to you under
5
+ # the Apache License, Version 2.0 (the "License"); you may
6
+ # not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing,
12
+ # software distributed under the License is distributed on an
13
+ # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
+ # KIND, either express or implied. See the License for the
15
+ # specific language governing permissions and limitations
16
+ # under the License.
17
+
1
18
from abc import ABC , abstractmethod
2
19
from enum import Enum
3
20
from typing import Any , Dict , List , Literal , Optional , Union , cast
@@ -23,9 +40,9 @@ async def es_query(
23
40
query : Optional [str ],
24
41
k : int ,
25
42
num_candidates : int ,
26
- filter : List [dict ] = [],
43
+ filter : List [Dict [ str , Any ] ] = [],
27
44
query_vector : Optional [List [float ]] = None ,
28
- ) -> Dict :
45
+ ) -> Dict [ str , Any ] :
29
46
"""
30
47
Returns the Elasticsearch query body for the given parameters.
31
48
The store will execute the query.
@@ -46,7 +63,7 @@ async def create_index(
46
63
self ,
47
64
client : AsyncElasticsearch ,
48
65
index_name : str ,
49
- metadata_mapping : Optional [dict [str , str ]],
66
+ metadata_mapping : Optional [Dict [str , str ]],
50
67
) -> None :
51
68
"""
52
69
Create the required index and do necessary preliminary work, like
@@ -95,9 +112,9 @@ async def es_query(
95
112
query : Optional [str ],
96
113
k : int ,
97
114
num_candidates : int ,
98
- filter : List [dict ] = [],
115
+ filter : List [Dict [ str , Any ] ] = [],
99
116
query_vector : Optional [List [float ]] = None ,
100
- ) -> Dict :
117
+ ) -> Dict [ str , Any ] :
101
118
if query_vector :
102
119
raise ValueError (
103
120
"Cannot do sparse retrieval with a query_vector. "
@@ -117,12 +134,12 @@ async def create_index(
117
134
self ,
118
135
client : AsyncElasticsearch ,
119
136
index_name : str ,
120
- metadata_mapping : Optional [dict [str , str ]],
137
+ metadata_mapping : Optional [Dict [str , str ]],
121
138
) -> None :
122
139
if self .model_id :
123
140
await model_must_be_deployed (client , self .model_id )
124
141
125
- mappings : dict [str , Any ] = {
142
+ mappings : Dict [str , Any ] = {
126
143
"properties" : {
127
144
self .inference_field : {
128
145
"type" : "semantic_text" ,
@@ -155,9 +172,9 @@ async def es_query(
155
172
query : Optional [str ],
156
173
k : int ,
157
174
num_candidates : int ,
158
- filter : List [dict ] = [],
175
+ filter : List [Dict [ str , Any ] ] = [],
159
176
query_vector : Optional [List [float ]] = None ,
160
- ) -> Dict :
177
+ ) -> Dict [ str , Any ] :
161
178
if query_vector :
162
179
raise ValueError (
163
180
"Cannot do sparse retrieval with a query_vector. "
@@ -189,7 +206,7 @@ async def create_index(
189
206
self ,
190
207
client : AsyncElasticsearch ,
191
208
index_name : str ,
192
- metadata_mapping : Optional [dict [str , str ]],
209
+ metadata_mapping : Optional [Dict [str , str ]],
193
210
) -> None :
194
211
pipeline_name = f"{ self .model_id } _sparse_embedding"
195
212
@@ -214,7 +231,7 @@ async def create_index(
214
231
],
215
232
)
216
233
217
- mappings = {
234
+ mappings : Dict [ str , Any ] = {
218
235
"properties" : {
219
236
self .vector_field : {
220
237
"properties" : {self ._tokens_field : {"type" : "rank_features" }}
@@ -244,7 +261,7 @@ def __init__(
244
261
model_id : Optional [str ] = None ,
245
262
num_dimensions : Optional [int ] = None ,
246
263
hybrid : bool = False ,
247
- rrf : Union [bool , dict ] = True ,
264
+ rrf : Union [bool , Dict [ str , Any ] ] = True ,
248
265
text_field : Optional [str ] = "text_field" ,
249
266
):
250
267
if embedding_service and model_id :
@@ -273,9 +290,9 @@ async def es_query(
273
290
query : Optional [str ],
274
291
k : int ,
275
292
num_candidates : int ,
276
- filter : List [dict ] = [],
293
+ filter : List [Dict [ str , Any ] ] = [],
277
294
query_vector : Optional [List [float ]] = None ,
278
- ) -> Dict :
295
+ ) -> Dict [ str , Any ] :
279
296
knn = {
280
297
"filter" : filter ,
281
298
"field" : self .vector_field ,
@@ -308,7 +325,7 @@ async def create_index(
308
325
self ,
309
326
client : AsyncElasticsearch ,
310
327
index_name : str ,
311
- metadata_mapping : Optional [dict [str , str ]],
328
+ metadata_mapping : Optional [Dict [str , str ]],
312
329
) -> None :
313
330
if self .embedding_service and not self .num_dimensions :
314
331
self .num_dimensions = len (
@@ -351,7 +368,9 @@ async def embed_for_indexing(self, text: str) -> Dict[str, Any]:
351
368
return {self .vector_field : vector }
352
369
return {}
353
370
354
- def _hybrid (self , query : str , knn : dict , filter : list ):
371
+ def _hybrid (
372
+ self , query : str , knn : Dict [str , Any ], filter : List [Dict [str , Any ]]
373
+ ) -> Dict [str , Any ]:
355
374
# Add a query to the knn query.
356
375
# RRF is used to even the score from the knn query and text query
357
376
# RRF has two optional parameters: {'rank_constant':int, 'window_size':int}
@@ -374,7 +393,7 @@ def _hybrid(self, query: str, knn: dict, filter: list):
374
393
},
375
394
}
376
395
377
- if isinstance (self .rrf , dict ):
396
+ if isinstance (self .rrf , Dict [ str , Any ] ):
378
397
query_body ["rank" ] = {"rrf" : self .rrf }
379
398
elif isinstance (self .rrf , bool ) and self .rrf is True :
380
399
query_body ["rank" ] = {"rrf" : {}}
@@ -402,9 +421,9 @@ async def es_query(
402
421
query : Optional [str ],
403
422
k : int ,
404
423
num_candidates : int ,
405
- filter : List [dict ] = [],
424
+ filter : List [Dict [ str , Any ] ] = [],
406
425
query_vector : Optional [List [float ]] = None ,
407
- ) -> Dict :
426
+ ) -> Dict [ str , Any ] :
408
427
if self .distance is DistanceMetric .COSINE :
409
428
similarityAlgo = (
410
429
f"cosineSimilarity(params.query_vector, '{ self .vector_field } ') + 1.0"
@@ -429,7 +448,7 @@ async def es_query(
429
448
else :
430
449
raise ValueError (f"Similarity { self .distance } not supported." )
431
450
432
- queryBool : Dict = {"match_all" : {}}
451
+ queryBool : Dict [ str , Any ] = {"match_all" : {}}
433
452
if filter :
434
453
queryBool = {"bool" : {"filter" : filter }}
435
454
@@ -459,7 +478,7 @@ async def create_index(
459
478
self ,
460
479
client : AsyncElasticsearch ,
461
480
index_name : str ,
462
- metadata_mapping : Optional [dict [str , str ]],
481
+ metadata_mapping : Optional [Dict [str , str ]],
463
482
) -> None :
464
483
if not self .num_dimensions :
465
484
self .num_dimensions = len (
@@ -502,9 +521,9 @@ async def es_query(
502
521
query : Optional [str ],
503
522
k : int ,
504
523
num_candidates : int ,
505
- filter : List [dict ] = [],
524
+ filter : List [Dict [ str , Any ] ] = [],
506
525
query_vector : Optional [List [float ]] = None ,
507
- ) -> Dict :
526
+ ) -> Dict [ str , Any ] :
508
527
return {
509
528
"query" : {
510
529
"bool" : {
@@ -526,11 +545,11 @@ async def create_index(
526
545
self ,
527
546
client : AsyncElasticsearch ,
528
547
index_name : str ,
529
- metadata_mapping : Optional [dict [str , str ]],
548
+ metadata_mapping : Optional [Dict [str , str ]],
530
549
) -> None :
531
550
similarity_name = "custom_bm25"
532
551
533
- mappings : Dict = {
552
+ mappings : Dict [ str , Any ] = {
534
553
"properties" : {
535
554
self .text_field : {
536
555
"type" : "text" ,
@@ -541,7 +560,7 @@ async def create_index(
541
560
if metadata_mapping :
542
561
mappings ["properties" ]["metadata" ] = {"properties" : metadata_mapping }
543
562
544
- bm25 : Dict = {
563
+ bm25 : Dict [ str , Any ] = {
545
564
"type" : "BM25" ,
546
565
}
547
566
if self .k1 is not None :
0 commit comments