Skip to content

Commit 663d6f4

Browse files
Type hints for tests and examples (elastic#1859)
* Type hints for tests and examples * add pyright check for examples only * add pyright check for examples only * simplified examples * review feedback
1 parent 619db4d commit 663d6f4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+1559
-1062
lines changed

elasticsearch_dsl/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
from . import connections
19-
from .aggs import A
19+
from .aggs import A, Agg
2020
from .analysis import analyzer, char_filter, normalizer, token_filter, tokenizer
2121
from .document import AsyncDocument, Document
2222
from .document_base import InnerDoc, M, MetaField, mapped_field
@@ -81,7 +81,8 @@
8181
from .function import SF
8282
from .index import AsyncIndex, AsyncIndexTemplate, Index, IndexTemplate
8383
from .mapping import AsyncMapping, Mapping
84-
from .query import Q
84+
from .query import Q, Query
85+
from .response import AggResponse, Response, UpdateByQueryResponse
8586
from .search import (
8687
AsyncEmptySearch,
8788
AsyncMultiSearch,
@@ -99,6 +100,8 @@
99100
__versionstr__ = ".".join(map(str, VERSION))
100101
__all__ = [
101102
"A",
103+
"Agg",
104+
"AggResponse",
102105
"AsyncDocument",
103106
"AsyncEmptySearch",
104107
"AsyncFacetedSearch",
@@ -158,11 +161,13 @@
158161
"Object",
159162
"Percolator",
160163
"Q",
164+
"Query",
161165
"Range",
162166
"RangeFacet",
163167
"RangeField",
164168
"RankFeature",
165169
"RankFeatures",
170+
"Response",
166171
"SF",
167172
"ScaledFloat",
168173
"Search",
@@ -174,6 +179,7 @@
174179
"TokenCount",
175180
"UnknownDslObject",
176181
"UpdateByQuery",
182+
"UpdateByQueryResponse",
177183
"ValidationException",
178184
"analyzer",
179185
"char_filter",

elasticsearch_dsl/_async/index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
name: str,
3939
template: str,
4040
index: Optional["AsyncIndex"] = None,
41-
order: Optional[str] = None,
41+
order: Optional[int] = None,
4242
**kwargs: Any,
4343
):
4444
if index is None:
@@ -100,7 +100,7 @@ def as_template(
100100
self,
101101
template_name: str,
102102
pattern: Optional[str] = None,
103-
order: Optional[str] = None,
103+
order: Optional[int] = None,
104104
) -> AsyncIndexTemplate:
105105
# TODO: should we allow pattern to be a top-level arg?
106106
# or maybe have an IndexPattern that allows for it and have

elasticsearch_dsl/_async/search.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,16 @@
1616
# under the License.
1717

1818
import contextlib
19-
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
AsyncIterator,
23+
Dict,
24+
Iterator,
25+
List,
26+
Optional,
27+
cast,
28+
)
2029

2130
from elasticsearch.exceptions import ApiError
2231
from elasticsearch.helpers import async_scan
@@ -68,6 +77,7 @@ async def count(self) -> int:
6877
query=cast(Optional[Dict[str, Any]], d.get("query", None)),
6978
**self._params,
7079
)
80+
7181
return cast(int, resp["count"])
7282

7383
async def execute(self, ignore_cache: bool = False) -> Response[_R]:
@@ -175,6 +185,10 @@ class AsyncMultiSearch(MultiSearchBase[_R]):
175185

176186
_using: AsyncUsingType
177187

188+
if TYPE_CHECKING:
189+
190+
def add(self, search: AsyncSearch[_R]) -> Self: ... # type: ignore[override]
191+
178192
async def execute(
179193
self, ignore_cache: bool = False, raise_on_error: bool = True
180194
) -> List[Response[_R]]:

elasticsearch_dsl/_sync/index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(
3838
name: str,
3939
template: str,
4040
index: Optional["Index"] = None,
41-
order: Optional[str] = None,
41+
order: Optional[int] = None,
4242
**kwargs: Any,
4343
):
4444
if index is None:
@@ -94,7 +94,7 @@ def as_template(
9494
self,
9595
template_name: str,
9696
pattern: Optional[str] = None,
97-
order: Optional[str] = None,
97+
order: Optional[int] = None,
9898
) -> IndexTemplate:
9999
# TODO: should we allow pattern to be a top-level arg?
100100
# or maybe have an IndexPattern that allows for it and have

elasticsearch_dsl/_sync/search.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717

1818
import contextlib
19-
from typing import Any, Dict, Iterator, List, Optional, cast
19+
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast
2020

2121
from elasticsearch.exceptions import ApiError
2222
from elasticsearch.helpers import scan
@@ -68,6 +68,7 @@ def count(self) -> int:
6868
query=cast(Optional[Dict[str, Any]], d.get("query", None)),
6969
**self._params,
7070
)
71+
7172
return cast(int, resp["count"])
7273

7374
def execute(self, ignore_cache: bool = False) -> Response[_R]:
@@ -169,6 +170,10 @@ class MultiSearch(MultiSearchBase[_R]):
169170

170171
_using: UsingType
171172

173+
if TYPE_CHECKING:
174+
175+
def add(self, search: Search[_R]) -> Self: ... # type: ignore[override]
176+
172177
def execute(
173178
self, ignore_cache: bool = False, raise_on_error: bool = True
174179
) -> List[Response[_R]]:

elasticsearch_dsl/aggs.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,12 @@ def __iter__(self) -> Iterable[str]:
140140
return iter(self.aggs)
141141

142142
def _agg(
143-
self, bucket: bool, name: str, agg_type: str, *args: Any, **params: Any
143+
self,
144+
bucket: bool,
145+
name: str,
146+
agg_type: Union[Dict[str, Any], Agg[_R], str],
147+
*args: Any,
148+
**params: Any,
144149
) -> Agg[_R]:
145150
agg = self[name] = A(agg_type, *args, **params)
146151

@@ -151,14 +156,32 @@ def _agg(
151156
else:
152157
return self._base
153158

154-
def metric(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg[_R]:
159+
def metric(
160+
self,
161+
name: str,
162+
agg_type: Union[Dict[str, Any], Agg[_R], str],
163+
*args: Any,
164+
**params: Any,
165+
) -> Agg[_R]:
155166
return self._agg(False, name, agg_type, *args, **params)
156167

157-
def bucket(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg[_R]:
158-
return self._agg(True, name, agg_type, *args, **params)
159-
160-
def pipeline(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg[_R]:
161-
return self._agg(False, name, agg_type, *args, **params)
168+
def bucket(
169+
self,
170+
name: str,
171+
agg_type: Union[Dict[str, Any], Agg[_R], str],
172+
*args: Any,
173+
**params: Any,
174+
) -> "Bucket[_R]":
175+
return cast("Bucket[_R]", self._agg(True, name, agg_type, *args, **params))
176+
177+
def pipeline(
178+
self,
179+
name: str,
180+
agg_type: Union[Dict[str, Any], Agg[_R], str],
181+
*args: Any,
182+
**params: Any,
183+
) -> "Pipeline[_R]":
184+
return cast("Pipeline[_R]", self._agg(False, name, agg_type, *args, **params))
162185

163186
def result(self, search: "SearchBase[_R]", data: Any) -> AttrDict[Any]:
164187
return BucketData(self, search, data) # type: ignore

elasticsearch_dsl/document_base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,7 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
195195
field = None
196196
field_args: List[Any] = []
197197
field_kwargs: Dict[str, Any] = {}
198-
if not isinstance(type_, type):
199-
raise TypeError(f"Cannot map type {type_}")
200-
elif issubclass(type_, InnerDoc):
198+
if isinstance(type_, type) and issubclass(type_, InnerDoc):
201199
# object or nested field
202200
field = Nested if multi else Object
203201
field_args = [type_]

elasticsearch_dsl/faceted_search_base.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,19 @@
1616
# under the License.
1717

1818
from datetime import datetime, timedelta
19-
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, Union, cast
19+
from typing import (
20+
TYPE_CHECKING,
21+
Any,
22+
Dict,
23+
Generic,
24+
List,
25+
Optional,
26+
Sequence,
27+
Tuple,
28+
Type,
29+
Union,
30+
cast,
31+
)
2032

2133
from typing_extensions import Self
2234

@@ -26,10 +38,11 @@
2638
from .utils import _R, AttrDict
2739

2840
if TYPE_CHECKING:
41+
from .document_base import DocumentBase
2942
from .response.aggs import BucketData
3043
from .search_base import SearchBase
3144

32-
FilterValueType = Union[str, datetime]
45+
FilterValueType = Union[str, datetime, Sequence[str]]
3346

3447
__all__ = [
3548
"FacetedSearchBase",
@@ -51,7 +64,7 @@ class Facet(Generic[_R]):
5164
agg_type: str = ""
5265

5366
def __init__(
54-
self, metric: Optional[str] = None, metric_sort: str = "desc", **kwargs: Any
67+
self, metric: Optional[Agg[_R]] = None, metric_sort: str = "desc", **kwargs: Any
5568
):
5669
self.filter_values = ()
5770
self._params = kwargs
@@ -137,7 +150,9 @@ def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]:
137150
class RangeFacet(Facet[_R]):
138151
agg_type = "range"
139152

140-
def _range_to_dict(self, range: Tuple[Any, Tuple[int, int]]) -> Dict[str, Any]:
153+
def _range_to_dict(
154+
self, range: Tuple[Any, Tuple[Optional[int], Optional[int]]]
155+
) -> Dict[str, Any]:
141156
key, _range = range
142157
out: Dict[str, Any] = {"key": key}
143158
if _range[0] is not None:
@@ -146,7 +161,11 @@ def _range_to_dict(self, range: Tuple[Any, Tuple[int, int]]) -> Dict[str, Any]:
146161
out["to"] = _range[1]
147162
return out
148163

149-
def __init__(self, ranges: List[Tuple[Any, Tuple[int, int]]], **kwargs: Any):
164+
def __init__(
165+
self,
166+
ranges: Sequence[Tuple[Any, Tuple[Optional[int], Optional[int]]]],
167+
**kwargs: Any,
168+
):
150169
super().__init__(**kwargs)
151170
self._params["ranges"] = list(map(self._range_to_dict, ranges))
152171
self._params["keyed"] = False
@@ -277,7 +296,7 @@ class FacetedResponse(Response[_R]):
277296
_facets: Dict[str, List[Tuple[Any, int, bool]]]
278297

279298
@property
280-
def query_string(self) -> Optional[Query]:
299+
def query_string(self) -> Optional[Union[str, Query]]:
281300
return self._faceted_search._query
282301

283302
@property
@@ -334,9 +353,9 @@ def search(self):
334353
335354
"""
336355

337-
index = None
338-
doc_types = None
339-
fields: List[str] = []
356+
index: Optional[str] = None
357+
doc_types: Optional[List[Union[str, Type["DocumentBase"]]]] = None
358+
fields: Sequence[str] = []
340359
facets: Dict[str, Facet[_R]] = {}
341360
using = "default"
342361

@@ -346,9 +365,9 @@ def search(self) -> "SearchBase[_R]": ...
346365

347366
def __init__(
348367
self,
349-
query: Optional[Query] = None,
368+
query: Optional[Union[str, Query]] = None,
350369
filters: Dict[str, FilterValueType] = {},
351-
sort: List[str] = [],
370+
sort: Sequence[str] = [],
352371
):
353372
"""
354373
:arg query: the text to search for
@@ -383,16 +402,18 @@ def add_filter(
383402
]
384403

385404
# remember the filter values for use in FacetedResponse
386-
self.filter_values[name] = filter_values
405+
self.filter_values[name] = filter_values # type: ignore[assignment]
387406

388407
# get the filter from the facet
389-
f = self.facets[name].add_filter(filter_values)
408+
f = self.facets[name].add_filter(filter_values) # type: ignore[arg-type]
390409
if f is None:
391410
return
392411

393412
self._filters[name] = f
394413

395-
def query(self, search: "SearchBase[_R]", query: Query) -> "SearchBase[_R]":
414+
def query(
415+
self, search: "SearchBase[_R]", query: Union[str, Query]
416+
) -> "SearchBase[_R]":
396417
"""
397418
Add query part to ``search``.
398419

elasticsearch_dsl/field.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ def _empty(self) -> "InnerDoc":
225225
def _wrap(self, data: Dict[str, Any]) -> "InnerDoc":
226226
return self._doc_class.from_es(data, data_only=True)
227227

228-
def empty(self) -> Union["InnerDoc", AttrList]:
228+
def empty(self) -> Union["InnerDoc", AttrList[Any]]:
229229
if self._multi:
230-
return AttrList([], self._wrap)
230+
return AttrList[Any]([], self._wrap)
231231
return self._empty()
232232

233233
def to_dict(self) -> Dict[str, Any]:

elasticsearch_dsl/response/aggs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252

5353
class BucketData(AggResponse[_R]):
5454
_bucket_class = Bucket
55-
_buckets: Union[AttrDict[Any], AttrList]
55+
_buckets: Union[AttrDict[Any], AttrList[Any]]
5656

5757
def _wrap_bucket(self, data: Dict[str, Any]) -> Bucket[_R]:
5858
return self._bucket_class(
@@ -70,11 +70,11 @@ def __len__(self) -> int:
7070

7171
def __getitem__(self, key: Any) -> Any:
7272
if isinstance(key, (int, slice)):
73-
return cast(AttrList, self.buckets)[key]
73+
return cast(AttrList[Any], self.buckets)[key]
7474
return super().__getitem__(key)
7575

7676
@property
77-
def buckets(self) -> Union[AttrDict[Any], AttrList]:
77+
def buckets(self) -> Union[AttrDict[Any], AttrList[Any]]:
7878
if not hasattr(self, "_buckets"):
7979
field = getattr(self._meta["aggs"], "field", None)
8080
if field:

0 commit comments

Comments
 (0)