16
16
# under the License.
17
17
18
18
import collections .abc
19
+ from copy import deepcopy
20
+ from typing import (
21
+ TYPE_CHECKING ,
22
+ Any ,
23
+ ClassVar ,
24
+ Dict ,
25
+ Iterable ,
26
+ MutableMapping ,
27
+ Optional ,
28
+ Union ,
29
+ cast ,
30
+ )
19
31
20
32
from .response .aggs import AggResponse , BucketData , FieldBucketData , TopHitsData
21
- from .utils import DslBase
33
+ from .utils import AttrDict , DslBase , JSONType
22
34
35
+ if TYPE_CHECKING :
36
+ from .query import Query
37
+ from .search_base import SearchBase
23
38
24
- def A (name_or_agg , filter = None , ** params ):
39
+
40
+ def A (
41
+ name_or_agg : Union [MutableMapping [str , Any ], "Agg" , str ],
42
+ filter : Optional [Union [str , "Query" ]] = None ,
43
+ ** params : Any ,
44
+ ) -> "Agg" :
25
45
if filter is not None :
26
46
if name_or_agg != "filter" :
27
47
raise ValueError (
@@ -31,11 +51,11 @@ def A(name_or_agg, filter=None, **params):
31
51
params ["filter" ] = filter
32
52
33
53
# {"terms": {"field": "tags"}, "aggs": {...}}
34
- if isinstance (name_or_agg , collections .abc .Mapping ):
54
+ if isinstance (name_or_agg , collections .abc .MutableMapping ):
35
55
if params :
36
56
raise ValueError ("A() cannot accept parameters when passing in a dict." )
37
57
# copy to avoid modifying in-place
38
- agg = name_or_agg . copy ( )
58
+ agg = deepcopy ( name_or_agg )
39
59
# pop out nested aggs
40
60
aggs = agg .pop ("aggs" , None )
41
61
# pop out meta data
@@ -70,48 +90,57 @@ def A(name_or_agg, filter=None, **params):
70
90
class Agg (DslBase ):
71
91
_type_name = "agg"
72
92
_type_shortcut = staticmethod (A )
73
- name = None
93
+ name = ""
74
94
75
- def __contains__ (self , key ) :
95
+ def __contains__ (self , key : str ) -> bool :
76
96
return False
77
97
78
- def to_dict (self ):
98
+ def to_dict (self ) -> Dict [ str , JSONType ] :
79
99
d = super ().to_dict ()
80
- if "meta" in d [self .name ]:
81
- d ["meta" ] = d [self .name ].pop ("meta" )
100
+ if isinstance (d [self .name ], dict ):
101
+ n = cast (Dict [str , JSONType ], d [self .name ])
102
+ if "meta" in n :
103
+ d ["meta" ] = n .pop ("meta" )
82
104
return d
83
105
84
- def result (self , search , data ) :
106
+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
85
107
return AggResponse (self , search , data )
86
108
87
109
88
110
class AggBase :
89
- _param_defs = {
111
+ aggs : Dict [str , Agg ]
112
+ _base : Agg
113
+ _params : Dict [str , Any ]
114
+ _param_defs : ClassVar [Dict [str , Any ]] = {
90
115
"aggs" : {"type" : "agg" , "hash" : True },
91
116
}
92
117
93
- def __contains__ (self , key ) :
118
+ def __contains__ (self , key : str ) -> bool :
94
119
return key in self ._params .get ("aggs" , {})
95
120
96
- def __getitem__ (self , agg_name ):
97
- agg = self ._params .setdefault ("aggs" , {})[agg_name ] # propagate KeyError
121
+ def __getitem__ (self , agg_name : str ) -> Agg :
122
+ agg = cast (
123
+ Agg , self ._params .setdefault ("aggs" , {})[agg_name ]
124
+ ) # propagate KeyError
98
125
99
126
# make sure we're not mutating a shared state - whenever accessing a
100
127
# bucket, return a shallow copy of it to be safe
101
128
if isinstance (agg , Bucket ):
102
- agg = A (agg .name , ** agg ._params )
129
+ agg = A (agg .name , filter = None , ** agg ._params )
103
130
# be sure to store the copy so any modifications to it will affect us
104
131
self ._params ["aggs" ][agg_name ] = agg
105
132
106
133
return agg
107
134
108
- def __setitem__ (self , agg_name , agg ) :
135
+ def __setitem__ (self , agg_name : str , agg : Agg ) -> None :
109
136
self .aggs [agg_name ] = A (agg )
110
137
111
- def __iter__ (self ):
138
+ def __iter__ (self ) -> Iterable [ str ] :
112
139
return iter (self .aggs )
113
140
114
- def _agg (self , bucket , name , agg_type , * args , ** params ):
141
+ def _agg (
142
+ self , bucket : bool , name : str , agg_type : str , * args : Any , ** params : Any
143
+ ) -> Agg :
115
144
agg = self [name ] = A (agg_type , * args , ** params )
116
145
117
146
# For chaining - when creating new buckets return them...
@@ -121,29 +150,31 @@ def _agg(self, bucket, name, agg_type, *args, **params):
121
150
else :
122
151
return self ._base
123
152
124
- def metric (self , name , agg_type , * args , ** params ) :
153
+ def metric (self , name : str , agg_type : str , * args : Any , ** params : Any ) -> Agg :
125
154
return self ._agg (False , name , agg_type , * args , ** params )
126
155
127
- def bucket (self , name , agg_type , * args , ** params ) :
156
+ def bucket (self , name : str , agg_type : str , * args : Any , ** params : Any ) -> Agg :
128
157
return self ._agg (True , name , agg_type , * args , ** params )
129
158
130
- def pipeline (self , name , agg_type , * args , ** params ) :
159
+ def pipeline (self , name : str , agg_type : str , * args : Any , ** params : Any ) -> Agg :
131
160
return self ._agg (False , name , agg_type , * args , ** params )
132
161
133
- def result (self , search , data ) :
162
+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
134
163
return BucketData (self , search , data )
135
164
136
165
137
166
class Bucket (AggBase , Agg ):
138
- def __init__ (self , ** params ):
167
+ def __init__ (self , ** params : Any ):
139
168
super ().__init__ (** params )
140
169
# remember self for chaining
141
170
self ._base = self
142
171
143
- def to_dict (self ):
172
+ def to_dict (self ) -> Dict [ str , JSONType ] :
144
173
d = super (AggBase , self ).to_dict ()
145
- if "aggs" in d [self .name ]:
146
- d ["aggs" ] = d [self .name ].pop ("aggs" )
174
+ if isinstance (d [self .name ], dict ):
175
+ n = cast (AttrDict [str , Any ], d [self .name ])
176
+ if "aggs" in n :
177
+ d ["aggs" ] = n .pop ("aggs" )
147
178
return d
148
179
149
180
@@ -154,14 +185,16 @@ class Filter(Bucket):
154
185
"aggs" : {"type" : "agg" , "hash" : True },
155
186
}
156
187
157
- def __init__ (self , filter = None , ** params ):
188
+ def __init__ (self , filter : Optional [ Union [ str , "Query" ]] = None , ** params : Any ):
158
189
if filter is not None :
159
190
params ["filter" ] = filter
160
191
super ().__init__ (** params )
161
192
162
- def to_dict (self ):
193
+ def to_dict (self ) -> Dict [ str , JSONType ] :
163
194
d = super ().to_dict ()
164
- d [self .name ].update (d [self .name ].pop ("filter" , {}))
195
+ if isinstance (d [self .name ], dict ):
196
+ n = cast (AttrDict [str , Any ], d [self .name ])
197
+ n .update (n .pop ("filter" , {}))
165
198
return d
166
199
167
200
@@ -189,7 +222,7 @@ class Parent(Bucket):
189
222
class DateHistogram (Bucket ):
190
223
name = "date_histogram"
191
224
192
- def result (self , search , data ) :
225
+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
193
226
return FieldBucketData (self , search , data )
194
227
195
228
@@ -232,7 +265,7 @@ class Global(Bucket):
232
265
class Histogram (Bucket ):
233
266
name = "histogram"
234
267
235
- def result (self , search , data ) :
268
+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
236
269
return FieldBucketData (self , search , data )
237
270
238
271
@@ -259,7 +292,7 @@ class Range(Bucket):
259
292
class RareTerms (Bucket ):
260
293
name = "rare_terms"
261
294
262
- def result (self , search , data ) :
295
+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
263
296
return FieldBucketData (self , search , data )
264
297
265
298
@@ -278,7 +311,7 @@ class SignificantText(Bucket):
278
311
class Terms (Bucket ):
279
312
name = "terms"
280
313
281
- def result (self , search , data ) :
314
+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
282
315
return FieldBucketData (self , search , data )
283
316
284
317
@@ -305,7 +338,7 @@ class Composite(Bucket):
305
338
class VariableWidthHistogram (Bucket ):
306
339
name = "variable_width_histogram"
307
340
308
- def result (self , search , data ) :
341
+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
309
342
return FieldBucketData (self , search , data )
310
343
311
344
@@ -321,7 +354,7 @@ class CategorizeText(Bucket):
321
354
class TopHits (Agg ):
322
355
name = "top_hits"
323
356
324
- def result (self , search , data ) :
357
+ def result (self , search : "SearchBase" , data : Any ) -> AttrDict [ str , Any ] :
325
358
return TopHitsData (self , search , data )
326
359
327
360
0 commit comments