Skip to content

[Backport 8.x] Type hints for tests and examples #1861

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions elasticsearch_dsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

from . import connections
from .aggs import A
from .aggs import A, Agg
from .analysis import analyzer, char_filter, normalizer, token_filter, tokenizer
from .document import AsyncDocument, Document
from .document_base import InnerDoc, M, MetaField, mapped_field
Expand Down Expand Up @@ -81,7 +81,8 @@
from .function import SF
from .index import AsyncIndex, AsyncIndexTemplate, Index, IndexTemplate
from .mapping import AsyncMapping, Mapping
from .query import Q
from .query import Q, Query
from .response import AggResponse, Response, UpdateByQueryResponse
from .search import (
AsyncEmptySearch,
AsyncMultiSearch,
Expand All @@ -99,6 +100,8 @@
__versionstr__ = ".".join(map(str, VERSION))
__all__ = [
"A",
"Agg",
"AggResponse",
"AsyncDocument",
"AsyncEmptySearch",
"AsyncFacetedSearch",
Expand Down Expand Up @@ -158,11 +161,13 @@
"Object",
"Percolator",
"Q",
"Query",
"Range",
"RangeFacet",
"RangeField",
"RankFeature",
"RankFeatures",
"Response",
"SF",
"ScaledFloat",
"Search",
Expand All @@ -174,6 +179,7 @@
"TokenCount",
"UnknownDslObject",
"UpdateByQuery",
"UpdateByQueryResponse",
"ValidationException",
"analyzer",
"char_filter",
Expand Down
4 changes: 2 additions & 2 deletions elasticsearch_dsl/_async/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
name: str,
template: str,
index: Optional["AsyncIndex"] = None,
order: Optional[str] = None,
order: Optional[int] = None,
**kwargs: Any,
):
if index is None:
Expand Down Expand Up @@ -100,7 +100,7 @@ def as_template(
self,
template_name: str,
pattern: Optional[str] = None,
order: Optional[str] = None,
order: Optional[int] = None,
) -> AsyncIndexTemplate:
# TODO: should we allow pattern to be a top-level arg?
# or maybe have an IndexPattern that allows for it and have
Expand Down
16 changes: 15 additions & 1 deletion elasticsearch_dsl/_async/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,16 @@
# under the License.

import contextlib
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, cast
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
cast,
)

from elasticsearch.exceptions import ApiError
from elasticsearch.helpers import async_scan
Expand Down Expand Up @@ -68,6 +77,7 @@ async def count(self) -> int:
query=cast(Optional[Dict[str, Any]], d.get("query", None)),
**self._params,
)

return cast(int, resp["count"])

async def execute(self, ignore_cache: bool = False) -> Response[_R]:
Expand Down Expand Up @@ -175,6 +185,10 @@ class AsyncMultiSearch(MultiSearchBase[_R]):

_using: AsyncUsingType

if TYPE_CHECKING:

def add(self, search: AsyncSearch[_R]) -> Self: ... # type: ignore[override]

async def execute(
self, ignore_cache: bool = False, raise_on_error: bool = True
) -> List[Response[_R]]:
Expand Down
4 changes: 2 additions & 2 deletions elasticsearch_dsl/_sync/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
name: str,
template: str,
index: Optional["Index"] = None,
order: Optional[str] = None,
order: Optional[int] = None,
**kwargs: Any,
):
if index is None:
Expand Down Expand Up @@ -94,7 +94,7 @@ def as_template(
self,
template_name: str,
pattern: Optional[str] = None,
order: Optional[str] = None,
order: Optional[int] = None,
) -> IndexTemplate:
# TODO: should we allow pattern to be a top-level arg?
# or maybe have an IndexPattern that allows for it and have
Expand Down
7 changes: 6 additions & 1 deletion elasticsearch_dsl/_sync/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.

import contextlib
from typing import Any, Dict, Iterator, List, Optional, cast
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, cast

from elasticsearch.exceptions import ApiError
from elasticsearch.helpers import scan
Expand Down Expand Up @@ -68,6 +68,7 @@ def count(self) -> int:
query=cast(Optional[Dict[str, Any]], d.get("query", None)),
**self._params,
)

return cast(int, resp["count"])

def execute(self, ignore_cache: bool = False) -> Response[_R]:
Expand Down Expand Up @@ -169,6 +170,10 @@ class MultiSearch(MultiSearchBase[_R]):

_using: UsingType

if TYPE_CHECKING:

def add(self, search: Search[_R]) -> Self: ... # type: ignore[override]

def execute(
self, ignore_cache: bool = False, raise_on_error: bool = True
) -> List[Response[_R]]:
Expand Down
37 changes: 30 additions & 7 deletions elasticsearch_dsl/aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ def __iter__(self) -> Iterable[str]:
return iter(self.aggs)

def _agg(
self, bucket: bool, name: str, agg_type: str, *args: Any, **params: Any
self,
bucket: bool,
name: str,
agg_type: Union[Dict[str, Any], Agg[_R], str],
*args: Any,
**params: Any,
) -> Agg[_R]:
agg = self[name] = A(agg_type, *args, **params)

Expand All @@ -151,14 +156,32 @@ def _agg(
else:
return self._base

def metric(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg[_R]:
def metric(
self,
name: str,
agg_type: Union[Dict[str, Any], Agg[_R], str],
*args: Any,
**params: Any,
) -> Agg[_R]:
return self._agg(False, name, agg_type, *args, **params)

def bucket(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg[_R]:
return self._agg(True, name, agg_type, *args, **params)

def pipeline(self, name: str, agg_type: str, *args: Any, **params: Any) -> Agg[_R]:
return self._agg(False, name, agg_type, *args, **params)
def bucket(
self,
name: str,
agg_type: Union[Dict[str, Any], Agg[_R], str],
*args: Any,
**params: Any,
) -> "Bucket[_R]":
return cast("Bucket[_R]", self._agg(True, name, agg_type, *args, **params))

def pipeline(
self,
name: str,
agg_type: Union[Dict[str, Any], Agg[_R], str],
*args: Any,
**params: Any,
) -> "Pipeline[_R]":
return cast("Pipeline[_R]", self._agg(False, name, agg_type, *args, **params))

def result(self, search: "SearchBase[_R]", data: Any) -> AttrDict[Any]:
return BucketData(self, search, data) # type: ignore
Expand Down
4 changes: 1 addition & 3 deletions elasticsearch_dsl/document_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def __init__(self, name: str, bases: Tuple[type, ...], attrs: Dict[str, Any]):
field = None
field_args: List[Any] = []
field_kwargs: Dict[str, Any] = {}
if not isinstance(type_, type):
raise TypeError(f"Cannot map type {type_}")
elif issubclass(type_, InnerDoc):
if isinstance(type_, type) and issubclass(type_, InnerDoc):
# object or nested field
field = Nested if multi else Object
field_args = [type_]
Expand Down
49 changes: 35 additions & 14 deletions elasticsearch_dsl/faceted_search_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,19 @@
# under the License.

from datetime import datetime, timedelta
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, Union, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
cast,
)

from typing_extensions import Self

Expand All @@ -26,10 +38,11 @@
from .utils import _R, AttrDict

if TYPE_CHECKING:
from .document_base import DocumentBase
from .response.aggs import BucketData
from .search_base import SearchBase

FilterValueType = Union[str, datetime]
FilterValueType = Union[str, datetime, Sequence[str]]

__all__ = [
"FacetedSearchBase",
Expand All @@ -51,7 +64,7 @@ class Facet(Generic[_R]):
agg_type: str = ""

def __init__(
self, metric: Optional[str] = None, metric_sort: str = "desc", **kwargs: Any
self, metric: Optional[Agg[_R]] = None, metric_sort: str = "desc", **kwargs: Any
):
self.filter_values = ()
self._params = kwargs
Expand Down Expand Up @@ -137,7 +150,9 @@ def add_filter(self, filter_values: List[FilterValueType]) -> Optional[Query]:
class RangeFacet(Facet[_R]):
agg_type = "range"

def _range_to_dict(self, range: Tuple[Any, Tuple[int, int]]) -> Dict[str, Any]:
def _range_to_dict(
self, range: Tuple[Any, Tuple[Optional[int], Optional[int]]]
) -> Dict[str, Any]:
key, _range = range
out: Dict[str, Any] = {"key": key}
if _range[0] is not None:
Expand All @@ -146,7 +161,11 @@ def _range_to_dict(self, range: Tuple[Any, Tuple[int, int]]) -> Dict[str, Any]:
out["to"] = _range[1]
return out

def __init__(self, ranges: List[Tuple[Any, Tuple[int, int]]], **kwargs: Any):
def __init__(
self,
ranges: Sequence[Tuple[Any, Tuple[Optional[int], Optional[int]]]],
**kwargs: Any,
):
super().__init__(**kwargs)
self._params["ranges"] = list(map(self._range_to_dict, ranges))
self._params["keyed"] = False
Expand Down Expand Up @@ -277,7 +296,7 @@ class FacetedResponse(Response[_R]):
_facets: Dict[str, List[Tuple[Any, int, bool]]]

@property
def query_string(self) -> Optional[Query]:
def query_string(self) -> Optional[Union[str, Query]]:
return self._faceted_search._query

@property
Expand Down Expand Up @@ -334,9 +353,9 @@ def search(self):

"""

index = None
doc_types = None
fields: List[str] = []
index: Optional[str] = None
doc_types: Optional[List[Union[str, Type["DocumentBase"]]]] = None
fields: Sequence[str] = []
facets: Dict[str, Facet[_R]] = {}
using = "default"

Expand All @@ -346,9 +365,9 @@ def search(self) -> "SearchBase[_R]": ...

def __init__(
self,
query: Optional[Query] = None,
query: Optional[Union[str, Query]] = None,
filters: Dict[str, FilterValueType] = {},
sort: List[str] = [],
sort: Sequence[str] = [],
):
"""
:arg query: the text to search for
Expand Down Expand Up @@ -383,16 +402,18 @@ def add_filter(
]

# remember the filter values for use in FacetedResponse
self.filter_values[name] = filter_values
self.filter_values[name] = filter_values # type: ignore[assignment]

# get the filter from the facet
f = self.facets[name].add_filter(filter_values)
f = self.facets[name].add_filter(filter_values) # type: ignore[arg-type]
if f is None:
return

self._filters[name] = f

def query(self, search: "SearchBase[_R]", query: Query) -> "SearchBase[_R]":
def query(
self, search: "SearchBase[_R]", query: Union[str, Query]
) -> "SearchBase[_R]":
"""
Add query part to ``search``.

Expand Down
4 changes: 2 additions & 2 deletions elasticsearch_dsl/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,9 @@ def _empty(self) -> "InnerDoc":
def _wrap(self, data: Dict[str, Any]) -> "InnerDoc":
return self._doc_class.from_es(data, data_only=True)

def empty(self) -> Union["InnerDoc", AttrList]:
def empty(self) -> Union["InnerDoc", AttrList[Any]]:
if self._multi:
return AttrList([], self._wrap)
return AttrList[Any]([], self._wrap)
return self._empty()

def to_dict(self) -> Dict[str, Any]:
Expand Down
6 changes: 3 additions & 3 deletions elasticsearch_dsl/response/aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(

class BucketData(AggResponse[_R]):
_bucket_class = Bucket
_buckets: Union[AttrDict[Any], AttrList]
_buckets: Union[AttrDict[Any], AttrList[Any]]

def _wrap_bucket(self, data: Dict[str, Any]) -> Bucket[_R]:
return self._bucket_class(
Expand All @@ -70,11 +70,11 @@ def __len__(self) -> int:

def __getitem__(self, key: Any) -> Any:
if isinstance(key, (int, slice)):
return cast(AttrList, self.buckets)[key]
return cast(AttrList[Any], self.buckets)[key]
return super().__getitem__(key)

@property
def buckets(self) -> Union[AttrDict[Any], AttrList]:
def buckets(self) -> Union[AttrDict[Any], AttrList[Any]]:
if not hasattr(self, "_buckets"):
field = getattr(self._meta["aggs"], "field", None)
if field:
Expand Down
Loading
Loading