16
16
# under the License.
17
17
18
18
import collections .abc
19
+ from copy import deepcopy
19
20
from itertools import chain
21
+ from typing import (
22
+ Optional ,
23
+ Any ,
24
+ overload ,
25
+ TypeVar ,
26
+ Protocol ,
27
+ Callable ,
28
+ ClassVar ,
29
+ Union ,
30
+ )
20
31
21
32
# 'SF' looks unused but the test suite assumes it's available
22
33
# from this module so others are liable to do so as well.
23
34
from .function import SF # noqa: F401
24
35
from .function import ScoreFunction
25
36
from .utils import DslBase
26
37
38
+ _T = TypeVar ("_T" )
39
+ _M = TypeVar ("_M" , bound = collections .abc .Mapping [str , Any ])
27
40
28
- def Q (name_or_query = "match_all" , ** params ):
41
+
42
+ class QProxiedProtocol (Protocol [_T ]):
43
+ _proxied : _T
44
+
45
+
46
+ @overload
47
+ def Q (name_or_query : collections .abc .MutableMapping [str , _M ]) -> "Query" : ...
48
+
49
+
50
+ @overload
51
+ def Q (name_or_query : "Query" ) -> "Query" : ...
52
+
53
+
54
+ @overload
55
+ def Q (name_or_query : QProxiedProtocol [_T ]) -> _T : ...
56
+
57
+
58
+ @overload
59
+ def Q (name_or_query : str , ** params : Any ) -> "Query" : ...
60
+
61
+
62
+ def Q (
63
+ name_or_query : Union [
64
+ str ,
65
+ "Query" ,
66
+ QProxiedProtocol [_T ],
67
+ collections .abc .MutableMapping [str , _M ],
68
+ ] = "match_all" ,
69
+ ** params : Any
70
+ ) -> Union ["Query" , _T ]:
29
71
# {"match": {"title": "python"}}
30
- if isinstance (name_or_query , collections .abc .Mapping ):
72
+ if isinstance (name_or_query , collections .abc .MutableMapping ):
31
73
if params :
32
74
raise ValueError ("Q() cannot accept parameters when passing in a dict." )
33
75
if len (name_or_query ) != 1 :
34
76
raise ValueError (
35
77
'Q() can only accept dict with a single query ({"match": {...}}). '
36
78
"Instead it got (%r)" % name_or_query
37
79
)
38
- name , params = name_or_query . copy ( ).popitem ()
39
- return Query .get_dsl_class (name )(_expand__to_dot = False , ** params )
80
+ name , q_params = deepcopy ( name_or_query ).popitem ()
81
+ return Query .get_dsl_class (name )(_expand__to_dot = False , ** q_params )
40
82
41
83
# MatchAll()
42
84
if isinstance (name_or_query , Query ):
@@ -57,26 +99,31 @@ def Q(name_or_query="match_all", **params):
57
99
class Query (DslBase ):
58
100
_type_name = "query"
59
101
_type_shortcut = staticmethod (Q )
60
- name = None
102
+ name : ClassVar [Optional [str ]] = None
103
+
104
+ # Add type annotations for methods not defined in every subclass
105
+ __ror__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
106
+ __radd__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
107
+ __rand__ : ClassVar [Callable [["Query" , "Query" ], "Query" ]]
61
108
62
- def __add__ (self , other ) :
109
+ def __add__ (self , other : "Query" ) -> "Query" :
63
110
# make sure we give queries that know how to combine themselves
64
111
# preference
65
112
if hasattr (other , "__radd__" ):
66
113
return other .__radd__ (self )
67
114
return Bool (must = [self , other ])
68
115
69
- def __invert__ (self ):
116
+ def __invert__ (self ) -> "Query" :
70
117
return Bool (must_not = [self ])
71
118
72
- def __or__ (self , other ) :
119
+ def __or__ (self , other : "Query" ) -> "Query" :
73
120
# make sure we give queries that know how to combine themselves
74
121
# preference
75
122
if hasattr (other , "__ror__" ):
76
123
return other .__ror__ (self )
77
124
return Bool (should = [self , other ])
78
125
79
- def __and__ (self , other ) :
126
+ def __and__ (self , other : "Query" ) -> "Query" :
80
127
# make sure we give queries that know how to combine themselves
81
128
# preference
82
129
if hasattr (other , "__rand__" ):
@@ -87,17 +134,17 @@ def __and__(self, other):
87
134
class MatchAll (Query ):
88
135
name = "match_all"
89
136
90
- def __add__ (self , other ) :
137
+ def __add__ (self , other : "Query" ) -> "Query" :
91
138
return other ._clone ()
92
139
93
140
__and__ = __rand__ = __radd__ = __add__
94
141
95
- def __or__ (self , other ) :
142
+ def __or__ (self , other : "Query" ) -> "MatchAll" :
96
143
return self
97
144
98
145
__ror__ = __or__
99
146
100
- def __invert__ (self ):
147
+ def __invert__ (self ) -> "MatchNone" :
101
148
return MatchNone ()
102
149
103
150
@@ -107,17 +154,17 @@ def __invert__(self):
107
154
class MatchNone (Query ):
108
155
name = "match_none"
109
156
110
- def __add__ (self , other ) :
157
+ def __add__ (self , other : "Query" ) -> "MatchNone" :
111
158
return self
112
159
113
160
__and__ = __rand__ = __radd__ = __add__
114
161
115
- def __or__ (self , other ) :
162
+ def __or__ (self , other : "Query" ) -> "Query" :
116
163
return other ._clone ()
117
164
118
165
__ror__ = __or__
119
166
120
- def __invert__ (self ):
167
+ def __invert__ (self ) -> MatchAll :
121
168
return MatchAll ()
122
169
123
170
@@ -130,7 +177,7 @@ class Bool(Query):
130
177
"filter" : {"type" : "query" , "multi" : True },
131
178
}
132
179
133
- def __add__ (self , other ) :
180
+ def __add__ (self , other : Query ) -> "Bool" :
134
181
q = self ._clone ()
135
182
if isinstance (other , Bool ):
136
183
q .must += other .must
@@ -143,7 +190,7 @@ def __add__(self, other):
143
190
144
191
__radd__ = __add__
145
192
146
- def __or__ (self , other ) :
193
+ def __or__ (self , other : Query ) -> Query :
147
194
for q in (self , other ):
148
195
if isinstance (q , Bool ) and not any (
149
196
(q .must , q .must_not , q .filter , getattr (q , "minimum_should_match" , None ))
@@ -168,20 +215,20 @@ def __or__(self, other):
168
215
__ror__ = __or__
169
216
170
217
@property
171
- def _min_should_match (self ):
218
+ def _min_should_match (self ) -> int :
172
219
return getattr (
173
220
self ,
174
221
"minimum_should_match" ,
175
222
0 if not self .should or (self .must or self .filter ) else 1 ,
176
223
)
177
224
178
- def __invert__ (self ):
225
+ def __invert__ (self ) -> Query :
179
226
# Because an empty Bool query is treated like
180
227
# MatchAll the inverse should be MatchNone
181
228
if not any (chain (self .must , self .filter , self .should , self .must_not )):
182
229
return MatchNone ()
183
230
184
- negations = []
231
+ negations : list [ Query ] = []
185
232
for q in chain (self .must , self .filter ):
186
233
negations .append (~ q )
187
234
@@ -195,7 +242,7 @@ def __invert__(self):
195
242
return negations [0 ]
196
243
return Bool (should = negations )
197
244
198
- def __and__ (self , other ) :
245
+ def __and__ (self , other : Query ) -> Query :
199
246
q = self ._clone ()
200
247
if isinstance (other , Bool ):
201
248
q .must += other .must
@@ -247,7 +294,7 @@ class FunctionScore(Query):
247
294
"functions" : {"type" : "score_function" , "multi" : True },
248
295
}
249
296
250
- def __init__ (self , ** kwargs ):
297
+ def __init__ (self , ** kwargs : Any ):
251
298
if "functions" in kwargs :
252
299
pass
253
300
else :
0 commit comments