Skip to content

Commit 667dd68

Browse files
Added type hints to aggs.py, analysis.py, connections.py and field.py (#1849)
* Added type hints to aggs.py, analysis.py, connections.py and field.py * review feedback
1 parent 0c3ffcd commit 667dd68

17 files changed

+458
-276
lines changed

elasticsearch_dsl/aggs.py

Lines changed: 68 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,32 @@
1616
# under the License.
1717

1818
import collections.abc
19+
from copy import deepcopy
20+
from typing import (
21+
TYPE_CHECKING,
22+
Any,
23+
ClassVar,
24+
Dict,
25+
Iterable,
26+
MutableMapping,
27+
Optional,
28+
Union,
29+
cast,
30+
)
1931

2032
from .response.aggs import AggResponse, BucketData, FieldBucketData, TopHitsData
21-
from .utils import DslBase
33+
from .utils import AttrDict, DslBase, JSONType
2234

35+
if TYPE_CHECKING:
36+
from .query import Query
37+
from .search_base import SearchBase
2338

24-
def A(name_or_agg, filter=None, **params):
39+
40+
def A(
41+
name_or_agg: Union[MutableMapping[str, Any], "Agg", str],
42+
filter: Optional[Union[str, "Query"]] = None,
43+
**params: Any,
44+
) -> "Agg":
2545
if filter is not None:
2646
if name_or_agg != "filter":
2747
raise ValueError(
@@ -31,11 +51,11 @@ def A(name_or_agg, filter=None, **params):
3151
params["filter"] = filter
3252

3353
# {"terms": {"field": "tags"}, "aggs": {...}}
34-
if isinstance(name_or_agg, collections.abc.Mapping):
54+
if isinstance(name_or_agg, collections.abc.MutableMapping):
3555
if params:
3656
raise ValueError("A() cannot accept parameters when passing in a dict.")
3757
# copy to avoid modifying in-place
38-
agg = name_or_agg.copy()
58+
agg = deepcopy(name_or_agg)
3959
# pop out nested aggs
4060
aggs = agg.pop("aggs", None)
4161
# pop out meta data
@@ -70,48 +90,57 @@ def A(name_or_agg, filter=None, **params):
7090
class Agg(DslBase):
7191
_type_name = "agg"
7292
_type_shortcut = staticmethod(A)
73-
name = None
93+
name = ""
7494

75-
def __contains__(self, key):
95+
def __contains__(self, key: str) -> bool:
7696
return False
7797

78-
def to_dict(self):
98+
def to_dict(self) -> Dict[str, JSONType]:
7999
d = super().to_dict()
80-
if "meta" in d[self.name]:
81-
d["meta"] = d[self.name].pop("meta")
100+
if isinstance(d[self.name], dict):
101+
n = cast(Dict[str, JSONType], d[self.name])
102+
if "meta" in n:
103+
d["meta"] = n.pop("meta")
82104
return d
83105

84-
def result(self, search, data):
106+
def result(self, search: "SearchBase", data: Any) -> AttrDict[str, Any]:
85107
return AggResponse(self, search, data)
86108

87109

88110
class AggBase:
89-
_param_defs = {
111+
aggs: Dict[str, Agg]
112+
_base: Agg
113+
_params: Dict[str, Any]
114+
_param_defs: ClassVar[Dict[str, Any]] = {
90115
"aggs": {"type": "agg", "hash": True},
91116
}
92117

93-
def __contains__(self, key):
118+
def __contains__(self, key: str) -> bool:
94119
return key in self._params.get("aggs", {})
95120

96-
def __getitem__(self, agg_name):
97-
agg = self._params.setdefault("aggs", {})[agg_name] # propagate KeyError
121+
def __getitem__(self, agg_name: str) -> Agg:
122+
agg = cast(
123+
Agg, self._params.setdefault("aggs", {})[agg_name]
124+
) # propagate KeyError
98125

99126
# make sure we're not mutating a shared state - whenever accessing a
100127
# bucket, return a shallow copy of it to be safe
101128
if isinstance(agg, Bucket):
102-
agg = A(agg.name, **agg._params)
129+
agg = A(agg.name, filter=None, **agg._params)
103130
# be sure to store the copy so any modifications to it will affect us
104131
self._params["aggs"][agg_name] = agg
105132

106133
return agg
107134

108-
def __setitem__(self, agg_name, agg):
135+
def __setitem__(self, agg_name: str, agg: Agg) -> None:
109136
self.aggs[agg_name] = A(agg)
110137

111-
def __iter__(self):
138+
def __iter__(self) -> Iterable[str]:
112139
return iter(self.aggs)
113140

114-
def _agg(self, bucket, name, agg_type, *args, **params):
141+
def _agg(
142+
self, bucket: bool, name: str, agg_type: str, *args: Any, **params: Any
143+
) -> Agg:
115144
agg = self[name] = A(agg_type, *args, **params)
116145

117146
# For chaining - when creating new buckets return them...
@@ -121,29 +150,31 @@ def _agg(self, bucket, name, agg_type, *args, **params):
121150
else:
122151
return self._base
123152

124-
def metric(self, name, agg_type, *args, **params):
153+
def metric(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg:
125154
return self._agg(False, name, agg_type, *args, **params)
126155

127-
def bucket(self, name, agg_type, *args, **params):
156+
def bucket(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg:
128157
return self._agg(True, name, agg_type, *args, **params)
129158

130-
def pipeline(self, name, agg_type, *args, **params):
159+
def pipeline(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg:
131160
return self._agg(False, name, agg_type, *args, **params)
132161

133-
def result(self, search, data):
162+
def result(self, search: "SearchBase", data: Any) -> AttrDict[str, Any]:
134163
return BucketData(self, search, data)
135164

136165

137166
class Bucket(AggBase, Agg):
138-
def __init__(self, **params):
167+
def __init__(self, **params: Any):
139168
super().__init__(**params)
140169
# remember self for chaining
141170
self._base = self
142171

143-
def to_dict(self):
172+
def to_dict(self) -> Dict[str, JSONType]:
144173
d = super(AggBase, self).to_dict()
145-
if "aggs" in d[self.name]:
146-
d["aggs"] = d[self.name].pop("aggs")
174+
if isinstance(d[self.name], dict):
175+
n = cast(AttrDict[str, Any], d[self.name])
176+
if "aggs" in n:
177+
d["aggs"] = n.pop("aggs")
147178
return d
148179

149180

@@ -154,14 +185,16 @@ class Filter(Bucket):
154185
"aggs": {"type": "agg", "hash": True},
155186
}
156187

157-
def __init__(self, filter=None, **params):
188+
def __init__(self, filter: Optional[Union[str, "Query"]] = None, **params: Any):
158189
if filter is not None:
159190
params["filter"] = filter
160191
super().__init__(**params)
161192

162-
def to_dict(self):
193+
def to_dict(self) -> Dict[str, JSONType]:
163194
d = super().to_dict()
164-
d[self.name].update(d[self.name].pop("filter", {}))
195+
if isinstance(d[self.name], dict):
196+
n = cast(AttrDict[str, Any], d[self.name])
197+
n.update(n.pop("filter", {}))
165198
return d
166199

167200

@@ -189,7 +222,7 @@ class Parent(Bucket):
189222
class DateHistogram(Bucket):
190223
name = "date_histogram"
191224

192-
def result(self, search, data):
225+
def result(self, search: "SearchBase", data: Any) -> AttrDict[str, Any]:
193226
return FieldBucketData(self, search, data)
194227

195228

@@ -232,7 +265,7 @@ class Global(Bucket):
232265
class Histogram(Bucket):
233266
name = "histogram"
234267

235-
def result(self, search, data):
268+
def result(self, search: "SearchBase", data: Any) -> AttrDict[str, Any]:
236269
return FieldBucketData(self, search, data)
237270

238271

@@ -259,7 +292,7 @@ class Range(Bucket):
259292
class RareTerms(Bucket):
260293
name = "rare_terms"
261294

262-
def result(self, search, data):
295+
def result(self, search: "SearchBase", data: Any) -> AttrDict[str, Any]:
263296
return FieldBucketData(self, search, data)
264297

265298

@@ -278,7 +311,7 @@ class SignificantText(Bucket):
278311
class Terms(Bucket):
279312
name = "terms"
280313

281-
def result(self, search, data):
314+
def result(self, search: "SearchBase", data: Any) -> AttrDict[str, Any]:
282315
return FieldBucketData(self, search, data)
283316

284317

@@ -305,7 +338,7 @@ class Composite(Bucket):
305338
class VariableWidthHistogram(Bucket):
306339
name = "variable_width_histogram"
307340

308-
def result(self, search, data):
341+
def result(self, search: "SearchBase", data: Any) -> AttrDict[str, Any]:
309342
return FieldBucketData(self, search, data)
310343

311344

@@ -321,7 +354,7 @@ class CategorizeText(Bucket):
321354
class TopHits(Agg):
322355
name = "top_hits"
323356

324-
def result(self, search, data):
357+
def result(self, search: "SearchBase", data: Any) -> AttrDict[str, Any]:
325358
return TopHitsData(self, search, data)
326359

327360

0 commit comments

Comments
 (0)